UniCalli_Dev / inference.py
Tianshuo-Xu
precache fa3 kernel and font before gpu task
743a20a
# -*- coding: utf-8 -*-
"""
Chinese Calligraphy Generation with Flux Model
Author and font style controllable generation
"""
import os
import json
import torch
from safetensors.torch import load_file as load_safetensors
from optimum.quanto import quantize, freeze, qint4
from PIL import Image, ImageDraw, ImageFont
from typing import Optional, List, Union, Dict, Any
from einops import rearrange
from pypinyin import lazy_pinyin
from huggingface_hub import hf_hub_download, snapshot_download
from src.flux.util import configs, load_ae, load_clip, load_t5
from src.flux.model import Flux
from src.flux.xflux_pipeline import XFluxSampler
# HuggingFace Hub model IDs
HF_MODEL_ID = "TSXu/Unicalli_Pro"
HF_CHECKPOINT_INDEX = "model.safetensors.index.json" # Sharded safetensors index
HF_INTERNVL_ID = "OpenGVLab/InternVL3-1B"
def download_sharded_safetensors(
model_id: str = HF_MODEL_ID,
local_dir: str = None,
force_download: bool = False
) -> str:
"""
Download sharded safetensors model from HuggingFace Hub
Args:
model_id: HuggingFace model repository ID
local_dir: Local directory to save files (optional)
force_download: Whether to force re-download
Returns:
Path to the index.json file
"""
print(f"Downloading sharded safetensors from HuggingFace Hub ({model_id})...")
# Get HF token from environment for private repos
hf_token = os.environ.get("HF_TOKEN", None)
try:
# First download the index file
index_path = hf_hub_download(
repo_id=model_id,
filename=HF_CHECKPOINT_INDEX,
local_dir=local_dir,
force_download=force_download,
token=hf_token
)
print(f"Index downloaded to: {index_path}")
# Read index to get shard filenames
with open(index_path, 'r') as f:
index = json.load(f)
# Get unique shard files
shard_files = set(index['weight_map'].values())
print(f"Downloading {len(shard_files)} shard files...")
# Download all shards
for shard_file in sorted(shard_files):
print(f" Downloading {shard_file}...")
hf_hub_download(
repo_id=model_id,
filename=shard_file,
local_dir=local_dir,
force_download=force_download,
token=hf_token
)
print(f"All shards downloaded!")
return index_path
except Exception as e:
print(f"Error downloading model: {e}")
raise
def is_huggingface_repo_id(path: str) -> bool:
"""
Check if a string looks like a HuggingFace repo ID (e.g., 'namespace/repo_name')
NOT a local file path
"""
# HF repo IDs have format: namespace/repo_name (exactly one /)
# Local paths typically have multiple / or start with / or .
if path.startswith('/') or path.startswith('.') or path.startswith('~'):
return False
parts = path.split('/')
# HF repo ID should have exactly 2 parts: namespace and repo_name
if len(parts) == 2 and all(part and not part.startswith('.') for part in parts):
return True
return False
def ensure_checkpoint_exists(checkpoint_path: str) -> str:
"""
Ensure checkpoint exists locally, download from HF Hub if not
Args:
checkpoint_path: Local path or HF model ID
Returns:
Path to the local checkpoint/index file
"""
# If it's a local path and exists, return it
if os.path.exists(checkpoint_path):
print(f"Using local checkpoint: {checkpoint_path}")
return checkpoint_path
# If it looks like a HuggingFace repo ID (e.g., "TSXu/Unicalli_Pro")
if is_huggingface_repo_id(checkpoint_path):
print(f"Downloading from HuggingFace Hub: {checkpoint_path}")
return download_sharded_safetensors(model_id=checkpoint_path)
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
def convert_to_pinyin(text):
return ' '.join([item[0] if isinstance(item, list) else item for item in lazy_pinyin(text)])
class CalligraphyGenerator:
"""
Chinese Calligraphy Generator using Flux model
Attributes:
device: torch device for computation
model_name: name of the flux model (flux-dev or flux-schnell)
font_styles: available font styles for generation
authors: available calligrapher authors
"""
def __init__(
self,
model_name: str = "flux-dev",
device: str = "cuda",
offload: bool = True,
checkpoint_path: Optional[str] = None,
intern_vlm_path: Optional[str] = None,
ref_latent_path: Optional[str] = None,
font_descriptions_path: str = "chirography.json",
author_descriptions_path: str = "calligraphy_styles_en.json",
use_deepspeed: bool = False,
use_4bit_quantization: bool = False,
use_float8_quantization: bool = False,
use_torch_compile: bool = False,
compile_mode: str = "reduce-overhead",
deepspeed_config: Optional[str] = None,
dtype: Optional[str] = None,
preloaded_embedding: Optional[torch.nn.Module] = None,
preloaded_tokenizer: Optional[Any] = None,
):
"""
Initialize the calligraphy generator
Args:
model_name: flux model name (flux-dev or flux-schnell)
device: device for computation
offload: whether to offload model to CPU when not in use
checkpoint_path: path to model checkpoint if using fine-tuned model
intern_vlm_path: path to InternVLM model for text embedding
ref_latent_path: path to reference latents for recognition mode
font_descriptions_path: path to font style descriptions JSON
author_descriptions_path: path to author style descriptions JSON
use_deepspeed: whether to use DeepSpeed ZeRO for memory optimization
use_4bit_quantization: whether to use 4-bit quantization (quanto/bitsandbytes)
use_float8_quantization: whether to use Float8 quantization (torchao) for faster inference
use_torch_compile: whether to use torch.compile for optimized inference
compile_mode: torch.compile mode - "reduce-overhead", "max-autotune", or "default"
deepspeed_config: path to DeepSpeed config JSON file
dtype: force specific dtype for inference: "fp16", "bf16", "fp32", or None for auto
"""
self.device = torch.device(device)
self.model_name = model_name
self.offload = offload
self.is_schnell = model_name == "flux-schnell"
self.use_deepspeed = use_deepspeed
self.deepspeed_config = deepspeed_config
self.use_4bit_quantization = use_4bit_quantization
self.use_float8_quantization = use_float8_quantization
self.use_torch_compile = use_torch_compile
self.compile_mode = compile_mode
self.forced_dtype = dtype # "fp16", "bf16", "fp32", or None for auto
# Load font and author style descriptions
if os.path.exists(font_descriptions_path):
with open(font_descriptions_path, 'r', encoding='utf-8') as f:
self.font_style_des = json.load(f)
else:
raise FileNotFoundError(f"Font descriptions file not found: {font_descriptions_path}")
if os.path.exists(author_descriptions_path):
with open(author_descriptions_path, 'r', encoding='utf-8') as f:
self.author_style = json.load(f)
else:
raise FileNotFoundError(f"Author descriptions file not found: {author_descriptions_path}")
# Load models
print("Loading models...")
# When using DeepSpeed, load text encoders on CPU first to save memory during initialization
# They will be moved to GPU after DeepSpeed initializes the main model
if self.use_deepspeed:
text_encoder_device = "cpu"
elif offload:
text_encoder_device = "cpu" # Will be moved to GPU during inference
else:
text_encoder_device = self.device
self.t5 = load_t5(text_encoder_device, max_length=256 if self.is_schnell else 512)
self.clip = load_clip(text_encoder_device)
self.clip.requires_grad_(False)
# Ensure checkpoint exists (download from HF Hub if needed)
if checkpoint_path:
checkpoint_path = ensure_checkpoint_exists(checkpoint_path)
print(f"Loading model from checkpoint: {checkpoint_path}")
# When using DeepSpeed, don't move to GPU yet - let DeepSpeed handle it
self.model = self._load_model_from_checkpoint(
checkpoint_path, model_name,
offload=offload,
use_deepspeed=self.use_deepspeed
)
# Initialize DeepSpeed if requested
if self.use_deepspeed:
self.model = self._init_deepspeed(self.model)
else:
# If no checkpoint path provided, download default from HF Hub
print("No checkpoint path provided, downloading from HuggingFace Hub...")
checkpoint_path = download_model_from_hf()
print(f"Loading model from checkpoint: {checkpoint_path}")
self.model = self._load_model_from_checkpoint(
checkpoint_path, model_name,
offload=offload,
use_deepspeed=self.use_deepspeed
)
if self.use_deepspeed:
self.model = self._init_deepspeed(self.model)
# Note: Float8 quantization and torch.compile optimizations
# are applied externally (e.g., in app.py) for better control
# over the optimization process with ZeroGPU AOT compilation.
# Load VAE
if self.use_deepspeed or offload:
vae_device = "cpu"
else:
vae_device = self.device
self.vae = load_ae(model_name, device=vae_device)
# Move VAE to GPU only if offload (not DeepSpeed)
if offload and not self.use_deepspeed:
self.vae = self.vae.to(self.device)
# After DeepSpeed init, move text encoders to GPU
if self.use_deepspeed:
print("Moving text encoders to GPU...")
self.t5 = self.t5.to(self.device)
self.clip = self.clip.to(self.device)
self.vae = self.vae.to(self.device)
# Load reference latents if provided
self.ref_latent = None
if ref_latent_path and os.path.exists(ref_latent_path):
print(f"Loading reference latents from {ref_latent_path}")
self.ref_latent = torch.load(ref_latent_path, map_location='cpu')
# Create sampler (use preloaded embedding if available)
self.sampler = XFluxSampler(
clip=self.clip,
t5=self.t5,
ae=self.vae,
ref_latent=self.ref_latent,
model=self.model,
device=self.device,
intern_vlm_path=intern_vlm_path,
preloaded_embedding=preloaded_embedding,
preloaded_tokenizer=preloaded_tokenizer,
)
# Font for generating condition images
project_root = os.path.dirname(os.path.abspath(__file__))
local_font_path = os.path.join(project_root, "FangZhengKaiTiFanTi-1.ttf")
self.font_path = self._ensure_font_exists(local_font_path)
self.default_font_size = 102 # 128 * 0.8
def _ensure_font_exists(self, font_path: str) -> str:
"""
Ensure font file exists locally, download from HF Hub if not
Args:
font_path: Local path to font file
Returns:
Path to the local font file
"""
cached_font_path = os.environ.get("UNICALLI_FONT_PATH")
if cached_font_path and os.path.exists(cached_font_path):
return cached_font_path
if os.path.exists(font_path):
return font_path
# Try to download from HF Hub
print(f"Font file not found locally, downloading from HuggingFace Hub...")
hf_token = os.environ.get("HF_TOKEN", None)
try:
font_path = hf_hub_download(
repo_id=HF_MODEL_ID,
filename="FangZhengKaiTiFanTi-1.ttf",
token=hf_token
)
print(f"Font downloaded to: {font_path}")
return font_path
except Exception as e:
print(f"Warning: Could not download font: {e}")
return font_path # Return original path, may fail later
def _load_model_from_checkpoint(self, checkpoint_path: str, model_name: str, offload: bool, use_deepspeed: bool = False):
"""
Load model from checkpoint without loading flux pretrained weights.
Supports both regular checkpoints and NF4 quantized checkpoints.
Args:
checkpoint_path: Path to your checkpoint file or NF4 model directory
model_name: flux model name (for config)
offload: whether to offload to CPU
use_deepspeed: whether using DeepSpeed (keeps model on CPU)
Returns:
model with loaded checkpoint
"""
print(f"Creating empty flux model structure...")
load_device = "cpu"
# Create model structure without loading pretrained weights (using "meta" device)
with torch.device("meta"):
model = Flux(configs[model_name].params)
# Initialize module embeddings (must be done before loading checkpoint)
print("Initializing module embeddings...")
model.init_module_embeddings(tokens_num=320, cond_txt_channel=896)
# Move model to loading device
print(f"Moving model to {load_device} for loading...")
model = model.to_empty(device=load_device)
# Check if this is an NF4 quantized model
is_nf4 = self._is_nf4_checkpoint(checkpoint_path)
# Load checkpoint
print(f"Loading checkpoint from {checkpoint_path}")
if is_nf4:
print("Detected NF4 quantized model, dequantizing...")
checkpoint = self._load_nf4_checkpoint(checkpoint_path)
else:
checkpoint = self._load_checkpoint_file(checkpoint_path)
# Determine dtype from checkpoint - keep original dtype for efficiency
first_tensor = next(iter(checkpoint.values()))
checkpoint_dtype = first_tensor.dtype
print(f"Checkpoint dtype: {checkpoint_dtype}")
# Check if user forced a specific dtype
forced_dtype = getattr(self, 'forced_dtype', None)
if forced_dtype:
dtype_map = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32,
"fp8": torch.float8_e4m3fn,
}
if forced_dtype not in dtype_map:
print(f"Warning: Unknown dtype '{forced_dtype}', using auto selection")
forced_dtype = None
else:
target_dtype = dtype_map[forced_dtype]
print(f"Using forced dtype: {target_dtype}")
if checkpoint_dtype != target_dtype:
print(f"Converting checkpoint from {checkpoint_dtype} to {target_dtype}...")
checkpoint = {k: v.to(target_dtype) for k, v in checkpoint.items()}
if not forced_dtype:
# Note: We trust the original precision (like FP8) if it is provided that way
target_dtype = checkpoint_dtype
print(f"Using auto-detected checkpoint dtype: {target_dtype} for inference loading")
# Load weights into model
model.load_state_dict(checkpoint, strict=False, assign=True)
print(f"Model dtype after loading: {next(model.parameters()).dtype}")
# Store target dtype for inference
self._model_dtype = target_dtype
# Free checkpoint memory
del checkpoint
# Apply bitsandbytes 4-bit quantization if requested
if hasattr(self, 'use_4bit_quantization') and self.use_4bit_quantization:
try:
import bitsandbytes as bnb
print("Applying bitsandbytes NF4 quantization for 4-bit inference...")
model = self._quantize_model_bnb(model)
model._is_quantized = True
print("bitsandbytes NF4 quantization complete!")
except ImportError:
print("bitsandbytes not available, using quanto quantization...")
model = model.float()
quantize(model, weights=qint4)
freeze(model)
model._is_quantized = True
print("quanto 4-bit quantization complete!")
# Move to GPU only if NOT using DeepSpeed
if not use_deepspeed:
if self.device.type != "cpu":
print(f"Moving model to {self.device}...")
model = model.to(self.device)
# Enable optimized attention backends
try:
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(False)
print("Enabled FlashAttention / Memory-Efficient SDPA backends")
except Exception as e:
print(f"Could not configure SDPA backends: {e}")
return model
def _is_nf4_checkpoint(self, path: str) -> bool:
"""Check if path contains an NF4 quantized checkpoint"""
if os.path.isdir(path):
return os.path.exists(os.path.join(path, "quantization_config.json"))
return False
def _load_nf4_checkpoint(self, checkpoint_dir: str) -> dict:
"""
Load NF4 quantized checkpoint and dequantize to float tensors.
Args:
checkpoint_dir: Directory containing NF4 model files
Returns:
Dequantized state dict
"""
from safetensors.torch import load_file as load_safetensors
# Load quantization config
config_path = os.path.join(checkpoint_dir, "quantization_config.json")
with open(config_path, 'r') as f:
quant_config = json.load(f)
block_size = quant_config.get("block_size", 64)
quantized_keys = set(quant_config.get("quantized_keys", []))
# Load index
index_path = os.path.join(checkpoint_dir, "model_nf4.safetensors.index.json")
with open(index_path, 'r') as f:
index = json.load(f)
# Load all shards
shard_files = sorted(set(index['weight_map'].values()))
print(f"Loading {len(shard_files)} NF4 shards...")
raw_state = {}
for shard_file in shard_files:
shard_path = os.path.join(checkpoint_dir, shard_file)
print(f" Loading {shard_file}...")
shard_data = load_safetensors(shard_path)
raw_state.update(shard_data)
# NF4 lookup table for dequantization
nf4_values = torch.tensor([
-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453,
-0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
0.07958029955625534, 0.16093020141124725, 0.24611230850220, 0.33791524171829224,
0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0
], dtype=torch.float32)
# Dequantize
state_dict = {}
dequant_count = 0
for key in list(raw_state.keys()):
if key.endswith('.quant_data'):
base_key = key.replace('.quant_data', '')
if base_key in quantized_keys:
# Dequantize this tensor
quant_data = raw_state[f"{base_key}.quant_data"]
scales = raw_state[f"{base_key}.scales"]
shape = raw_state[f"{base_key}.shape"].tolist()
pad_len = raw_state[f"{base_key}.pad_len"].item()
# Unpack 4-bit values
high = (quant_data >> 4) & 0x0F
low = quant_data & 0x0F
indices = torch.stack([high, low], dim=-1).flatten().long()
# Lookup and reshape
values = nf4_values[indices]
# Apply scales
num_blocks = len(scales)
values = values[:num_blocks * block_size].reshape(num_blocks, block_size)
values = values * scales.float().unsqueeze(1)
values = values.flatten()
# Remove padding and reshape
if pad_len > 0:
values = values[:-pad_len]
state_dict[base_key] = values.reshape(shape)
dequant_count += 1
elif not any(key.endswith(s) for s in ['.scales', '.shape', '.block_size', '.pad_len']):
# Non-quantized tensor, keep as-is
state_dict[key] = raw_state[key]
print(f"Dequantized {dequant_count} tensors")
return state_dict
def _quantize_model_bnb(self, model):
"""
Quantize model using bitsandbytes NF4.
Replaces Linear layers with Linear4bit for true 4-bit inference.
"""
import bitsandbytes as bnb
import torch.nn as nn
def replace_linear_with_4bit(module, name=''):
for child_name, child in list(module.named_children()):
full_name = f"{name}.{child_name}" if name else child_name
if isinstance(child, nn.Linear):
# Create 4-bit linear layer
new_layer = bnb.nn.Linear4bit(
child.in_features,
child.out_features,
bias=child.bias is not None,
compute_dtype=torch.bfloat16,
compress_statistics=True,
quant_type='nf4'
)
# Copy weights (will be quantized when moved to GPU)
new_layer.weight = bnb.nn.Params4bit(
child.weight.data,
requires_grad=False,
quant_type='nf4'
)
if child.bias is not None:
new_layer.bias = nn.Parameter(child.bias.data)
setattr(module, child_name, new_layer)
else:
replace_linear_with_4bit(child, full_name)
print("Replacing Linear layers with Linear4bit...")
replace_linear_with_4bit(model)
return model
def _init_deepspeed(self, model):
"""
Initialize DeepSpeed for the model with ZeRO-3 inference optimization.
Args:
model: PyTorch model to wrap with DeepSpeed
Returns:
DeepSpeed inference engine
"""
try:
import deepspeed
except ImportError:
raise ImportError("DeepSpeed is not installed. Install it with: pip install deepspeed")
# Load DeepSpeed config
if self.deepspeed_config is None:
self.deepspeed_config = "ds_config_zero2.json"
if not os.path.exists(self.deepspeed_config):
raise FileNotFoundError(f"DeepSpeed config not found: {self.deepspeed_config}")
print(f"Initializing DeepSpeed Inference with config: {self.deepspeed_config}")
# Initialize distributed environment for single GPU if not already initialized
if not torch.distributed.is_initialized():
import random
# Set environment variables for single-process mode
# Use a random port to avoid conflicts
port = random.randint(29500, 29600)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = str(port)
os.environ['RANK'] = '0'
os.environ['LOCAL_RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
# Initialize process group
try:
torch.distributed.init_process_group(
backend='nccl',
init_method='env://',
world_size=1,
rank=0
)
print(f"Initialized single-GPU distributed environment for DeepSpeed on port {port}")
except RuntimeError as e:
if "address already in use" in str(e):
print(f"Port {port} in use, trying again...")
# Try a different port
port = random.randint(29600, 29700)
os.environ['MASTER_PORT'] = str(port)
torch.distributed.init_process_group(
backend='nccl',
init_method='env://',
world_size=1,
rank=0
)
print(f"Initialized single-GPU distributed environment for DeepSpeed on port {port}")
else:
raise
# Use DeepSpeed inference API instead of initialize
# This doesn't require an optimizer
with open(self.deepspeed_config) as f:
ds_config = json.load(f)
model_engine = deepspeed.init_inference(
model=model,
mp_size=1, # model parallel size
dtype=torch.float32, # Use float32 for compatibility
replace_with_kernel_inject=False, # Don't replace with DeepSpeed kernels for custom models
)
print("DeepSpeed Inference initialized successfully")
return model_engine
def _load_checkpoint_file(self, checkpoint_path: str) -> dict:
"""
Load checkpoint file and extract state dict.
Supports: sharded safetensors, single safetensors, .bin/.pt files
Args:
checkpoint_path: Path to checkpoint file or index.json
Returns:
state_dict: model state dictionary
"""
# Check if it's a sharded safetensors (index.json file)
if checkpoint_path.endswith('.index.json'):
print(f"Loading sharded safetensors from index: {checkpoint_path}")
with open(checkpoint_path, 'r') as f:
index = json.load(f)
# Get the directory containing the shards
shard_dir = os.path.dirname(checkpoint_path)
# Get unique shard files
shard_files = sorted(set(index['weight_map'].values()))
print(f"Loading {len(shard_files)} shard files in parallel...")
# Load shards in parallel using ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor, as_completed
def load_shard(shard_file):
shard_path = os.path.join(shard_dir, shard_file)
return shard_file, load_safetensors(shard_path)
state_dict = {}
with ThreadPoolExecutor(max_workers=len(shard_files)) as executor:
futures = {executor.submit(load_shard, sf): sf for sf in shard_files}
for future in as_completed(futures):
shard_file, shard_dict = future.result()
print(f" Loaded {shard_file}")
state_dict.update(shard_dict)
print(f"Loaded {len(state_dict)} tensors from sharded safetensors")
return state_dict
# Check if it's a single safetensors file
if checkpoint_path.endswith('.safetensors'):
print(f"Loading safetensors: {checkpoint_path}")
state_dict = load_safetensors(checkpoint_path)
return state_dict
# Check if it's a directory containing checkpoint files
if os.path.isdir(checkpoint_path):
# Look for index.json first (sharded safetensors)
index_path = os.path.join(checkpoint_path, 'model.safetensors.index.json')
if os.path.exists(index_path):
return self._load_checkpoint_file(index_path)
# Look for common checkpoint filenames
possible_files = [
'model.safetensors',
'model.pt', 'model.pth', 'model.bin',
'checkpoint.pt', 'checkpoint.pth',
'pytorch_model.bin', 'model_state_dict.pt'
]
checkpoint_file = None
for filename in possible_files:
full_path = os.path.join(checkpoint_path, filename)
if os.path.exists(full_path):
checkpoint_file = full_path
print(f"Found checkpoint file: {filename}")
break
if checkpoint_file is None:
import glob
# Try safetensors first
st_files = glob.glob(os.path.join(checkpoint_path, "*.safetensors"))
if st_files:
checkpoint_file = st_files[0]
else:
pt_files = glob.glob(os.path.join(checkpoint_path, "*.pt")) + \
glob.glob(os.path.join(checkpoint_path, "*.pth")) + \
glob.glob(os.path.join(checkpoint_path, "*.bin"))
if pt_files:
checkpoint_file = pt_files[0]
else:
raise ValueError(f"No checkpoint files found in directory: {checkpoint_path}")
print(f"Found checkpoint file: {os.path.basename(checkpoint_file)}")
checkpoint_path = checkpoint_file
# Recursively call to handle the found file
return self._load_checkpoint_file(checkpoint_path)
# Load .bin or .pt checkpoint
print(f"Loading checkpoint file: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location='cpu')
# Handle different checkpoint formats
if isinstance(checkpoint, dict):
if 'model' in checkpoint:
state_dict = checkpoint['model']
elif 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
elif 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
if 'epoch' in checkpoint:
print(f"Checkpoint from epoch: {checkpoint['epoch']}")
if 'global_step' in checkpoint:
print(f"Checkpoint from step: {checkpoint['global_step']}")
if 'loss' in checkpoint:
print(f"Checkpoint loss: {checkpoint['loss']:.4f}")
else:
state_dict = checkpoint
# Remove 'module.' prefix if present
if any(key.startswith('module.') for key in state_dict.keys()):
state_dict = {key.replace('module.', ''): value
for key, value in state_dict.items()}
print("Removed 'module.' prefix from state dict keys")
return state_dict
def text_to_cond_image(
self,
text: str,
img_size: int = 128,
font_scale: float = 0.8,
font_path: Optional[str] = None,
fixed_chars: int = 7
) -> Image.Image:
"""
Convert text to condition image - always creates image for fixed_chars characters
Text is arranged from top to bottom.
Args:
text: Chinese text to convert (must be <= fixed_chars characters)
img_size: size of each character block (default 128)
font_scale: scale of font relative to image size (default 0.8)
font_path: path to font file
fixed_chars: fixed number of character slots (default 7)
Returns:
PIL Image with text rendered (always fixed_chars * img_size height)
"""
if len(text) > fixed_chars:
raise ValueError(f"Text must be at most {fixed_chars} characters, got {len(text)}")
if font_path is None:
font_path = self.font_path
# Create font - font size is scaled down from img_size
font_size_scaled = int(font_scale * img_size)
font = ImageFont.truetype(font_path, font_size_scaled)
# Calculate image dimensions - always fixed_chars height
img_width = img_size
img_height = img_size * fixed_chars # Fixed height for 7 characters
# Create white background image
cond_img = Image.new("RGB", (img_width, img_height), (255, 255, 255))
cond_draw = ImageDraw.Draw(cond_img)
# Draw each character from top to bottom
# Note: font_size for positioning should be img_size, not the scaled font size
for i, char in enumerate(text):
font_space = font_size_scaled * (1 - font_scale) // 2
# Position based on img_size blocks, not scaled font size
font_position = (font_space, img_size * i + font_space)
cond_draw.text(font_position, char, font=font, fill=(0, 0, 0))
return cond_img
def build_prompt(
self,
font_style: str = "楷",
author: str = None,
is_traditional: bool = True,
) -> str:
"""
Build prompt for generation following dataset.py logic
Args:
font_style: font style (楷/草/行)
author: author name (Chinese or None for synthetic)
is_traditional: whether generating traditional calligraphy
Returns:
formatted prompt string
"""
# Validate font style
if font_style not in self.font_style_des:
raise ValueError(f"Font style must be one of: {list(self.font_style_des.keys())}")
# Convert font style to pinyin
font_style_pinyin = convert_to_pinyin(font_style)
# Build prompt based on traditional or synthetic
if is_traditional and author and author in self.author_style:
# Traditional calligraphy with specific author
prompt = f"Traditional Chinese calligraphy works, background: black, font: {font_style_pinyin}, "
prompt += self.font_style_des[font_style]
author_info = self.author_style[author]
prompt += f" author: {author_info}"
else:
# Synthetic calligraphy
prompt = f"Synthetic calligraphy data, background: black, font: {font_style_pinyin}, "
prompt += self.font_style_des[font_style]
return prompt
@torch.no_grad()
def generate(
self,
text: str,
font_style: str = "楷",
author: str = None,
width: int = 128,
height: int = None, # Fixed to 7 characters height
num_steps: int = 50,
guidance: float = 3.5,
seed: int = None,
is_traditional: bool = None,
save_path: Optional[str] = None
) -> tuple[Image.Image, Image.Image]:
"""
Generate calligraphy image from text
Args:
text: Chinese text to generate (1-7 characters)
font_style: font style (楷/草/行)
author: author/calligrapher name from the style list
width: image width (default 128)
height: image height (fixed to 7 * width)
num_steps: number of denoising steps
guidance: guidance scale
seed: random seed for generation
is_traditional: whether generating traditional calligraphy (auto-determined if None)
save_path: optional path to save the generated image
Returns:
tuple of (generated_image, condition_image)
"""
# Fixed number of characters
FIXED_CHARS = 7
# Validate text - must have 1-7 characters
if len(text) < 1:
raise ValueError(f"Text must have at least 1 character, got empty string")
if len(text) > FIXED_CHARS:
raise ValueError(f"Text must have at most {FIXED_CHARS} characters, got {len(text)}")
if seed is None:
seed = torch.randint(0, 2**32, (1,)).item()
# Fixed height for 7 characters
num_chars = len(text)
height = width * FIXED_CHARS # Always 7 characters height
# Auto-determine traditional vs synthetic
if is_traditional is None:
is_traditional = author is not None and author in self.author_style
# Generate condition image (fixed size for 7 characters)
cond_img = self.text_to_cond_image(text, img_size=width, fixed_chars=FIXED_CHARS)
# Build prompt
prompt = self.build_prompt(
font_style=font_style,
author=author,
is_traditional=is_traditional,
)
print(f"Generating with prompt: {prompt}")
print(f"Text: {text} ({num_chars} chars), Seed: {seed}")
# Generate image
result_img, recognized_text = self.sampler(
prompt=prompt,
width=width,
height=height,
num_steps=num_steps,
controlnet_image=cond_img,
is_generation=True,
cond_text=text,
required_chars=FIXED_CHARS, # Always 7 characters
seed=seed
)
# Crop to actual text length if less than FIXED_CHARS
if num_chars < FIXED_CHARS:
actual_height = width * num_chars
# Crop result image (top portion only)
result_img = result_img.crop((0, 0, width, actual_height))
# Crop condition image as well
cond_img = cond_img.crop((0, 0, width, actual_height))
# Save if path provided
if save_path:
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
result_img.save(save_path)
print(f"Image saved to {save_path}")
return result_img, cond_img
def batch_generate(
self,
texts: List[str],
font_styles: Optional[List[str]] = None,
authors: Optional[List[str]] = None,
output_dir: str = "./outputs",
**kwargs
) -> List[tuple[Image.Image, Image.Image]]:
"""
Batch generate calligraphy images
Args:
texts: list of texts to generate (1-7 characters each)
font_styles: list of font styles (if None, use default)
authors: list of authors (if None, use synthetic)
output_dir: directory to save outputs
**kwargs: additional arguments for generate()
Returns:
list of (generated_image, condition_image) tuples
"""
os.makedirs(output_dir, exist_ok=True)
results = []
# Default styles and authors if not provided
if font_styles is None:
font_styles = ["楷"] * len(texts)
if authors is None:
authors = [None] * len(texts)
for i, (text, font, author) in enumerate(zip(texts, font_styles, authors)):
# Clean author name for filename
author_name = author if author else "synthetic"
if author and author in self.author_style:
author_name = convert_to_pinyin(author)
save_path = os.path.join(
output_dir,
f"{text}_{font}_{author_name}_{i}.png"
)
result_img, cond_img = self.generate(
text=text,
font_style=font,
author=author,
save_path=save_path,
**kwargs
)
results.append((result_img, cond_img))
return results
def get_available_authors(self) -> List[str]:
"""Get list of available author styles"""
return list(self.author_style.keys())
def get_available_fonts(self) -> List[str]:
"""Get list of available font styles"""
return list(self.font_style_des.keys())
# Hugging Face Pipeline wrapper
class FluxCalligraphyPipeline:
"""Hugging Face compatible pipeline for calligraphy generation"""
def __init__(
self,
model_name: str = "flux-dev",
device: str = "cuda",
checkpoint_path: Optional[str] = None,
**kwargs
):
"""Initialize the pipeline"""
self.generator = CalligraphyGenerator(
model_name=model_name,
device=device,
checkpoint_path=checkpoint_path,
**kwargs
)
def __call__(
self,
text: Union[str, List[str]],
font_style: Union[str, List[str]] = "楷",
author: Union[str, List[str]] = None,
num_inference_steps: int = 50,
guidance_scale: float = 3.5,
generator: Optional[torch.Generator] = None,
**kwargs
) -> Union[Image.Image, List[Image.Image]]:
"""
Generate calligraphy images
Args:
text: text or list of texts to generate (1-7 characters each)
font_style: font style(s) (楷/草/行)
author: author name(s) from the style list
num_inference_steps: number of denoising steps
guidance_scale: guidance scale for generation
generator: torch generator for reproducibility
Returns:
generated image(s)
"""
# Handle single text
if isinstance(text, str):
seed = None
if generator is not None:
seed = generator.initial_seed()
result, _ = self.generator.generate(
text=text,
font_style=font_style,
author=author,
num_steps=num_inference_steps,
guidance=guidance_scale,
seed=seed,
**kwargs
)
return result
# Handle batch
else:
if isinstance(font_style, str):
font_style = [font_style] * len(text)
if isinstance(author, str) or author is None:
author = [author] * len(text)
results = []
for t, f, a in zip(text, font_style, author):
seed = None
if generator is not None:
seed = generator.initial_seed()
result, _ = self.generator.generate(
text=t,
font_style=f,
author=a,
num_steps=num_inference_steps,
guidance=guidance_scale,
seed=seed,
**kwargs
)
results.append(result)
return results
if __name__ == "__main__":
# Example usage
import argparse
parser = argparse.ArgumentParser(description="Generate Chinese calligraphy")
parser.add_argument("--text", type=str, default="暴富且平安", help="Text to generate (1-7 characters)")
parser.add_argument("--font", type=str, default="楷", help="Font style (楷/草/行)")
parser.add_argument("--author", type=str, default=None, help="Author/calligrapher name")
parser.add_argument("--steps", type=int, default=50, help="Number of inference steps")
parser.add_argument("--seed", type=int, default=None, help="Random seed")
parser.add_argument("--output", type=str, default="output.png", help="Output path")
parser.add_argument("--device", type=str, default="cuda", help="Device to use")
parser.add_argument("--checkpoint", type=str, default=None, help="Checkpoint path")
parser.add_argument("--list-authors", action="store_true", help="List available authors")
parser.add_argument("--list-fonts", action="store_true", help="List available font styles")
parser.add_argument("--float8", action="store_true", help="Use Float8 quantization (torchao) for faster inference")
parser.add_argument("--compile", action="store_true", help="Use torch.compile for optimized inference")
parser.add_argument("--compile-mode", type=str, default="max-autotune",
choices=["reduce-overhead", "max-autotune", "default"],
help="torch.compile mode")
args = parser.parse_args()
# Initialize generator
generator = CalligraphyGenerator(
model_name="flux-dev",
device=args.device,
checkpoint_path=args.checkpoint,
)
# Apply optimizations if requested (CLI mode)
if args.float8 or args.compile:
from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig
import torch._inductor.config as inductor_config
# Inductor configs from FLUX-Kontext-fp8
inductor_config.conv_1x1_as_mm = True
inductor_config.coordinate_descent_tuning = True
inductor_config.coordinate_descent_check_all_directions = True
inductor_config.max_autotune = True
if args.float8:
print("Applying Float8 quantization...")
quantize_(generator.model, Float8DynamicActivationFloat8WeightConfig())
print("✓ Float8 quantization complete!")
if args.compile:
print(f"Applying torch.compile (mode={args.compile_mode})...")
generator.model = torch.compile(
generator.model,
mode=args.compile_mode,
backend="inductor",
dynamic=True,
)
print("✓ torch.compile applied!")
# List available options
if args.list_authors:
print("Available authors:")
for author in generator.get_available_authors()[:20]: # Show first 20
print(f" - {author}")
print(f" ... and {len(generator.get_available_authors()) - 20} more")
exit(0)
if args.list_fonts:
print("Available font styles:")
for font in generator.get_available_fonts():
print(f" - {font}: {generator.font_style_des[font]}")
exit(0)
# Validate text - must have 1-7 characters
if len(args.text) < 1:
print(f"Error: Text must have at least 1 character")
exit(1)
if len(args.text) > 7:
print(f"Error: Text must have at most 7 characters, got {len(args.text)}")
exit(1)
# Generate
result_img, cond_img = generator.generate(
text=args.text,
font_style=args.font,
author=args.author,
num_steps=args.steps,
seed=args.seed,
save_path=args.output
)
print(f"Generation complete! Saved to {args.output}")