TSXu commited on
Commit
f1e9349
·
1 Parent(s): 7c94f94

Fix nested @spaces.GPU and torchao version compatibility

Browse files

1. Remove nested @spaces.GPU - all AOT compilation now in single function
2. Install nightly torchao compatible with torch 2.10.0 at startup
3. All CUDA operations now in @spaces.GPU decorated functions:
- compile_model_first_time(): 300s - AOT compilation
- run_generation(): 120s - normal generation

Files changed (1) hide show
  1. app.py +58 -67
app.py CHANGED
@@ -4,6 +4,17 @@ Gradio Demo for Chinese Calligraphy Generation - HuggingFace Space Version
4
  With Float8 quantization and AOT compilation for faster inference
5
  """
6
 
 
 
 
 
 
 
 
 
 
 
 
7
  # IMPORTANT: import spaces first before any CUDA-related packages
8
  import spaces
9
 
@@ -12,7 +23,6 @@ import json
12
  import csv
13
  import time
14
  import torch
15
- import os
16
 
17
  # Load author and font mappings from CSV
18
  def load_author_fonts_from_csv(csv_path):
@@ -178,85 +188,66 @@ def parse_font_style(font_style: str) -> str:
178
  return None
179
 
180
 
181
- def aot_compile_transformer(gen):
182
- """
183
- AOT compile the transformer using spaces.aoti_capture/compile/apply.
184
- Exactly following FLUX-Kontext-fp8 pattern.
185
- """
186
- model = gen.model
187
-
188
- @spaces.GPU(duration=300) # 5 minutes for AOT compilation
189
- def compile_transformer():
190
- print("="*50)
191
- print("Starting AOT compilation (FLUX-Kontext-fp8 pattern)...")
192
- print("="*50)
193
-
194
- # Step 1: Capture model forward during a real inference
195
- print("Step 1: Capturing model forward pass with spaces.aoti_capture...")
196
- with spaces.aoti_capture(model) as call:
197
- # Run a sample generation to capture the forward call
198
- gen.generate(
199
- text="测试",
200
- font_style="楷",
201
- author=None,
202
- num_steps=1,
203
- seed=42,
204
- )
205
- print("✓ Forward pass captured!")
206
-
207
- # Step 2: Build dynamic shapes (None = fixed shapes)
208
- print("Step 2: Building dynamic shapes...")
209
- dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
210
- print("✓ Dynamic shapes built!")
211
-
212
- # Step 3: Apply Float8 quantization
213
- print("Step 3: Applying Float8 quantization...")
214
- quantize_(model, Float8DynamicActivationFloat8WeightConfig())
215
- print("✓ Float8 quantization complete!")
216
-
217
- # Step 4: Export model with torch.export
218
- print("Step 4: Exporting model with torch.export...")
219
- exported = torch.export.export(
220
- mod=model,
221
- args=call.args,
222
- kwargs=call.kwargs,
223
- dynamic_shapes=dynamic_shapes,
224
- )
225
- print("✓ Model exported!")
226
-
227
- # Step 5: AOT compile with spaces.aoti_compile
228
- print("Step 5: AOT compiling with spaces.aoti_compile...")
229
- print(f" Inductor configs: {INDUCTOR_CONFIGS}")
230
- compiled = spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
231
- print("✓ AOT compilation complete!")
232
-
233
- return compiled
234
-
235
- # Run compilation and apply the result
236
- print("Running AOT compilation...")
237
- spaces.aoti_apply(compile_transformer(), model)
238
- print("="*50)
239
- print("✓ AOT compiled model applied!")
240
- print("="*50)
241
-
242
-
243
  @spaces.GPU(duration=300) # 5 minutes for first-time AOT compilation
244
  def compile_model_first_time():
245
  """
246
  First-time: Load model and run AOT compilation.
247
- Returns the optimized generator.
248
  """
249
  global _is_optimized, generator
250
 
251
  print("="*50)
252
- print("First-time run: Loading model and running AOT compilation...")
253
  print("="*50)
254
 
255
  # Load model
256
  gen = init_generator()
 
257
 
258
- # AOT compile the transformer
259
- aot_compile_transformer(gen)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
  _is_optimized = True
262
  print("="*50)
 
4
  With Float8 quantization and AOT compilation for faster inference
5
  """
6
 
7
+ # Install compatible torchao version for the current torch (following FLUX-Kontext-fp8 pattern)
8
+ import os
9
+ import subprocess
10
+ print("Installing compatible torchao version...")
11
+ subprocess.run([
12
+ "pip", "install", "--upgrade", "--pre",
13
+ "--extra-index-url", "https://download.pytorch.org/whl/nightly/cu126",
14
+ "torchao"
15
+ ], capture_output=True)
16
+ print("torchao installation complete!")
17
+
18
  # IMPORTANT: import spaces first before any CUDA-related packages
19
  import spaces
20
 
 
23
  import csv
24
  import time
25
  import torch
 
26
 
27
  # Load author and font mappings from CSV
28
  def load_author_fonts_from_csv(csv_path):
 
188
  return None
189
 
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  @spaces.GPU(duration=300) # 5 minutes for first-time AOT compilation
192
  def compile_model_first_time():
193
  """
194
  First-time: Load model and run AOT compilation.
195
+ Exactly following FLUX-Kontext-fp8 pattern.
196
  """
197
  global _is_optimized, generator
198
 
199
  print("="*50)
200
+ print("First-time run: Loading model and AOT compiling...")
201
  print("="*50)
202
 
203
  # Load model
204
  gen = init_generator()
205
+ model = gen.model
206
 
207
+ # ========== AOT Compilation (FLUX-Kontext-fp8 pattern) ==========
208
+
209
+ # Step 1: Capture model forward during a real inference
210
+ print("Step 1: Capturing model forward pass with spaces.aoti_capture...")
211
+ with spaces.aoti_capture(model) as call:
212
+ gen.generate(
213
+ text="测试",
214
+ font_style="楷",
215
+ author=None,
216
+ num_steps=1,
217
+ seed=42,
218
+ )
219
+ print("✓ Forward pass captured!")
220
+
221
+ # Step 2: Build dynamic shapes (None = fixed shapes)
222
+ print("Step 2: Building dynamic shapes...")
223
+ dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
224
+ print("✓ Dynamic shapes built!")
225
+
226
+ # Step 3: Apply Float8 quantization
227
+ print("Step 3: Applying Float8 quantization...")
228
+ quantize_(model, Float8DynamicActivationFloat8WeightConfig())
229
+ print("✓ Float8 quantization complete!")
230
+
231
+ # Step 4: Export model with torch.export
232
+ print("Step 4: Exporting model with torch.export...")
233
+ exported = torch.export.export(
234
+ mod=model,
235
+ args=call.args,
236
+ kwargs=call.kwargs,
237
+ dynamic_shapes=dynamic_shapes,
238
+ )
239
+ print("✓ Model exported!")
240
+
241
+ # Step 5: AOT compile with spaces.aoti_compile
242
+ print("Step 5: AOT compiling with spaces.aoti_compile...")
243
+ print(f" Inductor configs: {INDUCTOR_CONFIGS}")
244
+ compiled = spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
245
+ print("✓ AOT compilation complete!")
246
+
247
+ # Step 6: Apply compiled model
248
+ print("Step 6: Applying compiled model...")
249
+ spaces.aoti_apply(compiled, model)
250
+ print("✓ AOT compiled model applied!")
251
 
252
  _is_optimized = True
253
  print("="*50)