| """ |
| WAN 2.2 Multi-Task Video Generation - 3-Step V2V Pipeline |
| I2V: Lightning 14B (6 steps, FP8+AoT) |
| T2V: Lightning 14B (4 steps, Lightning LoRA + FP8) |
| V2V: 3-Step Pipeline (SAM2 → Composite → VACE) |
| Step 1: SAM2 video segmentation (click points → mask video) |
| Step 2: ImageComposite (original + mask → composite video) |
| Step 3: VACE generation (composite + grown mask + ref image + prompt → final) |
| LoRA: from lkzd7/WAN2.2_LoraSet_NSFW (I2V only) |
| """ |
| import os |
|
|
| import spaces |
| import shutil |
| import subprocess |
| import copy |
| import random |
| import tempfile |
| import warnings |
| import time |
| import gc |
| import uuid |
| from tqdm import tqdm |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| from torch.nn import functional as F |
| from PIL import Image, ImageFilter |
|
|
| import gradio as gr |
| from diffusers import ( |
| AutoencoderKLWan, |
| FlowMatchEulerDiscreteScheduler, |
| WanPipeline, |
| SASolverScheduler, |
| DEISMultistepScheduler, |
| DPMSolverMultistepInverseScheduler, |
| UniPCMultistepScheduler, |
| DPMSolverMultistepScheduler, |
| DPMSolverSinglestepScheduler, |
| ) |
| from diffusers.models.transformers.transformer_wan import WanTransformer3DModel |
| from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline |
| from diffusers.pipelines.wan.pipeline_wan_vace import WanVACEPipeline |
| from diffusers.utils.export_utils import export_to_video |
| from diffusers.utils import load_video |
| from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig |
| import aoti |
| import lora_loader |
|
|
| |
| from sam2.sam2_video_predictor import SAM2VideoPredictor |
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "true" |
| warnings.filterwarnings("ignore") |
|
|
| def clear_vram(): |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| |
| get_timestamp_js = """ |
| function() { |
| const video = document.querySelector('#generated-video video'); |
| if (video) { return video.currentTime; } |
| return 0; |
| } |
| """ |
|
|
| def extract_frame(video_path, timestamp): |
| if not video_path: |
| return None |
| cap = cv2.VideoCapture(video_path) |
| if not cap.isOpened(): |
| return None |
| fps = cap.get(cv2.CAP_FPS) |
| target_frame_num = int(float(timestamp) * fps) |
| total_frames = int(cap.get(cv2.CAP_FRAME_COUNT)) |
| if target_frame_num >= total_frames: |
| target_frame_num = total_frames - 1 |
| cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame_num) |
| ret, frame = cap.read() |
| cap.release() |
| if ret: |
| return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| return None |
|
|
| if not os.path.exists("RIFEv4.26_0921.zip"): |
| print("Downloading RIFE Model...") |
| subprocess.run(["wget", "-q", "https://huggingface.co/r3gm/RIFE/resolve/main/RIFEv4.26_0921.zip", "-O", "RIFEv4.26_0921.zip"], check=True) |
| subprocess.run(["unzip", "-o", "RIFEv4.26_0921.zip"], check=True) |
|
|
| from train_log.RIFE_HDv3 import Model |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| rife_model = Model() |
| rife_model.load_model("train_log", -1) |
| rife_model.eval() |
|
|
| @torch.no_grad() |
| def interpolate_bits(frames_np, multiplier=2, scale=1.0): |
| if isinstance(frames_np, list): |
| T = len(frames_np) |
| H, W, C = frames_np[0].shape |
| else: |
| T, H, W, C = frames_np.shape |
| if multiplier < 2: |
| return list(frames_np) if isinstance(frames_np, np.ndarray) else frames_np |
| n_interp = multiplier - 1 |
| tmp = max(128, int(128 / scale)) |
| ph = ((H - 1) // tmp + 1) * tmp |
| pw = ((W - 1) // tmp + 1) * tmp |
| padding = (0, pw - W, 0, ph - H) |
| def to_tensor(frame_np): |
| t = torch.from_numpy(frame_np).to(device) |
| t = t.permute(2, 0, 1).unsqueeze(0) |
| return F.pad(t, padding).half() |
| def from_tensor(tensor): |
| t = tensor[0, :, :H, :W] |
| return t.permute(1, 2, 0).float().cpu().numpy() |
| def make_inference(I0, I1, n): |
| if rife_model.version >= 3.9: |
| return [rife_model.inference(I0, I1, (i+1) * 1. / (n+1), scale) for i in range(n)] |
| else: |
| middle = rife_model.inference(I0, I1, scale) |
| if n == 1: return [middle] |
| first_half = make_inference(I0, middle, n//2) |
| second_half = make_inference(middle, I1, n//2) |
| return [*first_half, middle, *second_half] if n % 2 else [*first_half, *second_half] |
| output_frames = [] |
| I1 = to_tensor(frames_np[0]) |
| with tqdm(total=T-1, desc="Interpolating", unit="frame") as pbar: |
| for i in range(T - 1): |
| I0 = I1 |
| output_frames.append(from_tensor(I0)) |
| I1 = to_tensor(frames_np[i+1]) |
| for mid in make_inference(I0, I1, n_interp): |
| output_frames.append(from_tensor(mid)) |
| if (i + 1) % 50 == 0: |
| pbar.update(50) |
| pbar.update((T-1) % 50) |
| output_frames.append(from_tensor(I1)) |
| del I0, I1 |
| torch.cuda.empty_cache() |
| return output_frames |
|
|
| |
| FIXED_FPS = 16 |
| MAX_FRAMES_MODEL = 241 |
| MAX_SEED = np.iinfo(np.int32).max |
|
|
| SCHEDULER_MAP = { |
| "FlowMatchEulerDiscrete": FlowMatchEulerDiscreteScheduler, |
| "SASolver": SASolverScheduler, |
| "DEISMultistep": DEISMultistepScheduler, |
| "DPMSolverMultistepInverse": DPMSolverMultistepInverseScheduler, |
| "UniPCMultistep": UniPCMultistepScheduler, |
| "DPMSolverMultistep": DPMSolverMultistepScheduler, |
| "DPMSolverSinglestep": DPMSolverSinglestepScheduler, |
| } |
|
|
| default_negative_prompt = ( |
| "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, " |
| "still image, overall gray, worst quality, low quality, JPEG artifacts, ugly, incomplete, " |
| "extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, " |
| "malformed limbs, fused fingers, still frame, messy background, three legs, " |
| "many people in background, walking backwards, watermark, text, signature" |
| ) |
|
|
| |
| print("Loading I2V Pipeline (Lightning 14B)...") |
| i2v_pipe = WanImageToVideoPipeline.from_pretrained( |
| "TestOrganizationPleaseIgnore/WAMU_v2_WAN2.2_I2V_LIGHTNING", |
| torch_dtype=torch.bfloat16, |
| ).to('cuda') |
| i2v_original_scheduler = copy.deepcopy(i2v_pipe.scheduler) |
|
|
| quantize_(i2v_pipe.text_encoder, Int8WeightOnlyConfig()) |
| major, minor = torch.cuda.get_device_capability() |
| supports_fp8 = (major > 8) or (major == 8 and minor >= 9) |
| if supports_fp8: |
| quantize_(i2v_pipe.transformer, Float8DynamicActivationFloat8WeightConfig()) |
| quantize_(i2v_pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig()) |
| aoti.aoti_blocks_load(i2v_pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da') |
| aoti.aoti_blocks_load(i2v_pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da') |
| else: |
| quantize_(i2v_pipe.transformer, Int8WeightOnlyConfig()) |
| quantize_(i2v_pipe.transformer_2, Int8WeightOnlyConfig()) |
|
|
| |
| |
| |
| T2V_MODEL_ID = "Wan-AI/Wan2.2-T2V-A14B-Diffusers" |
| T2V_LORA_REPO = "Kijai/WanVideo_comfy" |
| T2V_LORA_HIGH = "LoRAs/Wan22-Lightning/Wan22_A14B_T2V_HIGH_Lightning_4steps_lora_250928_rank128_fp16.safetensors" |
| T2V_LORA_LOW = "LoRAs/Wan22-Lightning/Wan22_A14B_T2V_LOW_Lightning_4steps_lora_250928_rank64_fp16.safetensors" |
| t2v_pipe = None |
| t2v_ready = False |
|
|
| def load_t2v_pipeline(): |
| """Load T2V 14B + Lightning LoRA on-demand with CPU offload.""" |
| global t2v_pipe, t2v_ready |
|
|
| if t2v_pipe is not None and t2v_ready: |
| print("T2V pipeline reused from memory") |
| return t2v_pipe |
|
|
| print("Loading T2V Pipeline (14B + Lightning LoRA) first time...") |
|
|
| |
| i2v_pipe.to('cpu') |
| clear_vram() |
|
|
| t2v_vae = AutoencoderKLWan.from_pretrained(T2V_MODEL_ID, subfolder="vae", torch_dtype=torch.float32) |
| t2v_pipe = WanPipeline.from_pretrained( |
| T2V_MODEL_ID, |
| transformer=WanTransformer3DModel.from_pretrained( |
| 'linoyts/Wan2.2-T2V-A14B-Diffusers-BF16', |
| subfolder='transformer', |
| torch_dtype=torch.bfloat16, |
| ), |
| transformer_2=WanTransformer3DModel.from_pretrained( |
| 'linoyts/Wan2.2-T2V-A14B-Diffusers-BF16', |
| subfolder='transformer_2', |
| torch_dtype=torch.bfloat16, |
| ), |
| vae=t2v_vae, |
| torch_dtype=torch.bfloat16, |
| ) |
|
|
| |
| print("Fusing Lightning LoRA HIGH (transformer)...") |
| from safetensors.torch import load_file |
| from huggingface_hub import hf_hub_download |
|
|
| |
| high_path = hf_hub_download(T2V_LORA_REPO, T2V_LORA_HIGH) |
| low_path = hf_hub_download(T2V_LORA_REPO, T2V_LORA_LOW) |
|
|
| |
| t2v_pipe.load_lora_weights(high_path, adapter_name="lightning_high") |
| t2v_pipe.set_adapters(["lightning_high"], adapter_weights=[1.0]) |
| t2v_pipe.fuse_lora(adapter_names=["lightning_high"], lora_scale=1.0, components=["transformer"]) |
| t2v_pipe.unload_lora_weights() |
|
|
| |
| print("Fusing Lightning LoRA LOW (transformer_2)...") |
| t2v_pipe.load_lora_weights(low_path, adapter_name="lightning_low", load_into_transformer_2=True) |
| t2v_pipe.set_adapters(["lightning_low"], adapter_weights=[1.0]) |
| t2v_pipe.fuse_lora(adapter_names=["lightning_low"], lora_scale=1.0, components=["transformer_2"]) |
| t2v_pipe.unload_lora_weights() |
|
|
| |
| t2v_pipe.enable_model_cpu_offload() |
|
|
| t2v_ready = True |
| print("T2V pipeline ready (14B + Lightning + CPU offload)") |
| return t2v_pipe |
|
|
| def unload_t2v_pipeline(): |
| """Restore I2V to GPU after T2V is done.""" |
| clear_vram() |
| i2v_pipe.to('cuda') |
| print("I2V restored to GPU") |
|
|
| |
|
|
| |
| sam2_predictor = None |
|
|
| def get_sam2_predictor(): |
| global sam2_predictor |
| if sam2_predictor is None: |
| print("Loading SAM2.1 hiera-large...") |
| sam2_predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2.1-hiera-large") |
| print("SAM2 loaded") |
| return sam2_predictor |
|
|
| def extract_first_frame_from_video(video_path): |
| """Extract first frame from video as PIL Image.""" |
| cap = cv2.VideoCapture(video_path) |
| ret, frame = cap.read() |
| cap.release() |
| if ret: |
| return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
| return None |
|
|
| def video_to_frames_dir(video_path, max_frames=None): |
| """Extract video frames to a temp directory for SAM2.""" |
| tmp_dir = tempfile.mkdtemp(prefix="sam2_frames_") |
| cap = cv2.VideoCapture(video_path) |
| fps = cap.get(cv2.CAP_PROP_FPS) or 16 |
| idx = 0 |
| while True: |
| ret, frame = cap.read() |
| if not ret: |
| break |
| if max_frames and idx >= max_frames: |
| break |
| cv2.imwrite(os.path.join(tmp_dir, f"{idx:05d}.jpg"), frame) |
| idx += 1 |
| cap.release() |
| print(f"Extracted {idx} frames to {tmp_dir} (fps={fps:.1f})") |
| return tmp_dir, idx, fps |
|
|
| @spaces.GPU(duration=120) |
| def generate_mask_video(video_path, points_json, num_frames_limit=None): |
| """Generate mask video using SAM2 from user-clicked points.""" |
| import json |
|
|
| if not video_path: |
| raise gr.Error("请先上传视频 / Upload a video first") |
| if not points_json or points_json.strip() == "[]": |
| raise gr.Error("请在视频第一帧上点击要编辑的区域 / Click on the area to edit") |
|
|
| points_data = json.loads(points_json) |
| if not points_data: |
| raise gr.Error("没有标记点 / No points marked") |
|
|
| |
| frames_dir, total_frames, fps = video_to_frames_dir(video_path, max_frames=num_frames_limit) |
|
|
| predictor = get_sam2_predictor() |
|
|
| with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): |
| state = predictor.init_state(video_path=frames_dir) |
|
|
| |
| pos_points = [] |
| neg_points = [] |
| for p in points_data: |
| if p.get("label", 1) == 1: |
| pos_points.append([p["x"], p["y"]]) |
| else: |
| neg_points.append([p["x"], p["y"]]) |
|
|
| all_points = pos_points + neg_points |
| all_labels = [1] * len(pos_points) + [0] * len(neg_points) |
|
|
| points_np = np.array(all_points, dtype=np.float32) |
| labels_np = np.array(all_labels, dtype=np.int32) |
|
|
| _, _, _ = predictor.add_new_points_or_box( |
| state, |
| frame_idx=0, |
| obj_id=1, |
| points=points_np, |
| labels=labels_np, |
| ) |
|
|
| |
| all_masks = {} |
| for frame_idx, obj_ids, masks in predictor.propagate_in_video(state): |
| |
| mask = (masks[0, 0] > 0.0).cpu().numpy().astype(np.uint8) * 255 |
| all_masks[frame_idx] = mask |
|
|
| |
| out_path = os.path.join(tempfile.mkdtemp(), "mask_video.mp4") |
| |
| first_mask = all_masks[0] |
| h, w = first_mask.shape |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
| writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h), isColor=False) |
| for i in range(total_frames): |
| if i in all_masks: |
| writer.write(all_masks[i]) |
| elif all_masks: |
| |
| nearest = min(all_masks.keys(), key=lambda k: abs(k - i)) |
| writer.write(all_masks[nearest]) |
| writer.release() |
|
|
| |
| shutil.rmtree(frames_dir, ignore_errors=True) |
|
|
| print(f"Mask video generated: {out_path} ({total_frames} frames, {w}x{h})") |
| return out_path |
|
|
| |
| def grow_mask_frame(mask_gray, expand_pixels=5, blur=True): |
| """Expand mask by N pixels (matching ComfyUI GrowMask node). |
| mask_gray: numpy uint8 H×W (255=mask, 0=bg) |
| Returns: expanded mask as numpy uint8 H×W |
| """ |
| if expand_pixels <= 0: |
| return mask_gray |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (expand_pixels*2+1, expand_pixels*2+1)) |
| grown = cv2.dilate(mask_gray, kernel, iterations=1) |
| if blur: |
| grown = cv2.GaussianBlur(grown, (expand_pixels*2+1, expand_pixels*2+1), 0) |
| |
| _, grown = cv2.threshold(grown, 127, 255, cv2.THRESH_BINARY) |
| return grown |
|
|
| def grow_mask_video_file(mask_video_path, expand_pixels=5): |
| """Apply GrowMask to every frame of a mask video. Returns new video path.""" |
| if expand_pixels <= 0: |
| return mask_video_path |
|
|
| cap = cv2.VideoCapture(mask_video_path) |
| fps = cap.get(cv2.CAP_PROP_FPS) or 16 |
| w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
| out_path = os.path.join(tempfile.mkdtemp(), "grown_mask.mp4") |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
| writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h), isColor=False) |
|
|
| count = 0 |
| while True: |
| ret, frame = cap.read() |
| if not ret: |
| break |
| gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) if len(frame.shape) == 3 else frame |
| grown = grow_mask_frame(gray, expand_pixels) |
| writer.write(grown) |
| count += 1 |
|
|
| cap.release() |
| writer.release() |
| print(f"GrowMask applied: {count} frames, expand={expand_pixels}px → {out_path}") |
| return out_path |
|
|
| def composite_video_from_mask(source_video_path, mask_video_path): |
| """ImageComposite: replace masked region with mask overlay (from sam2.1_optimized workflow). |
| Creates a composite video where: |
| - Masked regions (white in mask) show the mask as white overlay |
| - Unmasked regions show original video |
| This gives VACE the control_video input it needs. |
| Returns: composite video path |
| """ |
| src_cap = cv2.VideoCapture(source_video_path) |
| mask_cap = cv2.VideoCapture(mask_video_path) |
|
|
| fps = src_cap.get(cv2.CAP_PROP_FPS) or 16 |
| w = int(src_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| h = int(src_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
| out_path = os.path.join(tempfile.mkdtemp(), "composite.mp4") |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
| writer = cv2.VideoWriter(out_path, fourcc, fps, (w, h)) |
|
|
| count = 0 |
| while True: |
| ret_s, src_frame = src_cap.read() |
| ret_m, mask_frame = mask_cap.read() |
| if not ret_s: |
| break |
| if not ret_m: |
| |
| mask_gray = np.zeros((h, w), dtype=np.uint8) |
| else: |
| |
| if mask_frame.shape[:2] != (h, w): |
| mask_frame = cv2.resize(mask_frame, (w, h), interpolation=cv2.INTER_NEAREST) |
| mask_gray = cv2.cvtColor(mask_frame, cv2.COLOR_BGR2GRAY) if len(mask_frame.shape) == 3 else mask_frame |
|
|
| |
| mask_bool = mask_gray > 127 |
| composite = src_frame.copy() |
| composite[mask_bool] = 255 |
|
|
| writer.write(composite) |
| count += 1 |
|
|
| src_cap.release() |
| mask_cap.release() |
| writer.release() |
| print(f"Composite video: {count} frames → {out_path}") |
| return out_path |
|
|
| |
| VACE_MODEL_ID = "Wan-AI/Wan2.1-VACE-14B-diffusers" |
| v2v_pipe = None |
| v2v_ready = False |
|
|
| def load_v2v_pipeline(): |
| """Load VACE 14B pipeline on-demand for mask-based video editing.""" |
| global v2v_pipe, v2v_ready |
|
|
| |
| i2v_pipe.to('cpu') |
| clear_vram() |
|
|
| if v2v_pipe is not None and v2v_ready: |
| v2v_pipe.to('cuda') |
| print("VACE pipeline restored to GPU") |
| return v2v_pipe |
|
|
| print("Loading VACE 14B Pipeline first time (this downloads ~75GB)...") |
|
|
| v2v_vae = AutoencoderKLWan.from_pretrained(VACE_MODEL_ID, subfolder="vae", torch_dtype=torch.float32) |
| v2v_pipe = WanVACEPipeline.from_pretrained( |
| VACE_MODEL_ID, |
| vae=v2v_vae, |
| torch_dtype=torch.bfloat16, |
| ) |
| v2v_pipe.scheduler = UniPCMultistepScheduler.from_config(v2v_pipe.scheduler.config, flow_shift=5.0) |
|
|
| |
| quantize_(v2v_pipe.text_encoder, Int8WeightOnlyConfig()) |
| major, minor = torch.cuda.get_device_capability() |
| if (major > 8) or (major == 8 and minor >= 9): |
| quantize_(v2v_pipe.transformer, Float8DynamicActivationFloat8WeightConfig()) |
| else: |
| quantize_(v2v_pipe.transformer, Int8WeightOnlyConfig()) |
|
|
| v2v_pipe.to('cuda') |
|
|
| v2v_ready = True |
| print("VACE 14B pipeline ready (quantized, on GPU)") |
| return v2v_pipe |
|
|
| def unload_v2v_pipeline(): |
| """Move V2V to CPU and restore I2V to GPU.""" |
| global v2v_pipe |
| if v2v_pipe is not None: |
| v2v_pipe.to('cpu') |
| clear_vram() |
| i2v_pipe.to('cuda') |
| print("VACE → CPU, I2V → GPU") |
|
|
| def load_video_frames_and_masks(video_path, mask_path, num_frames, target_h, target_w): |
| """Load source video frames and mask video frames for VACE.""" |
| |
| src_frames = load_video(video_path)[:num_frames] |
| print(f"Loaded {len(src_frames)} source frames (original size: {src_frames[0].size if src_frames else 'N/A'})") |
|
|
| |
| mask_frames_raw = load_video(mask_path)[:num_frames] |
|
|
| |
| masks = [] |
| for mf in mask_frames_raw: |
| gray = mf.convert("L") |
| masks.append(gray) |
| print(f"Loaded {len(masks)} mask frames") |
|
|
| |
| while len(masks) < len(src_frames): |
| masks.append(masks[-1] if masks else Image.new("L", src_frames[0].size, 0)) |
| while len(src_frames) < len(masks): |
| src_frames.append(src_frames[-1] if src_frames else Image.new("RGB", (target_w, target_h), (128, 128, 128))) |
|
|
| frame_count = min(len(src_frames), len(masks)) |
| src_frames = src_frames[:frame_count] |
| masks = masks[:frame_count] |
|
|
| return src_frames, masks |
|
|
| |
| def resize_image(image, max_dim=832, min_dim=480, square_dim=640, multiple_of=16): |
| width, height = image.size |
| if width == height: |
| return image.resize((square_dim, square_dim), Image.LANCZOS) |
| aspect_ratio = width / height |
| max_ar = max_dim / min_dim |
| min_ar = min_dim / max_dim |
| if aspect_ratio > max_ar: |
| crop_width = int(round(height * max_ar)) |
| left = (width - crop_width) // 2 |
| image = image.crop((left, 0, left + crop_width, height)) |
| target_w, target_h = max_dim, min_dim |
| elif aspect_ratio < min_ar: |
| crop_height = int(round(width / min_ar)) |
| top = (height - crop_height) // 2 |
| image = image.crop((0, top, width, top + crop_height)) |
| target_w, target_h = min_dim, max_dim |
| else: |
| if width > height: |
| target_w = max_dim |
| target_h = int(round(target_w / aspect_ratio)) |
| else: |
| target_h = max_dim |
| target_w = int(round(target_h * aspect_ratio)) |
| final_w = max(min_dim, min(max_dim, round(target_w / multiple_of) * multiple_of)) |
| final_h = max(min_dim, min(max_dim, round(target_h / multiple_of) * multiple_of)) |
| return image.resize((final_w, final_h), Image.LANCZOS) |
|
|
| def resize_and_crop_to_match(target_image, reference_image): |
| ref_w, ref_h = reference_image.size |
| tgt_w, tgt_h = target_image.size |
| scale = max(ref_w / tgt_w, ref_h / tgt_h) |
| new_w, new_h = int(tgt_w * scale), int(tgt_h * scale) |
| resized = target_image.resize((new_w, new_h), Image.Resampling.LANCZOS) |
| left, top = (new_w - ref_w) // 2, (new_h - ref_h) // 2 |
| return resized.crop((left, top, left + ref_w, top + ref_h)) |
|
|
| def get_num_frames(duration_seconds): |
| raw = int(round(duration_seconds * FIXED_FPS)) |
| raw = ((raw - 1) // 4) * 4 + 1 |
| return int(np.clip(raw, 9, MAX_FRAMES_MODEL)) |
|
|
| def extract_video_path(input_video): |
| if input_video is None: |
| return None |
| if isinstance(input_video, str): |
| return input_video |
| if isinstance(input_video, dict): |
| |
| return input_video.get("video", input_video.get("path", input_video.get("name", None))) |
| |
| if hasattr(input_video, 'video'): |
| return input_video.video |
| if hasattr(input_video, 'path'): |
| return input_video.path |
| if hasattr(input_video, 'name'): |
| return input_video.name |
| return str(input_video) |
|
|
| def extract_first_frame(video_input): |
| path = extract_video_path(video_input) |
| if not path or not os.path.exists(path): |
| return None |
| cap = cv2.VideoCapture(path) |
| ret, frame = cap.read() |
| cap.release() |
| if ret: |
| return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
| return None |
|
|
| |
| @spaces.GPU(duration=1200) |
| def run_inference( |
| task_type, input_image, input_video, mask_video, prompt, negative_prompt, |
| duration_seconds, steps, guidance_scale, guidance_scale_2, |
| current_seed, scheduler_name, flow_shift, frame_multiplier, |
| quality, last_image_input, lora_groups, |
| reference_image=None, grow_pixels=5, |
| progress=gr.Progress(track_tqdm=True), |
| ): |
| clear_vram() |
| num_frames = get_num_frames(duration_seconds) |
| task_id = str(uuid.uuid4())[:8] |
| print(f"Task: {task_id}, type={task_type}, duration={duration_seconds}s, frames={num_frames}") |
| start = time.time() |
|
|
| if "T2V" in task_type: |
| |
| t2v_steps = max(int(steps), 4) |
| print(f"T2V: steps={t2v_steps}, guidance={guidance_scale}/{guidance_scale_2}, frames={num_frames}") |
|
|
| pipe = load_t2v_pipeline() |
| result = pipe( |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| height=480, |
| width=832, |
| num_frames=num_frames, |
| guidance_scale=float(guidance_scale), |
| guidance_scale_2=float(guidance_scale_2), |
| num_inference_steps=t2v_steps, |
| generator=torch.Generator(device="cpu").manual_seed(int(current_seed)), |
| output_type="np", |
| ) |
| unload_t2v_pipeline() |
|
|
| else: |
| |
| if "V2V" in task_type: |
| |
| print(f"V2V 3-Step Pipeline: input_video type={type(input_video)}, value={input_video}") |
| video_path = extract_video_path(input_video) |
| if not video_path or not os.path.exists(video_path): |
| raise gr.Error("Upload a source video for V2V / V2V请上传原视频") |
|
|
| |
| mask_path = extract_video_path(mask_video) |
| if not mask_path or not os.path.exists(mask_path): |
| raise gr.Error("Upload a mask video for V2V / V2V请上传遮罩视频(黑白视频,白色=编辑区域)") |
|
|
| |
| grown_mask_path = grow_mask_video_file(mask_path, expand_pixels=int(grow_pixels)) |
| print(f"V2V: GrowMask applied ({grow_pixels}px)") |
|
|
| |
| composite_path = composite_video_from_mask(video_path, mask_path) |
| print(f"V2V: Composite video created") |
|
|
| |
| target_h, target_w = 480, 832 |
|
|
| |
| src_frames = load_video(composite_path)[:num_frames] |
| print(f"Loaded {len(src_frames)} composite frames") |
|
|
| |
| mask_frames_raw = load_video(grown_mask_path)[:num_frames] |
| masks = [mf.convert("L") for mf in mask_frames_raw] |
| print(f"Loaded {len(masks)} grown mask frames") |
|
|
| |
| while len(masks) < len(src_frames): |
| masks.append(masks[-1] if masks else Image.new("L", src_frames[0].size, 0)) |
| while len(src_frames) < len(masks): |
| src_frames.append(src_frames[-1] if src_frames else Image.new("RGB", (target_w, target_h), (128, 128, 128))) |
|
|
| |
| n = len(src_frames) |
| n = (n - 1) // 4 * 4 + 1 |
| n = max(n, 5) |
| src_frames = src_frames[:n] |
| masks = masks[:n] |
|
|
| |
| pipe = load_v2v_pipeline() |
| v2v_steps = max(int(steps), 20) |
| print(f"V2V VACE: steps={v2v_steps}, guidance={guidance_scale}, frames={len(src_frames)}, ref_image={'yes' if reference_image else 'no'}") |
|
|
| |
| vace_kwargs = dict( |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| video=src_frames, |
| mask=masks, |
| height=target_h, |
| width=target_w, |
| num_frames=len(src_frames), |
| guidance_scale=max(float(guidance_scale), 5.0), |
| num_inference_steps=v2v_steps, |
| generator=torch.Generator(device="cuda").manual_seed(int(current_seed)), |
| output_type="np", |
| ) |
|
|
| result = pipe(**vace_kwargs) |
| unload_v2v_pipeline() |
|
|
| |
| for p in [grown_mask_path, composite_path]: |
| try: |
| if p and os.path.exists(p): |
| os.remove(p) |
| except: |
| pass |
|
|
| else: |
| |
| if input_image is None: |
| raise gr.Error("Upload an image / 请上传图片") |
|
|
| scheduler_class = SCHEDULER_MAP.get(scheduler_name) |
| if scheduler_class and scheduler_class.__name__ != i2v_pipe.scheduler.config._class_name: |
| config = copy.deepcopy(i2v_original_scheduler.config) |
| if scheduler_class == FlowMatchEulerDiscreteScheduler: |
| config['shift'] = flow_shift |
| else: |
| config['flow_shift'] = flow_shift |
| i2v_pipe.scheduler = scheduler_class.from_config(config) |
|
|
| lora_loaded = False |
| if lora_groups: |
| try: |
| for idx, name in enumerate(lora_groups): |
| if name and name != "(None)": |
| lora_loader.load_lora_to_pipe(i2v_pipe, name, adapter_name=f"lora_{idx}") |
| lora_loaded = True |
| except Exception as e: |
| print(f"LoRA warning: {e}") |
|
|
| resized_image = resize_image(input_image) |
| processed_last = None |
| if last_image_input: |
| processed_last = resize_and_crop_to_match(last_image_input, resized_image) |
|
|
| print(f"I2V: size={resized_image.size}, steps={int(steps)}, guidance={guidance_scale}/{guidance_scale_2}") |
|
|
| result = i2v_pipe( |
| image=resized_image, |
| last_image=processed_last, |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| height=resized_image.height, |
| width=resized_image.width, |
| num_frames=num_frames, |
| guidance_scale=float(guidance_scale), |
| guidance_scale_2=float(guidance_scale_2), |
| num_inference_steps=int(steps), |
| generator=torch.Generator(device="cuda").manual_seed(int(current_seed)), |
| output_type="np", |
| ) |
|
|
| if lora_loaded: |
| lora_loader.unload_lora(i2v_pipe) |
|
|
| raw_frames = result.frames[0] |
| elapsed = time.time() - start |
| print(f"Generation took {elapsed:.1f}s ({len(raw_frames)} frames)") |
|
|
| frame_factor = frame_multiplier // FIXED_FPS |
| if frame_factor > 1: |
| rife_model.device() |
| rife_model.flownet = rife_model.flownet.half() |
| final_frames = interpolate_bits(raw_frames, multiplier=int(frame_factor)) |
| else: |
| final_frames = list(raw_frames) |
| final_fps = FIXED_FPS * max(1, frame_factor) |
|
|
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: |
| video_path = tmpfile.name |
| export_to_video(final_frames, video_path, fps=final_fps, quality=quality) |
| return video_path, task_id |
|
|
| |
| def generate_video( |
| task_type, input_image, input_video, mask_video, prompt, |
| lora_groups, duration_seconds, frame_multiplier, |
| steps, guidance_scale, guidance_scale_2, |
| negative_prompt, quality, seed, randomize_seed, |
| scheduler, flow_shift, last_image, display_result, |
| reference_image, grow_pixels, |
| progress=gr.Progress(track_tqdm=True), |
| ): |
| if not prompt or not prompt.strip(): |
| raise gr.Error("Enter a prompt / 请输入提示词") |
| current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) |
| video_path, task_id = run_inference( |
| task_type, input_image, input_video, mask_video, prompt, negative_prompt, |
| duration_seconds, steps, guidance_scale, guidance_scale_2, |
| current_seed, scheduler, flow_shift, frame_multiplier, |
| quality, last_image, lora_groups, |
| reference_image=reference_image, grow_pixels=grow_pixels, |
| ) |
| print(f"Done: {task_id}") |
| return (video_path if display_result else None), video_path, current_seed |
|
|
| |
| CSS = """ |
| #hidden-timestamp { opacity: 0; height: 0; width: 0; margin: 0; padding: 0; overflow: hidden; position: absolute; } |
| """ |
|
|
| with gr.Blocks(theme=gr.themes.Soft(), css=CSS, delete_cache=(3600, 10800)) as demo: |
| gr.Markdown("## WAN 2.2 Multi-Task Video Generation / 多任务视频生成") |
| gr.Markdown("#### I2V (Lightning 6-step) · T2V (Lightning 14B 4-step) · V2V (3-Step: SAM2→Composite→VACE)") |
| gr.Markdown("---") |
|
|
| task_type = gr.Radio( |
| choices=[ |
| "I2V (图生视频 / Image-to-Video)", |
| "T2V (文生视频 / Text-to-Video)", |
| "V2V (视频生视频 / Video-to-Video)", |
| ], |
| value="I2V (图生视频 / Image-to-Video)", |
| label="Task Type / 任务类型", |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| with gr.Group(): |
| input_image = gr.Image(type="pil", label="Input Image / 输入图片 (I2V)", sources=["upload", "clipboard"]) |
| with gr.Group(): |
| input_video = gr.Video(label="Source Video / 原视频 (V2V)", sources=["upload"], visible=False, interactive=True) |
| with gr.Group(): |
| mask_video = gr.Video(label="Mask Video / 遮罩视频 (V2V, 白色=编辑区域)", sources=["upload"], visible=False, interactive=True) |
| v2v_guide = gr.Markdown( |
| value="""### 📖 V2V 三步流水线 / 3-Step V2V Pipeline |
| |
| **Step 1 — SAM2 分割**: 上传原视频 → 提取第一帧 → 点击标记区域 → 生成遮罩视频 |
| **Step 2 — 自动合成**: 原视频 + 遮罩 → GrowMask扩展边界 + ImageComposite合成(自动完成) |
| **Step 3 — VACE 生成**: 合成视频 + 遮罩 + 参考图 + Prompt → 最终成品视频 |
| |
| 💡 也可跳过 Step 1,直接上传自己的遮罩视频(白色=编辑区域) |
| """, |
| visible=False, |
| ) |
| with gr.Group(visible=False) as v2v_mask_tools: |
| first_frame_display = gr.Image(label="第一帧预览 / First Frame (点击标记区域)", type="pil", interactive=True) |
| points_store = gr.State(value=[]) |
| points_display = gr.Textbox(label="标记点 / Points", value="无标记 / No points", interactive=False) |
| with gr.Row(): |
| point_mode = gr.Radio(choices=["include (编辑)", "exclude (排除)"], value="include (编辑)", label="点击模式") |
| with gr.Row(): |
| extract_frame_btn = gr.Button("📷 提取第一帧 / Extract First Frame", variant="secondary") |
| gen_mask_btn = gr.Button("🎭 生成遮罩 / Generate Mask (SAM2)", variant="primary") |
| clear_points_btn = gr.Button("🗑️ 清除标记 / Clear Points") |
| with gr.Accordion("🖼️ V2V 高级选项 / V2V Advanced", open=True): |
| reference_image = gr.Image(type="pil", label="参考图 / Reference Image (控制编辑区域的目标外观)", sources=["upload", "clipboard"]) |
| grow_pixels_sl = gr.Slider(minimum=0, maximum=30, step=1, value=5, label="GrowMask / 遮罩扩展 (像素)", info="扩展遮罩边界,让编辑区域过渡更自然") |
|
|
| prompt_input = gr.Textbox( |
| label="Prompt / 提示词", value="", |
| placeholder="Describe the video... / 描述你想生成的视频...", lines=3, |
| ) |
| duration_slider = gr.Slider( |
| minimum=0.5, maximum=15, step=0.5, value=3, |
| label="Duration / 时长 (seconds/秒)", |
| info="Max ~15s (241 frames @16fps) / 最大约15秒", |
| ) |
| frame_multi = gr.Dropdown(choices=[16, 32, 64], value=16, label="Output FPS / 输出帧率", info="RIFE interpolation / RIFE插帧") |
|
|
| with gr.Accordion("⚙️ Advanced Settings / 高级设置", open=False): |
| last_image = gr.Image(type="pil", label="Last Frame / 末帧 (Optional)", sources=["upload", "clipboard"]) |
| negative_prompt_input = gr.Textbox(label="Negative Prompt / 负面提示词", value=default_negative_prompt, lines=3) |
| with gr.Row(): |
| steps_slider = gr.Slider(minimum=1, maximum=50, step=1, value=6, label="Steps / 步数", info="I2V: 4-8 | T2V: 4-8 | V2V: 25-50") |
| quality_sl = gr.Slider(minimum=1, maximum=10, step=1, value=6, label="Quality / 质量") |
| with gr.Row(): |
| guidance_h = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance High / 引导(高噪声)") |
| guidance_l = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Low / 引导(低噪声)") |
| with gr.Row(): |
| scheduler_dd = gr.Dropdown(choices=list(SCHEDULER_MAP.keys()), value="UniPCMultistep", label="Scheduler / 调度器") |
| flow_shift_sl = gr.Slider(minimum=0.5, maximum=15.0, step=0.1, value=3.0, label="Flow Shift / 流偏移") |
| with gr.Row(): |
| seed_sl = gr.Slider(minimum=0, maximum=MAX_SEED, step=1, value=42, label="Seed / 种子") |
| random_seed_cb = gr.Checkbox(label="Random / 随机", value=True) |
| lora_dd = gr.Dropdown(choices=lora_loader.get_lora_choices(), label="LoRA (I2V only / 仅I2V)", multiselect=True, info="From WAN2.2_LoraSet_NSFW") |
| display_cb = gr.Checkbox(label="Display / 显示", value=True) |
|
|
| generate_btn = gr.Button("🎬 Generate / 生成视频", variant="primary", size="lg") |
|
|
| with gr.Column(): |
| video_output = gr.Video(label="Generated Video / 生成的视频", autoplay=True, sources=["upload"], show_download_button=True, show_share_button=True, interactive=False, elem_id="generated-video") |
| with gr.Row(): |
| grab_frame_btn = gr.Button("📸 Use Frame / 使用帧", variant="secondary") |
| timestamp_box = gr.Number(value=0, label="Timestamp", visible=False, elem_id="hidden-timestamp") |
| file_output = gr.File(label="Download / 下载") |
|
|
| def update_task_ui(task): |
| is_v2v = "V2V" in task |
| is_t2v = "T2V" in task |
| if is_t2v: |
| return (gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), |
| gr.update(visible=False), gr.update(visible=False), |
| gr.update(value=4), gr.update(value=1.0), gr.update(value=1.0)) |
| elif is_v2v: |
| return (gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), |
| gr.update(visible=True), gr.update(visible=True), |
| gr.update(value=30), gr.update(value=5.0), gr.update(value=1.0)) |
| else: |
| return (gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), |
| gr.update(visible=False), gr.update(visible=False), |
| gr.update(value=6), gr.update(value=1.0), gr.update(value=1.0)) |
|
|
| task_type.change(update_task_ui, inputs=[task_type], outputs=[input_image, input_video, mask_video, v2v_guide, v2v_mask_tools, steps_slider, guidance_h, guidance_l]) |
|
|
| |
| def on_extract_first_frame(video): |
| vpath = extract_video_path(video) |
| if not vpath or not os.path.exists(vpath): |
| raise gr.Error("请先上传视频 / Upload video first") |
| frame = extract_first_frame_from_video(vpath) |
| if frame is None: |
| raise gr.Error("无法提取第一帧 / Failed to extract first frame") |
| return frame, [], "无标记 / No points" |
|
|
| def on_click_frame(img, points, mode, evt: gr.SelectData): |
| if img is None: |
| return img, points, "请先提取第一帧 / Extract first frame first" |
| x, y = evt.index |
| print(f"[DEBUG] Click at ({x}, {y}), mode={mode}, total_points={len(points)+1}") |
| label = 1 if "include" in mode else 0 |
| points = list(points) if points else [] |
| points.append({"x": x, "y": y, "label": label}) |
| |
| display_img = img.copy() |
| draw = __import__('PIL').ImageDraw.Draw(display_img) |
| for p in points: |
| color = (0, 255, 0) if p["label"] == 1 else (255, 0, 0) |
| r = 8 |
| draw.ellipse([p["x"]-r, p["y"]-r, p["x"]+r, p["y"]+r], fill=color, outline="white", width=2) |
| info = f"{len([p for p in points if p['label']==1])} include, {len([p for p in points if p['label']==0])} exclude" |
| return display_img, points, info |
|
|
| def on_clear_points(original_video): |
| vpath = extract_video_path(original_video) |
| if vpath and os.path.exists(vpath): |
| frame = extract_first_frame_from_video(vpath) |
| return frame, [], "无标记 / No points" |
| return None, [], "无标记 / No points" |
|
|
| def on_generate_mask(video, points): |
| import json |
| print(f"[DEBUG] on_generate_mask called, points type={type(points)}, value={points}") |
| vpath = extract_video_path(video) |
| if not vpath: |
| raise gr.Error("请先上传视频 / Upload video first") |
| if not points: |
| raise gr.Error("请先在第一帧上点击标记 / Click on first frame to mark areas") |
| mask_path = generate_mask_video(vpath, json.dumps(points)) |
| return mask_path |
|
|
| extract_frame_btn.click(fn=on_extract_first_frame, inputs=[input_video], outputs=[first_frame_display, points_store, points_display]) |
| first_frame_display.select(fn=on_click_frame, inputs=[first_frame_display, points_store, point_mode], outputs=[first_frame_display, points_store, points_display]) |
| clear_points_btn.click(fn=on_clear_points, inputs=[input_video], outputs=[first_frame_display, points_store, points_display]) |
| gen_mask_btn.click(fn=on_generate_mask, inputs=[input_video, points_store], outputs=[mask_video]) |
| generate_btn.click( |
| fn=generate_video, |
| inputs=[task_type, input_image, input_video, mask_video, prompt_input, lora_dd, duration_slider, frame_multi, |
| steps_slider, guidance_h, guidance_l, negative_prompt_input, quality_sl, seed_sl, random_seed_cb, |
| scheduler_dd, flow_shift_sl, last_image, display_cb, |
| reference_image, grow_pixels_sl], |
| outputs=[video_output, file_output, seed_sl], |
| ) |
| grab_frame_btn.click(fn=None, inputs=None, outputs=[timestamp_box], js=get_timestamp_js) |
| timestamp_box.change(fn=extract_frame, inputs=[video_output, timestamp_box], outputs=[input_image]) |
|
|
| if __name__ == "__main__": |
| demo.queue().launch(mcp_server=True, show_error=True) |
|
|