File size: 23,027 Bytes
5c86cdc
 
 
a895d85
5c86cdc
 
f1e9349
46eccdb
2a1fcfe
 
 
 
 
46eccdb
 
 
a895d85
46eccdb
a895d85
46eccdb
a895d85
46eccdb
 
2a1fcfe
f1e9349
9d5c8cc
 
 
c2ad4cd
5c86cdc
 
89e2699
a1f5b88
5c86cdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2ad4cd
 
5c86cdc
 
 
 
 
 
c2ad4cd
a1f5b88
5c86cdc
a1f5b88
c2ad4cd
a1f5b88
4c08c35
a1f5b88
4c08c35
 
e51b773
 
743a20a
4c08c35
a76178d
 
743a20a
4ea2a04
a76178d
 
 
 
4c08c35
a76178d
 
46f36ce
a76178d
 
4ea2a04
a76178d
4ea2a04
 
 
ce4bbb3
4ea2a04
e51b773
ce4bbb3
0dde832
 
e51b773
 
4ea2a04
 
 
 
 
e51b773
0dde832
 
 
e51b773
 
4ea2a04
 
 
 
 
 
 
 
 
743a20a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ea2a04
743a20a
4c08c35
 
 
 
 
 
 
 
4ea2a04
a76178d
4c08c35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a76178d
4c08c35
a76178d
 
 
c2ad4cd
5c86cdc
6e4c09b
4c08c35
5c86cdc
 
a1f5b88
 
 
4c08c35
c322e84
9d5c8cc
414150e
5c86cdc
 
1ac7d4b
c49775d
a1f5b88
 
5c86cdc
 
 
a1f5b88
450e581
6e4c09b
450e581
4c08c35
 
5c86cdc
a1f5b88
5c86cdc
c2ad4cd
 
7c94f94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a8be65
 
 
c49775d
 
76d0cdd
b4ee924
76d0cdd
7d06fb9
d9765a4
 
 
76d0cdd
96d51ed
4c08c35
 
96d51ed
39d3dc3
 
1ac7d4b
 
 
 
 
 
 
 
 
 
 
4c08c35
39d3dc3
5a8be65
 
 
39d3dc3
5a8be65
39d3dc3
 
 
5a8be65
 
c49775d
1ac7d4b
46e86e6
 
39d3dc3
46e86e6
 
 
 
 
 
 
 
39d3dc3
aa36e12
450e581
 
96d51ed
4c08c35
76d0cdd
 
96d51ed
 
76d0cdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96d51ed
39d3dc3
96d51ed
 
 
89e2699
5c86cdc
76d0cdd
5c86cdc
89e2699
 
 
5c86cdc
 
6e8caef
5c86cdc
89e2699
5c86cdc
 
 
 
 
76d0cdd
 
 
 
 
 
 
 
 
 
5c86cdc
4c08c35
39d3dc3
 
 
 
89e2699
bd3a1d0
 
 
39d3dc3
 
 
96d51ed
76d0cdd
96d51ed
d3ccd4b
96d51ed
d3ccd4b
6e4c09b
76d0cdd
 
 
d3ccd4b
89e2699
 
c2ad4cd
5c86cdc
 
7d06fb9
5c86cdc
 
 
33b9eec
 
5c86cdc
 
 
 
33b9eec
 
5c86cdc
 
 
 
 
 
 
 
 
 
5dc8b08
 
5c86cdc
 
 
375cc31
5c86cdc
375cc31
76d0cdd
 
5c86cdc
5dc8b08
5c86cdc
 
2292172
 
 
 
5c86cdc
375cc31
 
76d0cdd
375cc31
 
76d0cdd
 
 
 
 
 
 
 
 
 
 
 
 
375cc31
 
 
 
5c86cdc
 
 
 
bd3a1d0
 
 
 
5c86cdc
bd3a1d0
 
5c86cdc
 
89e2699
 
 
 
 
5c86cdc
89e2699
 
d3ccd4b
89e2699
d3ccd4b
89e2699
d3ccd4b
 
89e2699
5c86cdc
7d06fb9
5c86cdc
89e2699
6e8caef
 
 
 
 
 
 
 
 
5c86cdc
d3ccd4b
 
 
 
7d06fb9
 
d3ccd4b
 
 
5c86cdc
89e2699
 
 
5c86cdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76d0cdd
 
 
 
 
 
375cc31
 
 
 
 
 
 
 
 
 
 
 
 
 
76d0cdd
 
 
 
 
 
5c86cdc
89e2699
5c86cdc
89e2699
76d0cdd
89e2699
5c86cdc
 
 
 
 
 
f57a02d
76d0cdd
 
 
c2ad4cd
76d0cdd
c2ad4cd
 
5c86cdc
7789c9e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
# -*- coding: utf-8 -*-
"""
Gradio Demo for Chinese Calligraphy Generation - HuggingFace Space Version
With FA3 + FP8 quantization for faster inference
"""

import os
import sys

print("=" * 50, flush=True)
print("APP.PY STARTING - DEBUG VERSION", flush=True)
print("=" * 50, flush=True)

import logging
from datetime import datetime

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
print("Logging setup complete", flush=True)

# IMPORTANT: import spaces first before any CUDA-related packages
import spaces

import gradio as gr
import json
import csv
import time
import torch

# Load author and font mappings from CSV
def load_author_fonts_from_csv(csv_path):
    """
    Load author and their available fonts from CSV file
    Filters out authors that only support 隶 or 篆 fonts
    Returns: dict mapping author to list of font styles
    """
    author_fonts = {}
    excluded_fonts = {'隶', '篆'}  # Fonts we don't support
    
    with open(csv_path, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        for row in reader:
            author = row['书法家']
            fonts = row['字体类型'].split('|')  # Split multiple fonts by |
            
            # Filter out unsupported fonts (隶 and 篆)
            supported_fonts = [f for f in fonts if f not in excluded_fonts]
            
            # Only include author if they have at least one supported font
            if supported_fonts:
                author_fonts[author] = supported_fonts
    
    return author_fonts

# Load author-font mappings
AUTHOR_FONTS = load_author_fonts_from_csv('dataset/author_fonts_summary.csv')

# Available authors (sorted)
AUTHOR_LIST = sorted(AUTHOR_FONTS.keys())

# Font style display names (only supported styles)
FONT_STYLE_NAMES = {
    "楷": "楷 (Regular Script)",
    "行": "行 (Running Script)", 
    "草": "草 (Cursive Script)"
}

# Load author descriptions if available
try:
    with open('dataset/calligraphy_styles_en.json', 'r', encoding='utf-8') as f:
        author_styles = json.load(f)
except:
    author_styles = {}

# Global generator instance
generator = None
_cached_model_dir = None

# ============================================================
# Pre-download and pre-load model files at startup (no GPU needed)
# ============================================================
_preloaded_embedding = None
_preloaded_tokenizer = None
_cached_t5_dir = None
_cached_clip_dir = None
_cached_font_path = None

def preload_model_files():
    """Pre-download model files to cache at startup (no GPU needed)"""
    global _preloaded_embedding, _preloaded_tokenizer, _cached_t5_dir, _cached_clip_dir, _cached_font_path
    from huggingface_hub import snapshot_download, hf_hub_download
    
    hf_token = os.environ.get("HF_TOKEN", None)
    print("Pre-downloading model files to cache...")
    
    # 1. Main model (Unicalli_Pro) - include all tokenizer files
    try:
        local_dir = snapshot_download(
            repo_id="TSXu/UniCalli_pro_dmd2_K4",
            token=hf_token
        )
        print(f"✓ Unicalli_Pro cached at: {local_dir}")
    except Exception as e:
        print(f"Warning: Could not pre-download Unicalli_Pro: {e}")
        local_dir = None
    
    # 2. T5 text encoder
    try:
        _cached_t5_dir = snapshot_download(
            "xlabs-ai/xflux_text_encoders",
            token=hf_token
        )
        os.environ["XFLUX_TEXT_ENCODER_PATH"] = _cached_t5_dir
        print(f"✓ T5 text encoder cached at: {_cached_t5_dir}")
    except Exception as e:
        print(f"Warning: Could not pre-download T5: {e}")
    
    # 3. CLIP text encoder
    try:
        _cached_clip_dir = snapshot_download(
            "openai/clip-vit-large-patch14",
            token=hf_token
        )
        os.environ["XFLUX_CLIP_ENCODER_PATH"] = _cached_clip_dir
        print(f"✓ CLIP text encoder cached at: {_cached_clip_dir}")
    except Exception as e:
        print(f"Warning: Could not pre-download CLIP: {e}")
    
    # 4. VAE (ae.safetensors from FLUX.1-dev)
    try:
        hf_hub_download("black-forest-labs/FLUX.1-dev", "ae.safetensors", token=hf_token)
        print("✓ VAE cached")
    except Exception as e:
        print(f"Warning: Could not pre-download VAE: {e}")

    # 5. Font file used for condition image rendering
    try:
        _cached_font_path = hf_hub_download(
            repo_id="TSXu/Unicalli_Pro",
            filename="FangZhengKaiTiFanTi-1.ttf",
            token=hf_token,
        )
        os.environ["UNICALLI_FONT_PATH"] = _cached_font_path
        print(f"✓ Font cached at: {_cached_font_path}")
    except Exception as e:
        print(f"Warning: Could not pre-download font: {e}")

    # 6. Flash Attention 3 kernel package (large) pre-cache
    try:
        from kernels import get_kernel
        get_kernel("kernels-community/vllm-flash-attn3")
        print("✓ Flash Attention 3 kernel cached")
    except Exception as e:
        print(f"Warning: Could not pre-cache Flash Attention 3 kernel: {e}")
    
    # 7. Pre-load InternVL embedding to CPU memory (saves ~5s in GPU session)
    if local_dir:
        try:
            intern_vlm_path = os.path.join(local_dir, "internvl_embedding")
            _preloaded_embedding, _preloaded_tokenizer = _preload_embedding(intern_vlm_path)
            print("✓ InternVL embedding pre-loaded to CPU")
        except Exception as e:
            print(f"Warning: Could not pre-load embedding: {e}")
    
    return local_dir


def _preload_embedding(intern_vlm_path):
    """Pre-load embedding and tokenizer to CPU memory"""
    from safetensors.torch import load_file as load_safetensors
    from transformers import AutoTokenizer
    
    embedding_file = os.path.join(intern_vlm_path, "embedding.safetensors")
    config_file = os.path.join(intern_vlm_path, "embedding_config.json")
    
    if not os.path.exists(embedding_file) or not os.path.exists(config_file):
        print(f"  Embedding files not found at {intern_vlm_path}")
        return None, None
    
    print(f"  Loading embedding from: {intern_vlm_path}")
    
    with open(config_file, 'r') as f:
        config = json.load(f)
    
    # Create embedding layer on CPU
    embed_tokens = torch.nn.Embedding(
        num_embeddings=config["num_embeddings"],
        embedding_dim=config["embedding_dim"],
        padding_idx=config.get("padding_idx", None)
    )
    
    # Load weights
    state_dict = load_safetensors(embedding_file)
    embed_tokens.load_state_dict(state_dict)
    embed_tokens.eval()
    embed_tokens.requires_grad_(False)
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        intern_vlm_path, trust_remote_code=True, use_fast=False
    )
    
    print(f"  Loaded: {config['num_embeddings']} x {config['embedding_dim']}")
    return embed_tokens, tokenizer


print("="*50)
print("Starting model pre-download and pre-load...")
_cached_model_dir = preload_model_files()
print("="*50)


def init_generator():
    """Initialize the generator (without optimization - that's done separately)"""
    global generator, _cached_model_dir, _preloaded_embedding, _preloaded_tokenizer
    
    if generator is None:
        intern_vlm_path = os.path.join(_cached_model_dir, "internvl_embedding")
        checkpoint_path = _cached_model_dir
        print(f"Using pre-cached model from: {_cached_model_dir}")
        print(f"Using pre-loaded embedding: {_preloaded_embedding is not None}")
        
        from inference import CalligraphyGenerator
        
        generator = CalligraphyGenerator(
            model_name="flux-dev",
            device="cpu",  # Changed to cpu to prevent CUDA init before ZeroGPU fork!
            offload=False,  # Set to False to let ZeroGPU manage CUDA memory directly instead of manual CPU thrashing
            intern_vlm_path=intern_vlm_path,
            checkpoint_path=checkpoint_path,
            font_descriptions_path='dataset/chirography.json',
            author_descriptions_path='dataset/calligraphy_styles_en.json',
            use_deepspeed=False,
            use_4bit_quantization=False,
            use_float8_quantization=False,
            use_torch_compile=False,
            dtype="bf16",  # Revert to bf16 to let FA3 run and avoid Float8 unimplemented errors
            preloaded_embedding=_preloaded_embedding,
            preloaded_tokenizer=_preloaded_tokenizer,
        )
    
    return generator


def update_font_choices(author: str):
    """
    Update available font choices based on selected author
    """
    if author == "None (Synthetic / 合成风格)" or author not in AUTHOR_FONTS:
        choices = list(FONT_STYLE_NAMES.values())
    else:
        available_fonts = AUTHOR_FONTS[author]
        choices = [FONT_STYLE_NAMES[font] for font in available_fonts if font in FONT_STYLE_NAMES]
    
    return gr.Dropdown(choices=choices, value=choices[0] if choices else None)


def parse_font_style(font_style: str) -> str:
    """Extract font key from display name"""
    for font_key, font_display in FONT_STYLE_NAMES.items():
        if font_display == font_style:
            return font_key
    return None


# IMPORTANT:
# Do NOT initialize generator globally at import time in ZeroGPU Spaces.
# Keep it lazy inside the @spaces.GPU worker to avoid any pre-fork CUDA side effects.


def _get_generation_duration(text, pairs, num_steps, start_seed, num_images, progress=None):
    """Calculate dynamic GPU duration: 24s base + 3s per image"""
    num_pairs = len(pairs) if pairs else 1
    return 30 + int(3 * num_images * num_pairs)


@spaces.GPU(duration=_get_generation_duration)
def run_generation(text, pairs, num_steps, start_seed, num_images, progress=gr.Progress()):
    """
    Load model, apply FP8 quantization, and generate images.
    All in one GPU session to avoid redundant loading.
    """
    progress(0.25, desc="准备 GPU 环境 / Preparing GPU runtime...")

    # Enable CUDA optimizations inside the worker
    try:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        torch.backends.cudnn.benchmark = True
        torch.backends.cuda.enable_flash_sdp(True)
        torch.backends.cuda.enable_mem_efficient_sdp(True)
        torch.backends.cuda.enable_math_sdp(False)
    except Exception:
        pass

    # Step 1: Load model
    progress(0.35, desc="检查模型状态 / Checking model state...")
    global generator
    if generator is None:
        logger.info("Initializing generator lazily inside GPU worker...")
        progress(0.45, desc="首次初始化模型 / First-time model initialization...")
        generator = init_generator()
        progress(0.65, desc="模型初始化完成 / Model initialization complete")
    else:
        progress(0.55, desc="复用已初始化模型 / Reusing initialized model")

    logger.info("Using initialized generator in ZeroGPU worker.")
    gen = generator
    # ZeroGPU automatically maps these to the acquired GPU during execution.
    # We must also correctly update internal Python attributes so runtime-generated latents go to GPU.
    target_device = torch.device("cuda")
    progress(0.72, desc="迁移模型到 GPU / Moving model to GPU...")
    gen.device = target_device
    if hasattr(gen, "sampler") and gen.sampler is not None:
        gen.sampler.device = target_device
        
    gen.model.to(target_device)
    gen.clip.to(target_device)
    gen.t5.to(target_device)
    gen.vae.to(target_device)
    progress(0.82, desc="模型就绪,开始生成 / Model ready, starting generation...")
    
    # Step 2: Since we reverted to bf16 load to avoid PyTorch native dtype mix issues, skip wrapping
    logger.info("Model weights decompressed to bfloat16 upon load. Skipping dynamic quantization to ensure stability.")
    
    # Step 3: Generate images
    total_gens = len(pairs) * num_images
    logger.info(f"Generating {total_gens} images across {len(pairs)} styles...")
    results = []
    seeds_used = []
    
    gen_idx = 0
    for author, font in pairs:
        for i in range(num_images):
            gen_idx += 1
            loop_progress = 0.82 + (gen_idx / max(total_gens, 1)) * 0.16
            progress(loop_progress, desc=f"生成第 {gen_idx}/{total_gens} 张 / Generating {gen_idx}/{total_gens}")
            current_seed = start_seed + i
            
            cond_author = author if author != "None (Synthetic / 合成风格)" else None
            
            result_img, cond_img = gen.generate(
                text=text, font_style=font, author=cond_author,
                num_steps=num_steps, seed=current_seed,
                guidance=1.0,
            )
            
            author_label = author if author else "Synthetic"
            label = f"{author_label} - {font} (Seed: {current_seed})"
            results.append((result_img, label))
            seeds_used.append(current_seed)
            logger.info(f"  Generated image {gen_idx}/{total_gens}")
    
    progress(1.0, desc="生成完成 / Generation complete")
    return results, seeds_used


def interactive_session(
    text: str,
    a1, f1, a2, f2, a3, f3, a4, f4,
    num_steps: int,
    start_seed: int,
    num_images: int,
    progress=gr.Progress()
):
    """
    Interactive session with FA3 + FP8.
    """
    # Validate text
    if len(text) < 1:
        raise gr.Error("文本不能为空 / Text cannot be empty")
    if len(text) > 7:
        raise gr.Error(f"文本最多7个字符 / Text must be at most 7 characters. Current: {len(text)}")
    
    raw_pairs = [(a1, f1), (a2, f2), (a3, f3), (a4, f4)]
    pairs = []
    for a, f_style in raw_pairs:
        if a and f_style:
            parsed_font = parse_font_style(f_style)
            if parsed_font is not None:
                pairs.append((a, parsed_font))
                
    if not pairs:
        raise gr.Error("请至少选择一项书法家和字体组合 / Please select at least one combination")
    
    # Run generation (includes model loading + FP8 quantization + generation)
    yield "⏳ 队列中:准备任务... / Queued: preparing task...", []
    progress(0.05, desc="校验输入参数 / Validating inputs...")
    yield "⏳ 输入已通过校验,等待 GPU 分配... / Inputs validated, waiting for GPU allocation...", []
    progress(0.15, desc="等待 GPU 资源 / Waiting for GPU allocation...")
    
    # Hardcode num_steps to 4 for DMD distillation
    num_steps = 4
    
    yield "⏳ 已分配 GPU,正在初始化与生成... / GPU allocated, initializing and generating...", []
    progress(0.22, desc="进入生成阶段 / Entering generation stage...")

    results, seeds_used = run_generation(
        text, pairs, num_steps, start_seed, num_images, progress
    )
    
    progress(1.0, desc="完成!")
    
    # Final status
    total_imgs = len(results)
    if total_imgs > 1:
        final_status = f"✅ 全部完成!共 {total_imgs} 张 (Seeds: {seeds_used[0]}-{seeds_used[-1]})"
    else:
        final_status = f"✅ 完成!Seed: {seeds_used[0]}"
    yield final_status, results


# Create Gradio interface
with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生成器", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # 🖌️ UniCalli - 中国书法生成器 / Chinese Calligraphy Generator
    
    **A Unified Diffusion Framework for Column-Level Generation and Recognition of Chinese Calligraphy**
    
    Generate beautiful Chinese calligraphy in various styles and by different historical masters.
    
    用不同历史书法大师的风格生成精美的中国书法。
    
    🌐 [Project Page](https://envision-research.github.io/UniCalli/) | 📄 [Paper](https://arxiv.org/abs/2510.13745) | 💻 [Code](https://github.com/Envision-Research/UniCalli)
    
    **注意 / Note**: 支持1-7个汉字输入 / Supports 1-7 Chinese characters.
    """)
    
    with gr.Row():
        with gr.Column(scale=1):
            # Input section
            gr.Markdown("### 📝 输入设置 / Input Settings")
            
            text_input = gr.Textbox(
                label="输入文本 / Input Text (1-7个字符 / 1-7 characters)",
                placeholder="请输入1-7个汉字 / Enter 1-7 Chinese characters, e.g.: 相见时难别亦难",
                value="相见时难别亦难",
                max_lines=1
            )
            
            gr.Markdown("### 👤 书法家与字体组合 / Calligraphers & Fonts")
            
            combo_groups = []
            author_dropdowns = []
            font_dropdowns = []
            
            initial_author = "文征明"
            initial_fonts = AUTHOR_FONTS.get(initial_author, ["楷", "草", "行"])
            initial_font_choices = [FONT_STYLE_NAMES[f] for f in initial_fonts if f in FONT_STYLE_NAMES]
            if initial_author == "文征明" and "行" in initial_fonts:
                default_font = FONT_STYLE_NAMES["行"]
            else:
                default_font = initial_font_choices[0] if initial_font_choices else "草 (Cursive Script)"
            
            active_count_state = gr.State(value=1)
            
            for i in range(4):
                with gr.Group(visible=(i == 0)) as group:
                    gr.Markdown(f"**组合 {i+1} / Combination {i+1}**")
                    with gr.Row():
                        a_drop = gr.Dropdown(
                            label=f"书法家 / Calligrapher",
                            choices=["None (Synthetic / 合成风格)"] + AUTHOR_LIST,
                            value="文征明" if i == 0 else None,
                        )
                        f_drop = gr.Dropdown(
                            label=f"字体风格 / Font Style",
                            choices=initial_font_choices if i == 0 else list(FONT_STYLE_NAMES.values()),
                            value=default_font if i == 0 else None,
                        )
                        author_dropdowns.append(a_drop)
                        font_dropdowns.append(f_drop)
                combo_groups.append(group)
            
            add_combo_btn = gr.Button("➕ 添加书法家与风格组合 / Add Combination", size="sm")

            
            gr.Markdown("### ⚙️ 生成设置 / Generation Settings")
            
            num_steps = gr.Slider(
                label="生成步数 / Inference Steps (Fixed to 4 for DMD)",
                minimum=4,
                maximum=4,
                value=4,
                step=1,
                interactive=False,
                info="使用 DMD 蒸馏模型,强制 4 步生成 / Uses DMD distilled model, forced to 4 steps"
            )
            
            start_seed = gr.Number(
                label="起始种子 / Start Seed",
                value=42,
                precision=0
            )
            
            num_images = gr.Slider(
                label="生成数量 / Number of Images",
                minimum=1,
                maximum=8,
                value=1,
                step=1
            )
            
            generate_btn = gr.Button("🎨 开始生成 / Start Generation", variant="primary", size="lg")
        
        with gr.Column(scale=1):
            # Output section
            gr.Markdown("### 🖼️ 生成结果 / Generated Results")
            gr.Markdown("""
⚠️ **首次生成说明 / First Run Note:**
- 第一次生成会触发 PyTorch 编译,可能需要 1-2 分钟
- 如果遇到错误,请**再点一次生成按钮**即可正常运行
- First generation triggers PyTorch compilation (~1-2 min)
- If you see an error, just **click generate again** and it will work

*点击图片可放大查看 / Click image to enlarge*
""")
            
            output_gallery = gr.Gallery(
                label="生成结果 / Generated Results",
                show_label=False,
                columns=2,
                rows=2,
                height=550,
                object_fit="contain",
                allow_preview=True
            )
            
            status_text = gr.Textbox(
                label="状态 / Status",
                value="准备就绪 / Ready",
                interactive=False
            )
    
    # Author info section
    with gr.Accordion("📚 可用书法家列表 / Available Calligraphers(共 {} 位 / {} total)".format(len(AUTHOR_LIST), len(AUTHOR_LIST)), open=False):
        author_info_md = "| 书法家 / Calligrapher | 可用字体 / Available Fonts |\n|--------|----------|\n"
        for author in AUTHOR_LIST[:30]:
            fonts = " | ".join(AUTHOR_FONTS[author])
            desc = author_styles.get(author, "")
            desc_short = desc[:50] + "..." if len(desc) > 50 else desc
            author_info_md += f"| **{author}** | {fonts} |\n"
        if len(AUTHOR_LIST) > 30:
            author_info_md += f"\n*... 还有 {len(AUTHOR_LIST) - 30} 位书法家 / {len(AUTHOR_LIST) - 30} more calligraphers*"
        gr.Markdown(author_info_md)
    
    # Event handlers
    for i in range(4):
        author_dropdowns[i].change(
            fn=update_font_choices,
            inputs=[author_dropdowns[i]],
            outputs=[font_dropdowns[i]]
        )
        
    def add_combo(current_count):
        new_count = min(current_count + 1, 4)
        updates = []
        for i in range(4):
            updates.append(gr.update(visible=(i < new_count)))
        updates.append(gr.update(interactive=(new_count < 4)))
        return [new_count] + updates

    add_combo_btn.click(
        fn=add_combo,
        inputs=[active_count_state],
        outputs=[active_count_state] + combo_groups + [add_combo_btn]
    )
    
    # Prepare inputs list for the interactive session
    session_inputs = [text_input]
    for i in range(4):
        session_inputs.extend([author_dropdowns[i], font_dropdowns[i]])
    session_inputs.extend([num_steps, start_seed, num_images])
    
    # Generate button - uses streaming for live updates
    generate_btn.click(
        fn=interactive_session,
        inputs=session_inputs,
        outputs=[status_text, output_gallery]
    )
    
    # Examples
    gr.Markdown("### 📋 示例 / Examples")
    gr.Examples(
        examples=[
            # [text, author1, font1, author2, font2, author3, font3, author4, font4, num_steps, start_seed, num_images]
            ["相见时难别亦难", "文征明", "行 (Running Script)", None, None, None, None, None, None, 4, 1024, 1],
            ["天道酬勤", "王羲之", "草 (Cursive Script)", "黄庭坚", "行 (Running Script)", None, None, None, None, 4, 42, 1],
            ["厚德载物", "赵孟頫", "楷 (Regular Script)", None, None, None, None, None, None, 4, 123, 1],
        ],
        inputs=session_inputs,
    )


demo.launch()