# -*- coding: utf-8 -*- """ Gradio Demo for Chinese Calligraphy Generation - HuggingFace Space Version With FA3 + FP8 quantization for faster inference """ import os import sys print("=" * 50, flush=True) print("APP.PY STARTING - DEBUG VERSION", flush=True) print("=" * 50, flush=True) import logging from datetime import datetime # Setup logging logging.basicConfig( level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s', handlers=[logging.StreamHandler(sys.stdout)] ) logger = logging.getLogger(__name__) print("Logging setup complete", flush=True) # IMPORTANT: import spaces first before any CUDA-related packages import spaces import gradio as gr import json import csv import time import torch # Load author and font mappings from CSV def load_author_fonts_from_csv(csv_path): """ Load author and their available fonts from CSV file Filters out authors that only support 隶 or 篆 fonts Returns: dict mapping author to list of font styles """ author_fonts = {} excluded_fonts = {'隶', '篆'} # Fonts we don't support with open(csv_path, 'r', encoding='utf-8') as f: reader = csv.DictReader(f) for row in reader: author = row['书法家'] fonts = row['字体类型'].split('|') # Split multiple fonts by | # Filter out unsupported fonts (隶 and 篆) supported_fonts = [f for f in fonts if f not in excluded_fonts] # Only include author if they have at least one supported font if supported_fonts: author_fonts[author] = supported_fonts return author_fonts # Load author-font mappings AUTHOR_FONTS = load_author_fonts_from_csv('dataset/author_fonts_summary.csv') # Available authors (sorted) AUTHOR_LIST = sorted(AUTHOR_FONTS.keys()) # Font style display names (only supported styles) FONT_STYLE_NAMES = { "楷": "楷 (Regular Script)", "行": "行 (Running Script)", "草": "草 (Cursive Script)" } # Load author descriptions if available try: with open('dataset/calligraphy_styles_en.json', 'r', encoding='utf-8') as f: author_styles = json.load(f) except: author_styles = {} # Global generator instance generator = None _cached_model_dir = None # ============================================================ # Pre-download and pre-load model files at startup (no GPU needed) # ============================================================ _preloaded_embedding = None _preloaded_tokenizer = None _cached_t5_dir = None _cached_clip_dir = None _cached_font_path = None def preload_model_files(): """Pre-download model files to cache at startup (no GPU needed)""" global _preloaded_embedding, _preloaded_tokenizer, _cached_t5_dir, _cached_clip_dir, _cached_font_path from huggingface_hub import snapshot_download, hf_hub_download hf_token = os.environ.get("HF_TOKEN", None) print("Pre-downloading model files to cache...") # 1. Main model (Unicalli_Pro) - include all tokenizer files try: local_dir = snapshot_download( repo_id="TSXu/UniCalli_pro_dmd2_K4", token=hf_token ) print(f"✓ Unicalli_Pro cached at: {local_dir}") except Exception as e: print(f"Warning: Could not pre-download Unicalli_Pro: {e}") local_dir = None # 2. T5 text encoder try: _cached_t5_dir = snapshot_download( "xlabs-ai/xflux_text_encoders", token=hf_token ) os.environ["XFLUX_TEXT_ENCODER_PATH"] = _cached_t5_dir print(f"✓ T5 text encoder cached at: {_cached_t5_dir}") except Exception as e: print(f"Warning: Could not pre-download T5: {e}") # 3. CLIP text encoder try: _cached_clip_dir = snapshot_download( "openai/clip-vit-large-patch14", token=hf_token ) os.environ["XFLUX_CLIP_ENCODER_PATH"] = _cached_clip_dir print(f"✓ CLIP text encoder cached at: {_cached_clip_dir}") except Exception as e: print(f"Warning: Could not pre-download CLIP: {e}") # 4. VAE (ae.safetensors from FLUX.1-dev) try: hf_hub_download("black-forest-labs/FLUX.1-dev", "ae.safetensors", token=hf_token) print("✓ VAE cached") except Exception as e: print(f"Warning: Could not pre-download VAE: {e}") # 5. Font file used for condition image rendering try: _cached_font_path = hf_hub_download( repo_id="TSXu/Unicalli_Pro", filename="FangZhengKaiTiFanTi-1.ttf", token=hf_token, ) os.environ["UNICALLI_FONT_PATH"] = _cached_font_path print(f"✓ Font cached at: {_cached_font_path}") except Exception as e: print(f"Warning: Could not pre-download font: {e}") # 6. Flash Attention 3 kernel package (large) pre-cache try: from kernels import get_kernel get_kernel("kernels-community/vllm-flash-attn3") print("✓ Flash Attention 3 kernel cached") except Exception as e: print(f"Warning: Could not pre-cache Flash Attention 3 kernel: {e}") # 7. Pre-load InternVL embedding to CPU memory (saves ~5s in GPU session) if local_dir: try: intern_vlm_path = os.path.join(local_dir, "internvl_embedding") _preloaded_embedding, _preloaded_tokenizer = _preload_embedding(intern_vlm_path) print("✓ InternVL embedding pre-loaded to CPU") except Exception as e: print(f"Warning: Could not pre-load embedding: {e}") return local_dir def _preload_embedding(intern_vlm_path): """Pre-load embedding and tokenizer to CPU memory""" from safetensors.torch import load_file as load_safetensors from transformers import AutoTokenizer embedding_file = os.path.join(intern_vlm_path, "embedding.safetensors") config_file = os.path.join(intern_vlm_path, "embedding_config.json") if not os.path.exists(embedding_file) or not os.path.exists(config_file): print(f" Embedding files not found at {intern_vlm_path}") return None, None print(f" Loading embedding from: {intern_vlm_path}") with open(config_file, 'r') as f: config = json.load(f) # Create embedding layer on CPU embed_tokens = torch.nn.Embedding( num_embeddings=config["num_embeddings"], embedding_dim=config["embedding_dim"], padding_idx=config.get("padding_idx", None) ) # Load weights state_dict = load_safetensors(embedding_file) embed_tokens.load_state_dict(state_dict) embed_tokens.eval() embed_tokens.requires_grad_(False) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained( intern_vlm_path, trust_remote_code=True, use_fast=False ) print(f" Loaded: {config['num_embeddings']} x {config['embedding_dim']}") return embed_tokens, tokenizer print("="*50) print("Starting model pre-download and pre-load...") _cached_model_dir = preload_model_files() print("="*50) def init_generator(): """Initialize the generator (without optimization - that's done separately)""" global generator, _cached_model_dir, _preloaded_embedding, _preloaded_tokenizer if generator is None: intern_vlm_path = os.path.join(_cached_model_dir, "internvl_embedding") checkpoint_path = _cached_model_dir print(f"Using pre-cached model from: {_cached_model_dir}") print(f"Using pre-loaded embedding: {_preloaded_embedding is not None}") from inference import CalligraphyGenerator generator = CalligraphyGenerator( model_name="flux-dev", device="cpu", # Changed to cpu to prevent CUDA init before ZeroGPU fork! offload=False, # Set to False to let ZeroGPU manage CUDA memory directly instead of manual CPU thrashing intern_vlm_path=intern_vlm_path, checkpoint_path=checkpoint_path, font_descriptions_path='dataset/chirography.json', author_descriptions_path='dataset/calligraphy_styles_en.json', use_deepspeed=False, use_4bit_quantization=False, use_float8_quantization=False, use_torch_compile=False, dtype="bf16", # Revert to bf16 to let FA3 run and avoid Float8 unimplemented errors preloaded_embedding=_preloaded_embedding, preloaded_tokenizer=_preloaded_tokenizer, ) return generator def update_font_choices(author: str): """ Update available font choices based on selected author """ if author == "None (Synthetic / 合成风格)" or author not in AUTHOR_FONTS: choices = list(FONT_STYLE_NAMES.values()) else: available_fonts = AUTHOR_FONTS[author] choices = [FONT_STYLE_NAMES[font] for font in available_fonts if font in FONT_STYLE_NAMES] return gr.Dropdown(choices=choices, value=choices[0] if choices else None) def parse_font_style(font_style: str) -> str: """Extract font key from display name""" for font_key, font_display in FONT_STYLE_NAMES.items(): if font_display == font_style: return font_key return None # IMPORTANT: # Do NOT initialize generator globally at import time in ZeroGPU Spaces. # Keep it lazy inside the @spaces.GPU worker to avoid any pre-fork CUDA side effects. def _get_generation_duration(text, pairs, num_steps, start_seed, num_images, progress=None): """Calculate dynamic GPU duration: 24s base + 3s per image""" num_pairs = len(pairs) if pairs else 1 return 30 + int(3 * num_images * num_pairs) @spaces.GPU(duration=_get_generation_duration) def run_generation(text, pairs, num_steps, start_seed, num_images, progress=gr.Progress()): """ Load model, apply FP8 quantization, and generate images. All in one GPU session to avoid redundant loading. """ progress(0.25, desc="准备 GPU 环境 / Preparing GPU runtime...") # Enable CUDA optimizations inside the worker try: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_math_sdp(False) except Exception: pass # Step 1: Load model progress(0.35, desc="检查模型状态 / Checking model state...") global generator if generator is None: logger.info("Initializing generator lazily inside GPU worker...") progress(0.45, desc="首次初始化模型 / First-time model initialization...") generator = init_generator() progress(0.65, desc="模型初始化完成 / Model initialization complete") else: progress(0.55, desc="复用已初始化模型 / Reusing initialized model") logger.info("Using initialized generator in ZeroGPU worker.") gen = generator # ZeroGPU automatically maps these to the acquired GPU during execution. # We must also correctly update internal Python attributes so runtime-generated latents go to GPU. target_device = torch.device("cuda") progress(0.72, desc="迁移模型到 GPU / Moving model to GPU...") gen.device = target_device if hasattr(gen, "sampler") and gen.sampler is not None: gen.sampler.device = target_device gen.model.to(target_device) gen.clip.to(target_device) gen.t5.to(target_device) gen.vae.to(target_device) progress(0.82, desc="模型就绪,开始生成 / Model ready, starting generation...") # Step 2: Since we reverted to bf16 load to avoid PyTorch native dtype mix issues, skip wrapping logger.info("Model weights decompressed to bfloat16 upon load. Skipping dynamic quantization to ensure stability.") # Step 3: Generate images total_gens = len(pairs) * num_images logger.info(f"Generating {total_gens} images across {len(pairs)} styles...") results = [] seeds_used = [] gen_idx = 0 for author, font in pairs: for i in range(num_images): gen_idx += 1 loop_progress = 0.82 + (gen_idx / max(total_gens, 1)) * 0.16 progress(loop_progress, desc=f"生成第 {gen_idx}/{total_gens} 张 / Generating {gen_idx}/{total_gens}") current_seed = start_seed + i cond_author = author if author != "None (Synthetic / 合成风格)" else None result_img, cond_img = gen.generate( text=text, font_style=font, author=cond_author, num_steps=num_steps, seed=current_seed, guidance=1.0, ) author_label = author if author else "Synthetic" label = f"{author_label} - {font} (Seed: {current_seed})" results.append((result_img, label)) seeds_used.append(current_seed) logger.info(f" Generated image {gen_idx}/{total_gens}") progress(1.0, desc="生成完成 / Generation complete") return results, seeds_used def interactive_session( text: str, a1, f1, a2, f2, a3, f3, a4, f4, num_steps: int, start_seed: int, num_images: int, progress=gr.Progress() ): """ Interactive session with FA3 + FP8. """ # Validate text if len(text) < 1: raise gr.Error("文本不能为空 / Text cannot be empty") if len(text) > 7: raise gr.Error(f"文本最多7个字符 / Text must be at most 7 characters. Current: {len(text)}") raw_pairs = [(a1, f1), (a2, f2), (a3, f3), (a4, f4)] pairs = [] for a, f_style in raw_pairs: if a and f_style: parsed_font = parse_font_style(f_style) if parsed_font is not None: pairs.append((a, parsed_font)) if not pairs: raise gr.Error("请至少选择一项书法家和字体组合 / Please select at least one combination") # Run generation (includes model loading + FP8 quantization + generation) yield "⏳ 队列中:准备任务... / Queued: preparing task...", [] progress(0.05, desc="校验输入参数 / Validating inputs...") yield "⏳ 输入已通过校验,等待 GPU 分配... / Inputs validated, waiting for GPU allocation...", [] progress(0.15, desc="等待 GPU 资源 / Waiting for GPU allocation...") # Hardcode num_steps to 4 for DMD distillation num_steps = 4 yield "⏳ 已分配 GPU,正在初始化与生成... / GPU allocated, initializing and generating...", [] progress(0.22, desc="进入生成阶段 / Entering generation stage...") results, seeds_used = run_generation( text, pairs, num_steps, start_seed, num_images, progress ) progress(1.0, desc="完成!") # Final status total_imgs = len(results) if total_imgs > 1: final_status = f"✅ 全部完成!共 {total_imgs} 张 (Seeds: {seeds_used[0]}-{seeds_used[-1]})" else: final_status = f"✅ 完成!Seed: {seeds_used[0]}" yield final_status, results # Create Gradio interface with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生成器", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🖌️ UniCalli - 中国书法生成器 / Chinese Calligraphy Generator **A Unified Diffusion Framework for Column-Level Generation and Recognition of Chinese Calligraphy** Generate beautiful Chinese calligraphy in various styles and by different historical masters. 用不同历史书法大师的风格生成精美的中国书法。 🌐 [Project Page](https://envision-research.github.io/UniCalli/) | 📄 [Paper](https://arxiv.org/abs/2510.13745) | 💻 [Code](https://github.com/Envision-Research/UniCalli) **注意 / Note**: 支持1-7个汉字输入 / Supports 1-7 Chinese characters. """) with gr.Row(): with gr.Column(scale=1): # Input section gr.Markdown("### 📝 输入设置 / Input Settings") text_input = gr.Textbox( label="输入文本 / Input Text (1-7个字符 / 1-7 characters)", placeholder="请输入1-7个汉字 / Enter 1-7 Chinese characters, e.g.: 相见时难别亦难", value="相见时难别亦难", max_lines=1 ) gr.Markdown("### 👤 书法家与字体组合 / Calligraphers & Fonts") combo_groups = [] author_dropdowns = [] font_dropdowns = [] initial_author = "文征明" initial_fonts = AUTHOR_FONTS.get(initial_author, ["楷", "草", "行"]) initial_font_choices = [FONT_STYLE_NAMES[f] for f in initial_fonts if f in FONT_STYLE_NAMES] if initial_author == "文征明" and "行" in initial_fonts: default_font = FONT_STYLE_NAMES["行"] else: default_font = initial_font_choices[0] if initial_font_choices else "草 (Cursive Script)" active_count_state = gr.State(value=1) for i in range(4): with gr.Group(visible=(i == 0)) as group: gr.Markdown(f"**组合 {i+1} / Combination {i+1}**") with gr.Row(): a_drop = gr.Dropdown( label=f"书法家 / Calligrapher", choices=["None (Synthetic / 合成风格)"] + AUTHOR_LIST, value="文征明" if i == 0 else None, ) f_drop = gr.Dropdown( label=f"字体风格 / Font Style", choices=initial_font_choices if i == 0 else list(FONT_STYLE_NAMES.values()), value=default_font if i == 0 else None, ) author_dropdowns.append(a_drop) font_dropdowns.append(f_drop) combo_groups.append(group) add_combo_btn = gr.Button("➕ 添加书法家与风格组合 / Add Combination", size="sm") gr.Markdown("### ⚙️ 生成设置 / Generation Settings") num_steps = gr.Slider( label="生成步数 / Inference Steps (Fixed to 4 for DMD)", minimum=4, maximum=4, value=4, step=1, interactive=False, info="使用 DMD 蒸馏模型,强制 4 步生成 / Uses DMD distilled model, forced to 4 steps" ) start_seed = gr.Number( label="起始种子 / Start Seed", value=42, precision=0 ) num_images = gr.Slider( label="生成数量 / Number of Images", minimum=1, maximum=8, value=1, step=1 ) generate_btn = gr.Button("🎨 开始生成 / Start Generation", variant="primary", size="lg") with gr.Column(scale=1): # Output section gr.Markdown("### 🖼️ 生成结果 / Generated Results") gr.Markdown(""" ⚠️ **首次生成说明 / First Run Note:** - 第一次生成会触发 PyTorch 编译,可能需要 1-2 分钟 - 如果遇到错误,请**再点一次生成按钮**即可正常运行 - First generation triggers PyTorch compilation (~1-2 min) - If you see an error, just **click generate again** and it will work *点击图片可放大查看 / Click image to enlarge* """) output_gallery = gr.Gallery( label="生成结果 / Generated Results", show_label=False, columns=2, rows=2, height=550, object_fit="contain", allow_preview=True ) status_text = gr.Textbox( label="状态 / Status", value="准备就绪 / Ready", interactive=False ) # Author info section with gr.Accordion("📚 可用书法家列表 / Available Calligraphers(共 {} 位 / {} total)".format(len(AUTHOR_LIST), len(AUTHOR_LIST)), open=False): author_info_md = "| 书法家 / Calligrapher | 可用字体 / Available Fonts |\n|--------|----------|\n" for author in AUTHOR_LIST[:30]: fonts = " | ".join(AUTHOR_FONTS[author]) desc = author_styles.get(author, "") desc_short = desc[:50] + "..." if len(desc) > 50 else desc author_info_md += f"| **{author}** | {fonts} |\n" if len(AUTHOR_LIST) > 30: author_info_md += f"\n*... 还有 {len(AUTHOR_LIST) - 30} 位书法家 / {len(AUTHOR_LIST) - 30} more calligraphers*" gr.Markdown(author_info_md) # Event handlers for i in range(4): author_dropdowns[i].change( fn=update_font_choices, inputs=[author_dropdowns[i]], outputs=[font_dropdowns[i]] ) def add_combo(current_count): new_count = min(current_count + 1, 4) updates = [] for i in range(4): updates.append(gr.update(visible=(i < new_count))) updates.append(gr.update(interactive=(new_count < 4))) return [new_count] + updates add_combo_btn.click( fn=add_combo, inputs=[active_count_state], outputs=[active_count_state] + combo_groups + [add_combo_btn] ) # Prepare inputs list for the interactive session session_inputs = [text_input] for i in range(4): session_inputs.extend([author_dropdowns[i], font_dropdowns[i]]) session_inputs.extend([num_steps, start_seed, num_images]) # Generate button - uses streaming for live updates generate_btn.click( fn=interactive_session, inputs=session_inputs, outputs=[status_text, output_gallery] ) # Examples gr.Markdown("### 📋 示例 / Examples") gr.Examples( examples=[ # [text, author1, font1, author2, font2, author3, font3, author4, font4, num_steps, start_seed, num_images] ["相见时难别亦难", "文征明", "行 (Running Script)", None, None, None, None, None, None, 4, 1024, 1], ["天道酬勤", "王羲之", "草 (Cursive Script)", "黄庭坚", "行 (Running Script)", None, None, None, None, 4, 42, 1], ["厚德载物", "赵孟頫", "楷 (Regular Script)", None, None, None, None, None, None, 4, 123, 1], ], inputs=session_inputs, ) demo.launch()