Wan2.2_Model / app.py
Sushantkas's picture
Update app.py
89d2e74 verified
import spaces
import gradio as gr
import torch
import numpy as np
from diffusers import WanImageToVideoPipeline
from diffusers.utils import export_to_video
model_id = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
print(f"Using video Model: {model_id}")
dtype = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load pipeline
pipe = WanImageToVideoPipeline.from_pretrained(
model_id,
torch_dtype=dtype
)
pipe.to(device)
print(f"Model Loaded in {device}")
pipe.vae.enable_tiling()
# ================================
# Image Preparation
# ================================
def prepare_vertical_image(pipe, image, base_width=384, base_height=672):
mod_value = (
pipe.vae_scale_factor_spatial *
pipe.transformer.config.patch_size[1]
)
final_width = (base_width // mod_value) * mod_value
final_height = (base_height // mod_value) * mod_value
resized_image = image.resize((final_width, final_height))
return resized_image, final_width, final_height
# ================================
# Video Generation
# ================================
@spaces.GPU(size="xlarge",duration=180)
def generate_video(input_image, prompt, negative_prompt, progress=gr.Progress(track_tqdm=True)):
if input_image is None:
return None
image, width, height = prepare_vertical_image(pipe, input_image)
print(f"Generating vertical video {width}x{height}")
video_frames = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=161, # FIXED
guidance_scale=5.0,
num_inference_steps=15
).frames[0]
output_path = "vertical_output.mp4"
export_to_video(video_frames, output_path, fps=16)
return output_path
# ================================
# Gradio UI
# ================================
with gr.Blocks(title="Wan 2.2 Vertical I2V") as demo:
gr.Markdown("# 🎬 Wan 2.2 Image → Video Generator")
gr.Markdown("Generate **10-second Vertical (9:16) AI Videos**")
with gr.Row():
# LEFT SIDE (INPUTS)
with gr.Column(scale=1):
input_image = gr.Image(
type="pil",
label="Upload Image"
)
prompt = gr.Textbox(
label="Prompt",
placeholder="Describe motion, camera movement..."
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="blurry, low quality, distorted, static"
)
generate_btn = gr.Button("Generate Video", variant="primary")
# RIGHT SIDE (OUTPUT)
with gr.Column(scale=1):
output_video = gr.Video(
label="Generated Video"
)
generate_btn.click(
generate_video,
inputs=[input_image, prompt, negative_prompt],
outputs=output_video
)
demo.launch(server_name="0.0.0.0", server_port=7860)