r3gm commited on
Commit
330958f
·
verified ·
1 Parent(s): 13df71f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -14
app.py CHANGED
@@ -35,6 +35,60 @@ import aoti
35
  os.environ["TOKENIZERS_PARALLELISM"] = "true"
36
  warnings.filterwarnings("ignore")
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  # RIFE
39
  if not os.path.exists("RIFEv4.26_0921.zip"):
40
  print("Downloading RIFE Model...")
@@ -347,23 +401,22 @@ def run_inference(
347
  raw_frames_np = result.frames[0] # Returns (T, H, W, C) float32
348
  pipe.scheduler = original_scheduler
349
 
350
- start = time.time()
351
  if frame_multiplier > 1:
 
352
  print(f"Processing frames (RIFE Multiplier: {frame_multiplier}x)...")
353
  final_frames = interpolate_bits(raw_frames_np, multiplier=int(frame_multiplier))
 
354
  else:
355
  final_frames = list(raw_frames_np)
356
- print("Interpolation time passed:", time.time() - start)
357
 
358
  final_fps = FIXED_FPS * int(frame_multiplier)
359
 
360
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
361
  video_path = tmpfile.name
362
 
363
- print(f"Exporting video at {final_fps} FPS...")
364
  start = time.time()
365
  export_to_video(final_frames, video_path, fps=final_fps, quality=quality)
366
- print("Export time passed:", time.time() - start)
367
 
368
  return video_path
369
 
@@ -382,6 +435,7 @@ def generate_video(
382
  scheduler="UniPCMultistep",
383
  flow_shift=6.0,
384
  frame_multiplier=1,
 
385
  progress=gr.Progress(track_tqdm=True),
386
  ):
387
  """
@@ -412,6 +466,8 @@ def generate_video(
412
  scheduler (str, optional): The name of the scheduler to use for inference. Defaults to "UniPCMultistep".
413
  flow_shift (float, optional): The flow shift value for compatible schedulers. Defaults to 6.0.
414
  frame_multiplier (int, optional): The int value for fps enhancer
 
 
415
  progress (gr.Progress, optional): Gradio progress tracker. Defaults to gr.Progress(track_tqdm=True).
416
  Returns:
417
  tuple: A tuple containing:
@@ -457,12 +513,26 @@ def generate_video(
457
  )
458
  print("GPU complete")
459
 
460
- return video_path, video_path, current_seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
 
462
 
463
  with gr.Blocks(delete_cache=(3600, 10800)) as demo:
464
- gr.Markdown("# WAMU - Wan 2.2 I2V (14B)")
465
- gr.Markdown("## ℹ️ **A Note on Performance:** This version prioritizes a straightforward setup over maximum speed, so performance may vary.")
466
  gr.Markdown("run Wan 2.2 in just 4-8 steps, fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU⚡️")
467
 
468
  with gr.Row():
@@ -475,16 +545,16 @@ with gr.Blocks(delete_cache=(3600, 10800)) as demo:
475
  choices=[1, 2, 4, 8],
476
  value=1,
477
  label="Frame Rate Enhancer (Interpolation)",
478
- info="2 = Double FPS (e.g. 16 -> 32). Higher multipliers create more intermediate frames."
479
  )
480
  with gr.Accordion("Advanced Settings", open=False):
481
  last_image_component = gr.Image(type="pil", label="Last Image (Optional)")
482
  negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, info="Used if any Guidance Scale > 1.", lines=3)
483
- quality_slider = gr.Slider(minimum=1, maximum=10, step=1, value=6, label="Video Quality")
484
  seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
485
  randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
486
- guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale - high noise stage")
487
- guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale 2 - low noise stage")
488
  scheduler_dropdown = gr.Dropdown(
489
  label="Scheduler",
490
  choices=list(SCHEDULER_MAP.keys()),
@@ -492,18 +562,29 @@ with gr.Blocks(delete_cache=(3600, 10800)) as demo:
492
  info="Select a custom scheduler."
493
  )
494
  flow_shift_slider = gr.Slider(minimum=0.5, maximum=15.0, step=0.1, value=3.0, label="Flow Shift")
 
 
495
 
496
  generate_button = gr.Button("Generate Video", variant="primary")
497
 
498
  with gr.Column():
499
- video_output = gr.Video(label="Generated Video", autoplay=True)
 
 
 
 
 
 
 
 
500
  file_output = gr.File(label="Download Video")
501
 
502
  ui_inputs = [
503
  input_image_component, last_image_component, prompt_input, steps_slider,
504
  negative_prompt_input, duration_seconds_input,
505
  guidance_scale_input, guidance_scale_2_input, seed_input, randomize_seed_checkbox,
506
- quality_slider, scheduler_dropdown, flow_shift_slider, frame_multi
 
507
  ]
508
 
509
  generate_button.click(
@@ -511,6 +592,25 @@ with gr.Blocks(delete_cache=(3600, 10800)) as demo:
511
  inputs=ui_inputs,
512
  outputs=[video_output, file_output, seed_input]
513
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
 
515
  if __name__ == "__main__":
516
- demo.queue().launch(mcp_server=True)
 
 
 
 
35
  os.environ["TOKENIZERS_PARALLELISM"] = "true"
36
  warnings.filterwarnings("ignore")
37
 
38
+ # --- FRAME EXTRACTION JS & LOGIC ---
39
+
40
+ # JS to grab timestamp from the output video
41
+ get_timestamp_js = """
42
+ function() {
43
+ // Select the video element specifically inside the component with id 'generated-video'
44
+ const video = document.querySelector('#generated-video video');
45
+
46
+ if (video) {
47
+ console.log("Video found! Time: " + video.currentTime);
48
+ return video.currentTime;
49
+ } else {
50
+ console.log("No video element found.");
51
+ return 0;
52
+ }
53
+ }
54
+ """
55
+
56
+ def extract_frame(video_path, timestamp):
57
+ # Safety check: if no video is present
58
+ if not video_path:
59
+ return None
60
+
61
+ print(f"Extracting frame at timestamp: {timestamp}")
62
+
63
+ cap = cv2.VideoCapture(video_path)
64
+
65
+ if not cap.isOpened():
66
+ return None
67
+
68
+ # Calculate frame number
69
+ fps = cap.get(cv2.CAP_PROP_FPS)
70
+ target_frame_num = int(float(timestamp) * fps)
71
+
72
+ # Cap total frames to prevent errors at the very end of video
73
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
74
+ if target_frame_num >= total_frames:
75
+ target_frame_num = total_frames - 1
76
+
77
+ # Set position
78
+ cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame_num)
79
+ ret, frame = cap.read()
80
+ cap.release()
81
+
82
+ if ret:
83
+ # Convert from BGR (OpenCV) to RGB (Gradio)
84
+ # Gradio Image component handles Numpy array -> PIL conversion automatically
85
+ return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
86
+
87
+ return None
88
+
89
+ # --- END FRAME EXTRACTION LOGIC ---
90
+
91
+
92
  # RIFE
93
  if not os.path.exists("RIFEv4.26_0921.zip"):
94
  print("Downloading RIFE Model...")
 
401
  raw_frames_np = result.frames[0] # Returns (T, H, W, C) float32
402
  pipe.scheduler = original_scheduler
403
 
 
404
  if frame_multiplier > 1:
405
+ start = time.time()
406
  print(f"Processing frames (RIFE Multiplier: {frame_multiplier}x)...")
407
  final_frames = interpolate_bits(raw_frames_np, multiplier=int(frame_multiplier))
408
+ print("Interpolation time passed:", time.time() - start)
409
  else:
410
  final_frames = list(raw_frames_np)
 
411
 
412
  final_fps = FIXED_FPS * int(frame_multiplier)
413
 
414
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
415
  video_path = tmpfile.name
416
 
 
417
  start = time.time()
418
  export_to_video(final_frames, video_path, fps=final_fps, quality=quality)
419
+ print(f"Export time passed, {final_fps} FPS:", time.time() - start)
420
 
421
  return video_path
422
 
 
435
  scheduler="UniPCMultistep",
436
  flow_shift=6.0,
437
  frame_multiplier=1,
438
+ video_component=True,
439
  progress=gr.Progress(track_tqdm=True),
440
  ):
441
  """
 
466
  scheduler (str, optional): The name of the scheduler to use for inference. Defaults to "UniPCMultistep".
467
  flow_shift (float, optional): The flow shift value for compatible schedulers. Defaults to 6.0.
468
  frame_multiplier (int, optional): The int value for fps enhancer
469
+ video_component(bool, optional): Show video player in output.
470
+ Defaults to True.
471
  progress (gr.Progress, optional): Gradio progress tracker. Defaults to gr.Progress(track_tqdm=True).
472
  Returns:
473
  tuple: A tuple containing:
 
513
  )
514
  print("GPU complete")
515
 
516
+ return (video_path if video_component else None), video_path, current_seed
517
+
518
+
519
+ CSS = """
520
+ #hidden-timestamp {
521
+ opacity: 0;
522
+ height: 0px;
523
+ width: 0px;
524
+ margin: 0px;
525
+ padding: 0px;
526
+ overflow: hidden;
527
+ position: absolute;
528
+ pointer-events: none;
529
+ }
530
+ """
531
 
532
 
533
  with gr.Blocks(delete_cache=(3600, 10800)) as demo:
534
+ gr.Markdown("## WAMU - Wan 2.2 I2V (14B)")
535
+ gr.Markdown("#### ℹ️ **A Note on Performance:** This version prioritizes a straightforward setup over maximum speed, so performance may vary.")
536
  gr.Markdown("run Wan 2.2 in just 4-8 steps, fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU⚡️")
537
 
538
  with gr.Row():
 
545
  choices=[1, 2, 4, 8],
546
  value=1,
547
  label="Frame Rate Enhancer (Interpolation)",
548
+ info="Increases video fluidity. Example: 2x converts 16 FPS -> 32 FPS."
549
  )
550
  with gr.Accordion("Advanced Settings", open=False):
551
  last_image_component = gr.Image(type="pil", label="Last Image (Optional)")
552
  negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, info="Used if any Guidance Scale > 1.", lines=3)
553
+ quality_slider = gr.Slider(minimum=1, maximum=10, step=1, value=6, label="Video Quality", info="If set to 10, the generated video may be too large and won't play in the Gradio preview.")
554
  seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
555
  randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
556
+ guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale - high noise stage", info="Values above 1 increase GPU usage and may take longer to process.")
557
+ guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale 2 - low noise stage", info="Values above 1 increase GPU usage and may take longer to process.")
558
  scheduler_dropdown = gr.Dropdown(
559
  label="Scheduler",
560
  choices=list(SCHEDULER_MAP.keys()),
 
562
  info="Select a custom scheduler."
563
  )
564
  flow_shift_slider = gr.Slider(minimum=0.5, maximum=15.0, step=0.1, value=3.0, label="Flow Shift")
565
+ play_result_video = gr.Checkbox(label="Display result", value=True, interactive=True)
566
+ gr.Markdown("[ZeroGPU Help, Tips, and Troubleshooting](https://huggingface.co/datasets/TestOrganizationPleaseIgnore/help/blob/main/gpu_help.md)")
567
 
568
  generate_button = gr.Button("Generate Video", variant="primary")
569
 
570
  with gr.Column():
571
+ # ASSIGNED elem_id="generated-video" so JS can find it
572
+ video_output = gr.Video(label="Generated Video", autoplay=True, elem_id="generated-video")
573
+
574
+ # --- Frame Grabbing UI ---
575
+ with gr.Row():
576
+ grab_frame_btn = gr.Button("📸 Use Current Frame as Input", variant="secondary")
577
+ timestamp_box = gr.Number(value=0, label="Timestamp", visible=True, elem_id="hidden-timestamp")
578
+ # -------------------------
579
+
580
  file_output = gr.File(label="Download Video")
581
 
582
  ui_inputs = [
583
  input_image_component, last_image_component, prompt_input, steps_slider,
584
  negative_prompt_input, duration_seconds_input,
585
  guidance_scale_input, guidance_scale_2_input, seed_input, randomize_seed_checkbox,
586
+ quality_slider, scheduler_dropdown, flow_shift_slider, frame_multi,
587
+ play_result_video
588
  ]
589
 
590
  generate_button.click(
 
592
  inputs=ui_inputs,
593
  outputs=[video_output, file_output, seed_input]
594
  )
595
+
596
+ # --- Frame Grabbing Events ---
597
+ # 1. Click button -> JS runs -> puts time in hidden number box
598
+ grab_frame_btn.click(
599
+ fn=None,
600
+ inputs=None,
601
+ outputs=[timestamp_box],
602
+ js=get_timestamp_js
603
+ )
604
+
605
+ # 2. Hidden number box changes -> Python runs -> puts frame in Input Image
606
+ timestamp_box.change(
607
+ fn=extract_frame,
608
+ inputs=[video_output, timestamp_box],
609
+ outputs=[input_image_component]
610
+ )
611
 
612
  if __name__ == "__main__":
613
+ demo.queue().launch(
614
+ mcp_server=True,
615
+ css=CSS,
616
+ )