TSXu commited on
Commit
46eccdb
Β·
1 Parent(s): 8d5a72c

Refactor AOT compilation to follow FLUX-Kontext-fp8 pattern exactly

Browse files

- Use tree_map_only for static dynamic_shapes
- Quantize before export (FLUX-Kontext-fp8 order)
- Use torch.export.export instead of draft_export
- Add comprehensive logging to file

Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +113 -56
.gitignore CHANGED
@@ -13,3 +13,4 @@ build/
13
  *.pth
14
  *.ckpt
15
  *.safetensors
 
 
13
  *.pth
14
  *.ckpt
15
  *.safetensors
16
+ *.log
app.py CHANGED
@@ -7,8 +7,40 @@ With Float8 quantization and AOT compilation for faster inference
7
  # Install compatible torch 2.8 + torchvision 0.23 + torchao + spaces (for AOT compilation)
8
  # spaces.aoti_capture requires PyTorch 2.8+
9
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch>=2.8,<2.9" "torchvision>=0.23,<0.24" torchao spaces')
11
- print("torch 2.8 + torchvision 0.23 + torchao + spaces installation complete!")
12
 
13
  # IMPORTANT: import spaces first before any CUDA-related packages
14
  import spaces
@@ -191,63 +223,88 @@ def compile_model_first_time():
191
  """
192
  global _is_optimized, generator
193
 
194
- print("="*50)
195
- print("First-time run: Loading model and AOT compiling...")
196
- print("="*50)
197
 
198
- # Load model
199
- gen = init_generator()
200
- model = gen.model
201
-
202
- # ========== AOT Compilation (FLUX-Kontext-fp8 pattern) ==========
203
-
204
- # Step 1: Capture model forward during a real inference
205
- print("Step 1: Capturing model forward pass with spaces.aoti_capture...")
206
- with spaces.aoti_capture(model) as call:
207
- gen.generate(
208
- text="ζ΅‹θ―•",
209
- font_style="ζ₯·",
210
- author=None,
211
- num_steps=1,
212
- seed=42,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  )
214
- print("βœ“ Forward pass captured!")
215
-
216
- # Step 2: Build dynamic shapes (None = fixed shapes)
217
- print("Step 2: Building dynamic shapes...")
218
- dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
219
- print("βœ“ Dynamic shapes built!")
220
-
221
- # Step 3: Apply Float8 quantization
222
- print("Step 3: Applying Float8 quantization...")
223
- quantize_(model, Float8DynamicActivationFloat8WeightConfig())
224
- print("βœ“ Float8 quantization complete!")
225
-
226
- # Step 4: Export model with torch.export
227
- print("Step 4: Exporting model with torch.export...")
228
- exported = torch.export.export(
229
- mod=model,
230
- args=call.args,
231
- kwargs=call.kwargs,
232
- dynamic_shapes=dynamic_shapes,
233
- )
234
- print("βœ“ Model exported!")
235
-
236
- # Step 5: AOT compile with spaces.aoti_compile
237
- print("Step 5: AOT compiling with spaces.aoti_compile...")
238
- print(f" Inductor configs: {INDUCTOR_CONFIGS}")
239
- compiled = spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
240
- print("βœ“ AOT compilation complete!")
241
-
242
- # Step 6: Apply compiled model
243
- print("Step 6: Applying compiled model...")
244
- spaces.aoti_apply(compiled, model)
245
- print("βœ“ AOT compiled model applied!")
246
-
247
- _is_optimized = True
248
- print("="*50)
249
- print("βœ“ Model loaded and AOT compiled!")
250
- print("="*50)
251
 
252
  return gen
253
 
 
7
  # Install compatible torch 2.8 + torchvision 0.23 + torchao + spaces (for AOT compilation)
8
  # spaces.aoti_capture requires PyTorch 2.8+
9
  import os
10
+ import sys
11
+ import logging
12
+ import traceback
13
+ from datetime import datetime
14
+
15
+ # Setup logging to file
16
+ LOG_FILE = "aot_compile.log"
17
+ logging.basicConfig(
18
+ level=logging.DEBUG,
19
+ format='%(asctime)s [%(levelname)s] %(message)s',
20
+ handlers=[
21
+ logging.FileHandler(LOG_FILE, mode='w', encoding='utf-8'),
22
+ logging.StreamHandler(sys.stdout)
23
+ ]
24
+ )
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # Also redirect print to log
28
+ class LoggingPrinter:
29
+ def __init__(self, logger, original_stdout):
30
+ self.logger = logger
31
+ self.original_stdout = original_stdout
32
+ def write(self, message):
33
+ if message.strip():
34
+ self.logger.info(message.strip())
35
+ self.original_stdout.write(message)
36
+ def flush(self):
37
+ self.original_stdout.flush()
38
+
39
+ # Keep original stdout for gradio
40
+ _original_stdout = sys.stdout
41
+
42
  os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch>=2.8,<2.9" "torchvision>=0.23,<0.24" torchao spaces')
43
+ logger.info("torch 2.8 + torchvision 0.23 + torchao + spaces installation complete!")
44
 
45
  # IMPORTANT: import spaces first before any CUDA-related packages
46
  import spaces
 
223
  """
224
  global _is_optimized, generator
225
 
226
+ logger.info("="*50)
227
+ logger.info("First-time run: Loading model and AOT compiling...")
228
+ logger.info("="*50)
229
 
230
+ try:
231
+ # Load model
232
+ gen = init_generator()
233
+ model = gen.model
234
+
235
+ # ========== AOT Compilation (FLUX-Kontext-fp8 pattern exactly) ==========
236
+
237
+ # Step 1: Capture model forward during a real inference
238
+ logger.info("Step 1: Capturing model forward pass with spaces.aoti_capture...")
239
+ with spaces.aoti_capture(model) as call:
240
+ gen.generate(
241
+ text="ζ΅‹θ―•",
242
+ font_style="ζ₯·",
243
+ author=None,
244
+ num_steps=1,
245
+ seed=42,
246
+ )
247
+ logger.info("βœ“ Forward pass captured!")
248
+
249
+ # Log call info
250
+ logger.info(f" call.args types: {[type(a).__name__ for a in call.args]}")
251
+ logger.info(f" call.kwargs keys: {list(call.kwargs.keys())}")
252
+ for k, v in call.kwargs.items():
253
+ if hasattr(v, 'shape'):
254
+ logger.info(f" {k}: tensor shape={v.shape}, dtype={v.dtype}")
255
+ else:
256
+ logger.info(f" {k}: {type(v).__name__} = {v}")
257
+
258
+ # Step 2: Build dynamic_shapes (FLUX-Kontext-fp8 pattern: all static)
259
+ # tree_map_only maps all tensors/bools to None = static shape
260
+ logger.info("Step 2: Building static shapes (FLUX-Kontext-fp8 pattern)...")
261
+ dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
262
+ logger.info(f" dynamic_shapes keys: {list(dynamic_shapes.keys()) if dynamic_shapes else 'None'}")
263
+ logger.info("βœ“ Static shapes configured!")
264
+
265
+ # Step 3: Apply Float8 quantization BEFORE export (FLUX-Kontext-fp8 pattern)
266
+ logger.info("Step 3: Applying Float8 quantization...")
267
+ quantize_(model, Float8DynamicActivationFloat8WeightConfig())
268
+ logger.info("βœ“ Float8 quantization complete!")
269
+
270
+ # Step 4: Export model with torch.export.export (not draft_export)
271
+ logger.info("Step 4: Exporting model with torch.export.export...")
272
+ exported = torch.export.export(
273
+ mod=model,
274
+ args=call.args,
275
+ kwargs=call.kwargs,
276
+ dynamic_shapes=dynamic_shapes,
277
  )
278
+ logger.info("βœ“ Model exported!")
279
+
280
+ # Step 5: AOT compile with spaces.aoti_compile
281
+ logger.info("Step 5: AOT compiling with spaces.aoti_compile...")
282
+ logger.info(f" Inductor configs: {INDUCTOR_CONFIGS}")
283
+ compiled = spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
284
+ logger.info("βœ“ AOT compilation complete!")
285
+
286
+ # Step 6: Apply compiled model
287
+ logger.info("Step 6: Applying compiled model...")
288
+ spaces.aoti_apply(compiled, model)
289
+ logger.info("βœ“ AOT compiled model applied!")
290
+
291
+ _is_optimized = True
292
+ logger.info("="*50)
293
+ except Exception as e:
294
+ logger.error("="*50)
295
+ logger.error("AOT COMPILATION FAILED!")
296
+ logger.error("="*50)
297
+ logger.error(f"Exception: {e}")
298
+ logger.error("Full traceback:")
299
+ logger.error(traceback.format_exc())
300
+ # Save full error to file
301
+ with open("aot_error.log", "w") as f:
302
+ f.write(f"Exception: {e}\n\n")
303
+ f.write(traceback.format_exc())
304
+ raise
305
+
306
+ logger.info("βœ“ Model loaded and AOT compiled!")
307
+ logger.info("="*50)
 
 
 
 
 
 
 
308
 
309
  return gen
310