Spaces:
Paused
Paused
File size: 5,355 Bytes
7a44095 cef6409 03488b7 4e7ff82 cef6409 4e7ff82 27251b0 2292603 cef6409 85ab54a cef6409 85ab54a 27251b0 878b63c 03488b7 878b63c 85ab54a 878b63c 27251b0 878b63c 27251b0 878b63c 27251b0 03488b7 878b63c 85ab54a 03488b7 85ab54a 03488b7 27251b0 03488b7 27251b0 03488b7 27251b0 03488b7 85ab54a 27251b0 cef6409 85ab54a 27251b0 85ab54a 878b63c 85ab54a 878b63c 27251b0 878b63c 27251b0 85ab54a cef6409 878b63c 85ab54a 27251b0 85ab54a 03488b7 27251b0 03488b7 85ab54a 878b63c 85ab54a cef6409 03488b7 85ab54a ba2c9ce 85ab54a 27251b0 85ab54a 27251b0 03488b7 85ab54a 27251b0 85ab54a 03488b7 cef6409 85ab54a 878b63c 85ab54a 27251b0 85ab54a 27251b0 85ab54a 27251b0 85ab54a 27251b0 85ab54a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import spaces
import gc
import tempfile
import random
import gradio as gr
from PIL import Image
# Core Diffusers imports
from diffusers import WanImageToVideoPipeline
from diffusers.utils import export_to_video
# =========================================================
# 1. ARCHITECTURAL UPGRADES (GQA & MoE)
# =========================================================
# The custom classes (GQAAttention, SparseMoEFFN) stay the same as
# they are architectural modifications to the base model's logic.
class GQAAttention(nn.Module):
def __init__(self, original_attn):
super().__init__()
self.num_heads = original_attn.num_heads
self.head_dim = original_attn.head_dim
self.num_kv_heads = max(1, self.num_heads // 4)
self.q_proj = original_attn.q_proj
self.k_proj = original_attn.k_proj
self.v_proj = original_attn.v_proj
self.o_proj = original_attn.o_proj
def forward(self, x, freqs_cis=None):
batch, seq_len, _ = x.shape
q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
k = self.k_proj(x).view(batch, seq_len, self.num_kv_heads, self.head_dim)
v = self.v_proj(x).view(batch, seq_len, self.num_kv_heads, self.head_dim)
k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
attn_output = F.scaled_dot_product_attention(q, k, v)
attn_output = attn_output.transpose(1, 2).reshape(batch, seq_len, -1)
return self.o_proj(attn_output)
class SparseMoEFFN(nn.Module):
def __init__(self, original_ffn):
super().__init__()
in_dim = original_ffn.ffn[0].in_features
self.router = nn.Linear(in_dim, 8)
self.experts = nn.ModuleList([
nn.Sequential(nn.Linear(in_dim, in_dim * 2), nn.SiLU(), nn.Linear(in_dim * 2, in_dim))
for _ in range(8)
])
self.top_k = 2
def forward(self, x):
batch, seq, dim = x.shape
flat_x = x.view(-1, dim)
logits = self.router(flat_x)
weights, selected_experts = torch.topk(logits, self.top_k)
weights = F.softmax(weights, dim=-1)
out = torch.zeros_like(flat_x)
for i, expert in enumerate(self.experts):
mask = (selected_experts == i).any(dim=-1)
if mask.any():
out[mask] += expert(flat_x[mask]) * weights[mask][:, :1]
return out.view(batch, seq, dim)
# =========================================================
# 2. CONFIGURATION & PATCHING
# =========================================================
# CORRECT MODEL ID: Wan 2.1 I2V 14B is the standard for Image-to-Video
MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
HF_TOKEN = os.environ.get("HF_TOKEN")
def patch_model(pipe):
print("🛠️ Patching Transformer with GQA and MoE...")
for i, block in enumerate(pipe.transformer.blocks):
if hasattr(block, 'attn'):
block.attn = GQAAttention(block.attn)
if hasattr(block, 'ffn') and i % 2 == 0:
block.ffn = SparseMoEFFN(block.ffn)
return pipe
# =========================================================
# 3. GENERATION
# =========================================================
@spaces.GPU(duration=600)
def generate_20s_video(image_path, prompt, duration, steps):
if not HF_TOKEN:
raise gr.Error("HF_TOKEN missing. Please set it in your environment variables.")
print(f"⏳ Loading Model: {MODEL_ID}")
pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
token=HF_TOKEN
)
# Apply architecture modifications
pipe = patch_model(pipe)
pipe.enable_model_cpu_offload()
pipe.vae.enable_tiling()
img = Image.open(image_path).convert("RGB")
img = img.resize((832, 480)) # Maintain 16:9 for 480P
# Wan formula for frames: 4n + 1
num_frames = int(duration * 16)
num_frames = ((num_frames - 1) // 4) * 4 + 1
with torch.inference_mode():
output = pipe(
image=img,
prompt=prompt + ", high quality, cinematically consistent",
num_frames=num_frames,
num_inference_steps=steps,
guidance_scale=5.0,
generator=torch.Generator("cuda").manual_seed(42)
)
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
video_path = f.name
export_to_video(output.frames[0], video_path, fps=16)
return video_path
# Gradio Setup
with gr.Blocks() as demo:
gr.Markdown("# 🎬 Optimized Wan 2.1 (GQA + MoE)")
with gr.Row():
with gr.Column():
img = gr.Image(type="filepath", label="Input Image")
txt = gr.Textbox(label="Prompt", value="A futuristic city with flying cars at sunset")
dur = gr.Slider(5, 20, value=20, label="Duration (Seconds)")
stp = gr.Slider(10, 30, value=20, label="Steps")
btn = gr.Button("Generate 20s Video")
with gr.Column():
vid = gr.Video()
btn.click(generate_20s_video, [img, txt, dur, stp], vid)
demo.queue().launch() |