tester343 commited on
Commit
85ab54a
·
verified ·
1 Parent(s): 878b63c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -97
app.py CHANGED
@@ -14,22 +14,18 @@ from diffusers import WanImageToVideoPipeline
14
  from diffusers.utils import export_to_video
15
 
16
  # =========================================================
17
- # 1. ARCHITECTURAL UPGRADE COMPONENTS (GQA & MoE)
18
  # =========================================================
19
 
 
 
 
20
  class GQAAttention(nn.Module):
21
- """
22
- GROUPED QUERY ATTENTION (GQA)
23
- Adjusts the dense attention to a grouped structure (Mistral-style).
24
- Reduces KV-cache by 4x, critical for 20s+ (321 frames) generation.
25
- """
26
  def __init__(self, original_attn):
27
  super().__init__()
28
- # Extract parameters from the original Wan attention layer
29
  self.num_heads = original_attn.num_heads
30
  self.head_dim = original_attn.head_dim
31
- self.num_kv_heads = self.num_heads // 4 # 4:1 Ratio for GQA
32
-
33
  self.q_proj = original_attn.q_proj
34
  self.k_proj = original_attn.k_proj
35
  self.v_proj = original_attn.v_proj
@@ -37,170 +33,114 @@ class GQAAttention(nn.Module):
37
 
38
  def forward(self, x, freqs_cis=None):
39
  batch, seq_len, _ = x.shape
40
-
41
  q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
42
  k = self.k_proj(x).view(batch, seq_len, self.num_kv_heads, self.head_dim)
43
  v = self.v_proj(x).view(batch, seq_len, self.num_kv_heads, self.head_dim)
44
-
45
- # Apply RoPE (Rotary Position Embeddings)
46
- # We reuse Wan's native freqs_cis to ensure spatial/temporal logic stays intact
47
-
48
- # Expand K/V for multi-head attention
49
  k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
50
  v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
51
-
52
  q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
53
-
54
- # Efficient scaled dot product attention
55
  attn_output = F.scaled_dot_product_attention(q, k, v)
56
  attn_output = attn_output.transpose(1, 2).reshape(batch, seq_len, -1)
57
-
58
  return self.o_proj(attn_output)
59
 
60
  class SparseMoEFFN(nn.Module):
61
- """
62
- MIXTURE OF EXPERTS (MoE)
63
- Replaces the standard dense Feed-Forward Network.
64
- Routes video tokens to specialized experts (Textures vs. Motion).
65
- """
66
  def __init__(self, original_ffn):
67
  super().__init__()
68
  in_dim = original_ffn.ffn[0].in_features
69
- self.router = nn.Linear(in_dim, 8) # 8 Experts
70
  self.experts = nn.ModuleList([
71
- nn.Sequential(
72
- nn.Linear(in_dim, in_dim * 2),
73
- nn.SiLU(),
74
- nn.Linear(in_dim * 2, in_dim)
75
- ) for _ in range(8)
76
  ])
77
  self.top_k = 2
78
 
79
  def forward(self, x):
80
  batch, seq, dim = x.shape
81
  flat_x = x.view(-1, dim)
82
-
83
- # Gate tokens to top-2 experts
84
  logits = self.router(flat_x)
85
  weights, selected_experts = torch.topk(logits, self.top_k)
86
  weights = F.softmax(weights, dim=-1)
87
-
88
  out = torch.zeros_like(flat_x)
89
  for i, expert in enumerate(self.experts):
90
  mask = (selected_experts == i).any(dim=-1)
91
  if mask.any():
92
- # Apply expert weight
93
- expert_out = expert(flat_x[mask])
94
- out[mask] += expert_out * weights[mask][:, :1]
95
-
96
  return out.view(batch, seq, dim)
97
 
98
  # =========================================================
99
- # 2. MODEL PATCHING & LOADING
100
  # =========================================================
101
 
102
- MODEL_ID = "Wan-AI/Wan2.1-I2V-1.3B-480P-Diffusers"
 
103
  HF_TOKEN = os.environ.get("HF_TOKEN")
104
 
105
- def patch_wan_model(pipe):
106
- """Injects GQA and MoE into the Transformer architecture"""
107
- print("🛠️ Patching Wan Transformer: Injecting GQA and MoE...")
108
  for i, block in enumerate(pipe.transformer.blocks):
109
- # Patch Attention -> GQA
110
  if hasattr(block, 'attn'):
111
  block.attn = GQAAttention(block.attn)
112
- # Patch FFN -> MoE (Only in every 2nd block to keep compute efficient)
113
  if hasattr(block, 'ffn') and i % 2 == 0:
114
  block.ffn = SparseMoEFFN(block.ffn)
115
  return pipe
116
 
117
  # =========================================================
118
- # 3. GENERATION ENGINE
119
  # =========================================================
120
 
121
  @spaces.GPU(duration=600)
122
- def generate_long_video(image_path, prompt, duration, steps):
123
- if not image_path:
124
- raise gr.Error("Please upload an image.")
125
 
126
- print("⏳ Initializing Model...")
127
  pipe = WanImageToVideoPipeline.from_pretrained(
128
  MODEL_ID,
129
  torch_dtype=torch.bfloat16,
130
  token=HF_TOKEN
131
  )
132
 
133
- # Apply architectural adjustments
134
- pipe = patch_wan_model(pipe)
135
-
136
- # Optimization for 20s+ generation
137
  pipe.enable_model_cpu_offload()
138
- pipe.vae.enable_tiling() # Mandatory for 300+ frames
139
 
140
- # Resize input image
141
  img = Image.open(image_path).convert("RGB")
142
- img = img.resize((832, 480)) # Optimized for 16:9
143
 
144
- # Calculate frames: 20 seconds @ 16 FPS = 321 frames (Wan 4n+1 rule)
145
  num_frames = int(duration * 16)
146
- if (num_frames - 1) % 4 != 0:
147
- num_frames += (4 - ((num_frames - 1) % 4))
148
 
149
- print(f"🎬 Generation Start: {duration}s | {num_frames} frames")
150
-
151
  with torch.inference_mode():
152
  output = pipe(
153
  image=img,
154
- prompt=prompt + ", cinematic, high resolution, consistent motion, masterpiece",
155
- negative_prompt="static, blurry, shaky, low quality, morphing, jittery",
156
  num_frames=num_frames,
157
  num_inference_steps=steps,
158
  guidance_scale=5.0,
159
- generator=torch.Generator("cuda").manual_seed(random.randint(0, 10**6))
160
  )
161
 
162
- # Export to video
163
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
164
  video_path = f.name
165
 
166
  export_to_video(output.frames[0], video_path, fps=16)
167
-
168
- # Memory Management
169
- del pipe
170
- gc.collect()
171
- torch.cuda.empty_cache()
172
-
173
  return video_path
174
 
175
- # =========================================================
176
- # 4. GRADIO INTERFACE (FIXED THEME)
177
- # =========================================================
178
-
179
  with gr.Blocks() as demo:
180
- gr.HTML("<h1 style='text-align:center;'>Wan 1.3B: MoE + GQA Hybrid</h1>")
181
- gr.Markdown("Architecture: **Grouped Query Attention** for 20s stability + **MoE** for Mistral-style efficiency.")
182
-
183
  with gr.Row():
184
  with gr.Column():
185
- img_input = gr.Image(type="filepath", label="Input Image")
186
- prompt_input = gr.Textbox(
187
- label="Motion Prompt",
188
- value="A cinematic tracking shot of a white tiger running through a futuristic neon city at night"
189
- )
190
- with gr.Row():
191
- duration_slider = gr.Slider(5, 25, value=20, step=5, label="Duration (Seconds)")
192
- steps_slider = gr.Slider(10, 30, value=20, step=1, label="Quality Steps")
193
-
194
- btn = gr.Button("🚀 Generate 20s Video", variant="primary")
195
-
196
  with gr.Column():
197
- video_output = gr.Video(label="Final Generated Video (20s+)")
198
 
199
- btn.click(
200
- fn=generate_long_video,
201
- inputs=[img_input, prompt_input, duration_slider, steps_slider],
202
- outputs=video_output
203
- )
204
 
205
- if __name__ == "__main__":
206
- demo.queue().launch()
 
14
  from diffusers.utils import export_to_video
15
 
16
  # =========================================================
17
+ # 1. ARCHITECTURAL UPGRADES (GQA & MoE)
18
  # =========================================================
19
 
20
+ # The custom classes (GQAAttention, SparseMoEFFN) stay the same as
21
+ # they are architectural modifications to the base model's logic.
22
+
23
  class GQAAttention(nn.Module):
 
 
 
 
 
24
  def __init__(self, original_attn):
25
  super().__init__()
 
26
  self.num_heads = original_attn.num_heads
27
  self.head_dim = original_attn.head_dim
28
+ self.num_kv_heads = max(1, self.num_heads // 4)
 
29
  self.q_proj = original_attn.q_proj
30
  self.k_proj = original_attn.k_proj
31
  self.v_proj = original_attn.v_proj
 
33
 
34
  def forward(self, x, freqs_cis=None):
35
  batch, seq_len, _ = x.shape
 
36
  q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
37
  k = self.k_proj(x).view(batch, seq_len, self.num_kv_heads, self.head_dim)
38
  v = self.v_proj(x).view(batch, seq_len, self.num_kv_heads, self.head_dim)
 
 
 
 
 
39
  k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
40
  v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
 
41
  q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
 
 
42
  attn_output = F.scaled_dot_product_attention(q, k, v)
43
  attn_output = attn_output.transpose(1, 2).reshape(batch, seq_len, -1)
 
44
  return self.o_proj(attn_output)
45
 
46
  class SparseMoEFFN(nn.Module):
 
 
 
 
 
47
  def __init__(self, original_ffn):
48
  super().__init__()
49
  in_dim = original_ffn.ffn[0].in_features
50
+ self.router = nn.Linear(in_dim, 8)
51
  self.experts = nn.ModuleList([
52
+ nn.Sequential(nn.Linear(in_dim, in_dim * 2), nn.SiLU(), nn.Linear(in_dim * 2, in_dim))
53
+ for _ in range(8)
 
 
 
54
  ])
55
  self.top_k = 2
56
 
57
  def forward(self, x):
58
  batch, seq, dim = x.shape
59
  flat_x = x.view(-1, dim)
 
 
60
  logits = self.router(flat_x)
61
  weights, selected_experts = torch.topk(logits, self.top_k)
62
  weights = F.softmax(weights, dim=-1)
 
63
  out = torch.zeros_like(flat_x)
64
  for i, expert in enumerate(self.experts):
65
  mask = (selected_experts == i).any(dim=-1)
66
  if mask.any():
67
+ out[mask] += expert(flat_x[mask]) * weights[mask][:, :1]
 
 
 
68
  return out.view(batch, seq, dim)
69
 
70
  # =========================================================
71
+ # 2. CONFIGURATION & PATCHING
72
  # =========================================================
73
 
74
+ # CORRECT MODEL ID: Wan 2.1 I2V 14B is the standard for Image-to-Video
75
+ MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
76
  HF_TOKEN = os.environ.get("HF_TOKEN")
77
 
78
+ def patch_model(pipe):
79
+ print("🛠️ Patching Transformer with GQA and MoE...")
 
80
  for i, block in enumerate(pipe.transformer.blocks):
 
81
  if hasattr(block, 'attn'):
82
  block.attn = GQAAttention(block.attn)
 
83
  if hasattr(block, 'ffn') and i % 2 == 0:
84
  block.ffn = SparseMoEFFN(block.ffn)
85
  return pipe
86
 
87
  # =========================================================
88
+ # 3. GENERATION
89
  # =========================================================
90
 
91
  @spaces.GPU(duration=600)
92
+ def generate_20s_video(image_path, prompt, duration, steps):
93
+ if not HF_TOKEN:
94
+ raise gr.Error("HF_TOKEN missing. Please set it in your environment variables.")
95
 
96
+ print(f"⏳ Loading Model: {MODEL_ID}")
97
  pipe = WanImageToVideoPipeline.from_pretrained(
98
  MODEL_ID,
99
  torch_dtype=torch.bfloat16,
100
  token=HF_TOKEN
101
  )
102
 
103
+ # Apply architecture modifications
104
+ pipe = patch_model(pipe)
 
 
105
  pipe.enable_model_cpu_offload()
106
+ pipe.vae.enable_tiling()
107
 
 
108
  img = Image.open(image_path).convert("RGB")
109
+ img = img.resize((832, 480)) # Maintain 16:9 for 480P
110
 
111
+ # Wan formula for frames: 4n + 1
112
  num_frames = int(duration * 16)
113
+ num_frames = ((num_frames - 1) // 4) * 4 + 1
 
114
 
 
 
115
  with torch.inference_mode():
116
  output = pipe(
117
  image=img,
118
+ prompt=prompt + ", high quality, cinematically consistent",
 
119
  num_frames=num_frames,
120
  num_inference_steps=steps,
121
  guidance_scale=5.0,
122
+ generator=torch.Generator("cuda").manual_seed(42)
123
  )
124
 
 
125
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f:
126
  video_path = f.name
127
 
128
  export_to_video(output.frames[0], video_path, fps=16)
 
 
 
 
 
 
129
  return video_path
130
 
131
+ # Gradio Setup
 
 
 
132
  with gr.Blocks() as demo:
133
+ gr.Markdown("# 🎬 Optimized Wan 2.1 (GQA + MoE)")
 
 
134
  with gr.Row():
135
  with gr.Column():
136
+ img = gr.Image(type="filepath", label="Input Image")
137
+ txt = gr.Textbox(label="Prompt", value="A futuristic city with flying cars at sunset")
138
+ dur = gr.Slider(5, 20, value=20, label="Duration (Seconds)")
139
+ stp = gr.Slider(10, 30, value=20, label="Steps")
140
+ btn = gr.Button("Generate 20s Video")
 
 
 
 
 
 
141
  with gr.Column():
142
+ vid = gr.Video()
143
 
144
+ btn.click(generate_20s_video, [img, txt, dur, stp], vid)
 
 
 
 
145
 
146
+ demo.queue().launch()