mirefall commited on
Commit
4da90bf
·
verified ·
1 Parent(s): 293e322

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +646 -0
app.py ADDED
@@ -0,0 +1,646 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import spaces
3
+ import shutil
4
+ import subprocess
5
+ import sys
6
+ import copy
7
+ import random
8
+ import tempfile
9
+ import warnings
10
+ import time
11
+ import gc
12
+ import uuid
13
+ from tqdm import tqdm
14
+
15
+ import cv2
16
+ import numpy as np
17
+ import torch
18
+ from torch.nn import functional as F
19
+ from PIL import Image
20
+
21
+ import gradio as gr
22
+ from diffusers import (
23
+ FlowMatchEulerDiscreteScheduler,
24
+ SASolverScheduler,
25
+ DEISMultistepScheduler,
26
+ DPMSolverMultistepInverseScheduler,
27
+ UniPCMultistepScheduler,
28
+ DPMSolverMultistepScheduler,
29
+ DPMSolverSinglestepScheduler,
30
+ )
31
+ from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
32
+ from diffusers.utils.export_utils import export_to_video
33
+
34
+ from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig
35
+ import aoti
36
+
37
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
38
+ warnings.filterwarnings("ignore")
39
+ IS_ZERO_GPU = bool(os.getenv("SPACES_ZERO_GPU"))
40
+
41
+ # if IS_ZERO_GPU:
42
+ # print("Loading...")
43
+ # subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
44
+
45
+ # --- FRAME EXTRACTION JS & LOGIC ---
46
+
47
+ # JS to grab timestamp from the output video
48
+ get_timestamp_js = """
49
+ function() {
50
+ // Select the video element specifically inside the component with id 'generated-video'
51
+ const video = document.querySelector('#generated-video video');
52
+
53
+ if (video) {
54
+ console.log("Video found! Time: " + video.currentTime);
55
+ return video.currentTime;
56
+ } else {
57
+ console.log("No video element found.");
58
+ return 0;
59
+ }
60
+ }
61
+ """
62
+
63
+
64
+ def extract_frame(video_path, timestamp):
65
+ # Safety check: if no video is present
66
+ if not video_path:
67
+ return None
68
+
69
+ print(f"Extracting frame at timestamp: {timestamp}")
70
+
71
+ cap = cv2.VideoCapture(video_path)
72
+
73
+ if not cap.isOpened():
74
+ return None
75
+
76
+ # Calculate frame number
77
+ fps = cap.get(cv2.CAP_PROP_FPS)
78
+ target_frame_num = int(float(timestamp) * fps)
79
+
80
+ # Cap total frames to prevent errors at the very end of video
81
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
82
+ if target_frame_num >= total_frames:
83
+ target_frame_num = total_frames - 1
84
+
85
+ # Set position
86
+ cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame_num)
87
+ ret, frame = cap.read()
88
+ cap.release()
89
+
90
+ if ret:
91
+ # Convert from BGR (OpenCV) to RGB (Gradio)
92
+ # Gradio Image component handles Numpy array -> PIL conversion automatically
93
+ return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
94
+
95
+ return None
96
+
97
+ # --- END FRAME EXTRACTION LOGIC ---
98
+
99
+
100
+ def clear_vram():
101
+ gc.collect()
102
+ torch.cuda.empty_cache()
103
+
104
+
105
+ # RIFE
106
+ if not os.path.exists("RIFEv4.26_0921.zip"):
107
+ print("Downloading RIFE Model...")
108
+ subprocess.run([
109
+ "wget", "-q",
110
+ "https://huggingface.co/r3gm/RIFE/resolve/main/RIFEv4.26_0921.zip",
111
+ "-O", "RIFEv4.26_0921.zip"
112
+ ], check=True)
113
+ subprocess.run(["unzip", "-o", "RIFEv4.26_0921.zip"], check=True)
114
+
115
+ # sys.path.append(os.getcwd())
116
+
117
+ from train_log.RIFE_HDv3 import Model
118
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
119
+ rife_model = Model()
120
+ rife_model.load_model("train_log", -1)
121
+ rife_model.eval()
122
+
123
+
124
+ @torch.no_grad()
125
+ def interpolate_bits(frames_np, multiplier=2, scale=1.0):
126
+ """
127
+ Interpolation maintaining Numpy Float 0-1 format.
128
+ Args:
129
+ frames_np: Numpy Array (Time, Height, Width, Channels) - Float32 [0.0, 1.0]
130
+ multiplier: int (2, 4, 8)
131
+ Returns:
132
+ List of Numpy Arrays (Height, Width, Channels) - Float32 [0.0, 1.0]
133
+ """
134
+
135
+ # Handle input shape
136
+ if isinstance(frames_np, list):
137
+ # Convert list of arrays to one big array for easier shape handling if needed,
138
+ # but here we just grab dims from first frame
139
+ T = len(frames_np)
140
+ H, W, C = frames_np[0].shape
141
+ else:
142
+ T, H, W, C = frames_np.shape
143
+
144
+ # 1. No Interpolation Case
145
+ if multiplier < 2:
146
+ # Just convert 4D array to list of 3D arrays
147
+ if isinstance(frames_np, np.ndarray):
148
+ return list(frames_np)
149
+ return frames_np
150
+
151
+ n_interp = multiplier - 1
152
+
153
+ # Pre-calc padding for RIFE (requires dimensions divisible by 32/scale)
154
+ tmp = max(128, int(128 / scale))
155
+ ph = ((H - 1) // tmp + 1) * tmp
156
+ pw = ((W - 1) // tmp + 1) * tmp
157
+ padding = (0, pw - W, 0, ph - H)
158
+
159
+ # Helper: Numpy (H, W, C) Float -> Tensor (1, C, H, W) Half
160
+ def to_tensor(frame_np):
161
+ # frame_np is float32 0-1
162
+ t = torch.from_numpy(frame_np).to(device)
163
+ # HWC -> CHW
164
+ t = t.permute(2, 0, 1).unsqueeze(0)
165
+ return F.pad(t, padding).half()
166
+
167
+ # Helper: Tensor (1, C, H, W) Half -> Numpy (H, W, C) Float
168
+ def from_tensor(tensor):
169
+ # Crop padding
170
+ t = tensor[0, :, :H, :W]
171
+ # CHW -> HWC
172
+ t = t.permute(1, 2, 0)
173
+ # Keep as float32, range 0-1
174
+ return t.float().cpu().numpy()
175
+
176
+ def make_inference(I0, I1, n):
177
+ if rife_model.version >= 3.9:
178
+ res = []
179
+ for i in range(n):
180
+ res.append(rife_model.inference(I0, I1, (i+1) * 1. / (n+1), scale))
181
+ return res
182
+ else:
183
+ middle = rife_model.inference(I0, I1, scale)
184
+ if n == 1:
185
+ return [middle]
186
+ first_half = make_inference(I0, middle, n=n//2)
187
+ second_half = make_inference(middle, I1, n=n//2)
188
+ if n % 2:
189
+ return [*first_half, middle, *second_half]
190
+ else:
191
+ return [*first_half, *second_half]
192
+
193
+ output_frames = []
194
+
195
+ # Process Frames
196
+ # Load first frame into GPU
197
+ I1 = to_tensor(frames_np[0])
198
+
199
+ total_steps = T - 1
200
+
201
+ with tqdm(total=total_steps, desc="Interpolating", unit="frame") as pbar:
202
+
203
+ for i in range(total_steps):
204
+ I0 = I1
205
+ # Add original frame to output
206
+ output_frames.append(from_tensor(I0))
207
+
208
+ # Load next frame
209
+ I1 = to_tensor(frames_np[i+1])
210
+
211
+ # Generate intermediate frames
212
+ mid_tensors = make_inference(I0, I1, n_interp)
213
+
214
+ # Append intermediate frames
215
+ for mid in mid_tensors:
216
+ output_frames.append(from_tensor(mid))
217
+
218
+ if (i + 1) % 50 == 0:
219
+ pbar.update(50)
220
+ pbar.update(total_steps % 50)
221
+
222
+ # Add the very last frame
223
+ output_frames.append(from_tensor(I1))
224
+
225
+ # Cleanup
226
+ del I0, I1, mid_tensors
227
+ torch.cuda.empty_cache()
228
+
229
+ return output_frames
230
+
231
+
232
+ # WAN
233
+
234
+ MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
235
+ CACHE_DIR = os.path.expanduser("~/.cache/huggingface/")
236
+
237
+ MAX_DIM = 832
238
+ MIN_DIM = 480
239
+ SQUARE_DIM = 640
240
+ MULTIPLE_OF = 16
241
+ MAX_SEED = np.iinfo(np.int32).max
242
+
243
+ FIXED_FPS = 16
244
+ MIN_FRAMES_MODEL = 8
245
+ MAX_FRAMES_MODEL = 160
246
+
247
+ MIN_DURATION = round(MIN_FRAMES_MODEL / FIXED_FPS, 1)
248
+ MAX_DURATION = round(MAX_FRAMES_MODEL / FIXED_FPS, 1)
249
+
250
+ SCHEDULER_MAP = {
251
+ "FlowMatchEulerDiscrete": FlowMatchEulerDiscreteScheduler,
252
+ "SASolver": SASolverScheduler,
253
+ "DEISMultistep": DEISMultistepScheduler,
254
+ "DPMSolverMultistepInverse": DPMSolverMultistepInverseScheduler,
255
+ "UniPCMultistep": UniPCMultistepScheduler,
256
+ "DPMSolverMultistep": DPMSolverMultistepScheduler,
257
+ "DPMSolverSinglestep": DPMSolverSinglestepScheduler,
258
+ }
259
+
260
+ pipe = WanImageToVideoPipeline.from_pretrained(
261
+ "TestOrganizationPleaseIgnore/WAMU_v2_WAN2.2_I2V_LIGHTNING",
262
+ torch_dtype=torch.bfloat16,
263
+ ).to('cuda')
264
+ original_scheduler = copy.deepcopy(pipe.scheduler)
265
+
266
+ # if os.path.exists(CACHE_DIR):
267
+ # shutil.rmtree(CACHE_DIR)
268
+ # print("Deleted Hugging Face cache.")
269
+ # else:
270
+ # print("No hub cache found.")
271
+
272
+ quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
273
+ quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
274
+ quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
275
+
276
+ aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
277
+ aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
278
+
279
+ # pipe.vae.enable_slicing()
280
+ # pipe.vae.enable_tiling()
281
+
282
+ default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
283
+ default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
284
+
285
+
286
+ def resize_image(image: Image.Image) -> Image.Image:
287
+ """
288
+ Resizes an image to fit within the model's constraints, preserving aspect ratio as much as possible.
289
+ """
290
+ width, height = image.size
291
+ if width == height:
292
+ return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
293
+
294
+ aspect_ratio = width / height
295
+ MAX_ASPECT_RATIO = MAX_DIM / MIN_DIM
296
+ MIN_ASPECT_RATIO = MIN_DIM / MAX_DIM
297
+
298
+ image_to_resize = image
299
+ if aspect_ratio > MAX_ASPECT_RATIO:
300
+ target_w, target_h = MAX_DIM, MIN_DIM
301
+ crop_width = int(round(height * MAX_ASPECT_RATIO))
302
+ left = (width - crop_width) // 2
303
+ image_to_resize = image.crop((left, 0, left + crop_width, height))
304
+ elif aspect_ratio < MIN_ASPECT_RATIO:
305
+ target_w, target_h = MIN_DIM, MAX_DIM
306
+ crop_height = int(round(width / MIN_ASPECT_RATIO))
307
+ top = (height - crop_height) // 2
308
+ image_to_resize = image.crop((0, top, width, top + crop_height))
309
+ else:
310
+ if width > height:
311
+ target_w = MAX_DIM
312
+ target_h = int(round(target_w / aspect_ratio))
313
+ else:
314
+ target_h = MAX_DIM
315
+ target_w = int(round(target_h * aspect_ratio))
316
+
317
+ final_w = round(target_w / MULTIPLE_OF) * MULTIPLE_OF
318
+ final_h = round(target_h / MULTIPLE_OF) * MULTIPLE_OF
319
+ final_w = max(MIN_DIM, min(MAX_DIM, final_w))
320
+ final_h = max(MIN_DIM, min(MAX_DIM, final_h))
321
+ return image_to_resize.resize((final_w, final_h), Image.LANCZOS)
322
+
323
+
324
+ def resize_and_crop_to_match(target_image, reference_image):
325
+ """Resizes and center-crops the target image to match the reference image's dimensions."""
326
+ ref_width, ref_height = reference_image.size
327
+ target_width, target_height = target_image.size
328
+ scale = max(ref_width / target_width, ref_height / target_height)
329
+ new_width, new_height = int(target_width * scale), int(target_height * scale)
330
+ resized = target_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
331
+ left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2
332
+ return resized.crop((left, top, left + ref_width, top + ref_height))
333
+
334
+
335
+ def get_num_frames(duration_seconds: float):
336
+ return 1 + int(np.clip(
337
+ int(round(duration_seconds * FIXED_FPS)),
338
+ MIN_FRAMES_MODEL,
339
+ MAX_FRAMES_MODEL,
340
+ ))
341
+
342
+
343
+ def get_inference_duration(
344
+ resized_image,
345
+ processed_last_image,
346
+ prompt,
347
+ steps,
348
+ negative_prompt,
349
+ num_frames,
350
+ guidance_scale,
351
+ guidance_scale_2,
352
+ current_seed,
353
+ scheduler_name,
354
+ flow_shift,
355
+ frame_multiplier,
356
+ quality,
357
+ duration_seconds,
358
+ progress
359
+ ):
360
+ BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
361
+ BASE_STEP_DURATION = 15
362
+ width, height = resized_image.size
363
+ factor = num_frames * width * height / BASE_FRAMES_HEIGHT_WIDTH
364
+ step_duration = BASE_STEP_DURATION * factor ** 1.5
365
+ gen_time = int(steps) * step_duration
366
+
367
+ if guidance_scale > 1:
368
+ gen_time = gen_time * 1.8
369
+
370
+ frame_factor = frame_multiplier // FIXED_FPS
371
+ if frame_factor > 1:
372
+ total_out_frames = (num_frames * frame_factor) - num_frames
373
+ inter_time = (total_out_frames * 0.02)
374
+ gen_time += inter_time
375
+
376
+ return 15 + gen_time
377
+
378
+
379
+ @spaces.GPU(duration=get_inference_duration)
380
+ def run_inference(
381
+ resized_image,
382
+ processed_last_image,
383
+ prompt,
384
+ steps,
385
+ negative_prompt,
386
+ num_frames,
387
+ guidance_scale,
388
+ guidance_scale_2,
389
+ current_seed,
390
+ scheduler_name,
391
+ flow_shift,
392
+ frame_multiplier,
393
+ quality,
394
+ duration_seconds,
395
+ progress=gr.Progress(track_tqdm=True),
396
+ ):
397
+ scheduler_class = SCHEDULER_MAP.get(scheduler_name)
398
+ if scheduler_class.__name__ != pipe.scheduler.config._class_name or flow_shift != pipe.scheduler.config.get("flow_shift", "shift"):
399
+ config = copy.deepcopy(original_scheduler.config)
400
+ if scheduler_class == FlowMatchEulerDiscreteScheduler:
401
+ config['shift'] = flow_shift
402
+ else:
403
+ config['flow_shift'] = flow_shift
404
+ pipe.scheduler = scheduler_class.from_config(config)
405
+
406
+ clear_vram()
407
+
408
+ task_name = str(uuid.uuid4())[:8]
409
+ print(f"Task: {task_name}, {duration_seconds}, {resized_image.size}, FM={frame_multiplier}")
410
+ start = time.time()
411
+ result = pipe(
412
+ image=resized_image,
413
+ last_image=processed_last_image,
414
+ prompt=prompt,
415
+ negative_prompt=negative_prompt,
416
+ height=resized_image.height,
417
+ width=resized_image.width,
418
+ num_frames=num_frames,
419
+ guidance_scale=float(guidance_scale),
420
+ guidance_scale_2=float(guidance_scale_2),
421
+ num_inference_steps=int(steps),
422
+ generator=torch.Generator(device="cuda").manual_seed(current_seed),
423
+ output_type="np"
424
+ )
425
+
426
+ raw_frames_np = result.frames[0] # Returns (T, H, W, C) float32
427
+ pipe.scheduler = original_scheduler
428
+
429
+ frame_factor = frame_multiplier // FIXED_FPS
430
+ if frame_factor > 1:
431
+ start = time.time()
432
+ rife_model.device()
433
+ rife_model.flownet = rife_model.flownet.half()
434
+ final_frames = interpolate_bits(raw_frames_np, multiplier=int(frame_factor))
435
+ else:
436
+ final_frames = list(raw_frames_np)
437
+
438
+ final_fps = FIXED_FPS * int(frame_factor)
439
+
440
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
441
+ video_path = tmpfile.name
442
+
443
+ start = time.time()
444
+ with tqdm(total=3, desc="Rendering Media", unit="clip") as pbar:
445
+ pbar.update(2)
446
+ export_to_video(final_frames, video_path, fps=final_fps, quality=quality)
447
+ pbar.update(1)
448
+
449
+ return video_path, task_name
450
+
451
+
452
+ def generate_video(
453
+ input_image,
454
+ last_image,
455
+ prompt,
456
+ steps=4,
457
+ negative_prompt=default_negative_prompt,
458
+ duration_seconds=MAX_DURATION,
459
+ guidance_scale=1,
460
+ guidance_scale_2=1,
461
+ seed=42,
462
+ randomize_seed=False,
463
+ quality=5,
464
+ scheduler="UniPCMultistep",
465
+ flow_shift=6.0,
466
+ frame_multiplier=16,
467
+ video_component=True,
468
+ progress=gr.Progress(track_tqdm=True),
469
+ ):
470
+ """
471
+ Generate a video from an input image using the Wan 2.2 14B I2V model with Lightning LoRA.
472
+ This function takes an input image and generates a video animation based on the provided
473
+ prompt and parameters. It uses an FP8 qunatized Wan 2.2 14B Image-to-Video model in with Lightning LoRA
474
+ for fast generation in 4-8 steps.
475
+ Args:
476
+ input_image (PIL.Image): The input image to animate. Will be resized to target dimensions.
477
+ last_image (PIL.Image, optional): The optional last image for the video.
478
+ prompt (str): Text prompt describing the desired animation or motion.
479
+ steps (int, optional): Number of inference steps. More steps = higher quality but slower.
480
+ Defaults to 4. Range: 1-30.
481
+ negative_prompt (str, optional): Negative prompt to avoid unwanted elements.
482
+ Defaults to default_negative_prompt (contains unwanted visual artifacts).
483
+ duration_seconds (float, optional): Duration of the generated video in seconds.
484
+ Defaults to 2. Clamped between MIN_FRAMES_MODEL/FIXED_FPS and MAX_FRAMES_MODEL/FIXED_FPS.
485
+ guidance_scale (float, optional): Controls adherence to the prompt. Higher values = more adherence.
486
+ Defaults to 1.0. Range: 0.0-20.0.
487
+ guidance_scale_2 (float, optional): Controls adherence to the prompt. Higher values = more adherence.
488
+ Defaults to 1.0. Range: 0.0-20.0.
489
+ seed (int, optional): Random seed for reproducible results. Defaults to 42.
490
+ Range: 0 to MAX_SEED (2147483647).
491
+ randomize_seed (bool, optional): Whether to use a random seed instead of the provided seed.
492
+ Defaults to False.
493
+ quality (float, optional): Video output quality. Default is 5. Uses variable bit rate.
494
+ Highest quality is 10, lowest is 1.
495
+ scheduler (str, optional): The name of the scheduler to use for inference. Defaults to "UniPCMultistep".
496
+ flow_shift (float, optional): The flow shift value for compatible schedulers. Defaults to 6.0.
497
+ frame_multiplier (int, optional): The int value for fps enhancer
498
+ video_component(bool, optional): Show video player in output.
499
+ Defaults to True.
500
+ progress (gr.Progress, optional): Gradio progress tracker. Defaults to gr.Progress(track_tqdm=True).
501
+ Returns:
502
+ tuple: A tuple containing:
503
+ - video_path (str): Path for the video component.
504
+ - video_path (str): Path for the file download component. Attempt to avoid reconversion in video component.
505
+ - current_seed (int): The seed used for generation.
506
+ Raises:
507
+ gr.Error: If input_image is None (no image uploaded).
508
+ Note:
509
+ - Frame count is calculated as duration_seconds * FIXED_FPS (24)
510
+ - Output dimensions are adjusted to be multiples of MOD_VALUE (32)
511
+ - The function uses GPU acceleration via the @spaces.GPU decorator
512
+ - Generation time varies based on steps and duration (see get_duration function)
513
+ """
514
+
515
+ if input_image is None:
516
+ raise gr.Error("Please upload an input image.")
517
+
518
+ num_frames = get_num_frames(duration_seconds)
519
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
520
+ resized_image = resize_image(input_image)
521
+
522
+ processed_last_image = None
523
+ if last_image:
524
+ processed_last_image = resize_and_crop_to_match(last_image, resized_image)
525
+
526
+ video_path, task_n = run_inference(
527
+ resized_image,
528
+ processed_last_image,
529
+ prompt,
530
+ steps,
531
+ negative_prompt,
532
+ num_frames,
533
+ guidance_scale,
534
+ guidance_scale_2,
535
+ current_seed,
536
+ scheduler,
537
+ flow_shift,
538
+ frame_multiplier,
539
+ quality,
540
+ duration_seconds,
541
+ progress,
542
+ )
543
+ print(f"GPU complete: {task_n}")
544
+
545
+ return (video_path if video_component else None), video_path, current_seed
546
+
547
+
548
+ CSS = """
549
+ #hidden-timestamp {
550
+ opacity: 0;
551
+ height: 0px;
552
+ width: 0px;
553
+ margin: 0px;
554
+ padding: 0px;
555
+ overflow: hidden;
556
+ position: absolute;
557
+ pointer-events: none;
558
+ }
559
+ """
560
+
561
+
562
+ with gr.Blocks(theme=gr.themes.Soft(), css=CSS, delete_cache=(3600, 10800)) as demo:
563
+ gr.Markdown("## WAMU V2 - Wan 2.2 I2V (14B) 🐢🐢")
564
+ gr.Markdown("#### ℹ️ **A Note on Performance:** This version prioritizes a straightforward setup over maximum speed, so performance may vary.")
565
+ gr.Markdown('Try the alternative version: [WAMU space](https://huggingface.co/spaces/r3gm/wan2-2-fp8da-aoti-preview2)')
566
+ gr.Markdown("Run Wan 2.2 in just 4-8 steps, fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU.")
567
+
568
+ with gr.Row():
569
+ with gr.Column():
570
+ input_image_component = gr.Image(type="pil", label="Input Image", sources=["upload", "clipboard"])
571
+ prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
572
+ duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
573
+ frame_multi = gr.Dropdown(
574
+ choices=[FIXED_FPS, FIXED_FPS*2, FIXED_FPS*4],
575
+ value=FIXED_FPS,
576
+ label="Video Fluidity (Frames per Second)",
577
+ info="Extra frames will be generated using flow estimation, which estimates motion between frames to make the video smoother."
578
+ )
579
+ with gr.Accordion("Advanced Settings", open=False):
580
+ last_image_component = gr.Image(type="pil", label="Last Image (Optional)", sources=["upload", "clipboard"])
581
+ negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, info="Used if any Guidance Scale > 1.", lines=3)
582
+ quality_slider = gr.Slider(minimum=1, maximum=10, step=1, value=6, label="Video Quality", info="If set to 10, the generated video may be too large and won't play in the Gradio preview.")
583
+ seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
584
+ randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
585
+ steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=6, label="Inference Steps")
586
+ guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale - high noise stage", info="Values above 1 increase GPU usage and may take longer to process.")
587
+ guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale 2 - low noise stage")
588
+ scheduler_dropdown = gr.Dropdown(
589
+ label="Scheduler",
590
+ choices=list(SCHEDULER_MAP.keys()),
591
+ value="UniPCMultistep",
592
+ info="Select a custom scheduler."
593
+ )
594
+ flow_shift_slider = gr.Slider(minimum=0.5, maximum=15.0, step=0.1, value=3.0, label="Flow Shift")
595
+ play_result_video = gr.Checkbox(label="Display result", value=True, interactive=True)
596
+
597
+ generate_button = gr.Button("Generate Video", variant="primary")
598
+
599
+ with gr.Column():
600
+ # ASSIGNED elem_id="generated-video" so JS can find it
601
+ video_output = gr.Video(label="Generated Video", autoplay=True, sources=["upload"], show_download_button=True, show_share_button=True, interactive=False, elem_id="generated-video")
602
+
603
+ # --- Frame Grabbing UI ---
604
+ with gr.Row():
605
+ grab_frame_btn = gr.Button("📸 Use Current Frame as Input", variant="secondary")
606
+ timestamp_box = gr.Number(value=0, label="Timestamp", visible=True, elem_id="hidden-timestamp")
607
+ # -------------------------
608
+
609
+ file_output = gr.File(label="Download Video")
610
+
611
+ ui_inputs = [
612
+ input_image_component, last_image_component, prompt_input, steps_slider,
613
+ negative_prompt_input, duration_seconds_input,
614
+ guidance_scale_input, guidance_scale_2_input, seed_input, randomize_seed_checkbox,
615
+ quality_slider, scheduler_dropdown, flow_shift_slider, frame_multi,
616
+ play_result_video
617
+ ]
618
+
619
+ generate_button.click(
620
+ fn=generate_video,
621
+ inputs=ui_inputs,
622
+ outputs=[video_output, file_output, seed_input]
623
+ )
624
+
625
+ # --- Frame Grabbing Events ---
626
+ # 1. Click button -> JS runs -> puts time in hidden number box
627
+ grab_frame_btn.click(
628
+ fn=None,
629
+ inputs=None,
630
+ outputs=[timestamp_box],
631
+ js=get_timestamp_js
632
+ )
633
+
634
+ # 2. Hidden number box changes -> Python runs -> puts frame in Input Image
635
+ timestamp_box.change(
636
+ fn=extract_frame,
637
+ inputs=[video_output, timestamp_box],
638
+ outputs=[input_image_component]
639
+ )
640
+
641
+ if __name__ == "__main__":
642
+ demo.queue().launch(
643
+ mcp_server=True,
644
+ ssr_mode=False,
645
+ show_error=True,
646
+ )