import spaces import torch from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline from diffusers.models.transformers.transformer_wan import WanTransformer3DModel from diffusers.utils.export_utils import export_to_video import gradio as gr import tempfile import numpy as np from PIL import Image import random import gc from torchao.quantization import quantize_ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig from torchao.quantization import Int8WeightOnlyConfig import aoti MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" MAX_DIM = 832 MIN_DIM = 480 SQUARE_DIM = 640 MULTIPLE_OF = 16 MAX_SEED = np.iinfo(np.int32).max FIXED_FPS = 16 MIN_FRAMES_MODEL = 8 MAX_FRAMES_MODEL = 80 MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS,1) MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS,1) device = "cuda" # -------------------- NSFW 检测模型加载 -------------------- try: print("Loading NSFW detector...") from transformers import AutoProcessor, AutoModelForImageClassification nsfw_processor = AutoProcessor.from_pretrained("Falconsai/nsfw_image_detection") nsfw_model = AutoModelForImageClassification.from_pretrained( "Falconsai/nsfw_image_detection" ).to(device) print("NSFW detector loaded successfully.") except Exception as e: print(f"Failed to load NSFW detector: {e}") nsfw_model = None nsfw_processor = None # ----------------------------------------------------------- class GenerationError(Exception): """Custom exception for generation errors""" pass def detect_image_nsfw(image: Image.Image, threshold: float = 0.5) -> bool: """Returns True if image is NSFW""" inputs = nsfw_processor(images=image, return_tensors="pt").to(device) with torch.no_grad(): outputs = nsfw_model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1) nsfw_score = probs[0][1].item() # label 1 = NSFW return nsfw_score > threshold def detect_frames_nsfw(frames, sample_n=5, threshold=0.5): """ 直接检测 PIL Image 列表 :param frames: 生成的 PIL 图片列表 :param sample_n: 采样图片数量(从视频中均匀抽取几帧) :param threshold: 判定阈值 """ if nsfw_model is None or nsfw_processor is None: return False num_frames = len(frames) # 均匀采样,例如 49 帧里抽 5 帧,避免每一帧都检导致变慢 sample_indices = np.linspace(0, num_frames - 1, sample_n, dtype=int) for idx in sample_indices: image = frames[idx] # 1. 核心修正:Wan2.2 的帧是 Tensor [C, H, W],且在 0-1 之间 if torch.is_tensor(image): # 转为 float32 并搬到 CPU img = image.detach().cpu().float() # 置换维度:从 [C, H, W] -> [H, W, C] if img.ndim == 3 and img.shape[0] == 3: img = img.permute(1, 2, 0) # 缩放数值:从 [0, 1] -> [0, 255] if img.max() <= 1.05: img = (img * 255).clamp(0, 255) # 转为 Numpy 整数 image = Image.fromarray(img.numpy().astype(np.uint8)) # 2. 如果是 Numpy 数组 if isinstance(image, np.ndarray): # 确保数值范围是 0-255 if image.max() <= 1.05: image = (image * 255).astype(np.uint8) # 将 Numpy 转回 PIL image = Image.fromarray(image) # 2. 统一转为 RGB PIL 格式 image = image.convert("RGB") # 直接调用你之前写的 detect_nsfw if detect_image_nsfw(image, threshold): return True return False pipe = WanImageToVideoPipeline.from_pretrained(MODEL_ID, transformer=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers', subfolder='transformer', torch_dtype=torch.bfloat16, device_map='cuda', ), transformer_2=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers', subfolder='transformer_2', torch_dtype=torch.bfloat16, device_map='cuda', ), torch_dtype=torch.bfloat16, ).to('cuda') pipe.load_lora_weights( "Kijai/WanVideo_comfy", weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", adapter_name="lightx2v" ) kwargs_lora = {} kwargs_lora["load_into_transformer_2"] = True pipe.load_lora_weights( "Kijai/WanVideo_comfy", weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", adapter_name="lightx2v_2", **kwargs_lora ) pipe.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1., 1.]) pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3., components=["transformer"]) pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1., components=["transformer_2"]) pipe.unload_lora_weights() quantize_(pipe.text_encoder, Int8WeightOnlyConfig()) quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig()) quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig()) aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da') aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da') default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation" default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走" def resize_image(image: Image.Image) -> Image.Image: """ Resizes an image to fit within the model's constraints, preserving aspect ratio as much as possible. """ width, height = image.size # Handle square case if width == height: return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS) aspect_ratio = width / height MAX_ASPECT_RATIO = MAX_DIM / MIN_DIM MIN_ASPECT_RATIO = MIN_DIM / MAX_DIM image_to_resize = image if aspect_ratio > MAX_ASPECT_RATIO: # Very wide image -> crop width to fit 832x480 aspect ratio target_w, target_h = MAX_DIM, MIN_DIM crop_width = int(round(height * MAX_ASPECT_RATIO)) left = (width - crop_width) // 2 image_to_resize = image.crop((left, 0, left + crop_width, height)) elif aspect_ratio < MIN_ASPECT_RATIO: # Very tall image -> crop height to fit 480x832 aspect ratio target_w, target_h = MIN_DIM, MAX_DIM crop_height = int(round(width / MIN_ASPECT_RATIO)) top = (height - crop_height) // 2 image_to_resize = image.crop((0, top, width, top + crop_height)) else: if width > height: # Landscape target_w = MAX_DIM target_h = int(round(target_w / aspect_ratio)) else: # Portrait target_h = MAX_DIM target_w = int(round(target_h * aspect_ratio)) final_w = round(target_w / MULTIPLE_OF) * MULTIPLE_OF final_h = round(target_h / MULTIPLE_OF) * MULTIPLE_OF final_w = max(MIN_DIM, min(MAX_DIM, final_w)) final_h = max(MIN_DIM, min(MAX_DIM, final_h)) return image_to_resize.resize((final_w, final_h), Image.LANCZOS) def get_num_frames(duration_seconds: float): return 1 + int(np.clip( int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL, )) def get_duration( input_image, prompt, steps, negative_prompt, duration_seconds, guidance_scale, guidance_scale_2, seed, randomize_seed, ): # 首先检查 input_image 是否为 None if input_image is None: # 返回一个较小的预估时间,因为实际上会提前在 _generate_video 中报错 return 10 BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624 BASE_STEP_DURATION = 15 width, height = resize_image(input_image).size frames = get_num_frames(duration_seconds) factor = frames * width * height / BASE_FRAMES_HEIGHT_WIDTH step_duration = BASE_STEP_DURATION * factor ** 1.5 return 10 + int(steps) * step_duration progress=gr.Progress() @spaces.GPU(duration=get_duration) def _generate_video( input_image, prompt, steps = 4, negative_prompt=default_negative_prompt, duration_seconds = MAX_DURATION, guidance_scale = 1, guidance_scale_2 = 1, seed = 42, randomize_seed = False, ): progress(0,desc="Starting") def callback_fn(pipe, step, timestep, callback_kwargs): print(f"[Step {step}] Timestep: {timestep}") progress_value = (step+1.0)/steps progress(progress_value, desc=f"Video generating, {step + 1}/{steps} steps") return callback_kwargs try: if input_image is None: raise gr.Error("Please upload an input image.") # NSFW 检测 if nsfw_model and nsfw_processor: if detect_image_nsfw(input_image): msg = "The input contains NSFW content and cannot be used. Please modify the prompt and try again." raise Exception(msg) num_frames = get_num_frames(duration_seconds) current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) resized_image = resize_image(input_image) output_frames_list = pipe( image=resized_image, prompt=prompt, negative_prompt=negative_prompt, height=resized_image.height, width=resized_image.width, num_frames=num_frames, guidance_scale=float(guidance_scale), guidance_scale_2=float(guidance_scale_2), num_inference_steps=int(steps), generator=torch.Generator(device="cuda").manual_seed(current_seed), callback_on_step_end=callback_fn, ).frames[0] with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: video_path = tmpfile.name # NSFW 检测 if nsfw_model and nsfw_processor: num_total_frames = len(output_frames_list) sample_count = max(1, int(num_total_frames // (FIXED_FPS / 2))) if detect_frames_nsfw(output_frames_list,sample_count,0.6): msg = "Generated video contains NSFW content and cannot be displayed. Please modify the input image or prompt and try again." raise Exception(msg) export_to_video(output_frames_list, video_path, fps=FIXED_FPS) progress(1, desc="Complete") info = { "status": "success" } return info, video_path, current_seed except GenerationError as e: error_info = { "error": str(e), "status": "failed", } return error_info, None, None except Exception as e: error_info = { "error": str(e), "status": "failed", } return error_info, None, None finally: # Cleanup torch.cuda.empty_cache() gc.collect() def generate_video( input_image, prompt, steps = 4, negative_prompt=default_negative_prompt, duration_seconds = MAX_DURATION, guidance_scale = 1, guidance_scale_2 = 1, seed = 42, randomize_seed = False, ): # 调用 GPU 函数 info, video_path, current_seed = _generate_video(input_image,prompt,steps,negative_prompt,duration_seconds,guidance_scale,guidance_scale_2,seed,randomize_seed) # 如果出错,抛出异常 if info["status"] == "failed": raise gr.Error(info["error"]) # 返回图片 return video_path, current_seed css=""" #col-container { margin: 0 auto; max-width: 1200px; } """ title = "# AI Video Maker" description = "Bring your imagination to motion with AI Video Maker. Simply upload a photo and provide a prompt to generate smooth, high-quality animations for your content creation. " note = "*Note: This demo has a daily usage cap. If you have reached the limit or need faster rendering, please visit [AI Image to Video](https://www.imgtovideo.ai/) to continue animating without interruptions.*" with gr.Blocks(css=css).queue() as demo: with gr.Column(elem_id="col-container"): gr.Markdown(title) gr.Markdown(description) gr.Markdown(note) with gr.Row(): with gr.Column(): gr.Markdown("### Input") input_image_component = gr.Image(type="pil", label="Input Image") prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v, lines=3) duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.") with gr.Accordion("Advanced Settings", open=False): negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3) seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True) randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True) steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=6, label="Inference Steps") guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale - high noise stage") guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale 2 - low noise stage") generate_button = gr.Button("Generate Video", variant="primary") with gr.Column(): gr.Markdown("### Output") video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False) ui_inputs = [ input_image_component, prompt_input, steps_slider, negative_prompt_input, duration_seconds_input, guidance_scale_input, guidance_scale_2_input, seed_input, randomize_seed_checkbox ] generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input]) gr.Examples( examples=[ [ "wan_i2v_input.JPG", "POV selfie video, white cat with sunglasses standing on surfboard, relaxed smile, tropical beach behind (clear water, green hills, blue sky with clouds). Surfboard tips, cat falls into ocean, camera plunges underwater with bubbles and sunlight beams. Brief underwater view of cat's face, then cat resurfaces, still filming selfie, playful summer vacation mood.", 4, ], [ "wan22_input_2.jpg", "A sleek lunar vehicle glides into view from left to right, kicking up moon dust as astronauts in white spacesuits hop aboard with characteristic lunar bouncing movements. In the distant background, a VTOL craft descends straight down and lands silently on the surface. Throughout the entire scene, ethereal aurora borealis ribbons dance across the star-filled sky, casting shimmering curtains of green, blue, and purple light that bathe the lunar landscape in an otherworldly, magical glow.", 4, ], [ "kill_bill.jpeg", "Uma Thurman's character, Beatrix Kiddo, holds her razor-sharp katana blade steady in the cinematic lighting. Suddenly, the polished steel begins to soften and distort, like heated metal starting to lose its structural integrity. The blade's perfect edge slowly warps and droops, molten steel beginning to flow downward in silvery rivulets while maintaining its metallic sheen. The transformation starts subtly at first - a slight bend in the blade - then accelerates as the metal becomes increasingly fluid. The camera holds steady on her face as her piercing eyes gradually narrow, not with lethal focus, but with confusion and growing alarm as she watches her weapon dissolve before her eyes. Her breathing quickens slightly as she witnesses this impossible transformation. The melting intensifies, the katana's perfect form becoming increasingly abstract, dripping like liquid mercury from her grip. Molten droplets fall to the ground with soft metallic impacts. Her expression shifts from calm readiness to bewilderment and concern as her legendary instrument of vengeance literally liquefies in her hands, leaving her defenseless and disoriented.", 6, ], ], inputs=[input_image_component, prompt_input, steps_slider], outputs=[video_output, seed_input], fn=generate_video, cache_examples=True, cache_mode="lazy" ) if __name__ == "__main__": demo.launch()