Spaces:
Running on Zero
Running on Zero
| # -*- coding: utf-8 -*- | |
| """ | |
| Gradio Demo for Chinese Calligraphy Generation - HuggingFace Space Version | |
| With FA3 + FP8 quantization for faster inference | |
| """ | |
| import os | |
| import sys | |
| print("=" * 50, flush=True) | |
| print("APP.PY STARTING - DEBUG VERSION", flush=True) | |
| print("=" * 50, flush=True) | |
| import logging | |
| from datetime import datetime | |
| # Setup logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s [%(levelname)s] %(message)s', | |
| handlers=[logging.StreamHandler(sys.stdout)] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| print("Logging setup complete", flush=True) | |
| # IMPORTANT: import spaces first before any CUDA-related packages | |
| import spaces | |
| import gradio as gr | |
| import json | |
| import csv | |
| import time | |
| import torch | |
| # Load author and font mappings from CSV | |
| def load_author_fonts_from_csv(csv_path): | |
| """ | |
| Load author and their available fonts from CSV file | |
| Filters out authors that only support 隶 or 篆 fonts | |
| Returns: dict mapping author to list of font styles | |
| """ | |
| author_fonts = {} | |
| excluded_fonts = {'隶', '篆'} # Fonts we don't support | |
| with open(csv_path, 'r', encoding='utf-8') as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| author = row['书法家'] | |
| fonts = row['字体类型'].split('|') # Split multiple fonts by | | |
| # Filter out unsupported fonts (隶 and 篆) | |
| supported_fonts = [f for f in fonts if f not in excluded_fonts] | |
| # Only include author if they have at least one supported font | |
| if supported_fonts: | |
| author_fonts[author] = supported_fonts | |
| return author_fonts | |
| # Load author-font mappings | |
| AUTHOR_FONTS = load_author_fonts_from_csv('dataset/author_fonts_summary.csv') | |
| # Available authors (sorted) | |
| AUTHOR_LIST = sorted(AUTHOR_FONTS.keys()) | |
| # Font style display names (only supported styles) | |
| FONT_STYLE_NAMES = { | |
| "楷": "楷 (Regular Script)", | |
| "行": "行 (Running Script)", | |
| "草": "草 (Cursive Script)" | |
| } | |
| # Load author descriptions if available | |
| try: | |
| with open('dataset/calligraphy_styles_en.json', 'r', encoding='utf-8') as f: | |
| author_styles = json.load(f) | |
| except: | |
| author_styles = {} | |
| # Global generator instance | |
| generator = None | |
| _cached_model_dir = None | |
| # ============================================================ | |
| # Pre-download and pre-load model files at startup (no GPU needed) | |
| # ============================================================ | |
| _preloaded_embedding = None | |
| _preloaded_tokenizer = None | |
| _cached_t5_dir = None | |
| _cached_clip_dir = None | |
| _cached_font_path = None | |
| def preload_model_files(): | |
| """Pre-download model files to cache at startup (no GPU needed)""" | |
| global _preloaded_embedding, _preloaded_tokenizer, _cached_t5_dir, _cached_clip_dir, _cached_font_path | |
| from huggingface_hub import snapshot_download, hf_hub_download | |
| hf_token = os.environ.get("HF_TOKEN", None) | |
| print("Pre-downloading model files to cache...") | |
| # 1. Main model (Unicalli_Pro) - include all tokenizer files | |
| try: | |
| local_dir = snapshot_download( | |
| repo_id="TSXu/UniCalli_pro_dmd2_K4", | |
| token=hf_token | |
| ) | |
| print(f"✓ Unicalli_Pro cached at: {local_dir}") | |
| except Exception as e: | |
| print(f"Warning: Could not pre-download Unicalli_Pro: {e}") | |
| local_dir = None | |
| # 2. T5 text encoder | |
| try: | |
| _cached_t5_dir = snapshot_download( | |
| "xlabs-ai/xflux_text_encoders", | |
| token=hf_token | |
| ) | |
| os.environ["XFLUX_TEXT_ENCODER_PATH"] = _cached_t5_dir | |
| print(f"✓ T5 text encoder cached at: {_cached_t5_dir}") | |
| except Exception as e: | |
| print(f"Warning: Could not pre-download T5: {e}") | |
| # 3. CLIP text encoder | |
| try: | |
| _cached_clip_dir = snapshot_download( | |
| "openai/clip-vit-large-patch14", | |
| token=hf_token | |
| ) | |
| os.environ["XFLUX_CLIP_ENCODER_PATH"] = _cached_clip_dir | |
| print(f"✓ CLIP text encoder cached at: {_cached_clip_dir}") | |
| except Exception as e: | |
| print(f"Warning: Could not pre-download CLIP: {e}") | |
| # 4. VAE (ae.safetensors from FLUX.1-dev) | |
| try: | |
| hf_hub_download("black-forest-labs/FLUX.1-dev", "ae.safetensors", token=hf_token) | |
| print("✓ VAE cached") | |
| except Exception as e: | |
| print(f"Warning: Could not pre-download VAE: {e}") | |
| # 5. Font file used for condition image rendering | |
| try: | |
| _cached_font_path = hf_hub_download( | |
| repo_id="TSXu/Unicalli_Pro", | |
| filename="FangZhengKaiTiFanTi-1.ttf", | |
| token=hf_token, | |
| ) | |
| os.environ["UNICALLI_FONT_PATH"] = _cached_font_path | |
| print(f"✓ Font cached at: {_cached_font_path}") | |
| except Exception as e: | |
| print(f"Warning: Could not pre-download font: {e}") | |
| # 6. Flash Attention 3 kernel package (large) pre-cache | |
| try: | |
| from kernels import get_kernel | |
| get_kernel("kernels-community/vllm-flash-attn3") | |
| print("✓ Flash Attention 3 kernel cached") | |
| except Exception as e: | |
| print(f"Warning: Could not pre-cache Flash Attention 3 kernel: {e}") | |
| # 7. Pre-load InternVL embedding to CPU memory (saves ~5s in GPU session) | |
| if local_dir: | |
| try: | |
| intern_vlm_path = os.path.join(local_dir, "internvl_embedding") | |
| _preloaded_embedding, _preloaded_tokenizer = _preload_embedding(intern_vlm_path) | |
| print("✓ InternVL embedding pre-loaded to CPU") | |
| except Exception as e: | |
| print(f"Warning: Could not pre-load embedding: {e}") | |
| return local_dir | |
| def _preload_embedding(intern_vlm_path): | |
| """Pre-load embedding and tokenizer to CPU memory""" | |
| from safetensors.torch import load_file as load_safetensors | |
| from transformers import AutoTokenizer | |
| embedding_file = os.path.join(intern_vlm_path, "embedding.safetensors") | |
| config_file = os.path.join(intern_vlm_path, "embedding_config.json") | |
| if not os.path.exists(embedding_file) or not os.path.exists(config_file): | |
| print(f" Embedding files not found at {intern_vlm_path}") | |
| return None, None | |
| print(f" Loading embedding from: {intern_vlm_path}") | |
| with open(config_file, 'r') as f: | |
| config = json.load(f) | |
| # Create embedding layer on CPU | |
| embed_tokens = torch.nn.Embedding( | |
| num_embeddings=config["num_embeddings"], | |
| embedding_dim=config["embedding_dim"], | |
| padding_idx=config.get("padding_idx", None) | |
| ) | |
| # Load weights | |
| state_dict = load_safetensors(embedding_file) | |
| embed_tokens.load_state_dict(state_dict) | |
| embed_tokens.eval() | |
| embed_tokens.requires_grad_(False) | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| intern_vlm_path, trust_remote_code=True, use_fast=False | |
| ) | |
| print(f" Loaded: {config['num_embeddings']} x {config['embedding_dim']}") | |
| return embed_tokens, tokenizer | |
| print("="*50) | |
| print("Starting model pre-download and pre-load...") | |
| _cached_model_dir = preload_model_files() | |
| print("="*50) | |
| def init_generator(): | |
| """Initialize the generator (without optimization - that's done separately)""" | |
| global generator, _cached_model_dir, _preloaded_embedding, _preloaded_tokenizer | |
| if generator is None: | |
| intern_vlm_path = os.path.join(_cached_model_dir, "internvl_embedding") | |
| checkpoint_path = _cached_model_dir | |
| print(f"Using pre-cached model from: {_cached_model_dir}") | |
| print(f"Using pre-loaded embedding: {_preloaded_embedding is not None}") | |
| from inference import CalligraphyGenerator | |
| generator = CalligraphyGenerator( | |
| model_name="flux-dev", | |
| device="cpu", # Changed to cpu to prevent CUDA init before ZeroGPU fork! | |
| offload=False, # Set to False to let ZeroGPU manage CUDA memory directly instead of manual CPU thrashing | |
| intern_vlm_path=intern_vlm_path, | |
| checkpoint_path=checkpoint_path, | |
| font_descriptions_path='dataset/chirography.json', | |
| author_descriptions_path='dataset/calligraphy_styles_en.json', | |
| use_deepspeed=False, | |
| use_4bit_quantization=False, | |
| use_float8_quantization=False, | |
| use_torch_compile=False, | |
| dtype="bf16", # Revert to bf16 to let FA3 run and avoid Float8 unimplemented errors | |
| preloaded_embedding=_preloaded_embedding, | |
| preloaded_tokenizer=_preloaded_tokenizer, | |
| ) | |
| return generator | |
| def update_font_choices(author: str): | |
| """ | |
| Update available font choices based on selected author | |
| """ | |
| if author == "None (Synthetic / 合成风格)" or author not in AUTHOR_FONTS: | |
| choices = list(FONT_STYLE_NAMES.values()) | |
| else: | |
| available_fonts = AUTHOR_FONTS[author] | |
| choices = [FONT_STYLE_NAMES[font] for font in available_fonts if font in FONT_STYLE_NAMES] | |
| return gr.Dropdown(choices=choices, value=choices[0] if choices else None) | |
| def parse_font_style(font_style: str) -> str: | |
| """Extract font key from display name""" | |
| for font_key, font_display in FONT_STYLE_NAMES.items(): | |
| if font_display == font_style: | |
| return font_key | |
| return None | |
| # IMPORTANT: | |
| # Do NOT initialize generator globally at import time in ZeroGPU Spaces. | |
| # Keep it lazy inside the @spaces.GPU worker to avoid any pre-fork CUDA side effects. | |
| def _get_generation_duration(text, pairs, num_steps, start_seed, num_images, progress=None): | |
| """Calculate dynamic GPU duration: 24s base + 3s per image""" | |
| num_pairs = len(pairs) if pairs else 1 | |
| return 30 + int(3 * num_images * num_pairs) | |
| def run_generation(text, pairs, num_steps, start_seed, num_images, progress=gr.Progress()): | |
| """ | |
| Load model, apply FP8 quantization, and generate images. | |
| All in one GPU session to avoid redundant loading. | |
| """ | |
| progress(0.25, desc="准备 GPU 环境 / Preparing GPU runtime...") | |
| # Enable CUDA optimizations inside the worker | |
| try: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cuda.enable_flash_sdp(True) | |
| torch.backends.cuda.enable_mem_efficient_sdp(True) | |
| torch.backends.cuda.enable_math_sdp(False) | |
| except Exception: | |
| pass | |
| # Step 1: Load model | |
| progress(0.35, desc="检查模型状态 / Checking model state...") | |
| global generator | |
| if generator is None: | |
| logger.info("Initializing generator lazily inside GPU worker...") | |
| progress(0.45, desc="首次初始化模型 / First-time model initialization...") | |
| generator = init_generator() | |
| progress(0.65, desc="模型初始化完成 / Model initialization complete") | |
| else: | |
| progress(0.55, desc="复用已初始化模型 / Reusing initialized model") | |
| logger.info("Using initialized generator in ZeroGPU worker.") | |
| gen = generator | |
| # ZeroGPU automatically maps these to the acquired GPU during execution. | |
| # We must also correctly update internal Python attributes so runtime-generated latents go to GPU. | |
| target_device = torch.device("cuda") | |
| progress(0.72, desc="迁移模型到 GPU / Moving model to GPU...") | |
| gen.device = target_device | |
| if hasattr(gen, "sampler") and gen.sampler is not None: | |
| gen.sampler.device = target_device | |
| gen.model.to(target_device) | |
| gen.clip.to(target_device) | |
| gen.t5.to(target_device) | |
| gen.vae.to(target_device) | |
| progress(0.82, desc="模型就绪,开始生成 / Model ready, starting generation...") | |
| # Step 2: Since we reverted to bf16 load to avoid PyTorch native dtype mix issues, skip wrapping | |
| logger.info("Model weights decompressed to bfloat16 upon load. Skipping dynamic quantization to ensure stability.") | |
| # Step 3: Generate images | |
| total_gens = len(pairs) * num_images | |
| logger.info(f"Generating {total_gens} images across {len(pairs)} styles...") | |
| results = [] | |
| seeds_used = [] | |
| gen_idx = 0 | |
| for author, font in pairs: | |
| for i in range(num_images): | |
| gen_idx += 1 | |
| loop_progress = 0.82 + (gen_idx / max(total_gens, 1)) * 0.16 | |
| progress(loop_progress, desc=f"生成第 {gen_idx}/{total_gens} 张 / Generating {gen_idx}/{total_gens}") | |
| current_seed = start_seed + i | |
| cond_author = author if author != "None (Synthetic / 合成风格)" else None | |
| result_img, cond_img = gen.generate( | |
| text=text, font_style=font, author=cond_author, | |
| num_steps=num_steps, seed=current_seed, | |
| guidance=1.0, | |
| ) | |
| author_label = author if author else "Synthetic" | |
| label = f"{author_label} - {font} (Seed: {current_seed})" | |
| results.append((result_img, label)) | |
| seeds_used.append(current_seed) | |
| logger.info(f" Generated image {gen_idx}/{total_gens}") | |
| progress(1.0, desc="生成完成 / Generation complete") | |
| return results, seeds_used | |
| def interactive_session( | |
| text: str, | |
| a1, f1, a2, f2, a3, f3, a4, f4, | |
| num_steps: int, | |
| start_seed: int, | |
| num_images: int, | |
| progress=gr.Progress() | |
| ): | |
| """ | |
| Interactive session with FA3 + FP8. | |
| """ | |
| # Validate text | |
| if len(text) < 1: | |
| raise gr.Error("文本不能为空 / Text cannot be empty") | |
| if len(text) > 7: | |
| raise gr.Error(f"文本最多7个字符 / Text must be at most 7 characters. Current: {len(text)}") | |
| raw_pairs = [(a1, f1), (a2, f2), (a3, f3), (a4, f4)] | |
| pairs = [] | |
| for a, f_style in raw_pairs: | |
| if a and f_style: | |
| parsed_font = parse_font_style(f_style) | |
| if parsed_font is not None: | |
| pairs.append((a, parsed_font)) | |
| if not pairs: | |
| raise gr.Error("请至少选择一项书法家和字体组合 / Please select at least one combination") | |
| # Run generation (includes model loading + FP8 quantization + generation) | |
| yield "⏳ 队列中:准备任务... / Queued: preparing task...", [] | |
| progress(0.05, desc="校验输入参数 / Validating inputs...") | |
| yield "⏳ 输入已通过校验,等待 GPU 分配... / Inputs validated, waiting for GPU allocation...", [] | |
| progress(0.15, desc="等待 GPU 资源 / Waiting for GPU allocation...") | |
| # Hardcode num_steps to 4 for DMD distillation | |
| num_steps = 4 | |
| yield "⏳ 已分配 GPU,正在初始化与生成... / GPU allocated, initializing and generating...", [] | |
| progress(0.22, desc="进入生成阶段 / Entering generation stage...") | |
| results, seeds_used = run_generation( | |
| text, pairs, num_steps, start_seed, num_images, progress | |
| ) | |
| progress(1.0, desc="完成!") | |
| # Final status | |
| total_imgs = len(results) | |
| if total_imgs > 1: | |
| final_status = f"✅ 全部完成!共 {total_imgs} 张 (Seeds: {seeds_used[0]}-{seeds_used[-1]})" | |
| else: | |
| final_status = f"✅ 完成!Seed: {seeds_used[0]}" | |
| yield final_status, results | |
| # Create Gradio interface | |
| with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生成器", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🖌️ UniCalli - 中国书法生成器 / Chinese Calligraphy Generator | |
| **A Unified Diffusion Framework for Column-Level Generation and Recognition of Chinese Calligraphy** | |
| Generate beautiful Chinese calligraphy in various styles and by different historical masters. | |
| 用不同历史书法大师的风格生成精美的中国书法。 | |
| 🌐 [Project Page](https://envision-research.github.io/UniCalli/) | 📄 [Paper](https://arxiv.org/abs/2510.13745) | 💻 [Code](https://github.com/Envision-Research/UniCalli) | |
| **注意 / Note**: 支持1-7个汉字输入 / Supports 1-7 Chinese characters. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Input section | |
| gr.Markdown("### 📝 输入设置 / Input Settings") | |
| text_input = gr.Textbox( | |
| label="输入文本 / Input Text (1-7个字符 / 1-7 characters)", | |
| placeholder="请输入1-7个汉字 / Enter 1-7 Chinese characters, e.g.: 相见时难别亦难", | |
| value="相见时难别亦难", | |
| max_lines=1 | |
| ) | |
| gr.Markdown("### 👤 书法家与字体组合 / Calligraphers & Fonts") | |
| combo_groups = [] | |
| author_dropdowns = [] | |
| font_dropdowns = [] | |
| initial_author = "文征明" | |
| initial_fonts = AUTHOR_FONTS.get(initial_author, ["楷", "草", "行"]) | |
| initial_font_choices = [FONT_STYLE_NAMES[f] for f in initial_fonts if f in FONT_STYLE_NAMES] | |
| if initial_author == "文征明" and "行" in initial_fonts: | |
| default_font = FONT_STYLE_NAMES["行"] | |
| else: | |
| default_font = initial_font_choices[0] if initial_font_choices else "草 (Cursive Script)" | |
| active_count_state = gr.State(value=1) | |
| for i in range(4): | |
| with gr.Group(visible=(i == 0)) as group: | |
| gr.Markdown(f"**组合 {i+1} / Combination {i+1}**") | |
| with gr.Row(): | |
| a_drop = gr.Dropdown( | |
| label=f"书法家 / Calligrapher", | |
| choices=["None (Synthetic / 合成风格)"] + AUTHOR_LIST, | |
| value="文征明" if i == 0 else None, | |
| ) | |
| f_drop = gr.Dropdown( | |
| label=f"字体风格 / Font Style", | |
| choices=initial_font_choices if i == 0 else list(FONT_STYLE_NAMES.values()), | |
| value=default_font if i == 0 else None, | |
| ) | |
| author_dropdowns.append(a_drop) | |
| font_dropdowns.append(f_drop) | |
| combo_groups.append(group) | |
| add_combo_btn = gr.Button("➕ 添加书法家与风格组合 / Add Combination", size="sm") | |
| gr.Markdown("### ⚙️ 生成设置 / Generation Settings") | |
| num_steps = gr.Slider( | |
| label="生成步数 / Inference Steps (Fixed to 4 for DMD)", | |
| minimum=4, | |
| maximum=4, | |
| value=4, | |
| step=1, | |
| interactive=False, | |
| info="使用 DMD 蒸馏模型,强制 4 步生成 / Uses DMD distilled model, forced to 4 steps" | |
| ) | |
| start_seed = gr.Number( | |
| label="起始种子 / Start Seed", | |
| value=42, | |
| precision=0 | |
| ) | |
| num_images = gr.Slider( | |
| label="生成数量 / Number of Images", | |
| minimum=1, | |
| maximum=8, | |
| value=1, | |
| step=1 | |
| ) | |
| generate_btn = gr.Button("🎨 开始生成 / Start Generation", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| # Output section | |
| gr.Markdown("### 🖼️ 生成结果 / Generated Results") | |
| gr.Markdown(""" | |
| ⚠️ **首次生成说明 / First Run Note:** | |
| - 第一次生成会触发 PyTorch 编译,可能需要 1-2 分钟 | |
| - 如果遇到错误,请**再点一次生成按钮**即可正常运行 | |
| - First generation triggers PyTorch compilation (~1-2 min) | |
| - If you see an error, just **click generate again** and it will work | |
| *点击图片可放大查看 / Click image to enlarge* | |
| """) | |
| output_gallery = gr.Gallery( | |
| label="生成结果 / Generated Results", | |
| show_label=False, | |
| columns=2, | |
| rows=2, | |
| height=550, | |
| object_fit="contain", | |
| allow_preview=True | |
| ) | |
| status_text = gr.Textbox( | |
| label="状态 / Status", | |
| value="准备就绪 / Ready", | |
| interactive=False | |
| ) | |
| # Author info section | |
| with gr.Accordion("📚 可用书法家列表 / Available Calligraphers(共 {} 位 / {} total)".format(len(AUTHOR_LIST), len(AUTHOR_LIST)), open=False): | |
| author_info_md = "| 书法家 / Calligrapher | 可用字体 / Available Fonts |\n|--------|----------|\n" | |
| for author in AUTHOR_LIST[:30]: | |
| fonts = " | ".join(AUTHOR_FONTS[author]) | |
| desc = author_styles.get(author, "") | |
| desc_short = desc[:50] + "..." if len(desc) > 50 else desc | |
| author_info_md += f"| **{author}** | {fonts} |\n" | |
| if len(AUTHOR_LIST) > 30: | |
| author_info_md += f"\n*... 还有 {len(AUTHOR_LIST) - 30} 位书法家 / {len(AUTHOR_LIST) - 30} more calligraphers*" | |
| gr.Markdown(author_info_md) | |
| # Event handlers | |
| for i in range(4): | |
| author_dropdowns[i].change( | |
| fn=update_font_choices, | |
| inputs=[author_dropdowns[i]], | |
| outputs=[font_dropdowns[i]] | |
| ) | |
| def add_combo(current_count): | |
| new_count = min(current_count + 1, 4) | |
| updates = [] | |
| for i in range(4): | |
| updates.append(gr.update(visible=(i < new_count))) | |
| updates.append(gr.update(interactive=(new_count < 4))) | |
| return [new_count] + updates | |
| add_combo_btn.click( | |
| fn=add_combo, | |
| inputs=[active_count_state], | |
| outputs=[active_count_state] + combo_groups + [add_combo_btn] | |
| ) | |
| # Prepare inputs list for the interactive session | |
| session_inputs = [text_input] | |
| for i in range(4): | |
| session_inputs.extend([author_dropdowns[i], font_dropdowns[i]]) | |
| session_inputs.extend([num_steps, start_seed, num_images]) | |
| # Generate button - uses streaming for live updates | |
| generate_btn.click( | |
| fn=interactive_session, | |
| inputs=session_inputs, | |
| outputs=[status_text, output_gallery] | |
| ) | |
| # Examples | |
| gr.Markdown("### 📋 示例 / Examples") | |
| gr.Examples( | |
| examples=[ | |
| # [text, author1, font1, author2, font2, author3, font3, author4, font4, num_steps, start_seed, num_images] | |
| ["相见时难别亦难", "文征明", "行 (Running Script)", None, None, None, None, None, None, 4, 1024, 1], | |
| ["天道酬勤", "王羲之", "草 (Cursive Script)", "黄庭坚", "行 (Running Script)", None, None, None, None, 4, 42, 1], | |
| ["厚德载物", "赵孟頫", "楷 (Regular Script)", None, None, None, None, None, None, 4, 123, 1], | |
| ], | |
| inputs=session_inputs, | |
| ) | |
| demo.launch() | |