XiangpengYang commited on
Commit
d5bd9f4
·
1 Parent(s): fa9d7c4

dmd lora 4step

Browse files
Files changed (1) hide show
  1. app.py +43 -6
app.py CHANGED
@@ -99,7 +99,7 @@ def load_video_frames(video_path: str, source_frames: int):
99
  return input_video, original_height, original_width
100
 
101
  class VideoCoF_Controller(Wan_Controller):
102
- @spaces.GPU(duration=2000)
103
  @timer
104
  def generate(
105
  self,
@@ -141,6 +141,8 @@ class VideoCoF_Controller(Wan_Controller):
141
  repeat_rope_checkbox=True,
142
  fps=10,
143
  is_api=False,
 
 
144
  ):
145
  self.clear_cache()
146
  print(f"VideoCoF Generation started.")
@@ -168,10 +170,21 @@ class VideoCoF_Controller(Wan_Controller):
168
  self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(scheduler_config)
169
 
170
  # LoRA merging
 
171
  if self.lora_model_path != "none":
172
- print(f"Merge Lora.")
173
  self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
174
 
 
 
 
 
 
 
 
 
 
 
175
  # Seed
176
  if int(seed_textbox) != -1 and seed_textbox != "":
177
  torch.manual_seed(int(seed_textbox))
@@ -232,12 +245,23 @@ class VideoCoF_Controller(Wan_Controller):
232
 
233
  except Exception as e:
234
  print(f"Error: {e}")
 
 
 
 
 
235
  if self.lora_model_path != "none":
 
236
  self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
237
  return gr.update(), gr.update(), f"Error: {str(e)}"
238
 
239
- # Unmerge LoRA
 
 
 
 
240
  if self.lora_model_path != "none":
 
241
  self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
242
 
243
  # Save output
@@ -278,9 +302,15 @@ def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype):
278
  from huggingface_hub import snapshot_download, hf_hub_download
279
  print("Downloading Wan2.1-T2V-14B weights...")
280
  snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-14B", local_dir="Wan-AI/Wan2.1-T2V-14B")
281
- print("Downloading VideoCoF weights...")
282
  os.makedirs("models/Personalized_Model", exist_ok=True)
 
 
283
  hf_hub_download(repo_id="XiangpengYang/VideoCoF", filename="videocof.safetensors", local_dir="models/Personalized_Model")
 
 
 
 
284
  except Exception as e:
285
  print(f"Warning: Failed to pre-download weights: {e}")
286
 
@@ -302,12 +332,17 @@ def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype):
302
  with gr.Column():
303
  sampler_dropdown, sample_step_slider = create_samplers(controller)
304
 
 
 
 
305
  # Custom VideoCoF Params
306
  with gr.Group():
307
  gr.Markdown("### VideoCoF Parameters")
308
  source_frames_slider = gr.Slider(label="Source Frames", minimum=1, maximum=100, value=33, step=1)
309
  reasoning_frames_slider = gr.Slider(label="Reasoning Frames", minimum=1, maximum=20, value=4, step=1)
310
  repeat_rope_checkbox = gr.Checkbox(label="Repeat RoPE", value=True)
 
 
311
 
312
  # Use custom height/width creation to hide/customize
313
  resize_method, width_slider, height_slider, base_resolution = create_height_width_english(
@@ -338,6 +373,7 @@ def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype):
338
  # Set default seed to 0
339
  cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(True)
340
  seed_textbox.value = "0"
 
341
 
342
  generate_button = gr.Button(value="Generate", variant='primary')
343
 
@@ -366,7 +402,7 @@ def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype):
366
  cfg_scale_slider,
367
  start_image,
368
  end_image,
369
- validation_video,
370
  validation_video_mask,
371
  control_video,
372
  denoise_strength,
@@ -382,7 +418,8 @@ def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype):
382
  # New inputs
383
  source_frames_slider,
384
  reasoning_frames_slider,
385
- repeat_rope_checkbox
 
386
  ],
387
  outputs=[result_image, result_video, infer_progress]
388
  )
 
99
  return input_video, original_height, original_width
100
 
101
  class VideoCoF_Controller(Wan_Controller):
102
+ @spaces.GPU(duration=300)
103
  @timer
104
  def generate(
105
  self,
 
141
  repeat_rope_checkbox=True,
142
  fps=10,
143
  is_api=False,
144
+ # New arg for acceleration
145
+ enable_acceleration=False,
146
  ):
147
  self.clear_cache()
148
  print(f"VideoCoF Generation started.")
 
170
  self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(scheduler_config)
171
 
172
  # LoRA merging
173
+ # 1. Merge VideoCoF LoRA
174
  if self.lora_model_path != "none":
175
+ print(f"Merge VideoCoF Lora: {self.lora_model_path}")
176
  self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
177
 
178
+ # 2. Merge Acceleration LoRA (FusionX) if enabled
179
+ acc_lora_path = os.path.join(self.personalized_model_dir, "Wan2.1_Text_to_Video_14B_FusionX_LoRA.safetensors")
180
+ if enable_acceleration:
181
+ if os.path.exists(acc_lora_path):
182
+ print(f"Merge Acceleration LoRA: {acc_lora_path}")
183
+ # FusionX LoRA generally uses multiplier 1.0
184
+ self.pipeline = merge_lora(self.pipeline, acc_lora_path, multiplier=1.0)
185
+ else:
186
+ print(f"Warning: Acceleration LoRA not found at {acc_lora_path}")
187
+
188
  # Seed
189
  if int(seed_textbox) != -1 and seed_textbox != "":
190
  torch.manual_seed(int(seed_textbox))
 
245
 
246
  except Exception as e:
247
  print(f"Error: {e}")
248
+ # Unmerge in case of error (LIFO order)
249
+ if enable_acceleration and os.path.exists(acc_lora_path):
250
+ print("Unmerging Acceleration LoRA (due to error)")
251
+ self.pipeline = unmerge_lora(self.pipeline, acc_lora_path, multiplier=1.0)
252
+
253
  if self.lora_model_path != "none":
254
+ print("Unmerging VideoCoF LoRA (due to error)")
255
  self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
256
  return gr.update(), gr.update(), f"Error: {str(e)}"
257
 
258
+ # Unmerge LoRAs (LIFO order)
259
+ if enable_acceleration and os.path.exists(acc_lora_path):
260
+ print("Unmerging Acceleration LoRA")
261
+ self.pipeline = unmerge_lora(self.pipeline, acc_lora_path, multiplier=1.0)
262
+
263
  if self.lora_model_path != "none":
264
+ print("Unmerging VideoCoF LoRA")
265
  self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
266
 
267
  # Save output
 
302
  from huggingface_hub import snapshot_download, hf_hub_download
303
  print("Downloading Wan2.1-T2V-14B weights...")
304
  snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-14B", local_dir="Wan-AI/Wan2.1-T2V-14B")
305
+
306
  os.makedirs("models/Personalized_Model", exist_ok=True)
307
+
308
+ print("Downloading VideoCoF weights...")
309
  hf_hub_download(repo_id="XiangpengYang/VideoCoF", filename="videocof.safetensors", local_dir="models/Personalized_Model")
310
+
311
+ print("Downloading FusionX Acceleration LoRA...")
312
+ hf_hub_download(repo_id="MonsterMMORPG/Wan_GGUF", filename="Wan2.1_Text_to_Video_14B_FusionX_LoRA.safetensors", local_dir="models/Personalized_Model")
313
+
314
  except Exception as e:
315
  print(f"Warning: Failed to pre-download weights: {e}")
316
 
 
332
  with gr.Column():
333
  sampler_dropdown, sample_step_slider = create_samplers(controller)
334
 
335
+ # Default steps lowered to 4 for acceleration
336
+ sample_step_slider.value = 4
337
+
338
  # Custom VideoCoF Params
339
  with gr.Group():
340
  gr.Markdown("### VideoCoF Parameters")
341
  source_frames_slider = gr.Slider(label="Source Frames", minimum=1, maximum=100, value=33, step=1)
342
  reasoning_frames_slider = gr.Slider(label="Reasoning Frames", minimum=1, maximum=20, value=4, step=1)
343
  repeat_rope_checkbox = gr.Checkbox(label="Repeat RoPE", value=True)
344
+ # Add Acceleration Checkbox
345
+ enable_acceleration = gr.Checkbox(label="Enable 4-step Acceleration (FusionX LoRA)", value=True)
346
 
347
  # Use custom height/width creation to hide/customize
348
  resize_method, width_slider, height_slider, base_resolution = create_height_width_english(
 
373
  # Set default seed to 0
374
  cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(True)
375
  seed_textbox.value = "0"
376
+ cfg_scale_slider.value = 1.0
377
 
378
  generate_button = gr.Button(value="Generate", variant='primary')
379
 
 
402
  cfg_scale_slider,
403
  start_image,
404
  end_image,
405
+ validation_video,
406
  validation_video_mask,
407
  control_video,
408
  denoise_strength,
 
418
  # New inputs
419
  source_frames_slider,
420
  reasoning_frames_slider,
421
+ repeat_rope_checkbox,
422
+ enable_acceleration
423
  ],
424
  outputs=[result_image, result_video, infer_progress]
425
  )