Tianshuo-Xu commited on
Commit
76d0cdd
·
1 Parent(s): b4ee924

Update UI to support up to 4 calligrapher/style pairs

Browse files
Files changed (1) hide show
  1. app.py +79 -69
app.py CHANGED
@@ -273,13 +273,14 @@ def parse_font_style(font_style: str) -> str:
273
  # Keep it lazy inside the @spaces.GPU worker to avoid any pre-fork CUDA side effects.
274
 
275
 
276
- def _get_generation_duration(text, font, author, num_steps, start_seed, num_images, progress=None):
277
  """Calculate dynamic GPU duration: 24s base + 3s per image"""
278
- return 24 + int(3 * num_images)
 
279
 
280
 
281
  @spaces.GPU(duration=_get_generation_duration)
282
- def run_generation(text, font, author, num_steps, start_seed, num_images, progress=gr.Progress()):
283
  """
284
  Load model, apply FP8 quantization, and generate images.
285
  All in one GPU session to avoid redundant loading.
@@ -328,21 +329,32 @@ def run_generation(text, font, author, num_steps, start_seed, num_images, progre
328
  logger.info("Model weights decompressed to bfloat16 upon load. Skipping dynamic quantization to ensure stability.")
329
 
330
  # Step 3: Generate images
331
- logger.info(f"Generating {num_images} images...")
 
332
  results = []
333
  seeds_used = []
334
- for i in range(num_images):
335
- loop_progress = 0.82 + ((i + 1) / max(num_images, 1)) * 0.16
336
- progress(loop_progress, desc=f"生成第 {i+1}/{num_images} 张 / Generating {i+1}/{num_images}")
337
- current_seed = start_seed + i
338
- result_img, cond_img = gen.generate(
339
- text=text, font_style=font, author=author,
340
- num_steps=num_steps, seed=current_seed,
341
- guidance=1.0,
342
- )
343
- results.append((result_img, f"Seed: {current_seed}"))
344
- seeds_used.append(current_seed)
345
- logger.info(f" Generated image {i+1}/{num_images}")
 
 
 
 
 
 
 
 
 
 
346
 
347
  progress(1.0, desc="生成完成 / Generation complete")
348
  return results, seeds_used
@@ -350,8 +362,7 @@ def run_generation(text, font, author, num_steps, start_seed, num_images, progre
350
 
351
  def interactive_session(
352
  text: str,
353
- author_dropdown: str,
354
- font_style: str,
355
  num_steps: int,
356
  start_seed: int,
357
  num_images: int,
@@ -366,13 +377,16 @@ def interactive_session(
366
  if len(text) > 7:
367
  raise gr.Error(f"文本最多7个字符 / Text must be at most 7 characters. Current: {len(text)}")
368
 
369
- # Parse font style
370
- font = parse_font_style(font_style)
371
- if font is None:
372
- raise gr.Error(f"无法识别的字体风格 / Unknown font style: {font_style}")
373
-
374
- # Determine author
375
- author = author_dropdown if author_dropdown != "None (Synthetic / 合成风格)" else None
 
 
 
376
 
377
  # Run generation (includes model loading + FP8 quantization + generation)
378
  yield "⏳ 队列中:准备任务... / Queued: preparing task...", []
@@ -387,14 +401,15 @@ def interactive_session(
387
  progress(0.22, desc="进入生成阶段 / Entering generation stage...")
388
 
389
  results, seeds_used = run_generation(
390
- text, font, author, num_steps, start_seed, num_images, progress
391
  )
392
 
393
  progress(1.0, desc="完成!")
394
 
395
  # Final status
396
- if num_images > 1:
397
- final_status = f"✅ 全部完成!共 {num_images} 张 (Seeds: {seeds_used[0]}-{seeds_used[-1]})"
 
398
  else:
399
  final_status = f"✅ 完成!Seed: {seeds_used[0]}"
400
  yield final_status, results
@@ -428,31 +443,35 @@ with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生
428
  max_lines=1
429
  )
430
 
431
- gr.Markdown("### 👤 书法家选择 / Calligrapher Selection")
432
 
433
- author_dropdown = gr.Dropdown(
434
- label="1. 选择书法家 / Select Calligrapher",
435
- choices=["None (Synthetic / 合成风格)"] + AUTHOR_LIST,
436
- value="文征明",
437
- info="先选择历史书法家 / Choose a historical calligrapher first"
438
- )
439
 
440
- # Get initial fonts for default author (文征明)
441
  initial_author = "文征明"
442
  initial_fonts = AUTHOR_FONTS.get(initial_author, ["楷", "草", "行"])
443
  initial_font_choices = [FONT_STYLE_NAMES[f] for f in initial_fonts if f in FONT_STYLE_NAMES]
444
- # Default to first available font for the author, prefer "行" for 文征明
445
  if initial_author == "文征明" and "行" in initial_fonts:
446
  default_font = FONT_STYLE_NAMES["行"]
447
  else:
448
  default_font = initial_font_choices[0] if initial_font_choices else "草 (Cursive Script)"
449
 
450
- font_style = gr.Dropdown(
451
- label="2. 选择字体风格 / Select Font Style",
452
- choices=initial_font_choices,
453
- value=default_font,
454
- info="根据所选书法家显示可用字体 / Shows available fonts for selected calligrapher"
455
- )
 
 
 
 
 
 
 
 
 
 
456
 
457
  gr.Markdown("### ⚙️ 生成设置 / Generation Settings")
458
 
@@ -524,23 +543,23 @@ with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生
524
  gr.Markdown(author_info_md)
525
 
526
  # Event handlers
527
- author_dropdown.change(
528
- fn=update_font_choices,
529
- inputs=[author_dropdown],
530
- outputs=[font_style]
531
- )
 
 
 
 
 
 
 
532
 
533
  # Generate button - uses streaming for live updates
534
  generate_btn.click(
535
  fn=interactive_session,
536
- inputs=[
537
- text_input,
538
- author_dropdown,
539
- font_style,
540
- num_steps,
541
- start_seed,
542
- num_images,
543
- ],
544
  outputs=[status_text, output_gallery]
545
  )
546
 
@@ -548,20 +567,11 @@ with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生
548
  gr.Markdown("### 📋 示例 / Examples")
549
  gr.Examples(
550
  examples=[
551
- ["相见时难别亦难", "文征明", "行 (Running Script)", 4, 1024, 1],
552
- ["天道酬勤", "王羲之", "草 (Cursive Script)", 4, 42, 1],
553
- ["厚德载物", "赵孟頫", "楷 (Regular Script)", 4, 123, 1],
554
- ["海内存知己", "黄庭坚", "行 (Running Script)", 4, 456, 1],
555
- ["宁静致远", "None (Synthetic / 合成风格)", "楷 (Regular Script)", 4, 789, 1],
556
- ],
557
- inputs=[
558
- text_input,
559
- author_dropdown,
560
- font_style,
561
- num_steps,
562
- start_seed,
563
- num_images,
564
  ],
 
565
  )
566
 
567
 
 
273
  # Keep it lazy inside the @spaces.GPU worker to avoid any pre-fork CUDA side effects.
274
 
275
 
276
+ def _get_generation_duration(text, pairs, num_steps, start_seed, num_images, progress=None):
277
  """Calculate dynamic GPU duration: 24s base + 3s per image"""
278
+ num_pairs = len(pairs) if pairs else 1
279
+ return 24 + int(3 * num_images * num_pairs)
280
 
281
 
282
  @spaces.GPU(duration=_get_generation_duration)
283
+ def run_generation(text, pairs, num_steps, start_seed, num_images, progress=gr.Progress()):
284
  """
285
  Load model, apply FP8 quantization, and generate images.
286
  All in one GPU session to avoid redundant loading.
 
329
  logger.info("Model weights decompressed to bfloat16 upon load. Skipping dynamic quantization to ensure stability.")
330
 
331
  # Step 3: Generate images
332
+ total_gens = len(pairs) * num_images
333
+ logger.info(f"Generating {total_gens} images across {len(pairs)} styles...")
334
  results = []
335
  seeds_used = []
336
+
337
+ gen_idx = 0
338
+ for author, font in pairs:
339
+ for i in range(num_images):
340
+ gen_idx += 1
341
+ loop_progress = 0.82 + (gen_idx / max(total_gens, 1)) * 0.16
342
+ progress(loop_progress, desc=f"生成第 {gen_idx}/{total_gens} 张 / Generating {gen_idx}/{total_gens}")
343
+ current_seed = start_seed + i
344
+
345
+ cond_author = author if author != "None (Synthetic / 合成风格)" else None
346
+
347
+ result_img, cond_img = gen.generate(
348
+ text=text, font_style=font, author=cond_author,
349
+ num_steps=num_steps, seed=current_seed,
350
+ guidance=1.0,
351
+ )
352
+
353
+ author_label = author if author else "Synthetic"
354
+ label = f"{author_label} - {font} (Seed: {current_seed})"
355
+ results.append((result_img, label))
356
+ seeds_used.append(current_seed)
357
+ logger.info(f" Generated image {gen_idx}/{total_gens}")
358
 
359
  progress(1.0, desc="生成完成 / Generation complete")
360
  return results, seeds_used
 
362
 
363
  def interactive_session(
364
  text: str,
365
+ a1, f1, a2, f2, a3, f3, a4, f4,
 
366
  num_steps: int,
367
  start_seed: int,
368
  num_images: int,
 
377
  if len(text) > 7:
378
  raise gr.Error(f"文本最多7个字符 / Text must be at most 7 characters. Current: {len(text)}")
379
 
380
+ raw_pairs = [(a1, f1), (a2, f2), (a3, f3), (a4, f4)]
381
+ pairs = []
382
+ for a, f_style in raw_pairs:
383
+ if a and f_style:
384
+ parsed_font = parse_font_style(f_style)
385
+ if parsed_font is not None:
386
+ pairs.append((a, parsed_font))
387
+
388
+ if not pairs:
389
+ raise gr.Error("请至少选择一项书法家和字体组合 / Please select at least one combination")
390
 
391
  # Run generation (includes model loading + FP8 quantization + generation)
392
  yield "⏳ 队列中:准备任务... / Queued: preparing task...", []
 
401
  progress(0.22, desc="进入生成阶段 / Entering generation stage...")
402
 
403
  results, seeds_used = run_generation(
404
+ text, pairs, num_steps, start_seed, num_images, progress
405
  )
406
 
407
  progress(1.0, desc="完成!")
408
 
409
  # Final status
410
+ total_imgs = len(results)
411
+ if total_imgs > 1:
412
+ final_status = f"✅ 全部完成!共 {total_imgs} 张 (Seeds: {seeds_used[0]}-{seeds_used[-1]})"
413
  else:
414
  final_status = f"✅ 完成!Seed: {seeds_used[0]}"
415
  yield final_status, results
 
443
  max_lines=1
444
  )
445
 
446
+ gr.Markdown("### 👤 书法家与字体组合 / Calligraphers & Fonts (最多4组 / Up to 4)")
447
 
448
+ author_dropdowns = []
449
+ font_dropdowns = []
 
 
 
 
450
 
 
451
  initial_author = "文征明"
452
  initial_fonts = AUTHOR_FONTS.get(initial_author, ["楷", "草", "行"])
453
  initial_font_choices = [FONT_STYLE_NAMES[f] for f in initial_fonts if f in FONT_STYLE_NAMES]
 
454
  if initial_author == "文征明" and "行" in initial_fonts:
455
  default_font = FONT_STYLE_NAMES["行"]
456
  else:
457
  default_font = initial_font_choices[0] if initial_font_choices else "草 (Cursive Script)"
458
 
459
+ for i in range(4):
460
+ with gr.Group():
461
+ gr.Markdown(f"**组合 {i+1} / Combination {i+1}** (不填则忽略 / Leave blank to ignore)")
462
+ with gr.Row():
463
+ a_drop = gr.Dropdown(
464
+ label=f"书法家 / Calligrapher",
465
+ choices=["None (Synthetic / 合成风格)"] + AUTHOR_LIST,
466
+ value="文征明" if i == 0 else None,
467
+ )
468
+ f_drop = gr.Dropdown(
469
+ label=f"字体风格 / Font Style",
470
+ choices=initial_font_choices if i == 0 else list(FONT_STYLE_NAMES.values()),
471
+ value=default_font if i == 0 else None,
472
+ )
473
+ author_dropdowns.append(a_drop)
474
+ font_dropdowns.append(f_drop)
475
 
476
  gr.Markdown("### ⚙️ 生成设置 / Generation Settings")
477
 
 
543
  gr.Markdown(author_info_md)
544
 
545
  # Event handlers
546
+ for i in range(4):
547
+ author_dropdowns[i].change(
548
+ fn=update_font_choices,
549
+ inputs=[author_dropdowns[i]],
550
+ outputs=[font_dropdowns[i]]
551
+ )
552
+
553
+ # Prepare inputs list for the interactive session
554
+ session_inputs = [text_input]
555
+ for i in range(4):
556
+ session_inputs.extend([author_dropdowns[i], font_dropdowns[i]])
557
+ session_inputs.extend([num_steps, start_seed, num_images])
558
 
559
  # Generate button - uses streaming for live updates
560
  generate_btn.click(
561
  fn=interactive_session,
562
+ inputs=session_inputs,
 
 
 
 
 
 
 
563
  outputs=[status_text, output_gallery]
564
  )
565
 
 
567
  gr.Markdown("### 📋 示例 / Examples")
568
  gr.Examples(
569
  examples=[
570
+ ["相见时难别亦难", "文征明", "行 (Running Script)", None, None, None, None, None, None, 4, 1024, 1],
571
+ ["天道酬勤", "王羲之", "草 (Cursive Script)", "黄庭坚", "行 (Running Script)", None, None, None, None, 4, 42, 1],
572
+ ["厚德载物", "赵孟頫", "楷 (Regular Script)", None, None, None, None, None, None, 4, 123, 1],
 
 
 
 
 
 
 
 
 
 
573
  ],
574
+ inputs=session_inputs,
575
  )
576
 
577