sdmklgdfmkl / app.py
matthewkram's picture
Update app.py
c20db4f verified
raw
history blame
8.63 kB
import os
import sys
import uuid
import shutil
import time
import gradio as gr
import torch
from diffusers import StableVideoDiffusionPipeline
from PIL import Image
import numpy as np
import cv2
import subprocess
import tempfile
class WanAnimateApp:
def __init__(self):
model_name = "stabilityai/stable-video-diffusion-img2vid-xt"
self.pipe = StableVideoDiffusionPipeline.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
variant="fp16",
device_map="cpu"
)
def predict(
self,
ref_img,
video,
model_id,
model,
):
if ref_img is None or video is None:
return None, "Upload both image and video."
try:
# Local processing — PIL directly (no open for PIL type)
if isinstance(ref_img, Image.Image):
ref_image = ref_img.convert("RGB").resize((576, 320))
else:
ref_image = Image.open(ref_img).convert("RGB").resize((576, 320))
cap = cv2.VideoCapture(video)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
motion_hint = f" with dynamic motion from {frame_count} frames"
# Prompt based on mode
if model_id == "wan2.2-animate-move":
prompt = f"Animate the character in the reference image{motion_hint}, high quality, smooth movements."
else:
prompt = f"Replace the character in the video with the reference image{motion_hint}, seamless, detailed."
# Parameters
num_frames = 25 if model == "wan-pro" else 14
num_steps = 25 if model == "wan-pro" else 15
# Local generation
generator = torch.Generator(device="cpu").manual_seed(42)
output = self.pipe(
ref_image,
num_inference_steps=num_steps,
num_frames=num_frames,
generator=generator,
decode_chunk_size=2
).frames[0]
# Save MP4 with ffmpeg
temp_dir = tempfile.mkdtemp()
for i, frame in enumerate(output):
frame.save(f"{temp_dir}/frame_{i:04d}.png")
temp_video = f"/tmp/output_{uuid.uuid4()}.mp4"
subprocess.run([
'ffmpeg', '-y', '-framerate', '7', '-i', f"{temp_dir}/frame_%04d.png",
'-c:v', 'libx264', '-pix_fmt', 'yuv420p', temp_video
], check=True)
shutil.rmtree(temp_dir)
return temp_video, "SUCCEEDED"
except Exception as e:
return None, f"Failed: {str(e)}"
def start_app():
app = WanAnimateApp()
with gr.Blocks(title="Wan2.2-Animate (Local No API)") as demo:
gr.HTML("""
<div style="padding: 2rem; text-align: center; max-width: 1200px; margin: 0 auto; font-family: Arial, sans-serif;">
<h1 style="font-size: 2.5rem; font-weight: bold; margin-bottom: 0.5rem; color: #333;">
Wan2.2-Animate: Unified Character Animation and Replacement with Holistic Replication
</h1>
<h3 style="font-size: 1.5rem; font-weight: bold; margin-bottom: 0.5rem; color: #333;">
Local version without API (SVD Proxy)
</h3>
<div style="font-size: 1.25rem; margin-bottom: 1.5rem; color: #555;">
Tongyi Lab, Alibaba
</div>
<div style="display: flex; flex-wrap: wrap; justify-content: center; gap: 1rem; margin-bottom: 1.5rem;">
<a href="https://arxiv.org/abs/2509.14055" target="_blank" style="display: inline-flex; align-items: center; padding: 0.5rem 1rem; background-color: #f0f0f0; color: #333; text-decoration: none; border-radius: 9999px; font-weight: 500;">
<span style="margin-right: 0.5rem;">📄</span>Paper
</a>
<a href="https://github.com/Wan-Video/Wan2.2" target="_blank" style="display: inline-flex; align-items: center; padding: 0.5rem 1rem; background-color: #f0f0f0; color: #333; text-decoration: none; border-radius: 9999px; font-weight: 500;">
<span style="margin-right: 0.5rem;">💻</span>GitHub
</a>
<a href="https://huggingface.co/Wan-AI/Wan2.2-Animate-14B" target="_blank" style="display: inline-flex; align-items: center; padding: 0.5rem 1rem; background-color: #f0f0f0; color: #333; text-decoration: none; border-radius: 9999px; font-weight: 500;">
<span style="margin-right: 0.5rem;">🤗</span>HF Model
</a>
</div>
</div>
""")
gr.HTML("""
<details>
<summary>‼️Usage (使用说明)</summary>
Wan-Animate supports two modes:
<ul>
<li>Move Mode: animate the character in input image with movements from the input video</li>
<li>Mix Mode: replace the character in input video with the character in input image</li>
</ul>
Wan-Animate supports two modes:
<ul>
<li>Move Mode: Use the movements extracted from the input video to drive the character in the input image</li>
<li>Mix Mode: Use the character in the input image to replace the character in the input video</li>
</ul>
Currently, the following restrictions apply to inputs:
<ul>
<li>Video file size: Less than 200MB</li>
<li>Video resolution: The shorter side must be greater than 200, and the longer side must be less than 2048</li>
<li>Video duration: 2s to 30s</li>
<li>Video aspect ratio: 1:3 to 3:1</li>
<li>Video formats: mp4, avi, mov</li>
<li>Image file size: Less than 5MB</li>
<li>Image resolution: The shorter side must be greater than 200, and the longer side must be less than 4096</li>
<li>Image formats: jpg, png, jpeg, webp, bmp</li>
</ul>
Current, the inference quality has two variants. You can use our open-source code for more flexible configuration.
<ul>
<li>wan-pro: 25fps, 720p</li>
<li>wan-std: 15fps, 720p</li>
</ul>
</details>
""")
with gr.Row():
with gr.Column():
ref_img = gr.Image(
label="Reference Image(参考图像)",
type="pil", # PIL type to fix FileNotFound
sources=["upload"],
)
video = gr.Video(
label="Template Video(模版视频)",
sources=["upload"],
)
with gr.Row():
model_id = gr.Dropdown(
label="Mode(模式)",
choices=["wan2.2-animate-move", "wan2.2-animate-mix"],
value="wan2.2-animate-move",
info=""
)
model = gr.Dropdown(
label="Inference Quality",
choices=["wan-pro", "wan-std"],
value="wan-pro",
)
run_button = gr.Button("Generate Video(生成视频)")
with gr.Column():
output_video = gr.Video(label="Output Video(输出视频)")
output_status = gr.Textbox(label="Status(状态)")
run_button.click(
fn=app.predict,
inputs=[
ref_img,
video,
model_id,
model,
],
outputs=[output_video, output_status],
)
gr.Examples(
examples=[["examples/mov/1/1.jpeg", "examples/mov/1/1.mp4", "wan2.2-animate-move", "wan-pro"]],
inputs=[ref_img, video, model_id, model],
outputs=[output_video, output_status],
fn=app.predict,
cache_examples=False,
)
demo.queue(default_concurrency_limit=1)
demo.launch(
server_name="0.0.0.0",
server_port=7860
)
if __name__ == "__main__":
start_app()