TSXu commited on
Commit
6e4c09b
·
1 Parent(s): 96d51ed

Use proper spaces.aoti_capture + aoti_compile + aoti_apply for AOT compilation

Browse files

- Use spaces.aoti_capture to capture real forward pass inputs
- Use spaces.aoti_compile with INDUCTOR_CONFIGS (including triton.cudagraphs)
- Use spaces.aoti_apply to apply compiled model
- Separate compilation (5 min one-time) and generation (2 min)
- First run: AOT compilation takes 3-5 minutes
- Subsequent runs: only 2 minutes for generation

This follows FLUX-Kontext-fp8 pattern exactly.

Files changed (1) hide show
  1. app.py +79 -151
app.py CHANGED
@@ -101,14 +101,9 @@ print("="*50)
101
  # ============================================================
102
  # AOT Optimization Configuration (from FLUX-Kontext-fp8)
103
  # ============================================================
 
104
  from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig
105
 
106
- # Fixed input dimensions for this model
107
- # width=128, height=896 (7 chars × 128)
108
- # After VAE packing: h=448, w=64, img_seq_len = 448*64 = 28672
109
- IMG_SEQ_LEN = 28672 # Fixed: (896/2) * (128/2)
110
- TXT_SEQ_LEN = 512 # Fixed: T5 max_length
111
-
112
  # Inductor configuration for optimal performance
113
  INDUCTOR_CONFIGS = {
114
  'conv_1x1_as_mm': True,
@@ -120,102 +115,9 @@ INDUCTOR_CONFIGS = {
120
  }
121
 
122
 
123
- def create_sample_inputs(device="cuda", dtype=torch.float32):
124
- """
125
- Create sample inputs with fixed dimensions for torch.export.
126
- """
127
- batch_size = 1
128
- hidden_size = 3072 # Flux hidden size
129
- vec_dim = 768 # CLIP vec dim
130
- cond_txt_dim = 896 # Condition text embedding dim
131
-
132
- sample_inputs = {
133
- 'img': torch.randn(batch_size, IMG_SEQ_LEN, 64, device=device, dtype=dtype), # 64 = in_channels
134
- 'img_ids': torch.zeros(batch_size, IMG_SEQ_LEN, 3, device=device, dtype=dtype),
135
- 'txt': torch.randn(batch_size, TXT_SEQ_LEN, 4096, device=device, dtype=dtype), # 4096 = T5 dim
136
- 'txt_ids': torch.zeros(batch_size, TXT_SEQ_LEN, 3, device=device, dtype=dtype),
137
- 'y': torch.randn(batch_size, vec_dim, device=device, dtype=dtype),
138
- 'timesteps': torch.tensor([0.5], device=device, dtype=dtype),
139
- 'timesteps2': torch.tensor([0.5], device=device, dtype=dtype),
140
- 'cond_txt_latent': torch.randn(batch_size, 5, cond_txt_dim, device=device, dtype=dtype), # 5 cond tokens
141
- 'guidance': torch.tensor([3.5], device=device, dtype=dtype),
142
- }
143
- return sample_inputs
144
-
145
-
146
- def apply_aot_optimization(model, device="cuda"):
147
- """
148
- Apply Float8 quantization and AOT compilation with torch.export.
149
- Based on FLUX-Kontext-fp8 optimization pattern.
150
- """
151
- import torch._inductor.config as inductor_config
152
-
153
- # Apply inductor configurations
154
- for key, value in INDUCTOR_CONFIGS.items():
155
- if hasattr(inductor_config, key):
156
- setattr(inductor_config, key, value)
157
-
158
- print("="*50)
159
- print("Starting AOT optimization with fixed input shapes...")
160
- print(f" img_seq_len: {IMG_SEQ_LEN}")
161
- print(f" txt_seq_len: {TXT_SEQ_LEN}")
162
- print("="*50)
163
-
164
- # Step 1: Apply Float8 quantization
165
- print("Applying Float8 quantization...")
166
- quantize_(model, Float8DynamicActivationFloat8WeightConfig())
167
- print("✓ Float8 quantization complete!")
168
-
169
- # Step 2: Create sample inputs for export
170
- print("Creating sample inputs for torch.export...")
171
- sample_inputs = create_sample_inputs(device=device, dtype=torch.float32)
172
-
173
- # Step 3: Export model with fixed shapes (no dynamic dims needed)
174
- print("Exporting model with torch.export (fixed shapes)...")
175
- try:
176
- exported = torch.export.export(
177
- model,
178
- args=(),
179
- kwargs=sample_inputs,
180
- strict=False, # Allow some graph breaks if needed
181
- )
182
- print("✓ Model exported!")
183
-
184
- # Step 4: AOT compile with inductor
185
- print("AOT compiling with torch._inductor.aot_compile...")
186
- compiled_path = torch._inductor.aot_compile(
187
- exported.module(),
188
- args=(),
189
- kwargs=sample_inputs,
190
- options=INDUCTOR_CONFIGS,
191
- )
192
- print(f"✓ AOT compiled to: {compiled_path}")
193
-
194
- # Step 5: Load the compiled model
195
- print("Loading AOT compiled model...")
196
- compiled_model = torch._export.aot_load(compiled_path, device=device)
197
- print("✓ AOT model loaded!")
198
-
199
- return compiled_model
200
-
201
- except Exception as e:
202
- print(f"AOT compilation failed: {e}")
203
- print("Falling back to torch.compile (JIT)...")
204
-
205
- # Fallback to JIT compilation
206
- compiled_model = torch.compile(
207
- model,
208
- mode="max-autotune",
209
- backend="inductor",
210
- fullgraph=False,
211
- )
212
- print("✓ torch.compile (JIT) applied!")
213
- return compiled_model
214
-
215
-
216
  def init_generator():
217
- """Initialize the generator with Float8 + AOT compilation"""
218
- global generator, _cached_model_dir, _is_optimized
219
 
220
  if generator is None:
221
  # Enable CUDA optimizations
@@ -247,21 +149,71 @@ def init_generator():
247
  author_descriptions_path='dataset/calligraphy_styles_en.json',
248
  use_deepspeed=False,
249
  use_4bit_quantization=False,
250
- use_float8_quantization=False, # Apply via AOT below
251
- use_torch_compile=False, # Apply via AOT below
252
  dtype="fp32",
253
  )
254
-
255
- # Apply Float8 quantization + AOT compilation (fixed input shapes)
256
- if not _is_optimized:
257
- print("Applying Float8 + AOT optimizations to transformer...")
258
- generator.model = apply_aot_optimization(generator.model, device="cuda")
259
- _is_optimized = True
260
- print("✓ Transformer optimized with Float8 + AOT compilation!")
261
 
262
  return generator
263
 
264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  def update_font_choices(author: str):
266
  """
267
  Update available font choices based on selected author
@@ -283,31 +235,13 @@ def parse_font_style(font_style: str) -> str:
283
  return None
284
 
285
 
286
- @spaces.GPU(duration=300) # 5 minutes for first-time compilation
287
- def compile_and_warmup():
288
- """
289
- Compile the model with Float8 + AOT optimization (first time only).
290
- """
291
- print("="*50)
292
- print("First-time compilation starting...")
293
- print("="*50)
294
- gen = init_generator()
295
- # Warmup run to trigger JIT compilation
296
- print("Running warmup generation...")
297
- gen.generate(text="测", font_style="楷", author=None, num_steps=1, seed=42)
298
- print("="*50)
299
- print("Compilation and warmup complete!")
300
- print("="*50)
301
- return gen
302
-
303
-
304
- @spaces.GPU(duration=120) # 2 minutes for normal generation (20s + 25steps * 4s = ~120s)
305
  def run_generation(text, font, author, num_steps, start_seed, num_images):
306
  """
307
- Run generation after model is already compiled.
308
- Duration: ~20s base + 4s per step, for up to 8 images.
309
  """
310
- gen = init_generator() # Returns cached generator
311
 
312
  results = []
313
  seeds_used = []
@@ -337,18 +271,10 @@ def interactive_session(
337
  progress=gr.Progress()
338
  ):
339
  """
340
- Interactive session: compile model once (5 min), then generate images (2 min each).
341
-
342
- Args:
343
- text: Input text (1-7 characters)
344
- author_dropdown: Selected author
345
- font_style: Font style
346
- num_steps: Inference steps
347
- start_seed: Starting seed
348
- num_images: Number of images to generate (each with different seed)
349
 
350
- Yields:
351
- Progress status, gallery of results
352
  """
353
  global _is_optimized
354
 
@@ -366,14 +292,16 @@ def interactive_session(
366
  # Determine author
367
  author = author_dropdown if author_dropdown != "None (Synthetic / 合成风格)" else None
368
 
369
- # Step 1: Compile model if needed (first time only, 5 min budget)
370
  if not _is_optimized:
371
- yield "⏳ 首次运行,正在编译优化模型(约3-5分钟)... / First run, compiling...", []
372
- compile_and_warmup()
373
- yield "✅ 模型编译完成!/ Model compiled!", []
374
-
375
- # Step 2: Run generation (2 min budget)
376
- yield f"🎨 开始生成 {num_images} 张图片... / Generating {num_images} images...", []
 
 
377
  progress(0.1, desc="生成中...")
378
 
379
  results, seeds_used = run_generation(
@@ -382,7 +310,7 @@ def interactive_session(
382
 
383
  progress(1.0, desc="完成!")
384
 
385
- # Final yield
386
  if num_images > 1:
387
  final_status = f"✅ 全部完成!共 {num_images} 张 (Seeds: {seeds_used[0]}-{seeds_used[-1]})"
388
  else:
 
101
  # ============================================================
102
  # AOT Optimization Configuration (from FLUX-Kontext-fp8)
103
  # ============================================================
104
+ from torch.utils._pytree import tree_map_only
105
  from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig
106
 
 
 
 
 
 
 
107
  # Inductor configuration for optimal performance
108
  INDUCTOR_CONFIGS = {
109
  'conv_1x1_as_mm': True,
 
115
  }
116
 
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  def init_generator():
119
+ """Initialize the generator (without optimization - that's done separately)"""
120
+ global generator, _cached_model_dir
121
 
122
  if generator is None:
123
  # Enable CUDA optimizations
 
149
  author_descriptions_path='dataset/calligraphy_styles_en.json',
150
  use_deepspeed=False,
151
  use_4bit_quantization=False,
152
+ use_float8_quantization=False,
153
+ use_torch_compile=False,
154
  dtype="fp32",
155
  )
 
 
 
 
 
 
 
156
 
157
  return generator
158
 
159
 
160
+ def optimize_transformer_(gen):
161
+ """
162
+ Apply Float8 quantization + AOT compilation using spaces.aoti_capture.
163
+ Based on FLUX-Kontext-fp8 pattern.
164
+ """
165
+ model = gen.model
166
+
167
+ @spaces.GPU(duration=300) # 5 minutes for compilation
168
+ def compile_transformer():
169
+ print("="*50)
170
+ print("Starting AOT compilation with spaces.aoti_capture...")
171
+ print("="*50)
172
+
173
+ # Step 1: Capture model forward during a real generation
174
+ print("Capturing model forward pass...")
175
+ with spaces.aoti_capture(model) as call:
176
+ gen.generate(
177
+ text="测试",
178
+ font_style="楷",
179
+ author=None,
180
+ num_steps=1,
181
+ seed=42,
182
+ )
183
+
184
+ # Step 2: Build dynamic shapes (we use fixed shapes, so set to None)
185
+ dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
186
+
187
+ # Step 3: Apply Float8 quantization
188
+ print("Applying Float8 quantization...")
189
+ quantize_(model, Float8DynamicActivationFloat8WeightConfig())
190
+ print("✓ Float8 quantization complete!")
191
+
192
+ # Step 4: Export model
193
+ print("Exporting model with torch.export...")
194
+ exported = torch.export.export(
195
+ mod=model,
196
+ args=call.args,
197
+ kwargs=call.kwargs,
198
+ dynamic_shapes=dynamic_shapes,
199
+ )
200
+ print("✓ Model exported!")
201
+
202
+ # Step 5: AOT compile
203
+ print("AOT compiling with spaces.aoti_compile...")
204
+ print(f" Inductor configs: {INDUCTOR_CONFIGS}")
205
+ return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
206
+
207
+ # Run compilation and apply the result
208
+ print("="*50)
209
+ print("Running AOT compilation (this takes 3-5 minutes)...")
210
+ print("="*50)
211
+ spaces.aoti_apply(compile_transformer(), model)
212
+ print("="*50)
213
+ print("✓ AOT compilation complete! Model is now optimized.")
214
+ print("="*50)
215
+
216
+
217
  def update_font_choices(author: str):
218
  """
219
  Update available font choices based on selected author
 
235
  return None
236
 
237
 
238
+ @spaces.GPU(duration=120) # 2 minutes for normal generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  def run_generation(text, font, author, num_steps, start_seed, num_images):
240
  """
241
+ Run generation with the optimized model.
242
+ Duration: 20s base + ~4s per step per image.
243
  """
244
+ gen = init_generator()
245
 
246
  results = []
247
  seeds_used = []
 
271
  progress=gr.Progress()
272
  ):
273
  """
274
+ Interactive session with separate compilation and generation phases.
 
 
 
 
 
 
 
 
275
 
276
+ - First time: 5 min for AOT compilation (one-time)
277
+ - After that: 2 min for generation
278
  """
279
  global _is_optimized
280
 
 
292
  # Determine author
293
  author = author_dropdown if author_dropdown != "None (Synthetic / 合成风格)" else None
294
 
295
+ # Step 1: AOT Compile if not done yet (5 min, one-time only)
296
  if not _is_optimized:
297
+ yield "⏳ 首次运行,需要编译优化模型(约3-5分钟,仅此一次)...", []
298
+ gen = init_generator()
299
+ optimize_transformer_(gen) # This uses @spaces.GPU(duration=300) internally
300
+ _is_optimized = True
301
+ yield "✅ 模型编译完成!后续生成将会很快。", []
302
+
303
+ # Step 2: Run generation (2 min)
304
+ yield f"🎨 开始生成 {num_images} 张图片...", []
305
  progress(0.1, desc="生成中...")
306
 
307
  results, seeds_used = run_generation(
 
310
 
311
  progress(1.0, desc="完成!")
312
 
313
+ # Final status
314
  if num_images > 1:
315
  final_status = f"✅ 全部完成!共 {num_images} 张 (Seeds: {seeds_used[0]}-{seeds_used[-1]})"
316
  else: