TSXu commited on
Commit
aa36e12
·
1 Parent(s): 8c4267f

Load compiled graph in each GPU session

Browse files

ZeroGPU runs each @spaces.GPU call in a separate session, so the
compiled graph needs to be loaded at the start of each inference.

- run_generation now loads compiled graph from Hub at start
- Compilation only happens once (when no cached graph exists)
- Each inference downloads and applies the cached .pt2 file

Files changed (1) hide show
  1. app.py +23 -13
app.py CHANGED
@@ -415,12 +415,23 @@ def compile_model_first_time():
415
  return None
416
 
417
 
418
- @spaces.GPU(duration=120) # 2 minutes for normal generation
419
  def run_generation(text, font, author, num_steps, start_seed, num_images):
420
  """
421
  Run generation with the AOT-compiled model.
 
422
  """
423
- gen = init_generator() # Returns the already-compiled generator
 
 
 
 
 
 
 
 
 
 
424
 
425
  results = []
426
  seeds_used = []
@@ -466,19 +477,18 @@ def interactive_session(
466
  # Determine author
467
  author = author_dropdown if author_dropdown != "None (Synthetic / 合成风格)" else None
468
 
469
- # Step 1: Load compiled graph (cached) or compile (first time)
470
  if not _is_optimized:
471
- if _check_compiled_graph_exists():
472
- yield "⏳ 加载已缓存的编译模型...", []
473
- else:
474
  yield "⏳ 首次运行,编译优化模型(约5-10分钟,仅此一次)...", []
475
- progress(0.1, desc="加载/编译中...")
476
- compile_model_first_time()
477
- yield "✅ 模型加载完成!", []
478
-
479
- # Step 2: Run generation (2 min)
480
- yield f"🎨 开始生成 {num_images} 张图片...", []
481
- progress(0.5, desc="生成...")
 
482
 
483
  results, seeds_used = run_generation(
484
  text, font, author, num_steps, start_seed, num_images
 
415
  return None
416
 
417
 
418
+ @spaces.GPU(duration=180) # 3 minutes for generation (includes loading compiled graph)
419
  def run_generation(text, font, author, num_steps, start_seed, num_images):
420
  """
421
  Run generation with the AOT-compiled model.
422
+ Each GPU session loads the cached compiled graph from Hub.
423
  """
424
+ gen = init_generator()
425
+ model = gen.model
426
+
427
+ # Load compiled graph from Hub (fast, ~30s download + load)
428
+ # This is needed because each @spaces.GPU call is a new session
429
+ if _check_compiled_graph_exists():
430
+ logger.info("Loading cached compiled graph for this session...")
431
+ _load_compiled_graph(model)
432
+ logger.info("✓ Compiled graph loaded!")
433
+ else:
434
+ logger.warning("No compiled graph found on Hub - running unoptimized")
435
 
436
  results = []
437
  seeds_used = []
 
477
  # Determine author
478
  author = author_dropdown if author_dropdown != "None (Synthetic / 合成风格)" else None
479
 
480
+ # Step 1: First-time compile (only if no cached graph exists)
481
  if not _is_optimized:
482
+ if not _check_compiled_graph_exists():
 
 
483
  yield "⏳ 首次运行,编译优化模型(约5-10分钟,仅此一次)...", []
484
+ progress(0.1, desc="编译中...")
485
+ compile_model_first_time()
486
+ yield "✅ 编译完成并已上传缓存!", []
487
+ _is_optimized = True # Mark as done (cached graph exists)
488
+
489
+ # Step 2: Run generation (includes loading compiled graph + inference)
490
+ yield f"🎨 加载编译模型并生成 {num_images} 张图片...", []
491
+ progress(0.3, desc="加载编译模型...")
492
 
493
  results, seeds_used = run_generation(
494
  text, font, author, num_steps, start_seed, num_images