Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import spaces | |
| import gc | |
| import tempfile | |
| import random | |
| import gradio as gr | |
| from PIL import Image | |
| # Core Diffusers imports | |
| from diffusers import WanImageToVideoPipeline | |
| from diffusers.utils import export_to_video | |
| # ========================================================= | |
| # 1. ARCHITECTURAL UPGRADES (GQA & MoE) | |
| # ========================================================= | |
| # The custom classes (GQAAttention, SparseMoEFFN) stay the same as | |
| # they are architectural modifications to the base model's logic. | |
| class GQAAttention(nn.Module): | |
| def __init__(self, original_attn): | |
| super().__init__() | |
| self.num_heads = original_attn.num_heads | |
| self.head_dim = original_attn.head_dim | |
| self.num_kv_heads = max(1, self.num_heads // 4) | |
| self.q_proj = original_attn.q_proj | |
| self.k_proj = original_attn.k_proj | |
| self.v_proj = original_attn.v_proj | |
| self.o_proj = original_attn.o_proj | |
| def forward(self, x, freqs_cis=None): | |
| batch, seq_len, _ = x.shape | |
| q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim) | |
| k = self.k_proj(x).view(batch, seq_len, self.num_kv_heads, self.head_dim) | |
| v = self.v_proj(x).view(batch, seq_len, self.num_kv_heads, self.head_dim) | |
| k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2) | |
| v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2) | |
| q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) | |
| attn_output = F.scaled_dot_product_attention(q, k, v) | |
| attn_output = attn_output.transpose(1, 2).reshape(batch, seq_len, -1) | |
| return self.o_proj(attn_output) | |
| class SparseMoEFFN(nn.Module): | |
| def __init__(self, original_ffn): | |
| super().__init__() | |
| in_dim = original_ffn.ffn[0].in_features | |
| self.router = nn.Linear(in_dim, 8) | |
| self.experts = nn.ModuleList([ | |
| nn.Sequential(nn.Linear(in_dim, in_dim * 2), nn.SiLU(), nn.Linear(in_dim * 2, in_dim)) | |
| for _ in range(8) | |
| ]) | |
| self.top_k = 2 | |
| def forward(self, x): | |
| batch, seq, dim = x.shape | |
| flat_x = x.view(-1, dim) | |
| logits = self.router(flat_x) | |
| weights, selected_experts = torch.topk(logits, self.top_k) | |
| weights = F.softmax(weights, dim=-1) | |
| out = torch.zeros_like(flat_x) | |
| for i, expert in enumerate(self.experts): | |
| mask = (selected_experts == i).any(dim=-1) | |
| if mask.any(): | |
| out[mask] += expert(flat_x[mask]) * weights[mask][:, :1] | |
| return out.view(batch, seq, dim) | |
| # ========================================================= | |
| # 2. CONFIGURATION & PATCHING | |
| # ========================================================= | |
| # CORRECT MODEL ID: Wan 2.1 I2V 14B is the standard for Image-to-Video | |
| MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| def patch_model(pipe): | |
| print("🛠️ Patching Transformer with GQA and MoE...") | |
| for i, block in enumerate(pipe.transformer.blocks): | |
| if hasattr(block, 'attn'): | |
| block.attn = GQAAttention(block.attn) | |
| if hasattr(block, 'ffn') and i % 2 == 0: | |
| block.ffn = SparseMoEFFN(block.ffn) | |
| return pipe | |
| # ========================================================= | |
| # 3. GENERATION | |
| # ========================================================= | |
| def generate_20s_video(image_path, prompt, duration, steps): | |
| if not HF_TOKEN: | |
| raise gr.Error("HF_TOKEN missing. Please set it in your environment variables.") | |
| print(f"⏳ Loading Model: {MODEL_ID}") | |
| pipe = WanImageToVideoPipeline.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| token=HF_TOKEN | |
| ) | |
| # Apply architecture modifications | |
| pipe = patch_model(pipe) | |
| pipe.enable_model_cpu_offload() | |
| pipe.vae.enable_tiling() | |
| img = Image.open(image_path).convert("RGB") | |
| img = img.resize((832, 480)) # Maintain 16:9 for 480P | |
| # Wan formula for frames: 4n + 1 | |
| num_frames = int(duration * 16) | |
| num_frames = ((num_frames - 1) // 4) * 4 + 1 | |
| with torch.inference_mode(): | |
| output = pipe( | |
| image=img, | |
| prompt=prompt + ", high quality, cinematically consistent", | |
| num_frames=num_frames, | |
| num_inference_steps=steps, | |
| guidance_scale=5.0, | |
| generator=torch.Generator("cuda").manual_seed(42) | |
| ) | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f: | |
| video_path = f.name | |
| export_to_video(output.frames[0], video_path, fps=16) | |
| return video_path | |
| # Gradio Setup | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🎬 Optimized Wan 2.1 (GQA + MoE)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| img = gr.Image(type="filepath", label="Input Image") | |
| txt = gr.Textbox(label="Prompt", value="A futuristic city with flying cars at sunset") | |
| dur = gr.Slider(5, 20, value=20, label="Duration (Seconds)") | |
| stp = gr.Slider(10, 30, value=20, label="Steps") | |
| btn = gr.Button("Generate 20s Video") | |
| with gr.Column(): | |
| vid = gr.Video() | |
| btn.click(generate_20s_video, [img, txt, dur, stp], vid) | |
| demo.queue().launch() |