UniCalli_Dev / app.py
Tianshuo-Xu
Fix empty gallery by restoring stable Gradio gallery layout
7d06fb9
# -*- coding: utf-8 -*-
"""
Gradio Demo for Chinese Calligraphy Generation - HuggingFace Space Version
With FA3 + FP8 quantization for faster inference
"""
import os
import sys
print("=" * 50, flush=True)
print("APP.PY STARTING - DEBUG VERSION", flush=True)
print("=" * 50, flush=True)
import logging
from datetime import datetime
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
print("Logging setup complete", flush=True)
# IMPORTANT: import spaces first before any CUDA-related packages
import spaces
import gradio as gr
import json
import csv
import time
import torch
# Load author and font mappings from CSV
def load_author_fonts_from_csv(csv_path):
"""
Load author and their available fonts from CSV file
Filters out authors that only support 隶 or 篆 fonts
Returns: dict mapping author to list of font styles
"""
author_fonts = {}
excluded_fonts = {'隶', '篆'} # Fonts we don't support
with open(csv_path, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
author = row['书法家']
fonts = row['字体类型'].split('|') # Split multiple fonts by |
# Filter out unsupported fonts (隶 and 篆)
supported_fonts = [f for f in fonts if f not in excluded_fonts]
# Only include author if they have at least one supported font
if supported_fonts:
author_fonts[author] = supported_fonts
return author_fonts
# Load author-font mappings
AUTHOR_FONTS = load_author_fonts_from_csv('dataset/author_fonts_summary.csv')
# Available authors (sorted)
AUTHOR_LIST = sorted(AUTHOR_FONTS.keys())
# Font style display names (only supported styles)
FONT_STYLE_NAMES = {
"楷": "楷 (Regular Script)",
"行": "行 (Running Script)",
"草": "草 (Cursive Script)"
}
# Load author descriptions if available
try:
with open('dataset/calligraphy_styles_en.json', 'r', encoding='utf-8') as f:
author_styles = json.load(f)
except:
author_styles = {}
# Global generator instance
generator = None
_cached_model_dir = None
# ============================================================
# Pre-download and pre-load model files at startup (no GPU needed)
# ============================================================
_preloaded_embedding = None
_preloaded_tokenizer = None
_cached_t5_dir = None
_cached_clip_dir = None
_cached_font_path = None
def preload_model_files():
"""Pre-download model files to cache at startup (no GPU needed)"""
global _preloaded_embedding, _preloaded_tokenizer, _cached_t5_dir, _cached_clip_dir, _cached_font_path
from huggingface_hub import snapshot_download, hf_hub_download
hf_token = os.environ.get("HF_TOKEN", None)
print("Pre-downloading model files to cache...")
# 1. Main model (Unicalli_Pro) - include all tokenizer files
try:
local_dir = snapshot_download(
repo_id="TSXu/UniCalli_pro_dmd2_K4",
token=hf_token
)
print(f"✓ Unicalli_Pro cached at: {local_dir}")
except Exception as e:
print(f"Warning: Could not pre-download Unicalli_Pro: {e}")
local_dir = None
# 2. T5 text encoder
try:
_cached_t5_dir = snapshot_download(
"xlabs-ai/xflux_text_encoders",
token=hf_token
)
os.environ["XFLUX_TEXT_ENCODER_PATH"] = _cached_t5_dir
print(f"✓ T5 text encoder cached at: {_cached_t5_dir}")
except Exception as e:
print(f"Warning: Could not pre-download T5: {e}")
# 3. CLIP text encoder
try:
_cached_clip_dir = snapshot_download(
"openai/clip-vit-large-patch14",
token=hf_token
)
os.environ["XFLUX_CLIP_ENCODER_PATH"] = _cached_clip_dir
print(f"✓ CLIP text encoder cached at: {_cached_clip_dir}")
except Exception as e:
print(f"Warning: Could not pre-download CLIP: {e}")
# 4. VAE (ae.safetensors from FLUX.1-dev)
try:
hf_hub_download("black-forest-labs/FLUX.1-dev", "ae.safetensors", token=hf_token)
print("✓ VAE cached")
except Exception as e:
print(f"Warning: Could not pre-download VAE: {e}")
# 5. Font file used for condition image rendering
try:
_cached_font_path = hf_hub_download(
repo_id="TSXu/Unicalli_Pro",
filename="FangZhengKaiTiFanTi-1.ttf",
token=hf_token,
)
os.environ["UNICALLI_FONT_PATH"] = _cached_font_path
print(f"✓ Font cached at: {_cached_font_path}")
except Exception as e:
print(f"Warning: Could not pre-download font: {e}")
# 6. Flash Attention 3 kernel package (large) pre-cache
try:
from kernels import get_kernel
get_kernel("kernels-community/vllm-flash-attn3")
print("✓ Flash Attention 3 kernel cached")
except Exception as e:
print(f"Warning: Could not pre-cache Flash Attention 3 kernel: {e}")
# 7. Pre-load InternVL embedding to CPU memory (saves ~5s in GPU session)
if local_dir:
try:
intern_vlm_path = os.path.join(local_dir, "internvl_embedding")
_preloaded_embedding, _preloaded_tokenizer = _preload_embedding(intern_vlm_path)
print("✓ InternVL embedding pre-loaded to CPU")
except Exception as e:
print(f"Warning: Could not pre-load embedding: {e}")
return local_dir
def _preload_embedding(intern_vlm_path):
"""Pre-load embedding and tokenizer to CPU memory"""
from safetensors.torch import load_file as load_safetensors
from transformers import AutoTokenizer
embedding_file = os.path.join(intern_vlm_path, "embedding.safetensors")
config_file = os.path.join(intern_vlm_path, "embedding_config.json")
if not os.path.exists(embedding_file) or not os.path.exists(config_file):
print(f" Embedding files not found at {intern_vlm_path}")
return None, None
print(f" Loading embedding from: {intern_vlm_path}")
with open(config_file, 'r') as f:
config = json.load(f)
# Create embedding layer on CPU
embed_tokens = torch.nn.Embedding(
num_embeddings=config["num_embeddings"],
embedding_dim=config["embedding_dim"],
padding_idx=config.get("padding_idx", None)
)
# Load weights
state_dict = load_safetensors(embedding_file)
embed_tokens.load_state_dict(state_dict)
embed_tokens.eval()
embed_tokens.requires_grad_(False)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
intern_vlm_path, trust_remote_code=True, use_fast=False
)
print(f" Loaded: {config['num_embeddings']} x {config['embedding_dim']}")
return embed_tokens, tokenizer
print("="*50)
print("Starting model pre-download and pre-load...")
_cached_model_dir = preload_model_files()
print("="*50)
def init_generator():
"""Initialize the generator (without optimization - that's done separately)"""
global generator, _cached_model_dir, _preloaded_embedding, _preloaded_tokenizer
if generator is None:
intern_vlm_path = os.path.join(_cached_model_dir, "internvl_embedding")
checkpoint_path = _cached_model_dir
print(f"Using pre-cached model from: {_cached_model_dir}")
print(f"Using pre-loaded embedding: {_preloaded_embedding is not None}")
from inference import CalligraphyGenerator
generator = CalligraphyGenerator(
model_name="flux-dev",
device="cpu", # Changed to cpu to prevent CUDA init before ZeroGPU fork!
offload=False, # Set to False to let ZeroGPU manage CUDA memory directly instead of manual CPU thrashing
intern_vlm_path=intern_vlm_path,
checkpoint_path=checkpoint_path,
font_descriptions_path='dataset/chirography.json',
author_descriptions_path='dataset/calligraphy_styles_en.json',
use_deepspeed=False,
use_4bit_quantization=False,
use_float8_quantization=False,
use_torch_compile=False,
dtype="bf16", # Revert to bf16 to let FA3 run and avoid Float8 unimplemented errors
preloaded_embedding=_preloaded_embedding,
preloaded_tokenizer=_preloaded_tokenizer,
)
return generator
def update_font_choices(author: str):
"""
Update available font choices based on selected author
"""
if author == "None (Synthetic / 合成风格)" or author not in AUTHOR_FONTS:
choices = list(FONT_STYLE_NAMES.values())
else:
available_fonts = AUTHOR_FONTS[author]
choices = [FONT_STYLE_NAMES[font] for font in available_fonts if font in FONT_STYLE_NAMES]
return gr.Dropdown(choices=choices, value=choices[0] if choices else None)
def parse_font_style(font_style: str) -> str:
"""Extract font key from display name"""
for font_key, font_display in FONT_STYLE_NAMES.items():
if font_display == font_style:
return font_key
return None
# IMPORTANT:
# Do NOT initialize generator globally at import time in ZeroGPU Spaces.
# Keep it lazy inside the @spaces.GPU worker to avoid any pre-fork CUDA side effects.
def _get_generation_duration(text, pairs, num_steps, start_seed, num_images, progress=None):
"""Calculate dynamic GPU duration: 24s base + 3s per image"""
num_pairs = len(pairs) if pairs else 1
return 30 + int(3 * num_images * num_pairs)
@spaces.GPU(duration=_get_generation_duration)
def run_generation(text, pairs, num_steps, start_seed, num_images, progress=gr.Progress()):
"""
Load model, apply FP8 quantization, and generate images.
All in one GPU session to avoid redundant loading.
"""
progress(0.25, desc="准备 GPU 环境 / Preparing GPU runtime...")
# Enable CUDA optimizations inside the worker
try:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(False)
except Exception:
pass
# Step 1: Load model
progress(0.35, desc="检查模型状态 / Checking model state...")
global generator
if generator is None:
logger.info("Initializing generator lazily inside GPU worker...")
progress(0.45, desc="首次初始化模型 / First-time model initialization...")
generator = init_generator()
progress(0.65, desc="模型初始化完成 / Model initialization complete")
else:
progress(0.55, desc="复用已初始化模型 / Reusing initialized model")
logger.info("Using initialized generator in ZeroGPU worker.")
gen = generator
# ZeroGPU automatically maps these to the acquired GPU during execution.
# We must also correctly update internal Python attributes so runtime-generated latents go to GPU.
target_device = torch.device("cuda")
progress(0.72, desc="迁移模型到 GPU / Moving model to GPU...")
gen.device = target_device
if hasattr(gen, "sampler") and gen.sampler is not None:
gen.sampler.device = target_device
gen.model.to(target_device)
gen.clip.to(target_device)
gen.t5.to(target_device)
gen.vae.to(target_device)
progress(0.82, desc="模型就绪,开始生成 / Model ready, starting generation...")
# Step 2: Since we reverted to bf16 load to avoid PyTorch native dtype mix issues, skip wrapping
logger.info("Model weights decompressed to bfloat16 upon load. Skipping dynamic quantization to ensure stability.")
# Step 3: Generate images
total_gens = len(pairs) * num_images
logger.info(f"Generating {total_gens} images across {len(pairs)} styles...")
results = []
seeds_used = []
gen_idx = 0
for author, font in pairs:
for i in range(num_images):
gen_idx += 1
loop_progress = 0.82 + (gen_idx / max(total_gens, 1)) * 0.16
progress(loop_progress, desc=f"生成第 {gen_idx}/{total_gens} 张 / Generating {gen_idx}/{total_gens}")
current_seed = start_seed + i
cond_author = author if author != "None (Synthetic / 合成风格)" else None
result_img, cond_img = gen.generate(
text=text, font_style=font, author=cond_author,
num_steps=num_steps, seed=current_seed,
guidance=1.0,
)
author_label = author if author else "Synthetic"
label = f"{author_label} - {font} (Seed: {current_seed})"
results.append((result_img, label))
seeds_used.append(current_seed)
logger.info(f" Generated image {gen_idx}/{total_gens}")
progress(1.0, desc="生成完成 / Generation complete")
return results, seeds_used
def interactive_session(
text: str,
a1, f1, a2, f2, a3, f3, a4, f4,
num_steps: int,
start_seed: int,
num_images: int,
progress=gr.Progress()
):
"""
Interactive session with FA3 + FP8.
"""
# Validate text
if len(text) < 1:
raise gr.Error("文本不能为空 / Text cannot be empty")
if len(text) > 7:
raise gr.Error(f"文本最多7个字符 / Text must be at most 7 characters. Current: {len(text)}")
raw_pairs = [(a1, f1), (a2, f2), (a3, f3), (a4, f4)]
pairs = []
for a, f_style in raw_pairs:
if a and f_style:
parsed_font = parse_font_style(f_style)
if parsed_font is not None:
pairs.append((a, parsed_font))
if not pairs:
raise gr.Error("请至少选择一项书法家和字体组合 / Please select at least one combination")
# Run generation (includes model loading + FP8 quantization + generation)
yield "⏳ 队列中:准备任务... / Queued: preparing task...", []
progress(0.05, desc="校验输入参数 / Validating inputs...")
yield "⏳ 输入已通过校验,等待 GPU 分配... / Inputs validated, waiting for GPU allocation...", []
progress(0.15, desc="等待 GPU 资源 / Waiting for GPU allocation...")
# Hardcode num_steps to 4 for DMD distillation
num_steps = 4
yield "⏳ 已分配 GPU,正在初始化与生成... / GPU allocated, initializing and generating...", []
progress(0.22, desc="进入生成阶段 / Entering generation stage...")
results, seeds_used = run_generation(
text, pairs, num_steps, start_seed, num_images, progress
)
progress(1.0, desc="完成!")
# Final status
total_imgs = len(results)
if total_imgs > 1:
final_status = f"✅ 全部完成!共 {total_imgs} 张 (Seeds: {seeds_used[0]}-{seeds_used[-1]})"
else:
final_status = f"✅ 完成!Seed: {seeds_used[0]}"
yield final_status, results
# Create Gradio interface
with gr.Blocks(title="UniCalli - Chinese Calligraphy Generator / 中国书法生成器", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🖌️ UniCalli - 中国书法生成器 / Chinese Calligraphy Generator
**A Unified Diffusion Framework for Column-Level Generation and Recognition of Chinese Calligraphy**
Generate beautiful Chinese calligraphy in various styles and by different historical masters.
用不同历史书法大师的风格生成精美的中国书法。
🌐 [Project Page](https://envision-research.github.io/UniCalli/) | 📄 [Paper](https://arxiv.org/abs/2510.13745) | 💻 [Code](https://github.com/Envision-Research/UniCalli)
**注意 / Note**: 支持1-7个汉字输入 / Supports 1-7 Chinese characters.
""")
with gr.Row():
with gr.Column(scale=1):
# Input section
gr.Markdown("### 📝 输入设置 / Input Settings")
text_input = gr.Textbox(
label="输入文本 / Input Text (1-7个字符 / 1-7 characters)",
placeholder="请输入1-7个汉字 / Enter 1-7 Chinese characters, e.g.: 相见时难别亦难",
value="相见时难别亦难",
max_lines=1
)
gr.Markdown("### 👤 书法家与字体组合 / Calligraphers & Fonts")
combo_groups = []
author_dropdowns = []
font_dropdowns = []
initial_author = "文征明"
initial_fonts = AUTHOR_FONTS.get(initial_author, ["楷", "草", "行"])
initial_font_choices = [FONT_STYLE_NAMES[f] for f in initial_fonts if f in FONT_STYLE_NAMES]
if initial_author == "文征明" and "行" in initial_fonts:
default_font = FONT_STYLE_NAMES["行"]
else:
default_font = initial_font_choices[0] if initial_font_choices else "草 (Cursive Script)"
active_count_state = gr.State(value=1)
for i in range(4):
with gr.Group(visible=(i == 0)) as group:
gr.Markdown(f"**组合 {i+1} / Combination {i+1}**")
with gr.Row():
a_drop = gr.Dropdown(
label=f"书法家 / Calligrapher",
choices=["None (Synthetic / 合成风格)"] + AUTHOR_LIST,
value="文征明" if i == 0 else None,
)
f_drop = gr.Dropdown(
label=f"字体风格 / Font Style",
choices=initial_font_choices if i == 0 else list(FONT_STYLE_NAMES.values()),
value=default_font if i == 0 else None,
)
author_dropdowns.append(a_drop)
font_dropdowns.append(f_drop)
combo_groups.append(group)
add_combo_btn = gr.Button("➕ 添加书法家与风格组合 / Add Combination", size="sm")
gr.Markdown("### ⚙️ 生成设置 / Generation Settings")
num_steps = gr.Slider(
label="生成步数 / Inference Steps (Fixed to 4 for DMD)",
minimum=4,
maximum=4,
value=4,
step=1,
interactive=False,
info="使用 DMD 蒸馏模型,强制 4 步生成 / Uses DMD distilled model, forced to 4 steps"
)
start_seed = gr.Number(
label="起始种子 / Start Seed",
value=42,
precision=0
)
num_images = gr.Slider(
label="生成数量 / Number of Images",
minimum=1,
maximum=8,
value=1,
step=1
)
generate_btn = gr.Button("🎨 开始生成 / Start Generation", variant="primary", size="lg")
with gr.Column(scale=1):
# Output section
gr.Markdown("### 🖼️ 生成结果 / Generated Results")
gr.Markdown("""
⚠️ **首次生成说明 / First Run Note:**
- 第一次生成会触发 PyTorch 编译,可能需要 1-2 分钟
- 如果遇到错误,请**再点一次生成按钮**即可正常运行
- First generation triggers PyTorch compilation (~1-2 min)
- If you see an error, just **click generate again** and it will work
*点击图片可放大查看 / Click image to enlarge*
""")
output_gallery = gr.Gallery(
label="生成结果 / Generated Results",
show_label=False,
columns=2,
rows=2,
height=550,
object_fit="contain",
allow_preview=True
)
status_text = gr.Textbox(
label="状态 / Status",
value="准备就绪 / Ready",
interactive=False
)
# Author info section
with gr.Accordion("📚 可用书法家列表 / Available Calligraphers(共 {} 位 / {} total)".format(len(AUTHOR_LIST), len(AUTHOR_LIST)), open=False):
author_info_md = "| 书法家 / Calligrapher | 可用字体 / Available Fonts |\n|--------|----------|\n"
for author in AUTHOR_LIST[:30]:
fonts = " | ".join(AUTHOR_FONTS[author])
desc = author_styles.get(author, "")
desc_short = desc[:50] + "..." if len(desc) > 50 else desc
author_info_md += f"| **{author}** | {fonts} |\n"
if len(AUTHOR_LIST) > 30:
author_info_md += f"\n*... 还有 {len(AUTHOR_LIST) - 30} 位书法家 / {len(AUTHOR_LIST) - 30} more calligraphers*"
gr.Markdown(author_info_md)
# Event handlers
for i in range(4):
author_dropdowns[i].change(
fn=update_font_choices,
inputs=[author_dropdowns[i]],
outputs=[font_dropdowns[i]]
)
def add_combo(current_count):
new_count = min(current_count + 1, 4)
updates = []
for i in range(4):
updates.append(gr.update(visible=(i < new_count)))
updates.append(gr.update(interactive=(new_count < 4)))
return [new_count] + updates
add_combo_btn.click(
fn=add_combo,
inputs=[active_count_state],
outputs=[active_count_state] + combo_groups + [add_combo_btn]
)
# Prepare inputs list for the interactive session
session_inputs = [text_input]
for i in range(4):
session_inputs.extend([author_dropdowns[i], font_dropdowns[i]])
session_inputs.extend([num_steps, start_seed, num_images])
# Generate button - uses streaming for live updates
generate_btn.click(
fn=interactive_session,
inputs=session_inputs,
outputs=[status_text, output_gallery]
)
# Examples
gr.Markdown("### 📋 示例 / Examples")
gr.Examples(
examples=[
# [text, author1, font1, author2, font2, author3, font3, author4, font4, num_steps, start_seed, num_images]
["相见时难别亦难", "文征明", "行 (Running Script)", None, None, None, None, None, None, 4, 1024, 1],
["天道酬勤", "王羲之", "草 (Cursive Script)", "黄庭坚", "行 (Running Script)", None, None, None, None, 4, 42, 1],
["厚德载物", "赵孟頫", "楷 (Regular Script)", None, None, None, None, None, None, 4, 123, 1],
],
inputs=session_inputs,
)
demo.launch()