Tianshuo-Xu commited on
Commit
39d3dc3
·
1 Parent(s): 5a8be65

improve gradio progress stages and percentages

Browse files
Files changed (1) hide show
  1. app.py +21 -4
app.py CHANGED
@@ -254,11 +254,13 @@ def _get_generation_duration(text, font, author, num_steps, start_seed, num_imag
254
 
255
 
256
  @spaces.GPU(duration=_get_generation_duration)
257
- def run_generation(text, font, author, num_steps, start_seed, num_images):
258
  """
259
  Load model, apply FP8 quantization, and generate images.
260
  All in one GPU session to avoid redundant loading.
261
  """
 
 
262
  # Enable CUDA optimizations inside the worker
263
  try:
264
  torch.backends.cuda.matmul.allow_tf32 = True
@@ -271,16 +273,22 @@ def run_generation(text, font, author, num_steps, start_seed, num_images):
271
  pass
272
 
273
  # Step 1: Load model
 
274
  global generator
275
  if generator is None:
276
  logger.info("Initializing generator lazily inside GPU worker...")
 
277
  generator = init_generator()
 
 
 
278
 
279
  logger.info("Using initialized generator in ZeroGPU worker.")
280
  gen = generator
281
  # ZeroGPU automatically maps these to the acquired GPU during execution.
282
  # We must also correctly update internal Python attributes so runtime-generated latents go to GPU.
283
  target_device = torch.device("cuda")
 
284
  gen.device = target_device
285
  if hasattr(gen, "sampler") and gen.sampler is not None:
286
  gen.sampler.device = target_device
@@ -289,6 +297,7 @@ def run_generation(text, font, author, num_steps, start_seed, num_images):
289
  gen.clip.to(target_device)
290
  gen.t5.to(target_device)
291
  gen.vae.to(target_device)
 
292
 
293
  # Step 2: Since we reverted to bf16 load to avoid PyTorch native dtype mix issues, skip wrapping
294
  logger.info("Model weights decompressed to bfloat16 upon load. Skipping dynamic quantization to ensure stability.")
@@ -298,6 +307,8 @@ def run_generation(text, font, author, num_steps, start_seed, num_images):
298
  results = []
299
  seeds_used = []
300
  for i in range(num_images):
 
 
301
  current_seed = start_seed + i
302
  result_img, cond_img = gen.generate(
303
  text=text, font_style=font, author=author,
@@ -308,6 +319,7 @@ def run_generation(text, font, author, num_steps, start_seed, num_images):
308
  seeds_used.append(current_seed)
309
  logger.info(f" Generated image {i+1}/{num_images}")
310
 
 
311
  return results, seeds_used
312
 
313
 
@@ -338,14 +350,19 @@ def interactive_session(
338
  author = author_dropdown if author_dropdown != "None (Synthetic / 合成风格)" else None
339
 
340
  # Run generation (includes model loading + FP8 quantization + generation)
341
- yield "⏳ 加载模型并生成中... (首次需要1-2分钟编译) / Loading & generating... (first run ~1-2 min)", []
342
- progress(0.2, desc="处理中...")
 
 
343
 
344
  # Hardcode num_steps to 4 for DMD distillation
345
  num_steps = 4
346
 
 
 
 
347
  results, seeds_used = run_generation(
348
- text, font, author, num_steps, start_seed, num_images
349
  )
350
 
351
  progress(1.0, desc="完成!")
 
254
 
255
 
256
  @spaces.GPU(duration=_get_generation_duration)
257
+ def run_generation(text, font, author, num_steps, start_seed, num_images, progress=gr.Progress()):
258
  """
259
  Load model, apply FP8 quantization, and generate images.
260
  All in one GPU session to avoid redundant loading.
261
  """
262
+ progress(0.25, desc="准备 GPU 环境 / Preparing GPU runtime...")
263
+
264
  # Enable CUDA optimizations inside the worker
265
  try:
266
  torch.backends.cuda.matmul.allow_tf32 = True
 
273
  pass
274
 
275
  # Step 1: Load model
276
+ progress(0.35, desc="检查模型状态 / Checking model state...")
277
  global generator
278
  if generator is None:
279
  logger.info("Initializing generator lazily inside GPU worker...")
280
+ progress(0.45, desc="首次初始化模型 / First-time model initialization...")
281
  generator = init_generator()
282
+ progress(0.65, desc="模型初始化完成 / Model initialization complete")
283
+ else:
284
+ progress(0.55, desc="复用已初始化模型 / Reusing initialized model")
285
 
286
  logger.info("Using initialized generator in ZeroGPU worker.")
287
  gen = generator
288
  # ZeroGPU automatically maps these to the acquired GPU during execution.
289
  # We must also correctly update internal Python attributes so runtime-generated latents go to GPU.
290
  target_device = torch.device("cuda")
291
+ progress(0.72, desc="迁移模型到 GPU / Moving model to GPU...")
292
  gen.device = target_device
293
  if hasattr(gen, "sampler") and gen.sampler is not None:
294
  gen.sampler.device = target_device
 
297
  gen.clip.to(target_device)
298
  gen.t5.to(target_device)
299
  gen.vae.to(target_device)
300
+ progress(0.82, desc="模型就绪,开始生成 / Model ready, starting generation...")
301
 
302
  # Step 2: Since we reverted to bf16 load to avoid PyTorch native dtype mix issues, skip wrapping
303
  logger.info("Model weights decompressed to bfloat16 upon load. Skipping dynamic quantization to ensure stability.")
 
307
  results = []
308
  seeds_used = []
309
  for i in range(num_images):
310
+ loop_progress = 0.82 + ((i + 1) / max(num_images, 1)) * 0.16
311
+ progress(loop_progress, desc=f"生成第 {i+1}/{num_images} 张 / Generating {i+1}/{num_images}")
312
  current_seed = start_seed + i
313
  result_img, cond_img = gen.generate(
314
  text=text, font_style=font, author=author,
 
319
  seeds_used.append(current_seed)
320
  logger.info(f" Generated image {i+1}/{num_images}")
321
 
322
+ progress(1.0, desc="生成完成 / Generation complete")
323
  return results, seeds_used
324
 
325
 
 
350
  author = author_dropdown if author_dropdown != "None (Synthetic / 合成风格)" else None
351
 
352
  # Run generation (includes model loading + FP8 quantization + generation)
353
+ yield "⏳ 队列:准备任务... / Queued: preparing task...", []
354
+ progress(0.05, desc="校验输入参数 / Validating inputs...")
355
+ yield "⏳ 输入已通过校验,等待 GPU 分配... / Inputs validated, waiting for GPU allocation...", []
356
+ progress(0.15, desc="等待 GPU 资源 / Waiting for GPU allocation...")
357
 
358
  # Hardcode num_steps to 4 for DMD distillation
359
  num_steps = 4
360
 
361
+ yield "⏳ 已分配 GPU,正在初始化与生成... / GPU allocated, initializing and generating...", []
362
+ progress(0.22, desc="进入生成阶段 / Entering generation stage...")
363
+
364
  results, seeds_used = run_generation(
365
+ text, font, author, num_steps, start_seed, num_images, progress
366
  )
367
 
368
  progress(1.0, desc="完成!")