TSXu commited on
Commit
7c94f94
·
1 Parent(s): 6e4c09b

Implement proper AOT compilation with spaces.aoti_capture/compile/apply

Browse files

Exactly following FLUX-Kontext-fp8 pattern:
1. spaces.aoti_capture(model) - capture forward pass during real inference
2. tree_map_only to build dynamic shapes
3. quantize_(model, Float8DynamicActivationFloat8WeightConfig())
4. torch.export.export(model, args, kwargs, dynamic_shapes)
5. spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
6. spaces.aoti_apply(compiled, model)

Separate functions:
- compile_model_first_time(): 300s for AOT compilation (one-time)
- run_generation(): 120s for normal generation

Files changed (1) hide show
  1. app.py +76 -52
app.py CHANGED
@@ -157,22 +157,44 @@ def init_generator():
157
  return generator
158
 
159
 
160
- def optimize_transformer_(gen):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  """
162
- Apply Float8 quantization + AOT compilation using spaces.aoti_capture.
163
- Based on FLUX-Kontext-fp8 pattern.
164
  """
165
  model = gen.model
166
 
167
- @spaces.GPU(duration=300) # 5 minutes for compilation
168
  def compile_transformer():
169
  print("="*50)
170
- print("Starting AOT compilation with spaces.aoti_capture...")
171
  print("="*50)
172
 
173
- # Step 1: Capture model forward during a real generation
174
- print("Capturing model forward pass...")
175
  with spaces.aoti_capture(model) as call:
 
176
  gen.generate(
177
  text="测试",
178
  font_style="楷",
@@ -180,17 +202,20 @@ def optimize_transformer_(gen):
180
  num_steps=1,
181
  seed=42,
182
  )
 
183
 
184
- # Step 2: Build dynamic shapes (we use fixed shapes, so set to None)
 
185
  dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
 
186
 
187
  # Step 3: Apply Float8 quantization
188
- print("Applying Float8 quantization...")
189
  quantize_(model, Float8DynamicActivationFloat8WeightConfig())
190
  print("✓ Float8 quantization complete!")
191
 
192
- # Step 4: Export model
193
- print("Exporting model with torch.export...")
194
  exported = torch.export.export(
195
  mod=model,
196
  args=call.args,
@@ -199,61 +224,62 @@ def optimize_transformer_(gen):
199
  )
200
  print("✓ Model exported!")
201
 
202
- # Step 5: AOT compile
203
- print("AOT compiling with spaces.aoti_compile...")
204
  print(f" Inductor configs: {INDUCTOR_CONFIGS}")
205
- return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
 
 
 
206
 
207
  # Run compilation and apply the result
208
- print("="*50)
209
- print("Running AOT compilation (this takes 3-5 minutes)...")
210
- print("="*50)
211
  spaces.aoti_apply(compile_transformer(), model)
212
  print("="*50)
213
- print("✓ AOT compilation complete! Model is now optimized.")
214
  print("="*50)
215
 
216
 
217
- def update_font_choices(author: str):
 
218
  """
219
- Update available font choices based on selected author
 
220
  """
221
- if author == "None (Synthetic / 合成风格)" or author not in AUTHOR_FONTS:
222
- choices = list(FONT_STYLE_NAMES.values())
223
- else:
224
- available_fonts = AUTHOR_FONTS[author]
225
- choices = [FONT_STYLE_NAMES[font] for font in available_fonts if font in FONT_STYLE_NAMES]
226
 
227
- return gr.Dropdown(choices=choices, value=choices[0] if choices else None)
228
-
229
-
230
- def parse_font_style(font_style: str) -> str:
231
- """Extract font key from display name"""
232
- for font_key, font_display in FONT_STYLE_NAMES.items():
233
- if font_display == font_style:
234
- return font_key
235
- return None
 
 
 
 
 
 
 
236
 
237
 
238
  @spaces.GPU(duration=120) # 2 minutes for normal generation
239
  def run_generation(text, font, author, num_steps, start_seed, num_images):
240
  """
241
- Run generation with the optimized model.
242
- Duration: 20s base + ~4s per step per image.
243
  """
244
- gen = init_generator()
245
 
246
  results = []
247
  seeds_used = []
248
-
249
  for i in range(num_images):
250
  current_seed = start_seed + i
251
  result_img, cond_img = gen.generate(
252
- text=text,
253
- font_style=font,
254
- author=author,
255
- num_steps=num_steps,
256
- seed=current_seed,
257
  )
258
  results.append((result_img, f"Seed: {current_seed}"))
259
  seeds_used.append(current_seed)
@@ -271,8 +297,7 @@ def interactive_session(
271
  progress=gr.Progress()
272
  ):
273
  """
274
- Interactive session with separate compilation and generation phases.
275
-
276
  - First time: 5 min for AOT compilation (one-time)
277
  - After that: 2 min for generation
278
  """
@@ -292,17 +317,16 @@ def interactive_session(
292
  # Determine author
293
  author = author_dropdown if author_dropdown != "None (Synthetic / 合成风格)" else None
294
 
295
- # Step 1: AOT Compile if not done yet (5 min, one-time only)
296
  if not _is_optimized:
297
- yield "⏳ 首次运行,需要编译优化模型(约3-5分钟,仅此一次)...", []
298
- gen = init_generator()
299
- optimize_transformer_(gen) # This uses @spaces.GPU(duration=300) internally
300
- _is_optimized = True
301
- yield "✅ 模型编译完成!后续生成将会很快。", []
302
 
303
  # Step 2: Run generation (2 min)
304
  yield f"🎨 开始生成 {num_images} 张图片...", []
305
- progress(0.1, desc="生成中...")
306
 
307
  results, seeds_used = run_generation(
308
  text, font, author, num_steps, start_seed, num_images
 
157
  return generator
158
 
159
 
160
+ def update_font_choices(author: str):
161
+ """
162
+ Update available font choices based on selected author
163
+ """
164
+ if author == "None (Synthetic / 合成风格)" or author not in AUTHOR_FONTS:
165
+ choices = list(FONT_STYLE_NAMES.values())
166
+ else:
167
+ available_fonts = AUTHOR_FONTS[author]
168
+ choices = [FONT_STYLE_NAMES[font] for font in available_fonts if font in FONT_STYLE_NAMES]
169
+
170
+ return gr.Dropdown(choices=choices, value=choices[0] if choices else None)
171
+
172
+
173
+ def parse_font_style(font_style: str) -> str:
174
+ """Extract font key from display name"""
175
+ for font_key, font_display in FONT_STYLE_NAMES.items():
176
+ if font_display == font_style:
177
+ return font_key
178
+ return None
179
+
180
+
181
+ def aot_compile_transformer(gen):
182
  """
183
+ AOT compile the transformer using spaces.aoti_capture/compile/apply.
184
+ Exactly following FLUX-Kontext-fp8 pattern.
185
  """
186
  model = gen.model
187
 
188
+ @spaces.GPU(duration=300) # 5 minutes for AOT compilation
189
  def compile_transformer():
190
  print("="*50)
191
+ print("Starting AOT compilation (FLUX-Kontext-fp8 pattern)...")
192
  print("="*50)
193
 
194
+ # Step 1: Capture model forward during a real inference
195
+ print("Step 1: Capturing model forward pass with spaces.aoti_capture...")
196
  with spaces.aoti_capture(model) as call:
197
+ # Run a sample generation to capture the forward call
198
  gen.generate(
199
  text="测试",
200
  font_style="楷",
 
202
  num_steps=1,
203
  seed=42,
204
  )
205
+ print("✓ Forward pass captured!")
206
 
207
+ # Step 2: Build dynamic shapes (None = fixed shapes)
208
+ print("Step 2: Building dynamic shapes...")
209
  dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
210
+ print("✓ Dynamic shapes built!")
211
 
212
  # Step 3: Apply Float8 quantization
213
+ print("Step 3: Applying Float8 quantization...")
214
  quantize_(model, Float8DynamicActivationFloat8WeightConfig())
215
  print("✓ Float8 quantization complete!")
216
 
217
+ # Step 4: Export model with torch.export
218
+ print("Step 4: Exporting model with torch.export...")
219
  exported = torch.export.export(
220
  mod=model,
221
  args=call.args,
 
224
  )
225
  print("✓ Model exported!")
226
 
227
+ # Step 5: AOT compile with spaces.aoti_compile
228
+ print("Step 5: AOT compiling with spaces.aoti_compile...")
229
  print(f" Inductor configs: {INDUCTOR_CONFIGS}")
230
+ compiled = spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
231
+ print("✓ AOT compilation complete!")
232
+
233
+ return compiled
234
 
235
  # Run compilation and apply the result
236
+ print("Running AOT compilation...")
 
 
237
  spaces.aoti_apply(compile_transformer(), model)
238
  print("="*50)
239
+ print("✓ AOT compiled model applied!")
240
  print("="*50)
241
 
242
 
243
+ @spaces.GPU(duration=300) # 5 minutes for first-time AOT compilation
244
+ def compile_model_first_time():
245
  """
246
+ First-time: Load model and run AOT compilation.
247
+ Returns the optimized generator.
248
  """
249
+ global _is_optimized, generator
 
 
 
 
250
 
251
+ print("="*50)
252
+ print("First-time run: Loading model and running AOT compilation...")
253
+ print("="*50)
254
+
255
+ # Load model
256
+ gen = init_generator()
257
+
258
+ # AOT compile the transformer
259
+ aot_compile_transformer(gen)
260
+
261
+ _is_optimized = True
262
+ print("="*50)
263
+ print("✓ Model loaded and AOT compiled!")
264
+ print("="*50)
265
+
266
+ return gen
267
 
268
 
269
  @spaces.GPU(duration=120) # 2 minutes for normal generation
270
  def run_generation(text, font, author, num_steps, start_seed, num_images):
271
  """
272
+ Run generation with the AOT-compiled model.
 
273
  """
274
+ gen = init_generator() # Returns the already-compiled generator
275
 
276
  results = []
277
  seeds_used = []
 
278
  for i in range(num_images):
279
  current_seed = start_seed + i
280
  result_img, cond_img = gen.generate(
281
+ text=text, font_style=font, author=author,
282
+ num_steps=num_steps, seed=current_seed,
 
 
 
283
  )
284
  results.append((result_img, f"Seed: {current_seed}"))
285
  seeds_used.append(current_seed)
 
297
  progress=gr.Progress()
298
  ):
299
  """
300
+ Interactive session:
 
301
  - First time: 5 min for AOT compilation (one-time)
302
  - After that: 2 min for generation
303
  """
 
317
  # Determine author
318
  author = author_dropdown if author_dropdown != "None (Synthetic / 合成风格)" else None
319
 
320
+ # Step 1: AOT compile if not done yet (5 min, one-time)
321
  if not _is_optimized:
322
+ yield "⏳ 首次运行,AOT编译优化模型(约3-5分钟,仅此一次)...", []
323
+ progress(0.1, desc="AOT编译中...")
324
+ compile_model_first_time()
325
+ yield "✅ AOT编译完成!", []
 
326
 
327
  # Step 2: Run generation (2 min)
328
  yield f"🎨 开始生成 {num_images} 张图片...", []
329
+ progress(0.5, desc="生成中...")
330
 
331
  results, seeds_used = run_generation(
332
  text, font, author, num_steps, start_seed, num_images