File size: 5,355 Bytes
7a44095
cef6409
03488b7
 
 
4e7ff82
cef6409
4e7ff82
 
27251b0
 
 
2292603
 
 
cef6409
85ab54a
cef6409
 
85ab54a
 
 
27251b0
878b63c
03488b7
878b63c
 
85ab54a
878b63c
 
 
 
27251b0
 
 
 
878b63c
 
 
 
27251b0
878b63c
 
 
 
 
27251b0
03488b7
878b63c
85ab54a
03488b7
85ab54a
 
03488b7
27251b0
03488b7
 
27251b0
 
 
03488b7
 
27251b0
03488b7
 
 
85ab54a
27251b0
cef6409
 
85ab54a
27251b0
 
85ab54a
 
878b63c
 
85ab54a
 
878b63c
27251b0
 
878b63c
 
 
27251b0
 
85ab54a
cef6409
 
878b63c
85ab54a
 
 
27251b0
85ab54a
03488b7
27251b0
 
 
03488b7
 
85ab54a
 
878b63c
85ab54a
cef6409
03488b7
85ab54a
ba2c9ce
85ab54a
27251b0
85ab54a
27251b0
03488b7
 
 
85ab54a
27251b0
 
 
85ab54a
03488b7
 
 
 
 
 
 
cef6409
85ab54a
878b63c
85ab54a
27251b0
 
85ab54a
 
 
 
 
27251b0
85ab54a
27251b0
85ab54a
27251b0
85ab54a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import spaces
import gc
import tempfile
import random
import gradio as gr
from PIL import Image

# Core Diffusers imports
from diffusers import WanImageToVideoPipeline
from diffusers.utils import export_to_video

# =========================================================
# 1. ARCHITECTURAL UPGRADES (GQA & MoE)
# =========================================================

# The custom classes (GQAAttention, SparseMoEFFN) stay the same as 
# they are architectural modifications to the base model's logic.

class GQAAttention(nn.Module):
    def __init__(self, original_attn):
        super().__init__()
        self.num_heads = original_attn.num_heads
        self.head_dim = original_attn.head_dim
        self.num_kv_heads = max(1, self.num_heads // 4) 
        self.q_proj = original_attn.q_proj
        self.k_proj = original_attn.k_proj
        self.v_proj = original_attn.v_proj
        self.o_proj = original_attn.o_proj

    def forward(self, x, freqs_cis=None):
        batch, seq_len, _ = x.shape
        q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(batch, seq_len, self.num_kv_heads, self.head_dim)
        v = self.v_proj(x).view(batch, seq_len, self.num_kv_heads, self.head_dim)
        k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
        v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
        attn_output = F.scaled_dot_product_attention(q, k, v)
        attn_output = attn_output.transpose(1, 2).reshape(batch, seq_len, -1)
        return self.o_proj(attn_output)

class SparseMoEFFN(nn.Module):
    def __init__(self, original_ffn):
        super().__init__()
        in_dim = original_ffn.ffn[0].in_features
        self.router = nn.Linear(in_dim, 8) 
        self.experts = nn.ModuleList([
            nn.Sequential(nn.Linear(in_dim, in_dim * 2), nn.SiLU(), nn.Linear(in_dim * 2, in_dim))
            for _ in range(8)
        ])
        self.top_k = 2

    def forward(self, x):
        batch, seq, dim = x.shape
        flat_x = x.view(-1, dim)
        logits = self.router(flat_x)
        weights, selected_experts = torch.topk(logits, self.top_k)
        weights = F.softmax(weights, dim=-1)
        out = torch.zeros_like(flat_x)
        for i, expert in enumerate(self.experts):
            mask = (selected_experts == i).any(dim=-1)
            if mask.any():
                out[mask] += expert(flat_x[mask]) * weights[mask][:, :1]
        return out.view(batch, seq, dim)

# =========================================================
# 2. CONFIGURATION & PATCHING
# =========================================================

# CORRECT MODEL ID: Wan 2.1 I2V 14B is the standard for Image-to-Video
MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" 
HF_TOKEN = os.environ.get("HF_TOKEN")

def patch_model(pipe):
    print("🛠️ Patching Transformer with GQA and MoE...")
    for i, block in enumerate(pipe.transformer.blocks):
        if hasattr(block, 'attn'):
            block.attn = GQAAttention(block.attn)
        if hasattr(block, 'ffn') and i % 2 == 0:
            block.ffn = SparseMoEFFN(block.ffn)
    return pipe

# =========================================================
# 3. GENERATION
# =========================================================

@spaces.GPU(duration=600)
def generate_20s_video(image_path, prompt, duration, steps):
    if not HF_TOKEN:
        raise gr.Error("HF_TOKEN missing. Please set it in your environment variables.")

    print(f"⏳ Loading Model: {MODEL_ID}")
    pipe = WanImageToVideoPipeline.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,
        token=HF_TOKEN
    )
    
    # Apply architecture modifications
    pipe = patch_model(pipe)
    pipe.enable_model_cpu_offload() 
    pipe.vae.enable_tiling() 
    
    img = Image.open(image_path).convert("RGB")
    img = img.resize((832, 480)) # Maintain 16:9 for 480P
    
    # Wan formula for frames: 4n + 1
    num_frames = int(duration * 16)
    num_frames = ((num_frames - 1) // 4) * 4 + 1
        
    with torch.inference_mode():
        output = pipe(
            image=img,
            prompt=prompt + ", high quality, cinematically consistent",
            num_frames=num_frames,
            num_inference_steps=steps,
            guidance_scale=5.0,
            generator=torch.Generator("cuda").manual_seed(42)
        )

    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
        video_path = f.name
    
    export_to_video(output.frames[0], video_path, fps=16)
    return video_path

# Gradio Setup
with gr.Blocks() as demo:
    gr.Markdown("# 🎬 Optimized Wan 2.1 (GQA + MoE)")
    with gr.Row():
        with gr.Column():
            img = gr.Image(type="filepath", label="Input Image")
            txt = gr.Textbox(label="Prompt", value="A futuristic city with flying cars at sunset")
            dur = gr.Slider(5, 20, value=20, label="Duration (Seconds)")
            stp = gr.Slider(10, 30, value=20, label="Steps")
            btn = gr.Button("Generate 20s Video")
        with gr.Column():
            vid = gr.Video()

    btn.click(generate_20s_video, [img, txt, dur, stp], vid)

demo.queue().launch()