LTTEAM commited on
Commit
49996f1
·
verified ·
1 Parent(s): 7bd3681

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -112
app.py CHANGED
@@ -22,30 +22,22 @@ from inference import (
22
  from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXMultiScalePipeline
23
  from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
24
 
25
- # --- Cấu hình và tải mô hình từ repo của bạn ---
26
- CONFIG_PATH = "configs/ltxv-13b-0.9.7-distilled.yaml"
27
- with open(CONFIG_PATH, "r") as f:
28
  CFG = yaml.safe_load(f)
29
 
30
  HF_REPO = "LTTEAM/VideoAI"
31
  MODELS_DIR = "downloaded_models"
32
- Path(MODELS_DIR).mkdir(parents=True, exist_ok=True)
33
 
34
  print("Đang tải mô hình (nếu chưa có)…")
35
- ckpt_path = hf_hub_download(
36
- repo_id=HF_REPO,
37
- filename=CFG["checkpoint_path"],
38
- local_dir=MODELS_DIR
39
- )
40
- CFG["checkpoint_path"] = ckpt_path
41
- upscaler_path = hf_hub_download(
42
- repo_id=HF_REPO,
43
- filename=CFG["spatial_upscaler_model_path"],
44
- local_dir=MODELS_DIR
45
- )
46
- CFG["spatial_upscaler_model_path"] = upscaler_path
47
 
48
- # --- Khởi tạo pipeline và upsampler trên CPU ban đầu ---
49
  print("Khởi tạo pipeline trên CPU…")
50
  pipeline = create_ltx_video_pipeline(
51
  ckpt_path=CFG["checkpoint_path"],
@@ -57,71 +49,69 @@ pipeline = create_ltx_video_pipeline(
57
  prompt_enhancer_image_caption_model_name_or_path=CFG["prompt_enhancer_image_caption_model_name_or_path"],
58
  prompt_enhancer_llm_model_name_or_path=CFG["prompt_enhancer_llm_model_name_or_path"],
59
  )
60
- print("Pipeline sẵn sàng trên CPU.")
61
  print("Khởi tạo latent upsampler trên CPU…")
62
  upsampler = create_latent_upsampler(CFG["spatial_upscaler_model_path"], device="cpu")
63
- print("Latent upsampler sẵn sàng trên CPU.")
64
 
65
- # --- Các thông số cố định ---
66
  FPS = 30.0
67
- MAX_NUM_FRAMES = 257
68
  MIN_DIM = 256
69
- TARGET_SIDE = 768
70
  MAX_RES = CFG.get("max_resolution", 1280)
71
 
72
- def calculate_new_dimensions(w, h):
73
  if w==0 or h==0:
74
- return TARGET_SIDE, TARGET_SIDE
75
  if w>=h:
76
- nh = TARGET_SIDE
77
- nw = round((nh * w/h)/32)*32
78
  else:
79
- nw = TARGET_SIDE
80
- nh = round((nw * h/w)/32)*32
81
  return (
82
  int(max(MIN_DIM, min(nh, MAX_RES))),
83
  int(max(MIN_DIM, min(nw, MAX_RES)))
84
  )
85
 
86
- def get_duration(*args, **kwargs):
87
- return 75 if kwargs.get("duration_ui",0) > 7 else 60
88
 
89
  @spaces.GPU(duration=get_duration)
90
- def generate(prompt, neg_prompt,
91
- img_path, vid_path,
92
- height, width,
93
- mode_task, duration_ui, frames_to_use,
94
- seed, rand_seed, cfg_scale,
95
- improve_tex, device_choice,
96
  progress=gr.Progress(track_tqdm=True)):
97
- # Chọn thiết bị
 
98
  dev = "cuda" if device_choice=="GPU" and torch.cuda.is_available() else "cpu"
99
- print(f"Chạy trên thiết bị: {dev}")
100
  pipeline.to(dev)
101
  upsampler.to(dev)
102
 
103
  # Seed
104
  if rand_seed:
105
- seed = random.randint(0, 2**32-1)
106
  seed_everething(int(seed))
107
 
108
  # Tính số frame
109
- tf = max(1, round(duration_ui*FPS))
110
  n8 = round((tf-1)/8)
111
- n_frames = max(9, min(n8*8+1, MAX_NUM_FRAMES))
112
 
113
  # Padding kích thước
114
  h, w = int(height), int(width)
115
- h32 = ((h-1)//32+1)*32
116
- w32 = ((w-1)//32+1)*32
117
- pad = calculate_padding(h, w, h32, w32)
118
 
119
- # Chuẩn bị kwargs
120
  kwargs = {
121
  "prompt": prompt,
122
  "negative_prompt": neg_prompt,
123
- "height": h32,
124
- "width": w32,
125
  "num_frames": n_frames,
126
  "frame_rate": int(FPS),
127
  "generator": torch.Generator(device=dev).manual_seed(int(seed)),
@@ -136,75 +126,76 @@ def generate(prompt, neg_prompt,
136
  "enhance_prompt": False,
137
  }
138
  # Skip-layer strategy
139
- stg = CFG.get("stg_mode","attention_values").lower()
140
- mapping = {
141
- "stg_av":SkipLayerStrategy.AttentionValues,
142
- "attention_values":SkipLayerStrategy.AttentionValues,
143
- "stg_as":SkipLayerStrategy.AttentionSkip,
144
- "attention_skip":SkipLayerStrategy.AttentionSkip,
145
- "stg_r":SkipLayerStrategy.Residual,
146
- "residual":SkipLayerStrategy.Residual,
147
- "stg_t":SkipLayerStrategy.TransformerBlock,
148
- "transformer_block":SkipLayerStrategy.TransformerBlock,
149
  }
150
- kwargs["skip_layer_strategy"] = mapping.get(stg, SkipLayerStrategy.AttentionValues)
151
 
152
  # Conditioning
153
- if mode_task=="image-to-video" and img_path:
154
- tensor = load_image_to_tensor_with_resize_and_crop(img_path, h, w)
155
- tensor = torch.nn.functional.pad(tensor, pad)
156
- kwargs["conditioning_items"] = [ConditioningItem(tensor.to(dev),0,1.0)]
157
- elif mode_task=="video-to-video" and vid_path:
158
  mi = load_media_file(vid_path, h, w, int(frames_to_use), pad).to(dev)
159
  kwargs["media_items"] = mi
160
 
161
- # Multi-scale hay single?
162
  if improve_tex:
163
  pipe_ms = LTXMultiScalePipeline(pipeline, upsampler)
164
  fp = CFG.get("first_pass",{}).copy()
165
  fp["guidance_scale"] = float(cfg_scale)
166
- fp.pop("num_inference_steps",None)
167
  sp = CFG.get("second_pass",{}).copy()
168
  sp["guidance_scale"] = float(cfg_scale)
169
- sp.pop("num_inference_steps",None)
170
  kwargs.update({
171
- "downscale_factor":CFG["downscale_factor"],
172
- "first_pass":fp,
173
- "second_pass":sp
174
  })
175
- out = pipe_ms(**kwargs).images
176
  else:
177
  fp0 = CFG.get("first_pass",{})
178
  kwargs.update({
179
- "timesteps":fp0.get("timesteps"),
180
- "guidance_scale":float(cfg_scale),
181
- "stg_scale":fp0.get("stg_scale"),
182
- "rescaling_scale":fp0.get("rescaling_scale"),
183
- "skip_block_list":fp0.get("skip_block_list")
184
  })
185
  for k in ["first_pass","second_pass","downscale_factor","num_inference_steps"]:
186
  kwargs.pop(k, None)
187
- out = pipeline(**kwargs).images
188
 
189
- # Loại padding, lưu video
190
- l, r, t, b = pad
191
  sh = None if b==0 else -b
192
  sw = None if r==0 else -r
193
- vid_tensor = out[0][:,:,:n_frames,t:sh,l:sw]
194
- arr = vid_tensor.permute(1,2,3,0).cpu().numpy()
195
  arr = (np.clip(arr,0,1)*255).astype(np.uint8)
196
 
197
- tmp = tempfile.mkdtemp()
198
- dst = os.path.join(tmp, f"out_{random.randint(0,99999)}.mp4")
199
- with imageio.get_writer(dst, fps=int(FPS), macro_block_size=1) as w:
200
  for i in range(arr.shape[0]):
201
  progress(i/arr.shape[0], desc="Lưu video")
202
- w.append_data(arr[i])
203
- return dst, seed
 
204
 
205
  # --- Giao diện Gradio ---
206
  css = """
207
- #col-container {margin:0 auto; max-width:900px;}
208
  """
209
  with gr.Blocks(css=css) as demo:
210
  gr.Markdown("## Ứng dụng LTX Video 0.9.7 Distilled")
@@ -215,54 +206,52 @@ with gr.Blocks(css=css) as demo:
215
 
216
  with gr.Row():
217
  with gr.Column():
218
- # Chọn thiết bị
219
- device = gr.Radio(["CPU","GPU"], label="Chạy trên thiết bị", value="CPU")
220
 
221
- # Tabs
222
  with gr.Tab("Ảnh→Video"):
223
  img_in = gr.Image(label="Ảnh đầu vào", type="filepath", sources=["upload","clipboard","webcam"])
224
  prompt1 = gr.Textbox(label="Mô tả", lines=2, value="Con sinh vật di chuyển")
225
- btn1 = gr.Button("Tạo từ ảnh")
226
 
227
  with gr.Tab("Văn bản→Video"):
228
  prompt2 = gr.Textbox(label="Mô tả", lines=2, value="Rồng bay trên lâu đài")
229
- btn2 = gr.Button("Tạo từ văn bản")
230
 
231
  with gr.Tab("Video→Video"):
232
- vid_in = gr.Video(label="Video đầu vào", type="filepath", sources=["upload","webcam"])
233
- frames = gr.Slider(label="Số frame dùng", minimum=9, maximum=MAX_NUM_FRAMES, step=8, value=9)
234
  prompt3 = gr.Textbox(label="Mô tả", lines=2, value="Chuyển phong cách anime")
235
- btn3 = gr.Button("Tạo từ video")
236
 
237
  duration = gr.Slider(label="Thời lượng (giây)", minimum=0.3, maximum=8.5, step=0.1, value=2)
238
- improve = gr.Checkbox(label="Cải thiện chi tiết", value=True)
239
 
240
  with gr.Column():
241
- out_video = gr.Video(label="Kết quả", interactive=False)
242
 
243
- # Trạng thái ẩn để điều khiển mode và seed
244
- mode_state = gr.State("image-to-video")
245
- seed_state = gr.State(42)
246
- neg_prompt = gr.State("worst quality, inconsistent motion, blurry, jittery, distorted")
247
- cfg_scale_st = gr.State(CFG["first_pass"]["guidance_scale"])
248
- height_st = gr.State(512)
249
- width_st = gr.State(704)
250
 
251
  btn1.click(fn=generate,
252
- inputs=[prompt1, neg_prompt, img_in, gr.State(""), height_st, width_st,
253
  mode_state, duration, frames, seed_state, gr.State(True),
254
- cfg_scale_st, improve, device],
255
- outputs=[out_video, seed_state])
256
  btn2.click(fn=generate,
257
- inputs=[prompt2, neg_prompt, gr.State(""), gr.State(""), height_st, width_st,
258
  mode_state, duration, frames, seed_state, gr.State(True),
259
- cfg_scale_st, improve, device],
260
- outputs=[out_video, seed_state])
261
  btn3.click(fn=generate,
262
- inputs=[prompt3, neg_prompt, gr.State(""), vid_in, height_st, width_st,
263
  mode_state, duration, frames, seed_state, gr.State(True),
264
- cfg_scale_st, improve, device],
265
- outputs=[out_video, seed_state])
266
 
267
  if __name__ == "__main__":
268
  demo.queue().launch(debug=True, share=False)
 
22
  from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXMultiScalePipeline
23
  from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
24
 
25
+ # --- Đọc cấu hình và tải mô hình từ HuggingFace ---
26
+ CONFIG_YAML = "configs/ltxv-13b-0.9.7-distilled.yaml"
27
+ with open(CONFIG_YAML, "r") as f:
28
  CFG = yaml.safe_load(f)
29
 
30
  HF_REPO = "LTTEAM/VideoAI"
31
  MODELS_DIR = "downloaded_models"
32
+ Path(MODELS_DIR).mkdir(exist_ok=True)
33
 
34
  print("Đang tải mô hình (nếu chưa có)…")
35
+ ckpt = hf_hub_download(repo_id=HF_REPO, filename=CFG["checkpoint_path"], local_dir=MODELS_DIR)
36
+ CFG["checkpoint_path"] = ckpt
37
+ upscaler = hf_hub_download(repo_id=HF_REPO, filename=CFG["spatial_upscaler_model_path"], local_dir=MODELS_DIR)
38
+ CFG["spatial_upscaler_model_path"] = upscaler
 
 
 
 
 
 
 
 
39
 
40
+ # --- Khởi tạo pipeline và upsampler trên CPU ---
41
  print("Khởi tạo pipeline trên CPU…")
42
  pipeline = create_ltx_video_pipeline(
43
  ckpt_path=CFG["checkpoint_path"],
 
49
  prompt_enhancer_image_caption_model_name_or_path=CFG["prompt_enhancer_image_caption_model_name_or_path"],
50
  prompt_enhancer_llm_model_name_or_path=CFG["prompt_enhancer_llm_model_name_or_path"],
51
  )
52
+ print("Pipeline sẵn sàng.")
53
  print("Khởi tạo latent upsampler trên CPU…")
54
  upsampler = create_latent_upsampler(CFG["spatial_upscaler_model_path"], device="cpu")
55
+ print("Upsampler sẵn sàng.")
56
 
57
+ # --- Thông số cố định ---
58
  FPS = 30.0
59
+ MAX_FRAMES = 257
60
  MIN_DIM = 256
61
+ FIXED_SIDE = 768
62
  MAX_RES = CFG.get("max_resolution", 1280)
63
 
64
+ def calc_new_dims(w, h):
65
  if w==0 or h==0:
66
+ return FIXED_SIDE, FIXED_SIDE
67
  if w>=h:
68
+ nh = FIXED_SIDE
69
+ nw = round((nh*w/h)/32)*32
70
  else:
71
+ nw = FIXED_SIDE
72
+ nh = round((nw*h/w)/32)*32
73
  return (
74
  int(max(MIN_DIM, min(nh, MAX_RES))),
75
  int(max(MIN_DIM, min(nw, MAX_RES)))
76
  )
77
 
78
+ def get_duration(*args, duration_ui=0, **kwargs):
79
+ return 75 if duration_ui > 7 else 60
80
 
81
  @spaces.GPU(duration=get_duration)
82
+ def generate(prompt, neg_prompt, img_path, vid_path,
83
+ height, width, mode, duration_ui, frames_to_use,
84
+ seed, rand_seed, cfg_scale, improve_tex, device_choice,
 
 
 
85
  progress=gr.Progress(track_tqdm=True)):
86
+
87
+ # Chọn thiết bị inference
88
  dev = "cuda" if device_choice=="GPU" and torch.cuda.is_available() else "cpu"
89
+ print(f"Sử dụng thiết bị: {dev}")
90
  pipeline.to(dev)
91
  upsampler.to(dev)
92
 
93
  # Seed
94
  if rand_seed:
95
+ seed = random.randint(0, 2**32 - 1)
96
  seed_everething(int(seed))
97
 
98
  # Tính số frame
99
+ tf = max(1, round(duration_ui * FPS))
100
  n8 = round((tf-1)/8)
101
+ n_frames = max(9, min(n8*8+1, MAX_FRAMES))
102
 
103
  # Padding kích thước
104
  h, w = int(height), int(width)
105
+ h_pad = ((h-1)//32+1)*32
106
+ w_pad = ((w-1)//32+1)*32
107
+ pad = calculate_padding(h, w, h_pad, w_pad)
108
 
109
+ # Chuẩn bị kwargs chung
110
  kwargs = {
111
  "prompt": prompt,
112
  "negative_prompt": neg_prompt,
113
+ "height": h_pad,
114
+ "width": w_pad,
115
  "num_frames": n_frames,
116
  "frame_rate": int(FPS),
117
  "generator": torch.Generator(device=dev).manual_seed(int(seed)),
 
126
  "enhance_prompt": False,
127
  }
128
  # Skip-layer strategy
129
+ mode_stg = CFG.get("stg_mode","attention_values").lower()
130
+ stg_map = {
131
+ "stg_av": SkipLayerStrategy.AttentionValues,
132
+ "attention_values": SkipLayerStrategy.AttentionValues,
133
+ "stg_as": SkipLayerStrategy.AttentionSkip,
134
+ "attention_skip": SkipLayerStrategy.AttentionSkip,
135
+ "stg_r": SkipLayerStrategy.Residual,
136
+ "residual": SkipLayerStrategy.Residual,
137
+ "stg_t": SkipLayerStrategy.TransformerBlock,
138
+ "transformer_block": SkipLayerStrategy.TransformerBlock,
139
  }
140
+ kwargs["skip_layer_strategy"] = stg_map.get(mode_stg, SkipLayerStrategy.AttentionValues)
141
 
142
  # Conditioning
143
+ if mode=="image-to-video" and img_path:
144
+ t = load_image_to_tensor_with_resize_and_crop(img_path, h, w)
145
+ t = torch.nn.functional.pad(t, pad)
146
+ kwargs["conditioning_items"] = [ConditioningItem(t.to(dev), 0, 1.0)]
147
+ elif mode=="video-to-video" and vid_path:
148
  mi = load_media_file(vid_path, h, w, int(frames_to_use), pad).to(dev)
149
  kwargs["media_items"] = mi
150
 
151
+ # Chọn multi-scale hay single-pass
152
  if improve_tex:
153
  pipe_ms = LTXMultiScalePipeline(pipeline, upsampler)
154
  fp = CFG.get("first_pass",{}).copy()
155
  fp["guidance_scale"] = float(cfg_scale)
156
+ fp.pop("num_inference_steps", None)
157
  sp = CFG.get("second_pass",{}).copy()
158
  sp["guidance_scale"] = float(cfg_scale)
159
+ sp.pop("num_inference_steps", None)
160
  kwargs.update({
161
+ "downscale_factor": CFG["downscale_factor"],
162
+ "first_pass": fp,
163
+ "second_pass": sp
164
  })
165
+ images = pipe_ms(**kwargs).images
166
  else:
167
  fp0 = CFG.get("first_pass",{})
168
  kwargs.update({
169
+ "timesteps": fp0.get("timesteps"),
170
+ "guidance_scale": float(cfg_scale),
171
+ "stg_scale": fp0.get("stg_scale"),
172
+ "rescaling_scale": fp0.get("rescaling_scale"),
173
+ "skip_block_list": fp0.get("skip_block_list")
174
  })
175
  for k in ["first_pass","second_pass","downscale_factor","num_inference_steps"]:
176
  kwargs.pop(k, None)
177
+ images = pipeline(**kwargs).images
178
 
179
+ # Bỏ pad, lưu video
180
+ l, r, t_, b = pad
181
  sh = None if b==0 else -b
182
  sw = None if r==0 else -r
183
+ vid_t = images[0][:,:,:n_frames, t_:sh, l:sw]
184
+ arr = vid_t.permute(1,2,3,0).cpu().numpy()
185
  arr = (np.clip(arr,0,1)*255).astype(np.uint8)
186
 
187
+ out_dir = tempfile.mkdtemp()
188
+ out_path = os.path.join(out_dir, f"output_{random.randint(0,99999)}.mp4")
189
+ with imageio.get_writer(out_path, fps=int(FPS), macro_block_size=1) as writer:
190
  for i in range(arr.shape[0]):
191
  progress(i/arr.shape[0], desc="Lưu video")
192
+ writer.append_data(arr[i])
193
+
194
+ return out_path, seed
195
 
196
  # --- Giao diện Gradio ---
197
  css = """
198
+ #col-container { margin:0 auto; max-width:900px; }
199
  """
200
  with gr.Blocks(css=css) as demo:
201
  gr.Markdown("## Ứng dụng LTX Video 0.9.7 Distilled")
 
206
 
207
  with gr.Row():
208
  with gr.Column():
209
+ device = gr.Radio(["CPU", "GPU"], label="Chạy trên thiết bị", value="CPU")
 
210
 
 
211
  with gr.Tab("Ảnh→Video"):
212
  img_in = gr.Image(label="Ảnh đầu vào", type="filepath", sources=["upload","clipboard","webcam"])
213
  prompt1 = gr.Textbox(label="Mô tả", lines=2, value="Con sinh vật di chuyển")
214
+ btn1 = gr.Button("Tạo từ ảnh")
215
 
216
  with gr.Tab("Văn bản→Video"):
217
  prompt2 = gr.Textbox(label="Mô tả", lines=2, value="Rồng bay trên lâu đài")
218
+ btn2 = gr.Button("Tạo từ văn bản")
219
 
220
  with gr.Tab("Video→Video"):
221
+ vid_in = gr.Video(label="Video đầu vào", sources=["upload","webcam"])
222
+ frames = gr.Slider(label="Số frame dùng", minimum=9, maximum=MAX_FRAMES, step=8, value=9)
223
  prompt3 = gr.Textbox(label="Mô tả", lines=2, value="Chuyển phong cách anime")
224
+ btn3 = gr.Button("Tạo từ video")
225
 
226
  duration = gr.Slider(label="Thời lượng (giây)", minimum=0.3, maximum=8.5, step=0.1, value=2)
227
+ improve = gr.Checkbox(label="Cải thiện chi tiết", value=True)
228
 
229
  with gr.Column():
230
+ out_vid = gr.Video(label="Kết quả", interactive=False)
231
 
232
+ # Trạng thái ẩn
233
+ mode_state = gr.State("image-to-video")
234
+ seed_state = gr.State(42)
235
+ neg_state = gr.State("worst quality, inconsistent motion, blurry, jittery, distorted")
236
+ cfg_state = gr.State(CFG["first_pass"]["guidance_scale"])
237
+ h_state = gr.State(512)
238
+ w_state = gr.State(704)
239
 
240
  btn1.click(fn=generate,
241
+ inputs=[prompt1, neg_state, img_in, gr.State(""), h_state, w_state,
242
  mode_state, duration, frames, seed_state, gr.State(True),
243
+ cfg_state, improve, device],
244
+ outputs=[out_vid, seed_state])
245
  btn2.click(fn=generate,
246
+ inputs=[prompt2, neg_state, gr.State(""), gr.State(""), h_state, w_state,
247
  mode_state, duration, frames, seed_state, gr.State(True),
248
+ cfg_state, improve, device],
249
+ outputs=[out_vid, seed_state])
250
  btn3.click(fn=generate,
251
+ inputs=[prompt3, neg_state, gr.State(""), vid_in, h_state, w_state,
252
  mode_state, duration, frames, seed_state, gr.State(True),
253
+ cfg_state, improve, device],
254
+ outputs=[out_vid, seed_state])
255
 
256
  if __name__ == "__main__":
257
  demo.queue().launch(debug=True, share=False)