r3gm commited on
Commit
58d027b
·
verified ·
1 Parent(s): a7b0a4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -8
app.py CHANGED
@@ -15,6 +15,8 @@ from tqdm import tqdm
15
  import cv2
16
  import numpy as np
17
  import torch
 
 
18
  from torch.nn import functional as F
19
  from PIL import Image
20
 
@@ -231,9 +233,30 @@ def interpolate_bits(frames_np, multiplier=2, scale=1.0):
231
 
232
  # WAN
233
 
234
- MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
 
 
 
 
235
  CACHE_DIR = os.path.expanduser("~/.cache/huggingface/")
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  MAX_DIM = 832
238
  MIN_DIM = 480
239
  SQUARE_DIM = 640
@@ -258,11 +281,43 @@ SCHEDULER_MAP = {
258
  }
259
 
260
  pipe = WanImageToVideoPipeline.from_pretrained(
261
- "TestOrganizationPleaseIgnore/WAMU_v1_WAN2.2_I2V_LIGHTNING",
262
  torch_dtype=torch.bfloat16,
263
  ).to('cuda')
264
  original_scheduler = copy.deepcopy(pipe.scheduler)
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  if os.path.exists(CACHE_DIR):
267
  shutil.rmtree(CACHE_DIR)
268
  print("Deleted Hugging Face cache.")
@@ -270,8 +325,11 @@ else:
270
  print("No hub cache found.")
271
 
272
  quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
 
273
  quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
 
274
  quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
 
275
 
276
  aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
277
  aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
@@ -283,6 +341,12 @@ default_prompt_i2v = "make this image come alive, cinematic motion, smooth anima
283
  default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
284
 
285
 
 
 
 
 
 
 
286
  def resize_image(image: Image.Image) -> Image.Image:
287
  width, height = image.size
288
  if width == height:
@@ -359,7 +423,7 @@ def get_inference_duration(
359
  factor = num_frames * width * height / BASE_FRAMES_HEIGHT_WIDTH
360
  step_duration = BASE_STEP_DURATION * factor ** 1.5
361
  gen_time = int(steps) * step_duration
362
- print(gen_time)
363
  if guidance_scale > 1:
364
  gen_time = gen_time * 1.8
365
 
@@ -367,10 +431,8 @@ def get_inference_duration(
367
  if frame_factor > 1:
368
  total_out_frames = (num_frames * frame_factor) - num_frames
369
  inter_time = (total_out_frames * 0.02)
370
- print(inter_time)
371
  gen_time += inter_time
372
 
373
- print("Time GPU", gen_time + 10)
374
  return 10 + gen_time
375
 
376
 
@@ -562,7 +624,7 @@ CSS = """
562
 
563
 
564
  with gr.Blocks(delete_cache=(3600, 10800)) as demo:
565
- gr.Markdown("## WAMU - Wan 2.2 I2V (14B) 🐢")
566
  gr.Markdown("#### ℹ️ **A Note on Performance:** This version prioritizes a straightforward setup over maximum speed, so performance may vary.")
567
  gr.Markdown("Run Wan 2.2 in just 4-8 steps, fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU")
568
 
@@ -594,8 +656,11 @@ with gr.Blocks(delete_cache=(3600, 10800)) as demo:
594
  )
595
  flow_shift_slider = gr.Slider(minimum=0.5, maximum=15.0, step=0.1, value=3.0, label="Flow Shift")
596
  play_result_video = gr.Checkbox(label="Display result", value=True, interactive=True)
597
- org_name = "TestOrganizationPleaseIgnore"
598
- gr.Markdown(f"[ZeroGPU help, tips and troubleshooting](https://huggingface.co/datasets/{org_name}/help/blob/main/gpu_help.md)")
 
 
 
599
 
600
  generate_button = gr.Button("Generate Video", variant="primary")
601
 
 
15
  import cv2
16
  import numpy as np
17
  import torch
18
+ import torch._dynamo
19
+ from huggingface_hub import list_models
20
  from torch.nn import functional as F
21
  from PIL import Image
22
 
 
233
 
234
  # WAN
235
 
236
+ ORG_NAME = "TestOrganizationPleaseIgnore"
237
+ # MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
238
+ MODEL_ID = os.getenv("REPO_ID") or random.choice(
239
+ list(list_models(author=ORG_NAME, filter='diffusers:WanImageToVideoPipeline'))
240
+ ).modelId
241
  CACHE_DIR = os.path.expanduser("~/.cache/huggingface/")
242
 
243
+ LORA_MODELS = [
244
+ # {
245
+ # "repo_id": "exampleuser/example_lora_1",
246
+ # "high_tr": "example_lora_1_high.safetensors",
247
+ # "low_tr": "example_lora_1_low.safetensors",
248
+ # "high_scale": 0.5,
249
+ # "low_scale": 0.5
250
+ # },
251
+ # {
252
+ # "repo_id": "exampleuser/example_lora_2",
253
+ # "high_tr": "subfolder/example_lora_2_high.safetensors",
254
+ # "low_tr": "subfolder/example_lora_2_low.safetensors",
255
+ # "high_scale": 0.4,
256
+ # "low_scale": 0.4
257
+ # },
258
+ ]
259
+
260
  MAX_DIM = 832
261
  MIN_DIM = 480
262
  SQUARE_DIM = 640
 
281
  }
282
 
283
  pipe = WanImageToVideoPipeline.from_pretrained(
284
+ MODEL_ID,
285
  torch_dtype=torch.bfloat16,
286
  ).to('cuda')
287
  original_scheduler = copy.deepcopy(pipe.scheduler)
288
 
289
+ for i, lora in enumerate(LORA_MODELS):
290
+ name_high_tr = lora["high_tr"].split(".")[0].split("/")[-1] + "Hh"
291
+ name_low_tr = lora["low_tr"].split(".")[0].split("/")[-1] + "Ll"
292
+
293
+ try:
294
+ pipe.load_lora_weights(
295
+ lora["repo_id"],
296
+ weight_name=lora["high_tr"],
297
+ adapter_name=name_high_tr
298
+ )
299
+
300
+ kwargs_lora = {"load_into_transformer_2": True}
301
+ pipe.load_lora_weights(
302
+ lora["repo_id"],
303
+ weight_name=lora["low_tr"],
304
+ adapter_name=name_low_tr,
305
+ **kwargs_lora
306
+ )
307
+
308
+ pipe.set_adapters([name_high_tr, name_low_tr], adapter_weights=[1.0, 1.0])
309
+
310
+ pipe.fuse_lora(adapter_names=[name_high_tr], lora_scale=lora["high_scale"], components=["transformer"])
311
+ pipe.fuse_lora(adapter_names=[name_low_tr], lora_scale=lora["low_scale"], components=["transformer_2"])
312
+
313
+ pipe.unload_lora_weights()
314
+
315
+ print(f"Applied: {lora['high_tr']}, hs={lora['high_scale']}/ls={lora['low_scale']}, {i+1}/{len(LORA_MODELS)}")
316
+ except Exception as e:
317
+ print("Error:", str(e))
318
+ print("Failed LoRA:", name_high_tr)
319
+ pipe.unload_lora_weights()
320
+
321
  if os.path.exists(CACHE_DIR):
322
  shutil.rmtree(CACHE_DIR)
323
  print("Deleted Hugging Face cache.")
 
325
  print("No hub cache found.")
326
 
327
  quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
328
+ torch._dynamo.reset()
329
  quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
330
+ torch._dynamo.reset()
331
  quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
332
+ torch._dynamo.reset()
333
 
334
  aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
335
  aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
 
341
  default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
342
 
343
 
344
+ def model_title():
345
+ repo_name = MODEL_ID.split('/')[-1].replace("_", " ")
346
+ url = f"https://huggingface.co/{MODEL_ID}"
347
+ return f"## This space is currently running [{repo_name}]({url}) 🐢"
348
+
349
+
350
  def resize_image(image: Image.Image) -> Image.Image:
351
  width, height = image.size
352
  if width == height:
 
423
  factor = num_frames * width * height / BASE_FRAMES_HEIGHT_WIDTH
424
  step_duration = BASE_STEP_DURATION * factor ** 1.5
425
  gen_time = int(steps) * step_duration
426
+
427
  if guidance_scale > 1:
428
  gen_time = gen_time * 1.8
429
 
 
431
  if frame_factor > 1:
432
  total_out_frames = (num_frames * frame_factor) - num_frames
433
  inter_time = (total_out_frames * 0.02)
 
434
  gen_time += inter_time
435
 
 
436
  return 10 + gen_time
437
 
438
 
 
624
 
625
 
626
  with gr.Blocks(delete_cache=(3600, 10800)) as demo:
627
+ gr.Markdown(model_title())
628
  gr.Markdown("#### ℹ️ **A Note on Performance:** This version prioritizes a straightforward setup over maximum speed, so performance may vary.")
629
  gr.Markdown("Run Wan 2.2 in just 4-8 steps, fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU")
630
 
 
656
  )
657
  flow_shift_slider = gr.Slider(minimum=0.5, maximum=15.0, step=0.1, value=3.0, label="Flow Shift")
658
  play_result_video = gr.Checkbox(label="Display result", value=True, interactive=True)
659
+ gr.Markdown(f"[ZeroGPU help, tips and troubleshooting](https://huggingface.co/datasets/{ORG_NAME}/help/blob/main/gpu_help.md)")
660
+ gr.Markdown( # TestOrganizationPleaseIgnore/wamu-tools
661
+ "To use a different model, **duplicate this Space** first, then change the `REPO_ID` environment variable. "
662
+ "[See compatible models here](https://huggingface.co/models?other=diffusers:WanImageToVideoPipeline&sort=trending&search=WAN2.2_I2V_LIGHTNING)."
663
+ )
664
 
665
  generate_button = gr.Button("Generate Video", variant="primary")
666