Spaces:
Runtime error
Runtime error
| import spaces | |
| import torch | |
| from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline | |
| from diffusers.models.transformers.transformer_wan import WanTransformer3DModel | |
| from diffusers.utils.export_utils import export_to_video | |
| import gradio as gr | |
| import tempfile | |
| import numpy as np | |
| from PIL import Image | |
| import random | |
| import gc | |
| from torchao.quantization import quantize_ | |
| from torchao.quantization import Float8DynamicActivationFloat8WeightConfig | |
| from torchao.quantization import Int8WeightOnlyConfig | |
| import aoti | |
| MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" | |
| MAX_DIM = 832 | |
| MIN_DIM = 480 | |
| SQUARE_DIM = 640 | |
| MULTIPLE_OF = 16 | |
| MAX_SEED = np.iinfo(np.int32).max | |
| FIXED_FPS = 16 | |
| MIN_FRAMES_MODEL = 8 | |
| MAX_FRAMES_MODEL = 80 | |
| MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS,1) | |
| MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS,1) | |
| device = "cuda" | |
| # -------------------- NSFW 检测模型加载 -------------------- | |
| try: | |
| print("Loading NSFW detector...") | |
| from transformers import AutoProcessor, AutoModelForImageClassification | |
| nsfw_processor = AutoProcessor.from_pretrained("Falconsai/nsfw_image_detection") | |
| nsfw_model = AutoModelForImageClassification.from_pretrained( | |
| "Falconsai/nsfw_image_detection" | |
| ).to(device) | |
| print("NSFW detector loaded successfully.") | |
| except Exception as e: | |
| print(f"Failed to load NSFW detector: {e}") | |
| nsfw_model = None | |
| nsfw_processor = None | |
| # ----------------------------------------------------------- | |
| class GenerationError(Exception): | |
| """Custom exception for generation errors""" | |
| pass | |
| def detect_image_nsfw(image: Image.Image, threshold: float = 0.5) -> bool: | |
| """Returns True if image is NSFW""" | |
| inputs = nsfw_processor(images=image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = nsfw_model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| nsfw_score = probs[0][1].item() # label 1 = NSFW | |
| return nsfw_score > threshold | |
| def detect_frames_nsfw(frames, sample_n=5, threshold=0.5): | |
| """ | |
| 直接检测 PIL Image 列表 | |
| :param frames: 生成的 PIL 图片列表 | |
| :param sample_n: 采样图片数量(从视频中均匀抽取几帧) | |
| :param threshold: 判定阈值 | |
| """ | |
| if nsfw_model is None or nsfw_processor is None: | |
| return False | |
| num_frames = len(frames) | |
| # 均匀采样,例如 49 帧里抽 5 帧,避免每一帧都检导致变慢 | |
| sample_indices = np.linspace(0, num_frames - 1, sample_n, dtype=int) | |
| for idx in sample_indices: | |
| image = frames[idx] | |
| # 1. 核心修正:Wan2.2 的帧是 Tensor [C, H, W],且在 0-1 之间 | |
| if torch.is_tensor(image): | |
| # 转为 float32 并搬到 CPU | |
| img = image.detach().cpu().float() | |
| # 置换维度:从 [C, H, W] -> [H, W, C] | |
| if img.ndim == 3 and img.shape[0] == 3: | |
| img = img.permute(1, 2, 0) | |
| # 缩放数值:从 [0, 1] -> [0, 255] | |
| if img.max() <= 1.05: | |
| img = (img * 255).clamp(0, 255) | |
| # 转为 Numpy 整数 | |
| image = Image.fromarray(img.numpy().astype(np.uint8)) | |
| # 2. 如果是 Numpy 数组 | |
| if isinstance(image, np.ndarray): | |
| # 确保数值范围是 0-255 | |
| if image.max() <= 1.05: | |
| image = (image * 255).astype(np.uint8) | |
| # 将 Numpy 转回 PIL | |
| image = Image.fromarray(image) | |
| # 2. 统一转为 RGB PIL 格式 | |
| image = image.convert("RGB") | |
| # 直接调用你之前写的 detect_nsfw | |
| if detect_image_nsfw(image, threshold): | |
| return True | |
| return False | |
| pipe = WanImageToVideoPipeline.from_pretrained(MODEL_ID, | |
| transformer=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers', | |
| subfolder='transformer', | |
| torch_dtype=torch.bfloat16, | |
| device_map='cuda', | |
| ), | |
| transformer_2=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers', | |
| subfolder='transformer_2', | |
| torch_dtype=torch.bfloat16, | |
| device_map='cuda', | |
| ), | |
| torch_dtype=torch.bfloat16, | |
| ).to('cuda') | |
| pipe.load_lora_weights( | |
| "Kijai/WanVideo_comfy", | |
| weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", | |
| adapter_name="lightx2v" | |
| ) | |
| kwargs_lora = {} | |
| kwargs_lora["load_into_transformer_2"] = True | |
| pipe.load_lora_weights( | |
| "Kijai/WanVideo_comfy", | |
| weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", | |
| adapter_name="lightx2v_2", **kwargs_lora | |
| ) | |
| pipe.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1., 1.]) | |
| pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3., components=["transformer"]) | |
| pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1., components=["transformer_2"]) | |
| pipe.unload_lora_weights() | |
| quantize_(pipe.text_encoder, Int8WeightOnlyConfig()) | |
| quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig()) | |
| quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig()) | |
| aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da') | |
| aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da') | |
| default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation" | |
| default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走" | |
| def resize_image(image: Image.Image) -> Image.Image: | |
| """ | |
| Resizes an image to fit within the model's constraints, preserving aspect ratio as much as possible. | |
| """ | |
| width, height = image.size | |
| # Handle square case | |
| if width == height: | |
| return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS) | |
| aspect_ratio = width / height | |
| MAX_ASPECT_RATIO = MAX_DIM / MIN_DIM | |
| MIN_ASPECT_RATIO = MIN_DIM / MAX_DIM | |
| image_to_resize = image | |
| if aspect_ratio > MAX_ASPECT_RATIO: | |
| # Very wide image -> crop width to fit 832x480 aspect ratio | |
| target_w, target_h = MAX_DIM, MIN_DIM | |
| crop_width = int(round(height * MAX_ASPECT_RATIO)) | |
| left = (width - crop_width) // 2 | |
| image_to_resize = image.crop((left, 0, left + crop_width, height)) | |
| elif aspect_ratio < MIN_ASPECT_RATIO: | |
| # Very tall image -> crop height to fit 480x832 aspect ratio | |
| target_w, target_h = MIN_DIM, MAX_DIM | |
| crop_height = int(round(width / MIN_ASPECT_RATIO)) | |
| top = (height - crop_height) // 2 | |
| image_to_resize = image.crop((0, top, width, top + crop_height)) | |
| else: | |
| if width > height: # Landscape | |
| target_w = MAX_DIM | |
| target_h = int(round(target_w / aspect_ratio)) | |
| else: # Portrait | |
| target_h = MAX_DIM | |
| target_w = int(round(target_h * aspect_ratio)) | |
| final_w = round(target_w / MULTIPLE_OF) * MULTIPLE_OF | |
| final_h = round(target_h / MULTIPLE_OF) * MULTIPLE_OF | |
| final_w = max(MIN_DIM, min(MAX_DIM, final_w)) | |
| final_h = max(MIN_DIM, min(MAX_DIM, final_h)) | |
| return image_to_resize.resize((final_w, final_h), Image.LANCZOS) | |
| def get_num_frames(duration_seconds: float): | |
| return 1 + int(np.clip( | |
| int(round(duration_seconds * FIXED_FPS)), | |
| MIN_FRAMES_MODEL, | |
| MAX_FRAMES_MODEL, | |
| )) | |
| def get_duration( | |
| input_image, | |
| prompt, | |
| steps, | |
| negative_prompt, | |
| duration_seconds, | |
| guidance_scale, | |
| guidance_scale_2, | |
| seed, | |
| randomize_seed, | |
| ): | |
| # 首先检查 input_image 是否为 None | |
| if input_image is None: | |
| # 返回一个较小的预估时间,因为实际上会提前在 _generate_video 中报错 | |
| return 10 | |
| BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624 | |
| BASE_STEP_DURATION = 15 | |
| width, height = resize_image(input_image).size | |
| frames = get_num_frames(duration_seconds) | |
| factor = frames * width * height / BASE_FRAMES_HEIGHT_WIDTH | |
| step_duration = BASE_STEP_DURATION * factor ** 1.5 | |
| return 10 + int(steps) * step_duration | |
| progress=gr.Progress() | |
| def _generate_video( | |
| input_image, | |
| prompt, | |
| steps = 4, | |
| negative_prompt=default_negative_prompt, | |
| duration_seconds = MAX_DURATION, | |
| guidance_scale = 1, | |
| guidance_scale_2 = 1, | |
| seed = 42, | |
| randomize_seed = False, | |
| ): | |
| progress(0,desc="Starting") | |
| def callback_fn(pipe, step, timestep, callback_kwargs): | |
| print(f"[Step {step}] Timestep: {timestep}") | |
| progress_value = (step+1.0)/steps | |
| progress(progress_value, desc=f"Video generating, {step + 1}/{steps} steps") | |
| return callback_kwargs | |
| try: | |
| if input_image is None: | |
| raise gr.Error("Please upload an input image.") | |
| # NSFW 检测 | |
| if nsfw_model and nsfw_processor: | |
| if detect_image_nsfw(input_image): | |
| msg = "The input contains NSFW content and cannot be used. Please modify the prompt and try again." | |
| raise Exception(msg) | |
| num_frames = get_num_frames(duration_seconds) | |
| current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) | |
| resized_image = resize_image(input_image) | |
| output_frames_list = pipe( | |
| image=resized_image, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=resized_image.height, | |
| width=resized_image.width, | |
| num_frames=num_frames, | |
| guidance_scale=float(guidance_scale), | |
| guidance_scale_2=float(guidance_scale_2), | |
| num_inference_steps=int(steps), | |
| generator=torch.Generator(device="cuda").manual_seed(current_seed), | |
| callback_on_step_end=callback_fn, | |
| ).frames[0] | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: | |
| video_path = tmpfile.name | |
| # NSFW 检测 | |
| if nsfw_model and nsfw_processor: | |
| num_total_frames = len(output_frames_list) | |
| sample_count = max(1, int(num_total_frames // (FIXED_FPS / 2))) | |
| if detect_frames_nsfw(output_frames_list,sample_count,0.6): | |
| msg = "Generated video contains NSFW content and cannot be displayed. Please modify the input image or prompt and try again." | |
| raise Exception(msg) | |
| export_to_video(output_frames_list, video_path, fps=FIXED_FPS) | |
| progress(1, desc="Complete") | |
| info = { | |
| "status": "success" | |
| } | |
| return info, video_path, current_seed | |
| except GenerationError as e: | |
| error_info = { | |
| "error": str(e), | |
| "status": "failed", | |
| } | |
| return error_info, None, None | |
| except Exception as e: | |
| error_info = { | |
| "error": str(e), | |
| "status": "failed", | |
| } | |
| return error_info, None, None | |
| finally: | |
| # Cleanup | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| def generate_video( | |
| input_image, | |
| prompt, | |
| steps = 4, | |
| negative_prompt=default_negative_prompt, | |
| duration_seconds = MAX_DURATION, | |
| guidance_scale = 1, | |
| guidance_scale_2 = 1, | |
| seed = 42, | |
| randomize_seed = False, | |
| ): | |
| # 调用 GPU 函数 | |
| info, video_path, current_seed = _generate_video(input_image,prompt,steps,negative_prompt,duration_seconds,guidance_scale,guidance_scale_2,seed,randomize_seed) | |
| # 如果出错,抛出异常 | |
| if info["status"] == "failed": | |
| raise gr.Error(info["error"]) | |
| # 返回图片 | |
| return video_path, current_seed | |
| css=""" | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 1200px; | |
| } | |
| """ | |
| title = "# AI Video Maker" | |
| description = "Bring your imagination to motion with AI Video Maker. Simply upload a photo and provide a prompt to generate smooth, high-quality animations for your content creation. " | |
| note = "*Note: This demo has a daily usage cap. If you have reached the limit or need faster rendering, please visit [AI Image to Video](https://www.imgtovideo.ai/) to continue animating without interruptions.*" | |
| with gr.Blocks(css=css).queue() as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown(title) | |
| gr.Markdown(description) | |
| gr.Markdown(note) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Input") | |
| input_image_component = gr.Image(type="pil", label="Input Image") | |
| prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v, lines=3) | |
| duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3) | |
| seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True) | |
| randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True) | |
| steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=6, label="Inference Steps") | |
| guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale - high noise stage") | |
| guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale 2 - low noise stage") | |
| generate_button = gr.Button("Generate Video", variant="primary") | |
| with gr.Column(): | |
| gr.Markdown("### Output") | |
| video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False) | |
| ui_inputs = [ | |
| input_image_component, prompt_input, steps_slider, | |
| negative_prompt_input, duration_seconds_input, | |
| guidance_scale_input, guidance_scale_2_input, seed_input, randomize_seed_checkbox | |
| ] | |
| generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input]) | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "wan_i2v_input.JPG", | |
| "POV selfie video, white cat with sunglasses standing on surfboard, relaxed smile, tropical beach behind (clear water, green hills, blue sky with clouds). Surfboard tips, cat falls into ocean, camera plunges underwater with bubbles and sunlight beams. Brief underwater view of cat's face, then cat resurfaces, still filming selfie, playful summer vacation mood.", | |
| 4, | |
| ], | |
| [ | |
| "wan22_input_2.jpg", | |
| "A sleek lunar vehicle glides into view from left to right, kicking up moon dust as astronauts in white spacesuits hop aboard with characteristic lunar bouncing movements. In the distant background, a VTOL craft descends straight down and lands silently on the surface. Throughout the entire scene, ethereal aurora borealis ribbons dance across the star-filled sky, casting shimmering curtains of green, blue, and purple light that bathe the lunar landscape in an otherworldly, magical glow.", | |
| 4, | |
| ], | |
| [ | |
| "kill_bill.jpeg", | |
| "Uma Thurman's character, Beatrix Kiddo, holds her razor-sharp katana blade steady in the cinematic lighting. Suddenly, the polished steel begins to soften and distort, like heated metal starting to lose its structural integrity. The blade's perfect edge slowly warps and droops, molten steel beginning to flow downward in silvery rivulets while maintaining its metallic sheen. The transformation starts subtly at first - a slight bend in the blade - then accelerates as the metal becomes increasingly fluid. The camera holds steady on her face as her piercing eyes gradually narrow, not with lethal focus, but with confusion and growing alarm as she watches her weapon dissolve before her eyes. Her breathing quickens slightly as she witnesses this impossible transformation. The melting intensifies, the katana's perfect form becoming increasingly abstract, dripping like liquid mercury from her grip. Molten droplets fall to the ground with soft metallic impacts. Her expression shifts from calm readiness to bewilderment and concern as her legendary instrument of vengeance literally liquefies in her hands, leaving her defenseless and disoriented.", | |
| 6, | |
| ], | |
| ], | |
| inputs=[input_image_component, prompt_input, steps_slider], | |
| outputs=[video_output, seed_input], | |
| fn=generate_video, | |
| cache_examples=True, | |
| cache_mode="lazy" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |