Spaces:
Runtime error
Runtime error
| """ | |
| WAN 2.2 Multi-Task Video Generation - 3-Step V2V Pipeline | |
| I2V: Lightning 14B (6 steps, FP8+AoT) | |
| T2V: Lightning 14B (4 steps, Lightning LoRA + FP8) | |
| V2V: 3-Step Pipeline (SAM2 → Composite → VACE) | |
| Step 1: SAM2 video segmentation (click points → mask video) | |
| Step 2: ImageComposite (original + mask → composite video) | |
| Step 3: VACE generation (composite + grown mask + ref image + prompt → final) | |
| LoRA: from lkzd7/WAN2.2_LoraSet_NSFW (I2V only) | |
| """ | |
| import os | |
| import spaces | |
| import shutil | |
| import subprocess | |
| import copy | |
| import random | |
| import tempfile | |
| import warnings | |
| import time | |
| import gc | |
| import uuid | |
| from tqdm import tqdm | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from torch.nn import functional as F | |
| from PIL import Image, ImageFilter | |
| import gradio as gr | |
| from diffusers import ( | |
| AutoencoderKLWan, | |
| FlowMatchEulerDiscreteScheduler, | |
| WanPipeline, | |
| SASolverScheduler, | |
| DEISMultistepScheduler, | |
| DPMSolverMultistepInverseScheduler, | |
| UniPCMultistepScheduler, | |
| DPMSolverMultistepScheduler, | |
| DPMSolverSinglestepScheduler, | |
| ) | |
| from diffusers.models.transformers.transformer_wan import WanTransformer3DModel | |
| from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline | |
| from diffusers.pipelines.wan.pipeline_wan_vace import WanVACEPipeline | |
| from diffusers.utils.export_utils import export_to_video | |
| from diffusers.utils import load_video | |
| from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig | |
| import aoti | |
| import lora_loader | |
| # SAM2 for video mask generation | |
| from sam2.sam2_video_predictor import SAM2VideoPredictor | |
| os.environ["TOKENIZERS_PARALLELISM"] = "true" | |
| warnings.filterwarnings("ignore") | |
| def clear_vram(): | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| # ============ RIFE ============ | |
| get_timestamp_js = """ | |
| function() { | |
| const video = document.querySelector('#generated-video video'); | |
| if (video) { return video.currentTime; } | |
| return 0; | |
| } | |
| """ | |
| def extract_frame(video_path, timestamp): | |
| if not video_path: | |
| return None | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| return None | |
| fps = cap.get(cv2.CAP_FPS) | |
| target_frame_num = int(float(timestamp) * fps) | |
| total_frames = int(cap.get(cv2.CAP_FRAME_COUNT)) | |
| if target_frame_num >= total_frames: | |
| target_frame_num = total_frames - 1 | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame_num) | |
| ret, frame = cap.read() | |
| cap.release() | |
| if ret: | |
| return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| return None | |
| if not os.path.exists("RIFEv4.26_0921.zip"): | |
| print("Downloading RIFE Model...") | |
| subprocess.run(["wget", "-q", "https://huggingface.co/r3gm/RIFE/resolve/main/RIFEv4.26_0921.zip", "-O", "RIFEv4.26_0921.zip"], check=True) | |
| subprocess.run(["unzip", "-o", "RIFEv4.26_0921.zip"], check=True) | |
| from train_log.RIFE_HDv3 import Model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| rife_model = Model() | |
| rife_model.load_model("train_log", -1) | |
| rife_model.eval() | |
| def interpolate_bits(frames_np, multiplier=2, scale=1.0): | |
| if isinstance(frames_np, list): | |
| T = len(frames_np) | |
| H, W, C = frames_np[0].shape | |
| else: | |
| T, H, W, C = frames_np.shape | |
| if multiplier < 2: | |
| return list(frames_np) if isinstance(frames_np, np.ndarray) else frames_np | |
| n_interp = multiplier - 1 | |
| tmp = max(128, int(128 / scale)) | |
| ph = ((H - 1) // tmp + 1) * tmp | |
| pw = ((W - 1) // tmp + 1) * tmp | |
| padding = (0, pw - W, 0, ph - H) | |
| def to_tensor(frame_np): | |
| t = torch.from_numpy(frame_np).to(device) | |
| t = t.permute(2, 0, 1).unsqueeze(0) | |
| return F.pad(t, padding).half() | |
| def from_tensor(tensor): | |
| t = tensor[0, :, :H, :W] | |
| return t.permute(1, 2, 0).float().cpu().numpy() | |
| def make_inference(I0, I1, n): | |
| if rife_model.version >= 3.9: | |
| return [rife_model.inference(I0, I1, (i+1) * 1. / (n+1), scale) for i in range(n)] | |
| else: | |
| middle = rife_model.inference(I0, I1, scale) | |
| if n == 1: return [middle] | |
| first_half = make_inference(I0, middle, n//2) | |
| second_half = make_inference(middle, I1, n//2) | |
| return [*first_half, middle, *second_half] if n % 2 else [*first_half, *second_half] | |
| output_frames = [] | |
| I1 = to_tensor(frames_np[0]) | |
| with tqdm(total=T-1, desc="Interpolating", unit="frame") as pbar: | |
| for i in range(T - 1): | |
| I0 = I1 | |
| output_frames.append(from_tensor(I0)) | |
| I1 = to_tensor(frames_np[i+1]) | |
| for mid in make_inference(I0, I1, n_interp): | |
| output_frames.append(from_tensor(mid)) | |
| if (i + 1) % 50 == 0: | |
| pbar.update(50) | |
| pbar.update((T-1) % 50) | |
| output_frames.append(from_tensor(I1)) | |
| del I0, I1 | |
| torch.cuda.empty_cache() | |
| return output_frames | |
| # ============ Config ============ | |
| FIXED_FPS = 16 | |
| MAX_FRAMES_MODEL = 241 # ~15s@16fps, requires more VRAM/time | |
| MAX_SEED = np.iinfo(np.int32).max | |
| SCHEDULER_MAP = { | |
| "FlowMatchEulerDiscrete": FlowMatchEulerDiscreteScheduler, | |
| "SASolver": SASolverScheduler, | |
| "DEISMultistep": DEISMultistepScheduler, | |
| "DPMSolverMultistepInverse": DPMSolverMultistepInverseScheduler, | |
| "UniPCMultistep": UniPCMultistepScheduler, | |
| "DPMSolverMultistep": DPMSolverMultistepScheduler, | |
| "DPMSolverSinglestep": DPMSolverSinglestepScheduler, | |
| } | |
| default_negative_prompt = ( | |
| "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, " | |
| "still image, overall gray, worst quality, low quality, JPEG artifacts, ugly, incomplete, " | |
| "extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, " | |
| "malformed limbs, fused fingers, still frame, messy background, three legs, " | |
| "many people in background, walking backwards, watermark, text, signature" | |
| ) | |
| # ============ Load I2V Pipeline (Lightning, AoT compiled) ============ | |
| print("Loading I2V Pipeline (Lightning 14B)...") | |
| i2v_pipe = WanImageToVideoPipeline.from_pretrained( | |
| "TestOrganizationPleaseIgnore/WAMU_v2_WAN2.2_I2V_LIGHTNING", | |
| torch_dtype=torch.bfloat16, | |
| ).to('cuda') | |
| i2v_original_scheduler = copy.deepcopy(i2v_pipe.scheduler) | |
| quantize_(i2v_pipe.text_encoder, Int8WeightOnlyConfig()) | |
| major, minor = torch.cuda.get_device_capability() | |
| supports_fp8 = (major > 8) or (major == 8 and minor >= 9) | |
| if supports_fp8: | |
| quantize_(i2v_pipe.transformer, Float8DynamicActivationFloat8WeightConfig()) | |
| quantize_(i2v_pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig()) | |
| aoti.aoti_blocks_load(i2v_pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da') | |
| aoti.aoti_blocks_load(i2v_pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da') | |
| else: | |
| quantize_(i2v_pipe.transformer, Int8WeightOnlyConfig()) | |
| quantize_(i2v_pipe.transformer_2, Int8WeightOnlyConfig()) | |
| # ============ T2V Pipeline (on-demand, 14B + Wan22 Lightning LoRA) ============ | |
| # Use T2V-A14B + Wan22 Lightning LoRA (separate HIGH/LOW for dual transformer) | |
| # Load on-demand with CPU offload to avoid OOM alongside I2V | |
| T2V_MODEL_ID = "Wan-AI/Wan2.2-T2V-A14B-Diffusers" | |
| T2V_LORA_REPO = "Kijai/WanVideo_comfy" | |
| T2V_LORA_HIGH = "LoRAs/Wan22-Lightning/Wan22_A14B_T2V_HIGH_Lightning_4steps_lora_250928_rank128_fp16.safetensors" | |
| T2V_LORA_LOW = "LoRAs/Wan22-Lightning/Wan22_A14B_T2V_LOW_Lightning_4steps_lora_250928_rank64_fp16.safetensors" | |
| t2v_pipe = None | |
| t2v_ready = False | |
| def load_t2v_pipeline(): | |
| """Load T2V 14B + Lightning LoRA on-demand with CPU offload.""" | |
| global t2v_pipe, t2v_ready | |
| if t2v_pipe is not None and t2v_ready: | |
| print("T2V pipeline reused from memory") | |
| return t2v_pipe | |
| print("Loading T2V Pipeline (14B + Lightning LoRA) first time...") | |
| # Move I2V components to CPU to make room | |
| i2v_pipe.to('cpu') | |
| clear_vram() | |
| t2v_vae = AutoencoderKLWan.from_pretrained(T2V_MODEL_ID, subfolder="vae", torch_dtype=torch.float32) | |
| t2v_pipe = WanPipeline.from_pretrained( | |
| T2V_MODEL_ID, | |
| transformer=WanTransformer3DModel.from_pretrained( | |
| 'linoyts/Wan2.2-T2V-A14B-Diffusers-BF16', | |
| subfolder='transformer', | |
| torch_dtype=torch.bfloat16, | |
| ), | |
| transformer_2=WanTransformer3DModel.from_pretrained( | |
| 'linoyts/Wan2.2-T2V-A14B-Diffusers-BF16', | |
| subfolder='transformer_2', | |
| torch_dtype=torch.bfloat16, | |
| ), | |
| vae=t2v_vae, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| # Load and fuse Lightning LoRAs (HIGH for transformer, LOW for transformer_2) | |
| print("Fusing Lightning LoRA HIGH (transformer)...") | |
| from safetensors.torch import load_file | |
| from huggingface_hub import hf_hub_download | |
| # Download LoRA files | |
| high_path = hf_hub_download(T2V_LORA_REPO, T2V_LORA_HIGH) | |
| low_path = hf_hub_download(T2V_LORA_REPO, T2V_LORA_LOW) | |
| # Load HIGH LoRA into transformer | |
| t2v_pipe.load_lora_weights(high_path, adapter_name="lightning_high") | |
| t2v_pipe.set_adapters(["lightning_high"], adapter_weights=[1.0]) | |
| t2v_pipe.fuse_lora(adapter_names=["lightning_high"], lora_scale=1.0, components=["transformer"]) | |
| t2v_pipe.unload_lora_weights() | |
| # Load LOW LoRA into transformer_2 | |
| print("Fusing Lightning LoRA LOW (transformer_2)...") | |
| t2v_pipe.load_lora_weights(low_path, adapter_name="lightning_low", load_into_transformer_2=True) | |
| t2v_pipe.set_adapters(["lightning_low"], adapter_weights=[1.0]) | |
| t2v_pipe.fuse_lora(adapter_names=["lightning_low"], lora_scale=1.0, components=["transformer_2"]) | |
| t2v_pipe.unload_lora_weights() | |
| # Use model CPU offload — only one component on GPU at a time | |
| t2v_pipe.enable_model_cpu_offload() | |
| t2v_ready = True | |
| print("T2V pipeline ready (14B + Lightning + CPU offload)") | |
| return t2v_pipe | |
| def unload_t2v_pipeline(): | |
| """Restore I2V to GPU after T2V is done.""" | |
| clear_vram() | |
| i2v_pipe.to('cuda') | |
| print("I2V restored to GPU") | |
| # Keep cache for on-demand T2V loading | |
| # ============ SAM2 Video Segmentation ============ | |
| sam2_predictor = None | |
| def get_sam2_predictor(): | |
| global sam2_predictor | |
| if sam2_predictor is None: | |
| print("Loading SAM2.1 hiera-large...") | |
| sam2_predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2.1-hiera-large") | |
| print("SAM2 loaded") | |
| return sam2_predictor | |
| def extract_first_frame_from_video(video_path): | |
| """Extract first frame from video as PIL Image.""" | |
| cap = cv2.VideoCapture(video_path) | |
| ret, frame = cap.read() | |
| cap.release() | |
| if ret: | |
| return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| return None | |
| def video_to_frames_dir(video_path, max_frames=None): | |
| """Extract video frames to a temp directory for SAM2.""" | |
| tmp_dir = tempfile.mkdtemp(prefix="sam2_frames_") | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 16 | |
| idx = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if max_frames and idx >= max_frames: | |
| break | |
| cv2.imwrite(os.path.join(tmp_dir, f"{idx:05d}.jpg"), frame) | |
| idx += 1 | |
| cap.release() | |
| print(f"Extracted {idx} frames to {tmp_dir} (fps={fps:.1f})") | |
| return tmp_dir, idx, fps | |
| def generate_mask_video(video_path, points_json, num_frames_limit=None): | |
| """Generate mask video using SAM2 from user-clicked points.""" | |
| import json | |
| if not video_path: | |
| raise gr.Error("请先上传视频 / Upload a video first") | |
| if not points_json or points_json.strip() == "[]": | |
| raise gr.Error("请在视频第一帧上点击要编辑的区域 / Click on the area to edit") | |
| points_data = json.loads(points_json) | |
| if not points_data: | |
| raise gr.Error("没有标记点 / No points marked") | |
| # Extract frames | |
| frames_dir, total_frames, fps = video_to_frames_dir(video_path, max_frames=num_frames_limit) | |
| predictor = get_sam2_predictor() | |
| with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): | |
| state = predictor.init_state(video_path=frames_dir) | |
| # Add points (all on frame 0) | |
| pos_points = [] | |
| neg_points = [] | |
| for p in points_data: | |
| if p.get("label", 1) == 1: | |
| pos_points.append([p["x"], p["y"]]) | |
| else: | |
| neg_points.append([p["x"], p["y"]]) | |
| all_points = pos_points + neg_points | |
| all_labels = [1] * len(pos_points) + [0] * len(neg_points) | |
| points_np = np.array(all_points, dtype=np.float32) | |
| labels_np = np.array(all_labels, dtype=np.int32) | |
| _, _, _ = predictor.add_new_points_or_box( | |
| state, | |
| frame_idx=0, | |
| obj_id=1, | |
| points=points_np, | |
| labels=labels_np, | |
| ) | |
| # Propagate through video | |
| all_masks = {} | |
| for frame_idx, obj_ids, masks in predictor.propagate_in_video(state): | |
| # masks shape: (num_objects, 1, H, W) | |
| mask = (masks[0, 0] > 0.0).cpu().numpy().astype(np.uint8) * 255 | |
| all_masks[frame_idx] = mask | |
| # Build mask video | |
| out_path = os.path.join(tempfile.mkdtemp(), "mask_video.mp4") | |
| # Get frame size from first mask | |
| first_mask = all_masks[0] | |
| h, w = first_mask.shape | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h), isColor=False) | |
| for i in range(total_frames): | |
| if i in all_masks: | |
| writer.write(all_masks[i]) | |
| elif all_masks: | |
| # Use nearest available mask | |
| nearest = min(all_masks.keys(), key=lambda k: abs(k - i)) | |
| writer.write(all_masks[nearest]) | |
| writer.release() | |
| # Cleanup frames dir | |
| shutil.rmtree(frames_dir, ignore_errors=True) | |
| print(f"Mask video generated: {out_path} ({total_frames} frames, {w}x{h})") | |
| return out_path | |
| # ============ Step 2: GrowMask + ImageComposite (from sam2.1_optimized workflow) ============ | |
| def grow_mask_frame(mask_gray, expand_pixels=5, blur=True): | |
| """Expand mask by N pixels (matching ComfyUI GrowMask node). | |
| mask_gray: numpy uint8 H×W (255=mask, 0=bg) | |
| Returns: expanded mask as numpy uint8 H×W | |
| """ | |
| if expand_pixels <= 0: | |
| return mask_gray | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (expand_pixels*2+1, expand_pixels*2+1)) | |
| grown = cv2.dilate(mask_gray, kernel, iterations=1) | |
| if blur: | |
| grown = cv2.GaussianBlur(grown, (expand_pixels*2+1, expand_pixels*2+1), 0) | |
| # Re-threshold to keep it binary-ish but with soft edges | |
| _, grown = cv2.threshold(grown, 127, 255, cv2.THRESH_BINARY) | |
| return grown | |
| def grow_mask_video_file(mask_video_path, expand_pixels=5): | |
| """Apply GrowMask to every frame of a mask video. Returns new video path.""" | |
| if expand_pixels <= 0: | |
| return mask_video_path | |
| cap = cv2.VideoCapture(mask_video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 16 | |
| w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| out_path = os.path.join(tempfile.mkdtemp(), "grown_mask.mp4") | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h), isColor=False) | |
| count = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) if len(frame.shape) == 3 else frame | |
| grown = grow_mask_frame(gray, expand_pixels) | |
| writer.write(grown) | |
| count += 1 | |
| cap.release() | |
| writer.release() | |
| print(f"GrowMask applied: {count} frames, expand={expand_pixels}px → {out_path}") | |
| return out_path | |
| def composite_video_from_mask(source_video_path, mask_video_path): | |
| """ImageComposite: replace masked region with mask overlay (from sam2.1_optimized workflow). | |
| Creates a composite video where: | |
| - Masked regions (white in mask) show the mask as white overlay | |
| - Unmasked regions show original video | |
| This gives VACE the control_video input it needs. | |
| Returns: composite video path | |
| """ | |
| src_cap = cv2.VideoCapture(source_video_path) | |
| mask_cap = cv2.VideoCapture(mask_video_path) | |
| fps = src_cap.get(cv2.CAP_PROP_FPS) or 16 | |
| w = int(src_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| h = int(src_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| out_path = os.path.join(tempfile.mkdtemp(), "composite.mp4") | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h)) | |
| count = 0 | |
| while True: | |
| ret_s, src_frame = src_cap.read() | |
| ret_m, mask_frame = mask_cap.read() | |
| if not ret_s: | |
| break | |
| if not ret_m: | |
| # If mask video is shorter, use last available or all-black | |
| mask_gray = np.zeros((h, w), dtype=np.uint8) | |
| else: | |
| # Resize mask to match source if needed | |
| if mask_frame.shape[:2] != (h, w): | |
| mask_frame = cv2.resize(mask_frame, (w, h), interpolation=cv2.INTER_NEAREST) | |
| mask_gray = cv2.cvtColor(mask_frame, cv2.COLOR_BGR2GRAY) if len(mask_frame.shape) == 3 else mask_frame | |
| # Composite: original where mask=0, white where mask=255 | |
| mask_bool = mask_gray > 127 | |
| composite = src_frame.copy() | |
| composite[mask_bool] = 255 # White in masked region | |
| writer.write(composite) | |
| count += 1 | |
| src_cap.release() | |
| mask_cap.release() | |
| writer.release() | |
| print(f"Composite video: {count} frames → {out_path}") | |
| return out_path | |
| # ============ V2V Pipeline (VACE 14B, on-demand) ============ | |
| VACE_MODEL_ID = "Wan-AI/Wan2.1-VACE-14B-diffusers" | |
| v2v_pipe = None | |
| v2v_ready = False | |
| def load_v2v_pipeline(): | |
| """Load VACE 14B pipeline on-demand for mask-based video editing.""" | |
| global v2v_pipe, v2v_ready | |
| # Move I2V to CPU to free GPU | |
| i2v_pipe.to('cpu') | |
| clear_vram() | |
| if v2v_pipe is not None and v2v_ready: | |
| v2v_pipe.to('cuda') | |
| print("VACE pipeline restored to GPU") | |
| return v2v_pipe | |
| print("Loading VACE 14B Pipeline first time (this downloads ~75GB)...") | |
| v2v_vae = AutoencoderKLWan.from_pretrained(VACE_MODEL_ID, subfolder="vae", torch_dtype=torch.float32) | |
| v2v_pipe = WanVACEPipeline.from_pretrained( | |
| VACE_MODEL_ID, | |
| vae=v2v_vae, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| v2v_pipe.scheduler = UniPCMultistepScheduler.from_config(v2v_pipe.scheduler.config, flow_shift=5.0) | |
| # Quantize to fit in A100 80GB | |
| quantize_(v2v_pipe.text_encoder, Int8WeightOnlyConfig()) | |
| major, minor = torch.cuda.get_device_capability() | |
| if (major > 8) or (major == 8 and minor >= 9): | |
| quantize_(v2v_pipe.transformer, Float8DynamicActivationFloat8WeightConfig()) | |
| else: | |
| quantize_(v2v_pipe.transformer, Int8WeightOnlyConfig()) | |
| v2v_pipe.to('cuda') | |
| v2v_ready = True | |
| print("VACE 14B pipeline ready (quantized, on GPU)") | |
| return v2v_pipe | |
| def unload_v2v_pipeline(): | |
| """Move V2V to CPU and restore I2V to GPU.""" | |
| global v2v_pipe | |
| if v2v_pipe is not None: | |
| v2v_pipe.to('cpu') | |
| clear_vram() | |
| i2v_pipe.to('cuda') | |
| print("VACE → CPU, I2V → GPU") | |
| def load_video_frames_and_masks(video_path, mask_path, num_frames, target_h, target_w): | |
| """Load source video frames and mask video frames for VACE.""" | |
| # Load source video frames as PIL Images | |
| src_frames = load_video(video_path)[:num_frames] | |
| print(f"Loaded {len(src_frames)} source frames (original size: {src_frames[0].size if src_frames else 'N/A'})") | |
| # Load mask video frames | |
| mask_frames_raw = load_video(mask_path)[:num_frames] | |
| # Convert mask to L mode (white=edit, black=keep) — don't resize, let pipeline handle it | |
| masks = [] | |
| for mf in mask_frames_raw: | |
| gray = mf.convert("L") | |
| masks.append(gray) | |
| print(f"Loaded {len(masks)} mask frames") | |
| # Pad or trim to match | |
| while len(masks) < len(src_frames): | |
| masks.append(masks[-1] if masks else Image.new("L", src_frames[0].size, 0)) | |
| while len(src_frames) < len(masks): | |
| src_frames.append(src_frames[-1] if src_frames else Image.new("RGB", (target_w, target_h), (128, 128, 128))) | |
| frame_count = min(len(src_frames), len(masks)) | |
| src_frames = src_frames[:frame_count] | |
| masks = masks[:frame_count] | |
| return src_frames, masks | |
| # ============ Utils ============ | |
| def resize_image(image, max_dim=832, min_dim=480, square_dim=640, multiple_of=16): | |
| width, height = image.size | |
| if width == height: | |
| return image.resize((square_dim, square_dim), Image.LANCZOS) | |
| aspect_ratio = width / height | |
| max_ar = max_dim / min_dim | |
| min_ar = min_dim / max_dim | |
| if aspect_ratio > max_ar: | |
| crop_width = int(round(height * max_ar)) | |
| left = (width - crop_width) // 2 | |
| image = image.crop((left, 0, left + crop_width, height)) | |
| target_w, target_h = max_dim, min_dim | |
| elif aspect_ratio < min_ar: | |
| crop_height = int(round(width / min_ar)) | |
| top = (height - crop_height) // 2 | |
| image = image.crop((0, top, width, top + crop_height)) | |
| target_w, target_h = min_dim, max_dim | |
| else: | |
| if width > height: | |
| target_w = max_dim | |
| target_h = int(round(target_w / aspect_ratio)) | |
| else: | |
| target_h = max_dim | |
| target_w = int(round(target_h * aspect_ratio)) | |
| final_w = max(min_dim, min(max_dim, round(target_w / multiple_of) * multiple_of)) | |
| final_h = max(min_dim, min(max_dim, round(target_h / multiple_of) * multiple_of)) | |
| return image.resize((final_w, final_h), Image.LANCZOS) | |
| def resize_and_crop_to_match(target_image, reference_image): | |
| ref_w, ref_h = reference_image.size | |
| tgt_w, tgt_h = target_image.size | |
| scale = max(ref_w / tgt_w, ref_h / tgt_h) | |
| new_w, new_h = int(tgt_w * scale), int(tgt_h * scale) | |
| resized = target_image.resize((new_w, new_h), Image.Resampling.LANCZOS) | |
| left, top = (new_w - ref_w) // 2, (new_h - ref_h) // 2 | |
| return resized.crop((left, top, left + ref_w, top + ref_h)) | |
| def get_num_frames(duration_seconds): | |
| raw = int(round(duration_seconds * FIXED_FPS)) | |
| raw = ((raw - 1) // 4) * 4 + 1 | |
| return int(np.clip(raw, 9, MAX_FRAMES_MODEL)) | |
| def extract_video_path(input_video): | |
| if input_video is None: | |
| return None | |
| if isinstance(input_video, str): | |
| return input_video | |
| if isinstance(input_video, dict): | |
| # Gradio 5.x format: {'video': filepath, ...} or {'name': filepath, ...} or {'path': filepath} | |
| return input_video.get("video", input_video.get("path", input_video.get("name", None))) | |
| # Could be a Gradio VideoData object | |
| if hasattr(input_video, 'video'): | |
| return input_video.video | |
| if hasattr(input_video, 'path'): | |
| return input_video.path | |
| if hasattr(input_video, 'name'): | |
| return input_video.name | |
| return str(input_video) | |
| def extract_first_frame(video_input): | |
| path = extract_video_path(video_input) | |
| if not path or not os.path.exists(path): | |
| return None | |
| cap = cv2.VideoCapture(path) | |
| ret, frame = cap.read() | |
| cap.release() | |
| if ret: | |
| return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| return None | |
| # ============ Inference ============ | |
| def run_inference( | |
| task_type, input_image, input_video, mask_video, prompt, negative_prompt, | |
| duration_seconds, steps, guidance_scale, guidance_scale_2, | |
| current_seed, scheduler_name, flow_shift, frame_multiplier, | |
| quality, last_image_input, lora_groups, | |
| reference_image=None, grow_pixels=5, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| clear_vram() | |
| num_frames = get_num_frames(duration_seconds) | |
| task_id = str(uuid.uuid4())[:8] | |
| print(f"Task: {task_id}, type={task_type}, duration={duration_seconds}s, frames={num_frames}") | |
| start = time.time() | |
| if "T2V" in task_type: | |
| # ====== T2V: 14B + Lightning LoRA (4 steps, dual guidance) ====== | |
| t2v_steps = max(int(steps), 4) | |
| print(f"T2V: steps={t2v_steps}, guidance={guidance_scale}/{guidance_scale_2}, frames={num_frames}") | |
| pipe = load_t2v_pipeline() | |
| result = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=480, | |
| width=832, | |
| num_frames=num_frames, | |
| guidance_scale=float(guidance_scale), | |
| guidance_scale_2=float(guidance_scale_2), | |
| num_inference_steps=t2v_steps, | |
| generator=torch.Generator(device="cpu").manual_seed(int(current_seed)), | |
| output_type="np", | |
| ) | |
| unload_t2v_pipeline() | |
| else: | |
| # ====== I2V / V2V ====== | |
| if "V2V" in task_type: | |
| # ====== V2V: 3-Step Pipeline (SAM2 mask → Composite → VACE) ====== | |
| print(f"V2V 3-Step Pipeline: input_video type={type(input_video)}, value={input_video}") | |
| video_path = extract_video_path(input_video) | |
| if not video_path or not os.path.exists(video_path): | |
| raise gr.Error("Upload a source video for V2V / V2V请上传原视频") | |
| # Get mask video path | |
| mask_path = extract_video_path(mask_video) | |
| if not mask_path or not os.path.exists(mask_path): | |
| raise gr.Error("Upload a mask video for V2V / V2V请上传遮罩视频(黑白视频,白色=编辑区域)") | |
| # Step 2a: GrowMask — expand mask boundaries (from vace_optimized workflow) | |
| grown_mask_path = grow_mask_video_file(mask_path, expand_pixels=int(grow_pixels)) | |
| print(f"V2V: GrowMask applied ({grow_pixels}px)") | |
| # Step 2b: Composite — original video with mask overlay (from sam2.1_optimized workflow) | |
| composite_path = composite_video_from_mask(video_path, mask_path) | |
| print(f"V2V: Composite video created") | |
| # Step 3: VACE generation using composite as control_video + grown mask | |
| target_h, target_w = 480, 832 | |
| # Load composite video as control frames for VACE | |
| src_frames = load_video(composite_path)[:num_frames] | |
| print(f"Loaded {len(src_frames)} composite frames") | |
| # Load grown mask frames | |
| mask_frames_raw = load_video(grown_mask_path)[:num_frames] | |
| masks = [mf.convert("L") for mf in mask_frames_raw] | |
| print(f"Loaded {len(masks)} grown mask frames") | |
| # Pad or trim to match | |
| while len(masks) < len(src_frames): | |
| masks.append(masks[-1] if masks else Image.new("L", src_frames[0].size, 0)) | |
| while len(src_frames) < len(masks): | |
| src_frames.append(src_frames[-1] if src_frames else Image.new("RGB", (target_w, target_h), (128, 128, 128))) | |
| # Ensure num_frames satisfies (n-1) % 4 == 0 for VACE | |
| n = len(src_frames) | |
| n = (n - 1) // 4 * 4 + 1 | |
| n = max(n, 5) | |
| src_frames = src_frames[:n] | |
| masks = masks[:n] | |
| # Load VACE pipeline | |
| pipe = load_v2v_pipeline() | |
| v2v_steps = max(int(steps), 20) | |
| print(f"V2V VACE: steps={v2v_steps}, guidance={guidance_scale}, frames={len(src_frames)}, ref_image={'yes' if reference_image else 'no'}") | |
| # Build VACE kwargs | |
| vace_kwargs = dict( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| video=src_frames, | |
| mask=masks, | |
| height=target_h, | |
| width=target_w, | |
| num_frames=len(src_frames), | |
| guidance_scale=max(float(guidance_scale), 5.0), | |
| num_inference_steps=v2v_steps, | |
| generator=torch.Generator(device="cuda").manual_seed(int(current_seed)), | |
| output_type="np", | |
| ) | |
| result = pipe(**vace_kwargs) | |
| unload_v2v_pipeline() | |
| # Cleanup temp files | |
| for p in [grown_mask_path, composite_path]: | |
| try: | |
| if p and os.path.exists(p): | |
| os.remove(p) | |
| except: | |
| pass | |
| else: | |
| # ====== I2V ====== | |
| if input_image is None: | |
| raise gr.Error("Upload an image / 请上传图片") | |
| scheduler_class = SCHEDULER_MAP.get(scheduler_name) | |
| if scheduler_class and scheduler_class.__name__ != i2v_pipe.scheduler.config._class_name: | |
| config = copy.deepcopy(i2v_original_scheduler.config) | |
| if scheduler_class == FlowMatchEulerDiscreteScheduler: | |
| config['shift'] = flow_shift | |
| else: | |
| config['flow_shift'] = flow_shift | |
| i2v_pipe.scheduler = scheduler_class.from_config(config) | |
| lora_loaded = False | |
| if lora_groups: | |
| try: | |
| for idx, name in enumerate(lora_groups): | |
| if name and name != "(None)": | |
| lora_loader.load_lora_to_pipe(i2v_pipe, name, adapter_name=f"lora_{idx}") | |
| lora_loaded = True | |
| except Exception as e: | |
| print(f"LoRA warning: {e}") | |
| resized_image = resize_image(input_image) | |
| processed_last = None | |
| if last_image_input: | |
| processed_last = resize_and_crop_to_match(last_image_input, resized_image) | |
| print(f"I2V: size={resized_image.size}, steps={int(steps)}, guidance={guidance_scale}/{guidance_scale_2}") | |
| result = i2v_pipe( | |
| image=resized_image, | |
| last_image=processed_last, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=resized_image.height, | |
| width=resized_image.width, | |
| num_frames=num_frames, | |
| guidance_scale=float(guidance_scale), | |
| guidance_scale_2=float(guidance_scale_2), | |
| num_inference_steps=int(steps), | |
| generator=torch.Generator(device="cuda").manual_seed(int(current_seed)), | |
| output_type="np", | |
| ) | |
| if lora_loaded: | |
| lora_loader.unload_lora(i2v_pipe) | |
| raw_frames = result.frames[0] | |
| elapsed = time.time() - start | |
| print(f"Generation took {elapsed:.1f}s ({len(raw_frames)} frames)") | |
| frame_factor = frame_multiplier // FIXED_FPS | |
| if frame_factor > 1: | |
| rife_model.device() | |
| rife_model.flownet = rife_model.flownet.half() | |
| final_frames = interpolate_bits(raw_frames, multiplier=int(frame_factor)) | |
| else: | |
| final_frames = list(raw_frames) | |
| final_fps = FIXED_FPS * max(1, frame_factor) | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: | |
| video_path = tmpfile.name | |
| export_to_video(final_frames, video_path, fps=final_fps, quality=quality) | |
| return video_path, task_id | |
| # ============ Generate ============ | |
| def generate_video( | |
| task_type, input_image, input_video, mask_video, prompt, | |
| lora_groups, duration_seconds, frame_multiplier, | |
| steps, guidance_scale, guidance_scale_2, | |
| negative_prompt, quality, seed, randomize_seed, | |
| scheduler, flow_shift, last_image, display_result, | |
| reference_image, grow_pixels, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| if not prompt or not prompt.strip(): | |
| raise gr.Error("Enter a prompt / 请输入提示词") | |
| current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) | |
| video_path, task_id = run_inference( | |
| task_type, input_image, input_video, mask_video, prompt, negative_prompt, | |
| duration_seconds, steps, guidance_scale, guidance_scale_2, | |
| current_seed, scheduler, flow_shift, frame_multiplier, | |
| quality, last_image, lora_groups, | |
| reference_image=reference_image, grow_pixels=grow_pixels, | |
| ) | |
| print(f"Done: {task_id}") | |
| return (video_path if display_result else None), video_path, current_seed | |
| # ============ UI ============ | |
| CSS = """ | |
| #hidden-timestamp { opacity: 0; height: 0; width: 0; margin: 0; padding: 0; overflow: hidden; position: absolute; } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=CSS, delete_cache=(3600, 10800)) as demo: | |
| gr.Markdown("## WAN 2.2 Multi-Task Video Generation / 多任务视频生成") | |
| gr.Markdown("#### I2V (Lightning 6-step) · T2V (Lightning 14B 4-step) · V2V (3-Step: SAM2→Composite→VACE)") | |
| gr.Markdown("---") | |
| task_type = gr.Radio( | |
| choices=[ | |
| "I2V (图生视频 / Image-to-Video)", | |
| "T2V (文生视频 / Text-to-Video)", | |
| "V2V (视频生视频 / Video-to-Video)", | |
| ], | |
| value="I2V (图生视频 / Image-to-Video)", | |
| label="Task Type / 任务类型", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Group(): | |
| input_image = gr.Image(type="pil", label="Input Image / 输入图片 (I2V)", sources=["upload", "clipboard"]) | |
| with gr.Group(): | |
| input_video = gr.Video(label="Source Video / 原视频 (V2V)", sources=["upload"], visible=False, interactive=True) | |
| with gr.Group(): | |
| mask_video = gr.Video(label="Mask Video / 遮罩视频 (V2V, 白色=编辑区域)", sources=["upload"], visible=False, interactive=True) | |
| v2v_guide = gr.Markdown( | |
| value="""### 📖 V2V 三步流水线 / 3-Step V2V Pipeline | |
| **Step 1 — SAM2 分割**: 上传原视频 → 提取第一帧 → 点击标记区域 → 生成遮罩视频 | |
| **Step 2 — 自动合成**: 原视频 + 遮罩 → GrowMask扩展边界 + ImageComposite合成(自动完成) | |
| **Step 3 — VACE 生成**: 合成视频 + 遮罩 + 参考图 + Prompt → 最终成品视频 | |
| 💡 也可跳过 Step 1,直接上传自己的遮罩视频(白色=编辑区域) | |
| """, | |
| visible=False, | |
| ) | |
| with gr.Group(visible=False) as v2v_mask_tools: | |
| first_frame_display = gr.Image(label="第一帧预览 / First Frame (点击标记区域)", type="pil", interactive=True) | |
| points_store = gr.State(value=[]) | |
| points_display = gr.Textbox(label="标记点 / Points", value="无标记 / No points", interactive=False) | |
| with gr.Row(): | |
| point_mode = gr.Radio(choices=["include (编辑)", "exclude (排除)"], value="include (编辑)", label="点击模式") | |
| with gr.Row(): | |
| extract_frame_btn = gr.Button("📷 提取第一帧 / Extract First Frame", variant="secondary") | |
| gen_mask_btn = gr.Button("🎭 生成遮罩 / Generate Mask (SAM2)", variant="primary") | |
| clear_points_btn = gr.Button("🗑️ 清除标记 / Clear Points") | |
| with gr.Accordion("🖼️ V2V 高级选项 / V2V Advanced", open=True): | |
| reference_image = gr.Image(type="pil", label="参考图 / Reference Image (控制编辑区域的目标外观)", sources=["upload", "clipboard"]) | |
| grow_pixels_sl = gr.Slider(minimum=0, maximum=30, step=1, value=5, label="GrowMask / 遮罩扩展 (像素)", info="扩展遮罩边界,让编辑区域过渡更自然") | |
| prompt_input = gr.Textbox( | |
| label="Prompt / 提示词", value="", | |
| placeholder="Describe the video... / 描述你想生成的视频...", lines=3, | |
| ) | |
| duration_slider = gr.Slider( | |
| minimum=0.5, maximum=15, step=0.5, value=3, | |
| label="Duration / 时长 (seconds/秒)", | |
| info="Max ~15s (241 frames @16fps) / 最大约15秒", | |
| ) | |
| frame_multi = gr.Dropdown(choices=[16, 32, 64], value=16, label="Output FPS / 输出帧率", info="RIFE interpolation / RIFE插帧") | |
| with gr.Accordion("⚙️ Advanced Settings / 高级设置", open=False): | |
| last_image = gr.Image(type="pil", label="Last Frame / 末帧 (Optional)", sources=["upload", "clipboard"]) | |
| negative_prompt_input = gr.Textbox(label="Negative Prompt / 负面提示词", value=default_negative_prompt, lines=3) | |
| with gr.Row(): | |
| steps_slider = gr.Slider(minimum=1, maximum=50, step=1, value=6, label="Steps / 步数", info="I2V: 4-8 | T2V: 4-8 | V2V: 25-50") | |
| quality_sl = gr.Slider(minimum=1, maximum=10, step=1, value=6, label="Quality / 质量") | |
| with gr.Row(): | |
| guidance_h = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance High / 引导(高噪声)") | |
| guidance_l = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Low / 引导(低噪声)") | |
| with gr.Row(): | |
| scheduler_dd = gr.Dropdown(choices=list(SCHEDULER_MAP.keys()), value="UniPCMultistep", label="Scheduler / 调度器") | |
| flow_shift_sl = gr.Slider(minimum=0.5, maximum=15.0, step=0.1, value=3.0, label="Flow Shift / 流偏移") | |
| with gr.Row(): | |
| seed_sl = gr.Slider(minimum=0, maximum=MAX_SEED, step=1, value=42, label="Seed / 种子") | |
| random_seed_cb = gr.Checkbox(label="Random / 随机", value=True) | |
| lora_dd = gr.Dropdown(choices=lora_loader.get_lora_choices(), label="LoRA (I2V only / 仅I2V)", multiselect=True, info="From WAN2.2_LoraSet_NSFW") | |
| display_cb = gr.Checkbox(label="Display / 显示", value=True) | |
| generate_btn = gr.Button("🎬 Generate / 生成视频", variant="primary", size="lg") | |
| with gr.Column(): | |
| video_output = gr.Video(label="Generated Video / 生成的视频", autoplay=True, sources=["upload"], show_download_button=True, show_share_button=True, interactive=False, elem_id="generated-video") | |
| with gr.Row(): | |
| grab_frame_btn = gr.Button("📸 Use Frame / 使用帧", variant="secondary") | |
| timestamp_box = gr.Number(value=0, label="Timestamp", visible=False, elem_id="hidden-timestamp") | |
| file_output = gr.File(label="Download / 下载") | |
| def update_task_ui(task): | |
| is_v2v = "V2V" in task | |
| is_t2v = "T2V" in task | |
| if is_t2v: | |
| return (gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), | |
| gr.update(visible=False), gr.update(visible=False), | |
| gr.update(value=4), gr.update(value=1.0), gr.update(value=1.0)) | |
| elif is_v2v: | |
| return (gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), | |
| gr.update(visible=True), gr.update(visible=True), | |
| gr.update(value=30), gr.update(value=5.0), gr.update(value=1.0)) | |
| else: | |
| return (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), | |
| gr.update(visible=False), gr.update(visible=False), | |
| gr.update(value=6), gr.update(value=1.0), gr.update(value=1.0)) | |
| task_type.change(update_task_ui, inputs=[task_type], outputs=[input_image, input_video, mask_video, v2v_guide, v2v_mask_tools, steps_slider, guidance_h, guidance_l]) | |
| # V2V mask generation callbacks | |
| def on_extract_first_frame(video): | |
| vpath = extract_video_path(video) | |
| if not vpath or not os.path.exists(vpath): | |
| raise gr.Error("请先上传视频 / Upload video first") | |
| frame = extract_first_frame_from_video(vpath) | |
| if frame is None: | |
| raise gr.Error("无法提取第一帧 / Failed to extract first frame") | |
| return frame, [], "无标记 / No points" | |
| def on_click_frame(img, points, mode, evt: gr.SelectData): | |
| if img is None: | |
| return img, points, "请先提取第一帧 / Extract first frame first" | |
| x, y = evt.index | |
| print(f"[DEBUG] Click at ({x}, {y}), mode={mode}, total_points={len(points)+1}") | |
| label = 1 if "include" in mode else 0 | |
| points = list(points) if points else [] # ensure mutable copy | |
| points.append({"x": x, "y": y, "label": label}) | |
| # Draw points on image | |
| display_img = img.copy() | |
| draw = __import__('PIL').ImageDraw.Draw(display_img) | |
| for p in points: | |
| color = (0, 255, 0) if p["label"] == 1 else (255, 0, 0) | |
| r = 8 | |
| draw.ellipse([p["x"]-r, p["y"]-r, p["x"]+r, p["y"]+r], fill=color, outline="white", width=2) | |
| info = f"{len([p for p in points if p['label']==1])} include, {len([p for p in points if p['label']==0])} exclude" | |
| return display_img, points, info | |
| def on_clear_points(original_video): | |
| vpath = extract_video_path(original_video) | |
| if vpath and os.path.exists(vpath): | |
| frame = extract_first_frame_from_video(vpath) | |
| return frame, [], "无标记 / No points" | |
| return None, [], "无标记 / No points" | |
| def on_generate_mask(video, points): | |
| import json | |
| print(f"[DEBUG] on_generate_mask called, points type={type(points)}, value={points}") | |
| vpath = extract_video_path(video) | |
| if not vpath: | |
| raise gr.Error("请先上传视频 / Upload video first") | |
| if not points: | |
| raise gr.Error("请先在第一帧上点击标记 / Click on first frame to mark areas") | |
| mask_path = generate_mask_video(vpath, json.dumps(points)) | |
| return mask_path | |
| extract_frame_btn.click(fn=on_extract_first_frame, inputs=[input_video], outputs=[first_frame_display, points_store, points_display]) | |
| first_frame_display.select(fn=on_click_frame, inputs=[first_frame_display, points_store, point_mode], outputs=[first_frame_display, points_store, points_display]) | |
| clear_points_btn.click(fn=on_clear_points, inputs=[input_video], outputs=[first_frame_display, points_store, points_display]) | |
| gen_mask_btn.click(fn=on_generate_mask, inputs=[input_video, points_store], outputs=[mask_video]) | |
| generate_btn.click( | |
| fn=generate_video, | |
| inputs=[task_type, input_image, input_video, mask_video, prompt_input, lora_dd, duration_slider, frame_multi, | |
| steps_slider, guidance_h, guidance_l, negative_prompt_input, quality_sl, seed_sl, random_seed_cb, | |
| scheduler_dd, flow_shift_sl, last_image, display_cb, | |
| reference_image, grow_pixels_sl], | |
| outputs=[video_output, file_output, seed_sl], | |
| ) | |
| grab_frame_btn.click(fn=None, inputs=None, outputs=[timestamp_box], js=get_timestamp_js) | |
| timestamp_box.change(fn=extract_frame, inputs=[video_output, timestamp_box], outputs=[input_image]) | |
| if __name__ == "__main__": | |
| demo.queue().launch(mcp_server=True, show_error=True) | |