TSXu commited on
Commit
6e8caef
·
1 Parent(s): a8c4850

Split model loading and generation for better progress visibility

Browse files

- Added load_and_optimize_model() for step 1 (model loading + FP8)
- Separated run_generation() for step 2 (actual inference)
- Added warning about first-run compilation and retry hint
- Updated progress messages to show current step

Files changed (1) hide show
  1. app.py +40 -11
app.py CHANGED
@@ -213,23 +213,37 @@ def parse_font_style(font_style: str) -> str:
213
  return None
214
 
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  def _get_generation_duration(text, font, author, num_steps, start_seed, num_images):
217
- """Calculate dynamic GPU duration: 20s loading + 1.5s per step per image"""
218
- return 40 + int(2 * num_steps * num_images)
219
 
220
 
221
  @spaces.GPU(duration=_get_generation_duration)
222
  def run_generation(text, font, author, num_steps, start_seed, num_images):
223
  """
224
- Run generation with FA3 + FP8 quantization.
225
- No AOT cache - quantization applied fresh each session.
226
  """
227
  gen = init_generator()
228
 
229
- # Apply FP8 quantization (works with FA3)
230
- logger.info("Applying FP8 quantization to model...")
231
  quantize_(gen.model, Float8DynamicActivationFloat8WeightConfig())
232
- logger.info("✓ FP8 quantization applied! Running with FA3 + FP8")
233
 
234
  results = []
235
  seeds_used = []
@@ -255,7 +269,8 @@ def interactive_session(
255
  progress=gr.Progress()
256
  ):
257
  """
258
- Interactive session with FA3 + bf16.
 
259
  """
260
  # Validate text
261
  if len(text) < 1:
@@ -271,8 +286,14 @@ def interactive_session(
271
  # Determine author
272
  author = author_dropdown if author_dropdown != "None (Synthetic / 合成风格)" else None
273
 
274
- # Run generation with FA3 + FP8
275
- yield f"🎨 生成 {num_images} 张图片 (FA3 + FP8)...", []
 
 
 
 
 
 
276
  progress(0.3, desc="生成中...")
277
 
278
  results, seeds_used = run_generation(
@@ -368,7 +389,15 @@ with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生
368
  with gr.Column(scale=1):
369
  # Output section
370
  gr.Markdown("### 🖼️ 生成结果 / Generated Results")
371
- gr.Markdown("*点击图片可放大查看 / Click image to enlarge*")
 
 
 
 
 
 
 
 
372
 
373
  output_gallery = gr.Gallery(
374
  label="生成结果 / Generated Results",
 
213
  return None
214
 
215
 
216
+ @spaces.GPU(duration=60)
217
+ def load_and_optimize_model():
218
+ """
219
+ Step 1: Load model and apply FP8 quantization.
220
+ This triggers torch compilation on first run (may take 1-2 minutes).
221
+ """
222
+ gen = init_generator()
223
+
224
+ # Apply FP8 quantization (works with FA3)
225
+ logger.info("Applying FP8 quantization to model...")
226
+ quantize_(gen.model, Float8DynamicActivationFloat8WeightConfig())
227
+ logger.info("✓ FP8 quantization applied! Running with FA3 + FP8")
228
+
229
+ return "ready"
230
+
231
+
232
  def _get_generation_duration(text, font, author, num_steps, start_seed, num_images):
233
+ """Calculate dynamic GPU duration: 2s per step per image"""
234
+ return 20 + int(2 * num_steps * num_images)
235
 
236
 
237
  @spaces.GPU(duration=_get_generation_duration)
238
  def run_generation(text, font, author, num_steps, start_seed, num_images):
239
  """
240
+ Step 2: Run actual generation.
241
+ Model should already be loaded from step 1.
242
  """
243
  gen = init_generator()
244
 
245
+ # Re-apply FP8 quantization (ZeroGPU sessions are isolated)
 
246
  quantize_(gen.model, Float8DynamicActivationFloat8WeightConfig())
 
247
 
248
  results = []
249
  seeds_used = []
 
269
  progress=gr.Progress()
270
  ):
271
  """
272
+ Interactive session with FA3 + FP8.
273
+ Split into load + generate steps for better progress visibility.
274
  """
275
  # Validate text
276
  if len(text) < 1:
 
286
  # Determine author
287
  author = author_dropdown if author_dropdown != "None (Synthetic / 合成风格)" else None
288
 
289
+ # Step 1: Load and optimize model
290
+ yield " 加载模型中... (首次可能需要1-2分钟编译) / Loading model... (first run may take 1-2 min to compile)", []
291
+ progress(0.1, desc="加载模型...")
292
+
293
+ load_and_optimize_model()
294
+
295
+ # Step 2: Run generation
296
+ yield f"🎨 生成 {num_images} 张图片中... / Generating {num_images} images...", []
297
  progress(0.3, desc="生成中...")
298
 
299
  results, seeds_used = run_generation(
 
389
  with gr.Column(scale=1):
390
  # Output section
391
  gr.Markdown("### 🖼️ 生成结果 / Generated Results")
392
+ gr.Markdown("""
393
+ ⚠️ **首次生成说明 / First Run Note:**
394
+ - 第一次生成会触发 PyTorch 编译,可能需要 1-2 分钟
395
+ - 如果遇到错误,请**再点一次生成按钮**即可正常运行
396
+ - First generation triggers PyTorch compilation (~1-2 min)
397
+ - If you see an error, just **click generate again** and it will work
398
+
399
+ *点击图片可放大查看 / Click image to enlarge*
400
+ """)
401
 
402
  output_gallery = gr.Gallery(
403
  label="生成结果 / Generated Results",