TSXu commited on
Commit
8c4267f
·
1 Parent(s): b21384d

Add compiled graph caching to HF Hub

Browse files

- Check Hub for cached compiled_graph.pt2 on startup
- If exists, load using ZeroGPUCompiledModel (fast ~30s)
- If not, compile and upload to Hub for future use
- Based on zerogpu-aoti/Flux-Compiled-Graph pattern

Files changed (1) hide show
  1. app.py +116 -34
app.py CHANGED
@@ -152,6 +152,79 @@ INDUCTOR_CONFIGS = {
152
  'triton.cudagraphs': True,
153
  }
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  def init_generator():
157
  """Initialize the generator (without optimization - that's done separately)"""
@@ -216,16 +289,16 @@ def parse_font_style(font_style: str) -> str:
216
  return None
217
 
218
 
219
- @spaces.GPU(duration=900) # don't change this duration!!!
220
  def compile_model_first_time():
221
  """
222
- First-time: Load model and run AOT compilation.
223
- With Float8 quantization for faster inference.
224
  """
225
  global _is_optimized, generator
226
 
227
  logger.info("="*50)
228
- logger.info("First-time run: Loading model and AOT compiling...")
229
  logger.info("="*50)
230
 
231
  try:
@@ -233,10 +306,25 @@ def compile_model_first_time():
233
  gen = init_generator()
234
  model = gen.model
235
 
236
- # ========== AOT Compilation (fp32 only - testing export without quantization) ==========
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
  # Step 1: Capture model forward during a real inference
239
- logger.info("Step 1: Capturing model forward pass with spaces.aoti_capture...")
240
  with spaces.aoti_capture(model) as call:
241
  gen.generate(
242
  text="测试长度等于七",
@@ -256,33 +344,22 @@ def compile_model_first_time():
256
  else:
257
  logger.info(f" {k}: {type(v).__name__} = {v}")
258
 
259
- # Step 2: Build dynamic_shapes (FLUX-Kontext-fp8 pattern: all static)
260
- # tree_map_only maps all tensors/bools to None = static shape
261
  # Step 2: Build dynamic_shapes (all static)
262
- # For non-tensor types (like float ip_scale), we must use None
263
  logger.info("Step 2: Building static shapes...")
264
  dynamic_shapes = {}
265
  for k, v in call.kwargs.items():
266
- if isinstance(v, torch.Tensor):
267
- dynamic_shapes[k] = None # Static shape for tensors
268
- else:
269
- dynamic_shapes[k] = None # Must be None for non-tensor types
270
  logger.info(f" dynamic_shapes: {dynamic_shapes}")
271
  logger.info("✓ Static shapes configured!")
272
 
273
- # Step 3: Disable gradients on model (required for AOT export)
274
  logger.info("Step 3: Disabling gradients on model...")
275
  model.eval()
276
  model.requires_grad_(False)
277
  logger.info("✓ Model in eval mode with gradients disabled!")
278
 
279
- # Step 4: Float8 quantization DISABLED (causes issues on some GPU types)
280
- # logger.info("Step 4: Applying Float8 quantization...")
281
- # quantize_(model, Float8DynamicActivationFloat8WeightConfig())
282
- # logger.info("✓ Float8 quantization complete!")
283
-
284
- # Step 4 (was Step 5): Detach inputs (requires_grad causes issues with AOT export)
285
- logger.info("Step 5: Detaching inputs to avoid gradient issues...")
286
  detached_args = tuple(
287
  a.detach() if isinstance(a, torch.Tensor) else a for a in call.args
288
  )
@@ -292,8 +369,8 @@ def compile_model_first_time():
292
  }
293
  logger.info("✓ Inputs detached!")
294
 
295
- # Step 6: Export model with torch.export.export (not draft_export)
296
- logger.info("Step 6: Exporting model with torch.export.export...")
297
  exported = torch.export.export(
298
  mod=model,
299
  args=detached_args,
@@ -302,12 +379,16 @@ def compile_model_first_time():
302
  )
303
  logger.info("✓ Model exported!")
304
 
305
- # Step 7: AOT compile with spaces.aoti_compile
306
- logger.info("Step 7: AOT compiling with spaces.aoti_compile...")
307
  logger.info(f" Inductor configs: {INDUCTOR_CONFIGS}")
308
  compiled = spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
309
  logger.info("✓ AOT compilation complete!")
310
 
 
 
 
 
311
  # Step 8: Apply compiled model
312
  logger.info("Step 8: Applying compiled model...")
313
  spaces.aoti_apply(compiled, model)
@@ -315,6 +396,9 @@ def compile_model_first_time():
315
 
316
  _is_optimized = True
317
  logger.info("="*50)
 
 
 
318
  except Exception as e:
319
  logger.error("="*50)
320
  logger.error("AOT COMPILATION FAILED!")
@@ -322,17 +406,12 @@ def compile_model_first_time():
322
  logger.error(f"Exception: {e}")
323
  logger.error("Full traceback:")
324
  logger.error(traceback.format_exc())
325
- # Save full error to file
326
  with open("aot_error.log", "w") as f:
327
  f.write(f"Exception: {e}\n\n")
328
  f.write(traceback.format_exc())
329
  raise
330
 
331
- logger.info("✓ Model loaded and AOT compiled!")
332
- logger.info("="*50)
333
-
334
  # NOTE: Don't return gen - causes pickle error in ZeroGPU multiprocessing
335
- # Generator is stored in global variable and accessed via init_generator()
336
  return None
337
 
338
 
@@ -387,12 +466,15 @@ def interactive_session(
387
  # Determine author
388
  author = author_dropdown if author_dropdown != "None (Synthetic / 合成风格)" else None
389
 
390
- # Step 1: AOT compile if not done yet (5 min, one-time)
391
  if not _is_optimized:
392
- yield "⏳ 首次运行,AOT编译优化模型(约3-5分钟,仅此一次)...", []
393
- progress(0.1, desc="AOT编译...")
 
 
 
394
  compile_model_first_time()
395
- yield "✅ AOT编译完成!", []
396
 
397
  # Step 2: Run generation (2 min)
398
  yield f"🎨 开始生成 {num_images} 张图片...", []
 
152
  'triton.cudagraphs': True,
153
  }
154
 
155
+ # ============================================================
156
+ # AOT Compiled Graph Caching (save to / load from HF Hub)
157
+ # ============================================================
158
+ HF_CACHE_REPO = "TSXu/Unicalli_Pro"
159
+ HF_CACHE_FILENAME = "compiled_graph.pt2"
160
+
161
+
162
+ def _check_compiled_graph_exists():
163
+ """Check if compiled graph exists on HF Hub"""
164
+ from huggingface_hub import hf_hub_download, HfApi
165
+ try:
166
+ api = HfApi()
167
+ files = api.list_repo_files(HF_CACHE_REPO)
168
+ return HF_CACHE_FILENAME in files
169
+ except Exception as e:
170
+ logger.info(f"Could not check Hub for compiled graph: {e}")
171
+ return False
172
+
173
+
174
+ def _load_compiled_graph(model):
175
+ """Load compiled graph from HF Hub using ZeroGPU internals"""
176
+ from huggingface_hub import hf_hub_download
177
+ from spaces.zero.torch.aoti import ZeroGPUCompiledModel, ZeroGPUWeights, drain_module_parameters
178
+
179
+ logger.info(f"Downloading compiled graph from {HF_CACHE_REPO}/{HF_CACHE_FILENAME}...")
180
+ compiled_graph_file = hf_hub_download(HF_CACHE_REPO, HF_CACHE_FILENAME)
181
+ logger.info(f"✓ Downloaded to: {compiled_graph_file}")
182
+
183
+ logger.info("Loading compiled graph into model...")
184
+ state_dict = model.state_dict()
185
+ zerogpu_weights = ZeroGPUWeights({name: weight for name, weight in state_dict.items()})
186
+ compiled = ZeroGPUCompiledModel(compiled_graph_file, zerogpu_weights)
187
+
188
+ # Replace forward method
189
+ setattr(model, "forward", compiled)
190
+ drain_module_parameters(model)
191
+ logger.info("✓ Compiled graph loaded and applied!")
192
+ return True
193
+
194
+
195
+ def _upload_compiled_graph(compiled):
196
+ """Upload compiled graph to HF Hub"""
197
+ from huggingface_hub import upload_file
198
+ import tempfile
199
+
200
+ hf_token = os.environ.get("HF_TOKEN")
201
+ if not hf_token:
202
+ logger.warning("HF_TOKEN not set, cannot upload compiled graph")
203
+ return False
204
+
205
+ logger.info(f"Uploading compiled graph to {HF_CACHE_REPO}/{HF_CACHE_FILENAME}...")
206
+
207
+ # Save archive to temp file
208
+ with tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) as f:
209
+ f.write(compiled.archive_file.getvalue())
210
+ temp_path = f.name
211
+
212
+ try:
213
+ upload_file(
214
+ path_or_fileobj=temp_path,
215
+ path_in_repo=HF_CACHE_FILENAME,
216
+ repo_id=HF_CACHE_REPO,
217
+ token=hf_token,
218
+ commit_message="Upload AOT compiled graph",
219
+ )
220
+ logger.info("✓ Compiled graph uploaded to Hub!")
221
+ return True
222
+ except Exception as e:
223
+ logger.error(f"Failed to upload compiled graph: {e}")
224
+ return False
225
+ finally:
226
+ os.unlink(temp_path)
227
+
228
 
229
  def init_generator():
230
  """Initialize the generator (without optimization - that's done separately)"""
 
289
  return None
290
 
291
 
292
+ @spaces.GPU(duration=900) # 15 min for compilation (if needed)
293
  def compile_model_first_time():
294
  """
295
+ First-time: Load model and either load cached compiled graph or compile from scratch.
296
+ Compiled graph is cached on HF Hub for fast subsequent cold starts.
297
  """
298
  global _is_optimized, generator
299
 
300
  logger.info("="*50)
301
+ logger.info("First-time run: Loading model...")
302
  logger.info("="*50)
303
 
304
  try:
 
306
  gen = init_generator()
307
  model = gen.model
308
 
309
+ # Check if compiled graph exists on Hub
310
+ logger.info("Checking for cached compiled graph on HF Hub...")
311
+ if _check_compiled_graph_exists():
312
+ logger.info("="*50)
313
+ logger.info("Found cached compiled graph! Loading from Hub...")
314
+ logger.info("="*50)
315
+ _load_compiled_graph(model)
316
+ _is_optimized = True
317
+ logger.info("✓ Model loaded with cached compiled graph!")
318
+ logger.info("="*50)
319
+ return None
320
+
321
+ # No cached graph, compile from scratch
322
+ logger.info("="*50)
323
+ logger.info("No cached graph found. Compiling from scratch...")
324
+ logger.info("="*50)
325
 
326
  # Step 1: Capture model forward during a real inference
327
+ logger.info("Step 1: Capturing model forward pass...")
328
  with spaces.aoti_capture(model) as call:
329
  gen.generate(
330
  text="测试长度等于七",
 
344
  else:
345
  logger.info(f" {k}: {type(v).__name__} = {v}")
346
 
 
 
347
  # Step 2: Build dynamic_shapes (all static)
 
348
  logger.info("Step 2: Building static shapes...")
349
  dynamic_shapes = {}
350
  for k, v in call.kwargs.items():
351
+ dynamic_shapes[k] = None # Static shape for all
 
 
 
352
  logger.info(f" dynamic_shapes: {dynamic_shapes}")
353
  logger.info("✓ Static shapes configured!")
354
 
355
+ # Step 3: Disable gradients on model
356
  logger.info("Step 3: Disabling gradients on model...")
357
  model.eval()
358
  model.requires_grad_(False)
359
  logger.info("✓ Model in eval mode with gradients disabled!")
360
 
361
+ # Step 4: Detach inputs
362
+ logger.info("Step 4: Detaching inputs...")
 
 
 
 
 
363
  detached_args = tuple(
364
  a.detach() if isinstance(a, torch.Tensor) else a for a in call.args
365
  )
 
369
  }
370
  logger.info("✓ Inputs detached!")
371
 
372
+ # Step 5: Export model
373
+ logger.info("Step 5: Exporting model with torch.export.export...")
374
  exported = torch.export.export(
375
  mod=model,
376
  args=detached_args,
 
379
  )
380
  logger.info("✓ Model exported!")
381
 
382
+ # Step 6: AOT compile
383
+ logger.info("Step 6: AOT compiling with spaces.aoti_compile...")
384
  logger.info(f" Inductor configs: {INDUCTOR_CONFIGS}")
385
  compiled = spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
386
  logger.info("✓ AOT compilation complete!")
387
 
388
+ # Step 7: Upload compiled graph to Hub
389
+ logger.info("Step 7: Uploading compiled graph to Hub...")
390
+ _upload_compiled_graph(compiled)
391
+
392
  # Step 8: Apply compiled model
393
  logger.info("Step 8: Applying compiled model...")
394
  spaces.aoti_apply(compiled, model)
 
396
 
397
  _is_optimized = True
398
  logger.info("="*50)
399
+ logger.info("✓ Model compiled and cached to Hub!")
400
+ logger.info("="*50)
401
+
402
  except Exception as e:
403
  logger.error("="*50)
404
  logger.error("AOT COMPILATION FAILED!")
 
406
  logger.error(f"Exception: {e}")
407
  logger.error("Full traceback:")
408
  logger.error(traceback.format_exc())
 
409
  with open("aot_error.log", "w") as f:
410
  f.write(f"Exception: {e}\n\n")
411
  f.write(traceback.format_exc())
412
  raise
413
 
 
 
 
414
  # NOTE: Don't return gen - causes pickle error in ZeroGPU multiprocessing
 
415
  return None
416
 
417
 
 
466
  # Determine author
467
  author = author_dropdown if author_dropdown != "None (Synthetic / 合成风格)" else None
468
 
469
+ # Step 1: Load compiled graph (cached) or compile (first time)
470
  if not _is_optimized:
471
+ if _check_compiled_graph_exists():
472
+ yield "⏳ 加载已缓存的编译模型...", []
473
+ else:
474
+ yield "⏳ 首次运行,编译优化模型(约5-10分钟,仅此一次)...", []
475
+ progress(0.1, desc="加载/编译中...")
476
  compile_model_first_time()
477
+ yield "✅ 模型加载完成!", []
478
 
479
  # Step 2: Run generation (2 min)
480
  yield f"🎨 开始生成 {num_images} 张图片...", []