Spaces:
Running on Zero
Running on Zero
Split model loading and generation for better progress visibility
Browse files- Added load_and_optimize_model() for step 1 (model loading + FP8)
- Separated run_generation() for step 2 (actual inference)
- Added warning about first-run compilation and retry hint
- Updated progress messages to show current step
app.py
CHANGED
|
@@ -213,23 +213,37 @@ def parse_font_style(font_style: str) -> str:
|
|
| 213 |
return None
|
| 214 |
|
| 215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
def _get_generation_duration(text, font, author, num_steps, start_seed, num_images):
|
| 217 |
-
"""Calculate dynamic GPU duration:
|
| 218 |
-
return
|
| 219 |
|
| 220 |
|
| 221 |
@spaces.GPU(duration=_get_generation_duration)
|
| 222 |
def run_generation(text, font, author, num_steps, start_seed, num_images):
|
| 223 |
"""
|
| 224 |
-
|
| 225 |
-
|
| 226 |
"""
|
| 227 |
gen = init_generator()
|
| 228 |
|
| 229 |
-
#
|
| 230 |
-
logger.info("Applying FP8 quantization to model...")
|
| 231 |
quantize_(gen.model, Float8DynamicActivationFloat8WeightConfig())
|
| 232 |
-
logger.info("✓ FP8 quantization applied! Running with FA3 + FP8")
|
| 233 |
|
| 234 |
results = []
|
| 235 |
seeds_used = []
|
|
@@ -255,7 +269,8 @@ def interactive_session(
|
|
| 255 |
progress=gr.Progress()
|
| 256 |
):
|
| 257 |
"""
|
| 258 |
-
Interactive session with FA3 +
|
|
|
|
| 259 |
"""
|
| 260 |
# Validate text
|
| 261 |
if len(text) < 1:
|
|
@@ -271,8 +286,14 @@ def interactive_session(
|
|
| 271 |
# Determine author
|
| 272 |
author = author_dropdown if author_dropdown != "None (Synthetic / 合成风格)" else None
|
| 273 |
|
| 274 |
-
#
|
| 275 |
-
yield
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
progress(0.3, desc="生成中...")
|
| 277 |
|
| 278 |
results, seeds_used = run_generation(
|
|
@@ -368,7 +389,15 @@ with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生
|
|
| 368 |
with gr.Column(scale=1):
|
| 369 |
# Output section
|
| 370 |
gr.Markdown("### 🖼️ 生成结果 / Generated Results")
|
| 371 |
-
gr.Markdown("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
|
| 373 |
output_gallery = gr.Gallery(
|
| 374 |
label="生成结果 / Generated Results",
|
|
|
|
| 213 |
return None
|
| 214 |
|
| 215 |
|
| 216 |
+
@spaces.GPU(duration=60)
|
| 217 |
+
def load_and_optimize_model():
|
| 218 |
+
"""
|
| 219 |
+
Step 1: Load model and apply FP8 quantization.
|
| 220 |
+
This triggers torch compilation on first run (may take 1-2 minutes).
|
| 221 |
+
"""
|
| 222 |
+
gen = init_generator()
|
| 223 |
+
|
| 224 |
+
# Apply FP8 quantization (works with FA3)
|
| 225 |
+
logger.info("Applying FP8 quantization to model...")
|
| 226 |
+
quantize_(gen.model, Float8DynamicActivationFloat8WeightConfig())
|
| 227 |
+
logger.info("✓ FP8 quantization applied! Running with FA3 + FP8")
|
| 228 |
+
|
| 229 |
+
return "ready"
|
| 230 |
+
|
| 231 |
+
|
| 232 |
def _get_generation_duration(text, font, author, num_steps, start_seed, num_images):
|
| 233 |
+
"""Calculate dynamic GPU duration: 2s per step per image"""
|
| 234 |
+
return 20 + int(2 * num_steps * num_images)
|
| 235 |
|
| 236 |
|
| 237 |
@spaces.GPU(duration=_get_generation_duration)
|
| 238 |
def run_generation(text, font, author, num_steps, start_seed, num_images):
|
| 239 |
"""
|
| 240 |
+
Step 2: Run actual generation.
|
| 241 |
+
Model should already be loaded from step 1.
|
| 242 |
"""
|
| 243 |
gen = init_generator()
|
| 244 |
|
| 245 |
+
# Re-apply FP8 quantization (ZeroGPU sessions are isolated)
|
|
|
|
| 246 |
quantize_(gen.model, Float8DynamicActivationFloat8WeightConfig())
|
|
|
|
| 247 |
|
| 248 |
results = []
|
| 249 |
seeds_used = []
|
|
|
|
| 269 |
progress=gr.Progress()
|
| 270 |
):
|
| 271 |
"""
|
| 272 |
+
Interactive session with FA3 + FP8.
|
| 273 |
+
Split into load + generate steps for better progress visibility.
|
| 274 |
"""
|
| 275 |
# Validate text
|
| 276 |
if len(text) < 1:
|
|
|
|
| 286 |
# Determine author
|
| 287 |
author = author_dropdown if author_dropdown != "None (Synthetic / 合成风格)" else None
|
| 288 |
|
| 289 |
+
# Step 1: Load and optimize model
|
| 290 |
+
yield "⏳ 加载模型中... (首次可能需要1-2分钟编译) / Loading model... (first run may take 1-2 min to compile)", []
|
| 291 |
+
progress(0.1, desc="加载模型...")
|
| 292 |
+
|
| 293 |
+
load_and_optimize_model()
|
| 294 |
+
|
| 295 |
+
# Step 2: Run generation
|
| 296 |
+
yield f"🎨 生成 {num_images} 张图片中... / Generating {num_images} images...", []
|
| 297 |
progress(0.3, desc="生成中...")
|
| 298 |
|
| 299 |
results, seeds_used = run_generation(
|
|
|
|
| 389 |
with gr.Column(scale=1):
|
| 390 |
# Output section
|
| 391 |
gr.Markdown("### 🖼️ 生成结果 / Generated Results")
|
| 392 |
+
gr.Markdown("""
|
| 393 |
+
⚠️ **首次生成说明 / First Run Note:**
|
| 394 |
+
- 第一次生成会触发 PyTorch 编译,可能需要 1-2 分钟
|
| 395 |
+
- 如果遇到错误,请**再点一次生成按钮**即可正常运行
|
| 396 |
+
- First generation triggers PyTorch compilation (~1-2 min)
|
| 397 |
+
- If you see an error, just **click generate again** and it will work
|
| 398 |
+
|
| 399 |
+
*点击图片可放大查看 / Click image to enlarge*
|
| 400 |
+
""")
|
| 401 |
|
| 402 |
output_gallery = gr.Gallery(
|
| 403 |
label="生成结果 / Generated Results",
|