import os import torch import torch.nn as nn import torch.nn.functional as F import spaces import gc import tempfile import random import gradio as gr from PIL import Image # Core Diffusers imports from diffusers import WanImageToVideoPipeline from diffusers.utils import export_to_video # ========================================================= # 1. ARCHITECTURAL UPGRADES (GQA & MoE) # ========================================================= # The custom classes (GQAAttention, SparseMoEFFN) stay the same as # they are architectural modifications to the base model's logic. class GQAAttention(nn.Module): def __init__(self, original_attn): super().__init__() self.num_heads = original_attn.num_heads self.head_dim = original_attn.head_dim self.num_kv_heads = max(1, self.num_heads // 4) self.q_proj = original_attn.q_proj self.k_proj = original_attn.k_proj self.v_proj = original_attn.v_proj self.o_proj = original_attn.o_proj def forward(self, x, freqs_cis=None): batch, seq_len, _ = x.shape q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim) k = self.k_proj(x).view(batch, seq_len, self.num_kv_heads, self.head_dim) v = self.v_proj(x).view(batch, seq_len, self.num_kv_heads, self.head_dim) k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2) v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2) q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) attn_output = F.scaled_dot_product_attention(q, k, v) attn_output = attn_output.transpose(1, 2).reshape(batch, seq_len, -1) return self.o_proj(attn_output) class SparseMoEFFN(nn.Module): def __init__(self, original_ffn): super().__init__() in_dim = original_ffn.ffn[0].in_features self.router = nn.Linear(in_dim, 8) self.experts = nn.ModuleList([ nn.Sequential(nn.Linear(in_dim, in_dim * 2), nn.SiLU(), nn.Linear(in_dim * 2, in_dim)) for _ in range(8) ]) self.top_k = 2 def forward(self, x): batch, seq, dim = x.shape flat_x = x.view(-1, dim) logits = self.router(flat_x) weights, selected_experts = torch.topk(logits, self.top_k) weights = F.softmax(weights, dim=-1) out = torch.zeros_like(flat_x) for i, expert in enumerate(self.experts): mask = (selected_experts == i).any(dim=-1) if mask.any(): out[mask] += expert(flat_x[mask]) * weights[mask][:, :1] return out.view(batch, seq, dim) # ========================================================= # 2. CONFIGURATION & PATCHING # ========================================================= # CORRECT MODEL ID: Wan 2.1 I2V 14B is the standard for Image-to-Video MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" HF_TOKEN = os.environ.get("HF_TOKEN") def patch_model(pipe): print("🛠️ Patching Transformer with GQA and MoE...") for i, block in enumerate(pipe.transformer.blocks): if hasattr(block, 'attn'): block.attn = GQAAttention(block.attn) if hasattr(block, 'ffn') and i % 2 == 0: block.ffn = SparseMoEFFN(block.ffn) return pipe # ========================================================= # 3. GENERATION # ========================================================= @spaces.GPU(duration=600) def generate_20s_video(image_path, prompt, duration, steps): if not HF_TOKEN: raise gr.Error("HF_TOKEN missing. Please set it in your environment variables.") print(f"⏳ Loading Model: {MODEL_ID}") pipe = WanImageToVideoPipeline.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, token=HF_TOKEN ) # Apply architecture modifications pipe = patch_model(pipe) pipe.enable_model_cpu_offload() pipe.vae.enable_tiling() img = Image.open(image_path).convert("RGB") img = img.resize((832, 480)) # Maintain 16:9 for 480P # Wan formula for frames: 4n + 1 num_frames = int(duration * 16) num_frames = ((num_frames - 1) // 4) * 4 + 1 with torch.inference_mode(): output = pipe( image=img, prompt=prompt + ", high quality, cinematically consistent", num_frames=num_frames, num_inference_steps=steps, guidance_scale=5.0, generator=torch.Generator("cuda").manual_seed(42) ) with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f: video_path = f.name export_to_video(output.frames[0], video_path, fps=16) return video_path # Gradio Setup with gr.Blocks() as demo: gr.Markdown("# 🎬 Optimized Wan 2.1 (GQA + MoE)") with gr.Row(): with gr.Column(): img = gr.Image(type="filepath", label="Input Image") txt = gr.Textbox(label="Prompt", value="A futuristic city with flying cars at sunset") dur = gr.Slider(5, 20, value=20, label="Duration (Seconds)") stp = gr.Slider(10, 30, value=20, label="Steps") btn = gr.Button("Generate 20s Video") with gr.Column(): vid = gr.Video() btn.click(generate_20s_video, [img, txt, dur, stp], vid) demo.queue().launch()