Long-video / app.py
tester343's picture
Update app.py
85ab54a verified
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import spaces
import gc
import tempfile
import random
import gradio as gr
from PIL import Image
# Core Diffusers imports
from diffusers import WanImageToVideoPipeline
from diffusers.utils import export_to_video
# =========================================================
# 1. ARCHITECTURAL UPGRADES (GQA & MoE)
# =========================================================
# The custom classes (GQAAttention, SparseMoEFFN) stay the same as
# they are architectural modifications to the base model's logic.
class GQAAttention(nn.Module):
def __init__(self, original_attn):
super().__init__()
self.num_heads = original_attn.num_heads
self.head_dim = original_attn.head_dim
self.num_kv_heads = max(1, self.num_heads // 4)
self.q_proj = original_attn.q_proj
self.k_proj = original_attn.k_proj
self.v_proj = original_attn.v_proj
self.o_proj = original_attn.o_proj
def forward(self, x, freqs_cis=None):
batch, seq_len, _ = x.shape
q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
k = self.k_proj(x).view(batch, seq_len, self.num_kv_heads, self.head_dim)
v = self.v_proj(x).view(batch, seq_len, self.num_kv_heads, self.head_dim)
k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
attn_output = F.scaled_dot_product_attention(q, k, v)
attn_output = attn_output.transpose(1, 2).reshape(batch, seq_len, -1)
return self.o_proj(attn_output)
class SparseMoEFFN(nn.Module):
def __init__(self, original_ffn):
super().__init__()
in_dim = original_ffn.ffn[0].in_features
self.router = nn.Linear(in_dim, 8)
self.experts = nn.ModuleList([
nn.Sequential(nn.Linear(in_dim, in_dim * 2), nn.SiLU(), nn.Linear(in_dim * 2, in_dim))
for _ in range(8)
])
self.top_k = 2
def forward(self, x):
batch, seq, dim = x.shape
flat_x = x.view(-1, dim)
logits = self.router(flat_x)
weights, selected_experts = torch.topk(logits, self.top_k)
weights = F.softmax(weights, dim=-1)
out = torch.zeros_like(flat_x)
for i, expert in enumerate(self.experts):
mask = (selected_experts == i).any(dim=-1)
if mask.any():
out[mask] += expert(flat_x[mask]) * weights[mask][:, :1]
return out.view(batch, seq, dim)
# =========================================================
# 2. CONFIGURATION & PATCHING
# =========================================================
# CORRECT MODEL ID: Wan 2.1 I2V 14B is the standard for Image-to-Video
MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
HF_TOKEN = os.environ.get("HF_TOKEN")
def patch_model(pipe):
print("🛠️ Patching Transformer with GQA and MoE...")
for i, block in enumerate(pipe.transformer.blocks):
if hasattr(block, 'attn'):
block.attn = GQAAttention(block.attn)
if hasattr(block, 'ffn') and i % 2 == 0:
block.ffn = SparseMoEFFN(block.ffn)
return pipe
# =========================================================
# 3. GENERATION
# =========================================================
@spaces.GPU(duration=600)
def generate_20s_video(image_path, prompt, duration, steps):
if not HF_TOKEN:
raise gr.Error("HF_TOKEN missing. Please set it in your environment variables.")
print(f"⏳ Loading Model: {MODEL_ID}")
pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
token=HF_TOKEN
)
# Apply architecture modifications
pipe = patch_model(pipe)
pipe.enable_model_cpu_offload()
pipe.vae.enable_tiling()
img = Image.open(image_path).convert("RGB")
img = img.resize((832, 480)) # Maintain 16:9 for 480P
# Wan formula for frames: 4n + 1
num_frames = int(duration * 16)
num_frames = ((num_frames - 1) // 4) * 4 + 1
with torch.inference_mode():
output = pipe(
image=img,
prompt=prompt + ", high quality, cinematically consistent",
num_frames=num_frames,
num_inference_steps=steps,
guidance_scale=5.0,
generator=torch.Generator("cuda").manual_seed(42)
)
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
video_path = f.name
export_to_video(output.frames[0], video_path, fps=16)
return video_path
# Gradio Setup
with gr.Blocks() as demo:
gr.Markdown("# 🎬 Optimized Wan 2.1 (GQA + MoE)")
with gr.Row():
with gr.Column():
img = gr.Image(type="filepath", label="Input Image")
txt = gr.Textbox(label="Prompt", value="A futuristic city with flying cars at sunset")
dur = gr.Slider(5, 20, value=20, label="Duration (Seconds)")
stp = gr.Slider(10, 30, value=20, label="Steps")
btn = gr.Button("Generate 20s Video")
with gr.Column():
vid = gr.Video()
btn.click(generate_20s_video, [img, txt, dur, stp], vid)
demo.queue().launch()