XiangpengYang commited on
Commit
97a5f5d
·
1 Parent(s): c7eaeb0
Files changed (4) hide show
  1. app.py +14 -6
  2. inference.py +22 -1
  3. videox_fun/ui/ui.py +2 -2
  4. videox_fun/utils/lora_utils.py +18 -40
app.py CHANGED
@@ -79,6 +79,7 @@ def load_video_frames(video_path: str, source_frames: int):
79
  pil_frame = Image.fromarray(frame)
80
  if original_height is None:
81
  original_width, original_height = pil_frame.size
 
82
  frames.append(pil_frame)
83
  except IndexError:
84
  break
@@ -92,6 +93,9 @@ def load_video_frames(video_path: str, source_frames: int):
92
  w, h = (original_width, original_height) if original_width else (832, 480)
93
  frames.append(Image.new('RGB', (w, h), (0, 0, 0)))
94
 
 
 
 
95
  input_video = torch.from_numpy(np.array(frames))
96
  input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0).float()
97
  input_video = input_video * (2.0 / 255.0) - 1.0
@@ -143,6 +147,8 @@ class VideoCoF_Controller(Wan_Controller):
143
  # Ensure model is on CUDA inside the zero-gpu decorated function
144
  if torch.cuda.is_available():
145
  self.device = torch.device("cuda")
 
 
146
  # If pipeline is not on cuda, move it (if possible, but usually accelerate handles this or it's handled by parts)
147
  # However, Wan_Controller logic might rely on `self.device`.
148
  # We explicitly set `self.device` to cuda here.
@@ -166,7 +172,7 @@ class VideoCoF_Controller(Wan_Controller):
166
  # 1. Merge VideoCoF LoRA
167
  if self.lora_model_path != "none":
168
  print(f"Merge VideoCoF Lora: {self.lora_model_path}")
169
- self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
170
 
171
  # 2. Merge Acceleration LoRA (FusionX) if enabled
172
  acc_lora_path = os.path.join(self.personalized_model_dir, "Wan2.1_Text_to_Video_14B_FusionX_LoRA.safetensors")
@@ -174,7 +180,7 @@ class VideoCoF_Controller(Wan_Controller):
174
  if os.path.exists(acc_lora_path):
175
  print(f"Merge Acceleration LoRA: {acc_lora_path}")
176
  # FusionX LoRA generally uses multiplier 1.0
177
- self.pipeline = merge_lora(self.pipeline, acc_lora_path, multiplier=1.0)
178
  else:
179
  print(f"Warning: Acceleration LoRA not found at {acc_lora_path}")
180
 
@@ -217,6 +223,7 @@ class VideoCoF_Controller(Wan_Controller):
217
  print(f"Input video dimensions: {w}x{h}")
218
 
219
  print(f"Running pipeline with frames={length_slider}, source={source_frames_slider}, reasoning={reasoning_frames_slider}")
 
220
 
221
  sample = self.pipeline(
222
  video=input_video_tensor,
@@ -230,6 +237,7 @@ class VideoCoF_Controller(Wan_Controller):
230
  generator=generator,
231
  guidance_scale=cfg_scale_slider,
232
  num_inference_steps=sample_step_slider,
 
233
  repeat_rope=repeat_rope_checkbox,
234
  cot=True,
235
  ).videos
@@ -241,21 +249,21 @@ class VideoCoF_Controller(Wan_Controller):
241
  # Unmerge in case of error (LIFO order)
242
  if enable_acceleration and os.path.exists(acc_lora_path):
243
  print("Unmerging Acceleration LoRA (due to error)")
244
- self.pipeline = unmerge_lora(self.pipeline, acc_lora_path, multiplier=1.0)
245
 
246
  if self.lora_model_path != "none":
247
  print("Unmerging VideoCoF LoRA (due to error)")
248
- self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
249
  return gr.update(), gr.update(), f"Error: {str(e)}"
250
 
251
  # Unmerge LoRAs (LIFO order)
252
  if enable_acceleration and os.path.exists(acc_lora_path):
253
  print("Unmerging Acceleration LoRA")
254
- self.pipeline = unmerge_lora(self.pipeline, acc_lora_path, multiplier=1.0)
255
 
256
  if self.lora_model_path != "none":
257
  print("Unmerging VideoCoF LoRA")
258
- self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
259
 
260
  # Save output
261
  save_sample_path = self.save_outputs(
 
79
  pil_frame = Image.fromarray(frame)
80
  if original_height is None:
81
  original_width, original_height = pil_frame.size
82
+ print(f"Original video dimensions: {original_width}x{original_height}")
83
  frames.append(pil_frame)
84
  except IndexError:
85
  break
 
93
  w, h = (original_width, original_height) if original_width else (832, 480)
94
  frames.append(Image.new('RGB', (w, h), (0, 0, 0)))
95
 
96
+ assert len(frames) == source_frames, f"Loaded {len(frames)} frames, expected {source_frames}"
97
+ print(f"Loaded {source_frames} source frames")
98
+
99
  input_video = torch.from_numpy(np.array(frames))
100
  input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0).float()
101
  input_video = input_video * (2.0 / 255.0) - 1.0
 
147
  # Ensure model is on CUDA inside the zero-gpu decorated function
148
  if torch.cuda.is_available():
149
  self.device = torch.device("cuda")
150
+ else:
151
+ self.device = torch.device("cpu")
152
  # If pipeline is not on cuda, move it (if possible, but usually accelerate handles this or it's handled by parts)
153
  # However, Wan_Controller logic might rely on `self.device`.
154
  # We explicitly set `self.device` to cuda here.
 
172
  # 1. Merge VideoCoF LoRA
173
  if self.lora_model_path != "none":
174
  print(f"Merge VideoCoF Lora: {self.lora_model_path}")
175
+ self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider, device=self.device)
176
 
177
  # 2. Merge Acceleration LoRA (FusionX) if enabled
178
  acc_lora_path = os.path.join(self.personalized_model_dir, "Wan2.1_Text_to_Video_14B_FusionX_LoRA.safetensors")
 
180
  if os.path.exists(acc_lora_path):
181
  print(f"Merge Acceleration LoRA: {acc_lora_path}")
182
  # FusionX LoRA generally uses multiplier 1.0
183
+ self.pipeline = merge_lora(self.pipeline, acc_lora_path, multiplier=1.0, device=self.device)
184
  else:
185
  print(f"Warning: Acceleration LoRA not found at {acc_lora_path}")
186
 
 
223
  print(f"Input video dimensions: {w}x{h}")
224
 
225
  print(f"Running pipeline with frames={length_slider}, source={source_frames_slider}, reasoning={reasoning_frames_slider}")
226
+ shift = 3
227
 
228
  sample = self.pipeline(
229
  video=input_video_tensor,
 
237
  generator=generator,
238
  guidance_scale=cfg_scale_slider,
239
  num_inference_steps=sample_step_slider,
240
+ shift=shift,
241
  repeat_rope=repeat_rope_checkbox,
242
  cot=True,
243
  ).videos
 
249
  # Unmerge in case of error (LIFO order)
250
  if enable_acceleration and os.path.exists(acc_lora_path):
251
  print("Unmerging Acceleration LoRA (due to error)")
252
+ self.pipeline = unmerge_lora(self.pipeline, acc_lora_path, multiplier=1.0, device=self.device)
253
 
254
  if self.lora_model_path != "none":
255
  print("Unmerging VideoCoF LoRA (due to error)")
256
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider, device=self.device)
257
  return gr.update(), gr.update(), f"Error: {str(e)}"
258
 
259
  # Unmerge LoRAs (LIFO order)
260
  if enable_acceleration and os.path.exists(acc_lora_path):
261
  print("Unmerging Acceleration LoRA")
262
+ self.pipeline = unmerge_lora(self.pipeline, acc_lora_path, multiplier=1.0, device=self.device)
263
 
264
  if self.lora_model_path != "none":
265
  print("Unmerging VideoCoF LoRA")
266
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider, device=self.device)
267
 
268
  # Save output
269
  save_sample_path = self.save_outputs(
inference.py CHANGED
@@ -89,6 +89,9 @@ def parse_args():
89
  parser.add_argument("--output_dir", type=str, required=True, help="Output directory for generated videos")
90
  parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducible generation")
91
  parser.add_argument("--videocof_path", type=str, default=None, help="Path to videocof weight checkpoint")
 
 
 
92
  parser.add_argument("--num_frames", type=int, default=65, help="Total number of frames (input + generated)")
93
  parser.add_argument("--source_frames", type=int, default=33, help="Number of source frames; default 33")
94
  parser.add_argument("--reasoning_frames", type=int, default=4, help="Grounding frames in the middle segment (pixel-space)")
@@ -320,7 +323,25 @@ def main():
320
  else:
321
  pipeline.to(device=device)
322
 
323
- # LoRA
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  if args.videocof_path is not None:
325
  pipeline = merge_lora(pipeline, args.videocof_path, lora_weight, device=device)
326
  print(f"[GPU {rank}] Loaded LoRA from {args.videocof_path}")
 
89
  parser.add_argument("--output_dir", type=str, required=True, help="Output directory for generated videos")
90
  parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducible generation")
91
  parser.add_argument("--videocof_path", type=str, default=None, help="Path to videocof weight checkpoint")
92
+ parser.add_argument("--lora_path", type=str, default=None, help="Path to LoRA checkpoint")
93
+ parser.add_argument("--enable_acceleration_lora", action="store_true", help="Enable loading the acceleration (FusionX) LoRA")
94
+ parser.add_argument("--acceleration_lora_path", type=str, default=None, help="Optional path to acceleration LoRA; defaults to FusionX under model directory")
95
  parser.add_argument("--num_frames", type=int, default=65, help="Total number of frames (input + generated)")
96
  parser.add_argument("--source_frames", type=int, default=33, help="Number of source frames; default 33")
97
  parser.add_argument("--reasoning_frames", type=int, default=4, help="Grounding frames in the middle segment (pixel-space)")
 
323
  else:
324
  pipeline.to(device=device)
325
 
326
+ # Acceleration LoRA (FusionX) mirrors app.py behavior
327
+ if args.enable_acceleration_lora:
328
+ default_acc_path = os.path.join(model_name, "Wan2.1_Text_to_Video_14B_FusionX_LoRA.safetensors")
329
+ acc_lora_path = args.acceleration_lora_path or default_acc_path
330
+ if os.path.exists(acc_lora_path):
331
+ print(f"[GPU {rank}] Merge Acceleration LoRA: {acc_lora_path}")
332
+ pipeline = merge_lora(pipeline, acc_lora_path, multiplier=1.0, device=device)
333
+ else:
334
+ print(f"[GPU {rank}] Warning: Acceleration LoRA not found at {acc_lora_path}")
335
+
336
+ # Custom LoRA
337
+ if args.lora_path is not None:
338
+ if os.path.exists(args.lora_path):
339
+ print(f"[GPU {rank}] Loading custom LoRA: {args.lora_path}")
340
+ pipeline = merge_lora(pipeline, args.lora_path, lora_weight, device=device)
341
+ else:
342
+ print(f"[GPU {rank}] Warning: Provided lora_path not found: {args.lora_path}")
343
+
344
+ # VideoCoF LoRA
345
  if args.videocof_path is not None:
346
  pipeline = merge_lora(pipeline, args.videocof_path, lora_weight, device=device)
347
  print(f"[GPU {rank}] Loaded LoRA from {args.videocof_path}")
videox_fun/ui/ui.py CHANGED
@@ -194,10 +194,10 @@ def create_prompts(
194
  negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value=negative_prompt)
195
  return prompt_textbox, negative_prompt_textbox
196
 
197
- def create_samplers(controller, maximum_step=100):
198
  with gr.Row():
199
  sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(controller.scheduler_dict.keys()), value=list(controller.scheduler_dict.keys())[0])
200
- sample_step_slider = gr.Slider(label="Sampling steps", value=50, minimum=1, maximum=maximum_step, step=1)
201
 
202
  return sampler_dropdown, sample_step_slider
203
 
 
194
  negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value=negative_prompt)
195
  return prompt_textbox, negative_prompt_textbox
196
 
197
+ def create_samplers(controller, maximum_step=50):
198
  with gr.Row():
199
  sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(controller.scheduler_dict.keys()), value=list(controller.scheduler_dict.keys())[0])
200
+ sample_step_slider = gr.Slider(label="Sampling steps", value=4, minimum=1, maximum=maximum_step, step=1)
201
 
202
  return sampler_dropdown, sample_step_slider
203
 
videox_fun/utils/lora_utils.py CHANGED
@@ -389,28 +389,9 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3
389
  key = key.replace(".self_attn.", "_self_attn_")
390
  key = key.replace(".cross_attn.", "_cross_attn_")
391
  key = key.replace(".ffn.", "_ffn_")
392
- key = key.replace("text_embedding.", "text_embedding_")
393
- key = key.replace("time_embedding.", "time_embedding_")
394
  key = key.replace(".lora_A.default.", ".lora_down.")
395
  key = key.replace(".lora_B.default.", ".lora_up.")
396
- key = key.replace(".lora_A.weight", ".lora_down.weight")
397
- key = key.replace(".lora_B.weight", ".lora_up.weight")
398
-
399
- if key.endswith(".lora_down.weight"):
400
- layer = key[:-len(".lora_down.weight")]
401
- elem = "lora_down.weight"
402
- elif key.endswith(".lora_up.weight"):
403
- layer = key[:-len(".lora_up.weight")]
404
- elem = "lora_up.weight"
405
- elif key.endswith(".alpha"):
406
- layer = key[:-len(".alpha")]
407
- elem = "alpha"
408
- else:
409
- continue
410
-
411
- if layer.endswith("."):
412
- layer = layer[:-1]
413
-
414
  updates[layer][elem] = value
415
 
416
  sequential_cpu_offload_flag = False
@@ -484,10 +465,20 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3
484
  if error_flag:
485
  continue
486
 
 
 
 
 
 
487
  origin_dtype = curr_layer.weight.data.dtype
488
  origin_device = curr_layer.weight.data.device
489
 
490
  curr_layer = curr_layer.to(device, dtype)
 
 
 
 
 
491
  weight_up = elems['lora_up.weight'].to(device, dtype)
492
  weight_down = elems['lora_down.weight'].to(device, dtype)
493
 
@@ -529,28 +520,9 @@ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.fl
529
  key = key.replace(".self_attn.", "_self_attn_")
530
  key = key.replace(".cross_attn.", "_cross_attn_")
531
  key = key.replace(".ffn.", "_ffn_")
532
- key = key.replace("text_embedding.", "text_embedding_")
533
- key = key.replace("time_embedding.", "time_embedding_")
534
  key = key.replace(".lora_A.default.", ".lora_down.")
535
  key = key.replace(".lora_B.default.", ".lora_up.")
536
- key = key.replace(".lora_A.weight", ".lora_down.weight")
537
- key = key.replace(".lora_B.weight", ".lora_up.weight")
538
-
539
- if key.endswith(".lora_down.weight"):
540
- layer = key[:-len(".lora_down.weight")]
541
- elem = "lora_down.weight"
542
- elif key.endswith(".lora_up.weight"):
543
- layer = key[:-len(".lora_up.weight")]
544
- elem = "lora_up.weight"
545
- elif key.endswith(".alpha"):
546
- layer = key[:-len(".alpha")]
547
- elem = "alpha"
548
- else:
549
- continue
550
-
551
- if layer.endswith("."):
552
- layer = layer[:-1]
553
-
554
  updates[layer][elem] = value
555
 
556
  sequential_cpu_offload_flag = False
@@ -617,10 +589,16 @@ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.fl
617
  if error_flag:
618
  continue
619
 
 
 
 
620
  origin_dtype = curr_layer.weight.data.dtype
621
  origin_device = curr_layer.weight.data.device
622
 
623
  curr_layer = curr_layer.to(device, dtype)
 
 
 
624
  weight_up = elems['lora_up.weight'].to(device, dtype)
625
  weight_down = elems['lora_down.weight'].to(device, dtype)
626
 
 
389
  key = key.replace(".self_attn.", "_self_attn_")
390
  key = key.replace(".cross_attn.", "_cross_attn_")
391
  key = key.replace(".ffn.", "_ffn_")
 
 
392
  key = key.replace(".lora_A.default.", ".lora_down.")
393
  key = key.replace(".lora_B.default.", ".lora_up.")
394
+ layer, elem = key.split('.', 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  updates[layer][elem] = value
396
 
397
  sequential_cpu_offload_flag = False
 
465
  if error_flag:
466
  continue
467
 
468
+ # Some resolved modules (e.g., container blocks/norm-only) may not have a weight parameter.
469
+ if not hasattr(curr_layer, "weight"):
470
+ # Skip incompatible / non-leaf modules
471
+ continue
472
+
473
  origin_dtype = curr_layer.weight.data.dtype
474
  origin_device = curr_layer.weight.data.device
475
 
476
  curr_layer = curr_layer.to(device, dtype)
477
+ # Some checkpoints (e.g., norm-only entries) may not contain both weights.
478
+ if 'lora_up.weight' not in elems or 'lora_down.weight' not in elems:
479
+ # Skip incompatible layer instead of raising KeyError
480
+ curr_layer = curr_layer.to(origin_device, origin_dtype)
481
+ continue
482
  weight_up = elems['lora_up.weight'].to(device, dtype)
483
  weight_down = elems['lora_down.weight'].to(device, dtype)
484
 
 
520
  key = key.replace(".self_attn.", "_self_attn_")
521
  key = key.replace(".cross_attn.", "_cross_attn_")
522
  key = key.replace(".ffn.", "_ffn_")
 
 
523
  key = key.replace(".lora_A.default.", ".lora_down.")
524
  key = key.replace(".lora_B.default.", ".lora_up.")
525
+ layer, elem = key.split('.', 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  updates[layer][elem] = value
527
 
528
  sequential_cpu_offload_flag = False
 
589
  if error_flag:
590
  continue
591
 
592
+ if not hasattr(curr_layer, "weight"):
593
+ continue
594
+
595
  origin_dtype = curr_layer.weight.data.dtype
596
  origin_device = curr_layer.weight.data.device
597
 
598
  curr_layer = curr_layer.to(device, dtype)
599
+ if 'lora_up.weight' not in elems or 'lora_down.weight' not in elems:
600
+ curr_layer = curr_layer.to(origin_device, origin_dtype)
601
+ continue
602
  weight_up = elems['lora_up.weight'].to(device, dtype)
603
  weight_down = elems['lora_down.weight'].to(device, dtype)
604