TSXu commited on
Commit
3974489
·
1 Parent(s): 624511f

Remove AoT compilation from main (saved to feature/aot-compilation branch)

Browse files
Files changed (2) hide show
  1. app.py +0 -201
  2. requirements.txt +0 -3
app.py CHANGED
@@ -151,172 +151,6 @@ def init_generator():
151
  return generator
152
 
153
 
154
- # ============== AoT Compilation ==============
155
- COMPILED_MODEL_REPO = "TSXu/Unicalli_Pro" # Where to save/load compiled model
156
- COMPILED_MODEL_FILENAME = "compiled_flux_h200.pt2" # Compiled model filename (fp32)
157
- COMPILED_MODEL_FP8_FILENAME = "compiled_flux_h200_fp8.pt2" # FP8 quantized version
158
-
159
- def check_compiled_model_exists():
160
- """Check if compiled model exists on HuggingFace Hub"""
161
- import os
162
- from huggingface_hub import hf_hub_url, get_hf_file_metadata
163
-
164
- try:
165
- hf_token = os.environ.get("HF_TOKEN", None)
166
- url = hf_hub_url(COMPILED_MODEL_REPO, COMPILED_MODEL_FILENAME)
167
- metadata = get_hf_file_metadata(url, token=hf_token)
168
- print(f"✓ Found compiled model on Hub: {COMPILED_MODEL_FILENAME} ({metadata.size / 1e9:.2f} GB)")
169
- return True
170
- except Exception as e:
171
- print(f"Compiled model not found on Hub: {e}")
172
- return False
173
-
174
-
175
- @spaces.GPU(duration=300) # 5 minutes for compilation
176
- def compile_and_upload_model(use_fp8: bool = False):
177
- """
178
- Compile model with AoT and upload to HuggingFace Hub.
179
- This only needs to be done once!
180
-
181
- Args:
182
- use_fp8: Whether to use FP8 quantization (faster, less memory, H200 only)
183
- """
184
- import torch
185
- import os
186
- from huggingface_hub import HfApi
187
-
188
- global generator
189
-
190
- # Initialize generator if not already done
191
- if generator is None:
192
- generator = init_generator()
193
-
194
- model = generator.model
195
-
196
- filename = COMPILED_MODEL_FP8_FILENAME if use_fp8 else COMPILED_MODEL_FILENAME
197
-
198
- print("="*50)
199
- print(f"Starting AoT Compilation {'with FP8' if use_fp8 else '(FP32)'}...")
200
- print("This may take 5-10 minutes...")
201
- print("="*50)
202
-
203
- # Configure Inductor for optimal performance
204
- torch._inductor.config.conv_1x1_as_mm = True
205
- torch._inductor.config.coordinate_descent_tuning = True
206
-
207
- try:
208
- # Step 0: Apply FP8 quantization if requested
209
- if use_fp8:
210
- print("Step 0: Applying FP8 quantization...")
211
- try:
212
- from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig
213
- quantize_(model, Float8DynamicActivationFloat8WeightConfig())
214
- print(" ✓ FP8 quantization applied!")
215
- except Exception as e:
216
- print(f" ⚠️ FP8 quantization failed: {e}")
217
- print(" Falling back to FP32...")
218
- use_fp8 = False
219
- filename = COMPILED_MODEL_FILENAME
220
-
221
- # Step 1: Capture example inputs
222
- print("Step 1/4: Capturing example inputs...")
223
- with spaces.aoti_capture(model) as call:
224
- # Run minimal inference to capture inputs
225
- generator.generate(
226
- text="新年快乐发大财",
227
- font_style="楷",
228
- author=None,
229
- num_steps=1,
230
- seed=42,
231
- )
232
- print(f" Captured {len(call.args)} args, {len(call.kwargs)} kwargs")
233
-
234
- # Step 2: Export model
235
- print("Step 2/4: Exporting model graph...")
236
- try:
237
- exported = torch.export.export(
238
- model,
239
- args=call.args,
240
- kwargs=call.kwargs,
241
- strict=False, # Allow non-strict tracing
242
- )
243
- print(" Export complete!")
244
- except Exception as export_error:
245
- print(f" ❌ Export failed: {export_error}")
246
- if use_fp8:
247
- print(" FP8 may be incompatible with torch.export on this model.")
248
- print(" Try unchecking FP8 and compiling with FP32 instead.")
249
- return f"❌ Export failed: {export_error}\n\nTry disabling FP8 quantization."
250
-
251
- # Step 3: Compile with AOTInductor
252
- print("Step 3/4: Compiling with AOTInductor...")
253
- compiled = spaces.aoti_compile(exported)
254
- print(" Compilation complete!")
255
-
256
- # Step 4: Upload to Hub
257
- print("Step 4/4: Uploading to HuggingFace Hub...")
258
- local_path = f"/tmp/{filename}"
259
- torch.save(compiled, local_path)
260
-
261
- hf_token = os.environ.get("HF_TOKEN", None)
262
- api = HfApi()
263
- api.upload_file(
264
- path_or_fileobj=local_path,
265
- path_in_repo=filename,
266
- repo_id=COMPILED_MODEL_REPO,
267
- repo_type="model",
268
- token=hf_token,
269
- )
270
- print(f"✓ Uploaded compiled model to {COMPILED_MODEL_REPO}/{filename}")
271
-
272
- # Apply compiled model
273
- spaces.aoti_apply(compiled, model)
274
- print("✓ Applied compiled model!")
275
-
276
- mode = "FP8" if use_fp8 else "FP32"
277
- return f"✅ Compilation ({mode}) and upload successful!"
278
-
279
- except Exception as e:
280
- import traceback
281
- traceback.print_exc()
282
- return f"❌ Compilation failed: {e}"
283
-
284
-
285
- @spaces.GPU(duration=60) # 1 minute for loading
286
- def load_and_apply_compiled_model(use_fp8: bool = False):
287
- """Load compiled model from Hub and apply to generator"""
288
- import torch
289
- import os
290
- from huggingface_hub import hf_hub_download
291
-
292
- global generator
293
-
294
- if generator is None:
295
- generator = init_generator()
296
-
297
- filename = COMPILED_MODEL_FP8_FILENAME if use_fp8 else COMPILED_MODEL_FILENAME
298
-
299
- try:
300
- hf_token = os.environ.get("HF_TOKEN", None)
301
- print(f"Downloading {'FP8' if use_fp8 else 'FP32'} compiled model...")
302
-
303
- local_path = hf_hub_download(
304
- repo_id=COMPILED_MODEL_REPO,
305
- filename=filename,
306
- token=hf_token
307
- )
308
-
309
- compiled = torch.load(local_path)
310
- spaces.aoti_apply(compiled, generator.model)
311
-
312
- mode = "FP8" if use_fp8 else "FP32"
313
- print(f"✓ Applied pre-compiled model ({mode})!")
314
- return f"✅ Loaded and applied {mode} compiled model!"
315
- except Exception as e:
316
- print(f"Failed to load compiled model: {e}")
317
- return f"❌ Failed to load compiled model: {e}"
318
-
319
-
320
  def update_font_choices(author: str):
321
  """
322
  Update available font choices based on selected author
@@ -514,41 +348,6 @@ with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生
514
  interactive=False
515
  )
516
 
517
- # Admin section for AoT compilation
518
- with gr.Accordion("🔧 管理员工具 / Admin Tools (AoT Compilation)", open=False):
519
- gr.Markdown("""
520
- **AoT 编译** 可以将模型预编译以加速推理 (~1.5x)。
521
-
522
- - **首次使用**: 点击"编译并上传",等待 5-10 分钟
523
- - **后续使用**: 点击"加载已编译模型"
524
-
525
- **FP8 量化** (仅 H200): 更快的推理 + 更少显存,推荐在 ZeroGPU 上使用!
526
- """)
527
-
528
- fp8_checkbox = gr.Checkbox(
529
- label="使用 FP8 量化 / Use FP8 Quantization (推荐/Recommended for H200)",
530
- value=True,
531
- info="FP8 提供更快推理和更低显存,仅在 H200 上支持"
532
- )
533
-
534
- with gr.Row():
535
- compile_btn = gr.Button("🔨 编译并上传 / Compile & Upload", variant="secondary")
536
- load_compiled_btn = gr.Button("📥 加载已编译模型 / Load Compiled", variant="secondary")
537
-
538
- compile_status = gr.Textbox(label="编译状态 / Compilation Status", value="", interactive=False)
539
-
540
- compile_btn.click(
541
- fn=compile_and_upload_model,
542
- inputs=[fp8_checkbox],
543
- outputs=[compile_status]
544
- )
545
-
546
- load_compiled_btn.click(
547
- fn=load_and_apply_compiled_model,
548
- inputs=[fp8_checkbox],
549
- outputs=[compile_status]
550
- )
551
-
552
  # Author info section
553
  with gr.Accordion("📚 可用书法家列表 / Available Calligraphers(共 {} 位 / {} total)".format(len(AUTHOR_LIST), len(AUTHOR_LIST)), open=False):
554
  author_info_md = "| 书法家 / Calligrapher | 可用字体 / Available Fonts |\n|--------|----------|\n"
 
151
  return generator
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def update_font_choices(author: str):
155
  """
156
  Update available font choices based on selected author
 
348
  interactive=False
349
  )
350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  # Author info section
352
  with gr.Accordion("📚 可用书法家列表 / Available Calligraphers(共 {} 位 / {} total)".format(len(AUTHOR_LIST), len(AUTHOR_LIST)), open=False):
353
  author_info_md = "| 书法家 / Calligrapher | 可用字体 / Available Fonts |\n|--------|----------|\n"
requirements.txt CHANGED
@@ -14,9 +14,6 @@ timm
14
  sentencepiece
15
  diffusers
16
 
17
- # AoT Compilation & Optimization
18
- torchao # FP8 quantization for H200
19
- kernels # Pre-built kernels including FlashAttention-3
20
 
21
  # Data processing
22
  datasets
 
14
  sentencepiece
15
  diffusers
16
 
 
 
 
17
 
18
  # Data processing
19
  datasets