# -*- coding: utf-8 -*- """ Chinese Calligraphy Generation with Flux Model Author and font style controllable generation """ import os import json import torch from safetensors.torch import load_file as load_safetensors from optimum.quanto import quantize, freeze, qint4 from PIL import Image, ImageDraw, ImageFont from typing import Optional, List, Union, Dict, Any from einops import rearrange from pypinyin import lazy_pinyin from huggingface_hub import hf_hub_download, snapshot_download from src.flux.util import configs, load_ae, load_clip, load_t5 from src.flux.model import Flux from src.flux.xflux_pipeline import XFluxSampler # HuggingFace Hub model IDs HF_MODEL_ID = "TSXu/Unicalli_Pro" HF_CHECKPOINT_INDEX = "model.safetensors.index.json" # Sharded safetensors index HF_INTERNVL_ID = "OpenGVLab/InternVL3-1B" def download_sharded_safetensors( model_id: str = HF_MODEL_ID, local_dir: str = None, force_download: bool = False ) -> str: """ Download sharded safetensors model from HuggingFace Hub Args: model_id: HuggingFace model repository ID local_dir: Local directory to save files (optional) force_download: Whether to force re-download Returns: Path to the index.json file """ print(f"Downloading sharded safetensors from HuggingFace Hub ({model_id})...") # Get HF token from environment for private repos hf_token = os.environ.get("HF_TOKEN", None) try: # First download the index file index_path = hf_hub_download( repo_id=model_id, filename=HF_CHECKPOINT_INDEX, local_dir=local_dir, force_download=force_download, token=hf_token ) print(f"Index downloaded to: {index_path}") # Read index to get shard filenames with open(index_path, 'r') as f: index = json.load(f) # Get unique shard files shard_files = set(index['weight_map'].values()) print(f"Downloading {len(shard_files)} shard files...") # Download all shards for shard_file in sorted(shard_files): print(f" Downloading {shard_file}...") hf_hub_download( repo_id=model_id, filename=shard_file, local_dir=local_dir, force_download=force_download, token=hf_token ) print(f"All shards downloaded!") return index_path except Exception as e: print(f"Error downloading model: {e}") raise def is_huggingface_repo_id(path: str) -> bool: """ Check if a string looks like a HuggingFace repo ID (e.g., 'namespace/repo_name') NOT a local file path """ # HF repo IDs have format: namespace/repo_name (exactly one /) # Local paths typically have multiple / or start with / or . if path.startswith('/') or path.startswith('.') or path.startswith('~'): return False parts = path.split('/') # HF repo ID should have exactly 2 parts: namespace and repo_name if len(parts) == 2 and all(part and not part.startswith('.') for part in parts): return True return False def ensure_checkpoint_exists(checkpoint_path: str) -> str: """ Ensure checkpoint exists locally, download from HF Hub if not Args: checkpoint_path: Local path or HF model ID Returns: Path to the local checkpoint/index file """ # If it's a local path and exists, return it if os.path.exists(checkpoint_path): print(f"Using local checkpoint: {checkpoint_path}") return checkpoint_path # If it looks like a HuggingFace repo ID (e.g., "TSXu/Unicalli_Pro") if is_huggingface_repo_id(checkpoint_path): print(f"Downloading from HuggingFace Hub: {checkpoint_path}") return download_sharded_safetensors(model_id=checkpoint_path) raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") def convert_to_pinyin(text): return ' '.join([item[0] if isinstance(item, list) else item for item in lazy_pinyin(text)]) class CalligraphyGenerator: """ Chinese Calligraphy Generator using Flux model Attributes: device: torch device for computation model_name: name of the flux model (flux-dev or flux-schnell) font_styles: available font styles for generation authors: available calligrapher authors """ def __init__( self, model_name: str = "flux-dev", device: str = "cuda", offload: bool = True, checkpoint_path: Optional[str] = None, intern_vlm_path: Optional[str] = None, ref_latent_path: Optional[str] = None, font_descriptions_path: str = "chirography.json", author_descriptions_path: str = "calligraphy_styles_en.json", use_deepspeed: bool = False, use_4bit_quantization: bool = False, use_float8_quantization: bool = False, use_torch_compile: bool = False, compile_mode: str = "reduce-overhead", deepspeed_config: Optional[str] = None, dtype: Optional[str] = None, preloaded_embedding: Optional[torch.nn.Module] = None, preloaded_tokenizer: Optional[Any] = None, ): """ Initialize the calligraphy generator Args: model_name: flux model name (flux-dev or flux-schnell) device: device for computation offload: whether to offload model to CPU when not in use checkpoint_path: path to model checkpoint if using fine-tuned model intern_vlm_path: path to InternVLM model for text embedding ref_latent_path: path to reference latents for recognition mode font_descriptions_path: path to font style descriptions JSON author_descriptions_path: path to author style descriptions JSON use_deepspeed: whether to use DeepSpeed ZeRO for memory optimization use_4bit_quantization: whether to use 4-bit quantization (quanto/bitsandbytes) use_float8_quantization: whether to use Float8 quantization (torchao) for faster inference use_torch_compile: whether to use torch.compile for optimized inference compile_mode: torch.compile mode - "reduce-overhead", "max-autotune", or "default" deepspeed_config: path to DeepSpeed config JSON file dtype: force specific dtype for inference: "fp16", "bf16", "fp32", or None for auto """ self.device = torch.device(device) self.model_name = model_name self.offload = offload self.is_schnell = model_name == "flux-schnell" self.use_deepspeed = use_deepspeed self.deepspeed_config = deepspeed_config self.use_4bit_quantization = use_4bit_quantization self.use_float8_quantization = use_float8_quantization self.use_torch_compile = use_torch_compile self.compile_mode = compile_mode self.forced_dtype = dtype # "fp16", "bf16", "fp32", or None for auto # Load font and author style descriptions if os.path.exists(font_descriptions_path): with open(font_descriptions_path, 'r', encoding='utf-8') as f: self.font_style_des = json.load(f) else: raise FileNotFoundError(f"Font descriptions file not found: {font_descriptions_path}") if os.path.exists(author_descriptions_path): with open(author_descriptions_path, 'r', encoding='utf-8') as f: self.author_style = json.load(f) else: raise FileNotFoundError(f"Author descriptions file not found: {author_descriptions_path}") # Load models print("Loading models...") # When using DeepSpeed, load text encoders on CPU first to save memory during initialization # They will be moved to GPU after DeepSpeed initializes the main model if self.use_deepspeed: text_encoder_device = "cpu" elif offload: text_encoder_device = "cpu" # Will be moved to GPU during inference else: text_encoder_device = self.device self.t5 = load_t5(text_encoder_device, max_length=256 if self.is_schnell else 512) self.clip = load_clip(text_encoder_device) self.clip.requires_grad_(False) # Ensure checkpoint exists (download from HF Hub if needed) if checkpoint_path: checkpoint_path = ensure_checkpoint_exists(checkpoint_path) print(f"Loading model from checkpoint: {checkpoint_path}") # When using DeepSpeed, don't move to GPU yet - let DeepSpeed handle it self.model = self._load_model_from_checkpoint( checkpoint_path, model_name, offload=offload, use_deepspeed=self.use_deepspeed ) # Initialize DeepSpeed if requested if self.use_deepspeed: self.model = self._init_deepspeed(self.model) else: # If no checkpoint path provided, download default from HF Hub print("No checkpoint path provided, downloading from HuggingFace Hub...") checkpoint_path = download_model_from_hf() print(f"Loading model from checkpoint: {checkpoint_path}") self.model = self._load_model_from_checkpoint( checkpoint_path, model_name, offload=offload, use_deepspeed=self.use_deepspeed ) if self.use_deepspeed: self.model = self._init_deepspeed(self.model) # Note: Float8 quantization and torch.compile optimizations # are applied externally (e.g., in app.py) for better control # over the optimization process with ZeroGPU AOT compilation. # Load VAE if self.use_deepspeed or offload: vae_device = "cpu" else: vae_device = self.device self.vae = load_ae(model_name, device=vae_device) # Move VAE to GPU only if offload (not DeepSpeed) if offload and not self.use_deepspeed: self.vae = self.vae.to(self.device) # After DeepSpeed init, move text encoders to GPU if self.use_deepspeed: print("Moving text encoders to GPU...") self.t5 = self.t5.to(self.device) self.clip = self.clip.to(self.device) self.vae = self.vae.to(self.device) # Load reference latents if provided self.ref_latent = None if ref_latent_path and os.path.exists(ref_latent_path): print(f"Loading reference latents from {ref_latent_path}") self.ref_latent = torch.load(ref_latent_path, map_location='cpu') # Create sampler (use preloaded embedding if available) self.sampler = XFluxSampler( clip=self.clip, t5=self.t5, ae=self.vae, ref_latent=self.ref_latent, model=self.model, device=self.device, intern_vlm_path=intern_vlm_path, preloaded_embedding=preloaded_embedding, preloaded_tokenizer=preloaded_tokenizer, ) # Font for generating condition images project_root = os.path.dirname(os.path.abspath(__file__)) local_font_path = os.path.join(project_root, "FangZhengKaiTiFanTi-1.ttf") self.font_path = self._ensure_font_exists(local_font_path) self.default_font_size = 102 # 128 * 0.8 def _ensure_font_exists(self, font_path: str) -> str: """ Ensure font file exists locally, download from HF Hub if not Args: font_path: Local path to font file Returns: Path to the local font file """ cached_font_path = os.environ.get("UNICALLI_FONT_PATH") if cached_font_path and os.path.exists(cached_font_path): return cached_font_path if os.path.exists(font_path): return font_path # Try to download from HF Hub print(f"Font file not found locally, downloading from HuggingFace Hub...") hf_token = os.environ.get("HF_TOKEN", None) try: font_path = hf_hub_download( repo_id=HF_MODEL_ID, filename="FangZhengKaiTiFanTi-1.ttf", token=hf_token ) print(f"Font downloaded to: {font_path}") return font_path except Exception as e: print(f"Warning: Could not download font: {e}") return font_path # Return original path, may fail later def _load_model_from_checkpoint(self, checkpoint_path: str, model_name: str, offload: bool, use_deepspeed: bool = False): """ Load model from checkpoint without loading flux pretrained weights. Supports both regular checkpoints and NF4 quantized checkpoints. Args: checkpoint_path: Path to your checkpoint file or NF4 model directory model_name: flux model name (for config) offload: whether to offload to CPU use_deepspeed: whether using DeepSpeed (keeps model on CPU) Returns: model with loaded checkpoint """ print(f"Creating empty flux model structure...") load_device = "cpu" # Create model structure without loading pretrained weights (using "meta" device) with torch.device("meta"): model = Flux(configs[model_name].params) # Initialize module embeddings (must be done before loading checkpoint) print("Initializing module embeddings...") model.init_module_embeddings(tokens_num=320, cond_txt_channel=896) # Move model to loading device print(f"Moving model to {load_device} for loading...") model = model.to_empty(device=load_device) # Check if this is an NF4 quantized model is_nf4 = self._is_nf4_checkpoint(checkpoint_path) # Load checkpoint print(f"Loading checkpoint from {checkpoint_path}") if is_nf4: print("Detected NF4 quantized model, dequantizing...") checkpoint = self._load_nf4_checkpoint(checkpoint_path) else: checkpoint = self._load_checkpoint_file(checkpoint_path) # Determine dtype from checkpoint - keep original dtype for efficiency first_tensor = next(iter(checkpoint.values())) checkpoint_dtype = first_tensor.dtype print(f"Checkpoint dtype: {checkpoint_dtype}") # Check if user forced a specific dtype forced_dtype = getattr(self, 'forced_dtype', None) if forced_dtype: dtype_map = { "fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32, "fp8": torch.float8_e4m3fn, } if forced_dtype not in dtype_map: print(f"Warning: Unknown dtype '{forced_dtype}', using auto selection") forced_dtype = None else: target_dtype = dtype_map[forced_dtype] print(f"Using forced dtype: {target_dtype}") if checkpoint_dtype != target_dtype: print(f"Converting checkpoint from {checkpoint_dtype} to {target_dtype}...") checkpoint = {k: v.to(target_dtype) for k, v in checkpoint.items()} if not forced_dtype: # Note: We trust the original precision (like FP8) if it is provided that way target_dtype = checkpoint_dtype print(f"Using auto-detected checkpoint dtype: {target_dtype} for inference loading") # Load weights into model model.load_state_dict(checkpoint, strict=False, assign=True) print(f"Model dtype after loading: {next(model.parameters()).dtype}") # Store target dtype for inference self._model_dtype = target_dtype # Free checkpoint memory del checkpoint # Apply bitsandbytes 4-bit quantization if requested if hasattr(self, 'use_4bit_quantization') and self.use_4bit_quantization: try: import bitsandbytes as bnb print("Applying bitsandbytes NF4 quantization for 4-bit inference...") model = self._quantize_model_bnb(model) model._is_quantized = True print("bitsandbytes NF4 quantization complete!") except ImportError: print("bitsandbytes not available, using quanto quantization...") model = model.float() quantize(model, weights=qint4) freeze(model) model._is_quantized = True print("quanto 4-bit quantization complete!") # Move to GPU only if NOT using DeepSpeed if not use_deepspeed: if self.device.type != "cpu": print(f"Moving model to {self.device}...") model = model.to(self.device) # Enable optimized attention backends try: torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_math_sdp(False) print("Enabled FlashAttention / Memory-Efficient SDPA backends") except Exception as e: print(f"Could not configure SDPA backends: {e}") return model def _is_nf4_checkpoint(self, path: str) -> bool: """Check if path contains an NF4 quantized checkpoint""" if os.path.isdir(path): return os.path.exists(os.path.join(path, "quantization_config.json")) return False def _load_nf4_checkpoint(self, checkpoint_dir: str) -> dict: """ Load NF4 quantized checkpoint and dequantize to float tensors. Args: checkpoint_dir: Directory containing NF4 model files Returns: Dequantized state dict """ from safetensors.torch import load_file as load_safetensors # Load quantization config config_path = os.path.join(checkpoint_dir, "quantization_config.json") with open(config_path, 'r') as f: quant_config = json.load(f) block_size = quant_config.get("block_size", 64) quantized_keys = set(quant_config.get("quantized_keys", [])) # Load index index_path = os.path.join(checkpoint_dir, "model_nf4.safetensors.index.json") with open(index_path, 'r') as f: index = json.load(f) # Load all shards shard_files = sorted(set(index['weight_map'].values())) print(f"Loading {len(shard_files)} NF4 shards...") raw_state = {} for shard_file in shard_files: shard_path = os.path.join(checkpoint_dir, shard_file) print(f" Loading {shard_file}...") shard_data = load_safetensors(shard_path) raw_state.update(shard_data) # NF4 lookup table for dequantization nf4_values = torch.tensor([ -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230850220, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0 ], dtype=torch.float32) # Dequantize state_dict = {} dequant_count = 0 for key in list(raw_state.keys()): if key.endswith('.quant_data'): base_key = key.replace('.quant_data', '') if base_key in quantized_keys: # Dequantize this tensor quant_data = raw_state[f"{base_key}.quant_data"] scales = raw_state[f"{base_key}.scales"] shape = raw_state[f"{base_key}.shape"].tolist() pad_len = raw_state[f"{base_key}.pad_len"].item() # Unpack 4-bit values high = (quant_data >> 4) & 0x0F low = quant_data & 0x0F indices = torch.stack([high, low], dim=-1).flatten().long() # Lookup and reshape values = nf4_values[indices] # Apply scales num_blocks = len(scales) values = values[:num_blocks * block_size].reshape(num_blocks, block_size) values = values * scales.float().unsqueeze(1) values = values.flatten() # Remove padding and reshape if pad_len > 0: values = values[:-pad_len] state_dict[base_key] = values.reshape(shape) dequant_count += 1 elif not any(key.endswith(s) for s in ['.scales', '.shape', '.block_size', '.pad_len']): # Non-quantized tensor, keep as-is state_dict[key] = raw_state[key] print(f"Dequantized {dequant_count} tensors") return state_dict def _quantize_model_bnb(self, model): """ Quantize model using bitsandbytes NF4. Replaces Linear layers with Linear4bit for true 4-bit inference. """ import bitsandbytes as bnb import torch.nn as nn def replace_linear_with_4bit(module, name=''): for child_name, child in list(module.named_children()): full_name = f"{name}.{child_name}" if name else child_name if isinstance(child, nn.Linear): # Create 4-bit linear layer new_layer = bnb.nn.Linear4bit( child.in_features, child.out_features, bias=child.bias is not None, compute_dtype=torch.bfloat16, compress_statistics=True, quant_type='nf4' ) # Copy weights (will be quantized when moved to GPU) new_layer.weight = bnb.nn.Params4bit( child.weight.data, requires_grad=False, quant_type='nf4' ) if child.bias is not None: new_layer.bias = nn.Parameter(child.bias.data) setattr(module, child_name, new_layer) else: replace_linear_with_4bit(child, full_name) print("Replacing Linear layers with Linear4bit...") replace_linear_with_4bit(model) return model def _init_deepspeed(self, model): """ Initialize DeepSpeed for the model with ZeRO-3 inference optimization. Args: model: PyTorch model to wrap with DeepSpeed Returns: DeepSpeed inference engine """ try: import deepspeed except ImportError: raise ImportError("DeepSpeed is not installed. Install it with: pip install deepspeed") # Load DeepSpeed config if self.deepspeed_config is None: self.deepspeed_config = "ds_config_zero2.json" if not os.path.exists(self.deepspeed_config): raise FileNotFoundError(f"DeepSpeed config not found: {self.deepspeed_config}") print(f"Initializing DeepSpeed Inference with config: {self.deepspeed_config}") # Initialize distributed environment for single GPU if not already initialized if not torch.distributed.is_initialized(): import random # Set environment variables for single-process mode # Use a random port to avoid conflicts port = random.randint(29500, 29600) os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = str(port) os.environ['RANK'] = '0' os.environ['LOCAL_RANK'] = '0' os.environ['WORLD_SIZE'] = '1' # Initialize process group try: torch.distributed.init_process_group( backend='nccl', init_method='env://', world_size=1, rank=0 ) print(f"Initialized single-GPU distributed environment for DeepSpeed on port {port}") except RuntimeError as e: if "address already in use" in str(e): print(f"Port {port} in use, trying again...") # Try a different port port = random.randint(29600, 29700) os.environ['MASTER_PORT'] = str(port) torch.distributed.init_process_group( backend='nccl', init_method='env://', world_size=1, rank=0 ) print(f"Initialized single-GPU distributed environment for DeepSpeed on port {port}") else: raise # Use DeepSpeed inference API instead of initialize # This doesn't require an optimizer with open(self.deepspeed_config) as f: ds_config = json.load(f) model_engine = deepspeed.init_inference( model=model, mp_size=1, # model parallel size dtype=torch.float32, # Use float32 for compatibility replace_with_kernel_inject=False, # Don't replace with DeepSpeed kernels for custom models ) print("DeepSpeed Inference initialized successfully") return model_engine def _load_checkpoint_file(self, checkpoint_path: str) -> dict: """ Load checkpoint file and extract state dict. Supports: sharded safetensors, single safetensors, .bin/.pt files Args: checkpoint_path: Path to checkpoint file or index.json Returns: state_dict: model state dictionary """ # Check if it's a sharded safetensors (index.json file) if checkpoint_path.endswith('.index.json'): print(f"Loading sharded safetensors from index: {checkpoint_path}") with open(checkpoint_path, 'r') as f: index = json.load(f) # Get the directory containing the shards shard_dir = os.path.dirname(checkpoint_path) # Get unique shard files shard_files = sorted(set(index['weight_map'].values())) print(f"Loading {len(shard_files)} shard files in parallel...") # Load shards in parallel using ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor, as_completed def load_shard(shard_file): shard_path = os.path.join(shard_dir, shard_file) return shard_file, load_safetensors(shard_path) state_dict = {} with ThreadPoolExecutor(max_workers=len(shard_files)) as executor: futures = {executor.submit(load_shard, sf): sf for sf in shard_files} for future in as_completed(futures): shard_file, shard_dict = future.result() print(f" Loaded {shard_file}") state_dict.update(shard_dict) print(f"Loaded {len(state_dict)} tensors from sharded safetensors") return state_dict # Check if it's a single safetensors file if checkpoint_path.endswith('.safetensors'): print(f"Loading safetensors: {checkpoint_path}") state_dict = load_safetensors(checkpoint_path) return state_dict # Check if it's a directory containing checkpoint files if os.path.isdir(checkpoint_path): # Look for index.json first (sharded safetensors) index_path = os.path.join(checkpoint_path, 'model.safetensors.index.json') if os.path.exists(index_path): return self._load_checkpoint_file(index_path) # Look for common checkpoint filenames possible_files = [ 'model.safetensors', 'model.pt', 'model.pth', 'model.bin', 'checkpoint.pt', 'checkpoint.pth', 'pytorch_model.bin', 'model_state_dict.pt' ] checkpoint_file = None for filename in possible_files: full_path = os.path.join(checkpoint_path, filename) if os.path.exists(full_path): checkpoint_file = full_path print(f"Found checkpoint file: {filename}") break if checkpoint_file is None: import glob # Try safetensors first st_files = glob.glob(os.path.join(checkpoint_path, "*.safetensors")) if st_files: checkpoint_file = st_files[0] else: pt_files = glob.glob(os.path.join(checkpoint_path, "*.pt")) + \ glob.glob(os.path.join(checkpoint_path, "*.pth")) + \ glob.glob(os.path.join(checkpoint_path, "*.bin")) if pt_files: checkpoint_file = pt_files[0] else: raise ValueError(f"No checkpoint files found in directory: {checkpoint_path}") print(f"Found checkpoint file: {os.path.basename(checkpoint_file)}") checkpoint_path = checkpoint_file # Recursively call to handle the found file return self._load_checkpoint_file(checkpoint_path) # Load .bin or .pt checkpoint print(f"Loading checkpoint file: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location='cpu') # Handle different checkpoint formats if isinstance(checkpoint, dict): if 'model' in checkpoint: state_dict = checkpoint['model'] elif 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] elif 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint if 'epoch' in checkpoint: print(f"Checkpoint from epoch: {checkpoint['epoch']}") if 'global_step' in checkpoint: print(f"Checkpoint from step: {checkpoint['global_step']}") if 'loss' in checkpoint: print(f"Checkpoint loss: {checkpoint['loss']:.4f}") else: state_dict = checkpoint # Remove 'module.' prefix if present if any(key.startswith('module.') for key in state_dict.keys()): state_dict = {key.replace('module.', ''): value for key, value in state_dict.items()} print("Removed 'module.' prefix from state dict keys") return state_dict def text_to_cond_image( self, text: str, img_size: int = 128, font_scale: float = 0.8, font_path: Optional[str] = None, fixed_chars: int = 7 ) -> Image.Image: """ Convert text to condition image - always creates image for fixed_chars characters Text is arranged from top to bottom. Args: text: Chinese text to convert (must be <= fixed_chars characters) img_size: size of each character block (default 128) font_scale: scale of font relative to image size (default 0.8) font_path: path to font file fixed_chars: fixed number of character slots (default 7) Returns: PIL Image with text rendered (always fixed_chars * img_size height) """ if len(text) > fixed_chars: raise ValueError(f"Text must be at most {fixed_chars} characters, got {len(text)}") if font_path is None: font_path = self.font_path # Create font - font size is scaled down from img_size font_size_scaled = int(font_scale * img_size) font = ImageFont.truetype(font_path, font_size_scaled) # Calculate image dimensions - always fixed_chars height img_width = img_size img_height = img_size * fixed_chars # Fixed height for 7 characters # Create white background image cond_img = Image.new("RGB", (img_width, img_height), (255, 255, 255)) cond_draw = ImageDraw.Draw(cond_img) # Draw each character from top to bottom # Note: font_size for positioning should be img_size, not the scaled font size for i, char in enumerate(text): font_space = font_size_scaled * (1 - font_scale) // 2 # Position based on img_size blocks, not scaled font size font_position = (font_space, img_size * i + font_space) cond_draw.text(font_position, char, font=font, fill=(0, 0, 0)) return cond_img def build_prompt( self, font_style: str = "楷", author: str = None, is_traditional: bool = True, ) -> str: """ Build prompt for generation following dataset.py logic Args: font_style: font style (楷/草/行) author: author name (Chinese or None for synthetic) is_traditional: whether generating traditional calligraphy Returns: formatted prompt string """ # Validate font style if font_style not in self.font_style_des: raise ValueError(f"Font style must be one of: {list(self.font_style_des.keys())}") # Convert font style to pinyin font_style_pinyin = convert_to_pinyin(font_style) # Build prompt based on traditional or synthetic if is_traditional and author and author in self.author_style: # Traditional calligraphy with specific author prompt = f"Traditional Chinese calligraphy works, background: black, font: {font_style_pinyin}, " prompt += self.font_style_des[font_style] author_info = self.author_style[author] prompt += f" author: {author_info}" else: # Synthetic calligraphy prompt = f"Synthetic calligraphy data, background: black, font: {font_style_pinyin}, " prompt += self.font_style_des[font_style] return prompt @torch.no_grad() def generate( self, text: str, font_style: str = "楷", author: str = None, width: int = 128, height: int = None, # Fixed to 7 characters height num_steps: int = 50, guidance: float = 3.5, seed: int = None, is_traditional: bool = None, save_path: Optional[str] = None ) -> tuple[Image.Image, Image.Image]: """ Generate calligraphy image from text Args: text: Chinese text to generate (1-7 characters) font_style: font style (楷/草/行) author: author/calligrapher name from the style list width: image width (default 128) height: image height (fixed to 7 * width) num_steps: number of denoising steps guidance: guidance scale seed: random seed for generation is_traditional: whether generating traditional calligraphy (auto-determined if None) save_path: optional path to save the generated image Returns: tuple of (generated_image, condition_image) """ # Fixed number of characters FIXED_CHARS = 7 # Validate text - must have 1-7 characters if len(text) < 1: raise ValueError(f"Text must have at least 1 character, got empty string") if len(text) > FIXED_CHARS: raise ValueError(f"Text must have at most {FIXED_CHARS} characters, got {len(text)}") if seed is None: seed = torch.randint(0, 2**32, (1,)).item() # Fixed height for 7 characters num_chars = len(text) height = width * FIXED_CHARS # Always 7 characters height # Auto-determine traditional vs synthetic if is_traditional is None: is_traditional = author is not None and author in self.author_style # Generate condition image (fixed size for 7 characters) cond_img = self.text_to_cond_image(text, img_size=width, fixed_chars=FIXED_CHARS) # Build prompt prompt = self.build_prompt( font_style=font_style, author=author, is_traditional=is_traditional, ) print(f"Generating with prompt: {prompt}") print(f"Text: {text} ({num_chars} chars), Seed: {seed}") # Generate image result_img, recognized_text = self.sampler( prompt=prompt, width=width, height=height, num_steps=num_steps, controlnet_image=cond_img, is_generation=True, cond_text=text, required_chars=FIXED_CHARS, # Always 7 characters seed=seed ) # Crop to actual text length if less than FIXED_CHARS if num_chars < FIXED_CHARS: actual_height = width * num_chars # Crop result image (top portion only) result_img = result_img.crop((0, 0, width, actual_height)) # Crop condition image as well cond_img = cond_img.crop((0, 0, width, actual_height)) # Save if path provided if save_path: os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) result_img.save(save_path) print(f"Image saved to {save_path}") return result_img, cond_img def batch_generate( self, texts: List[str], font_styles: Optional[List[str]] = None, authors: Optional[List[str]] = None, output_dir: str = "./outputs", **kwargs ) -> List[tuple[Image.Image, Image.Image]]: """ Batch generate calligraphy images Args: texts: list of texts to generate (1-7 characters each) font_styles: list of font styles (if None, use default) authors: list of authors (if None, use synthetic) output_dir: directory to save outputs **kwargs: additional arguments for generate() Returns: list of (generated_image, condition_image) tuples """ os.makedirs(output_dir, exist_ok=True) results = [] # Default styles and authors if not provided if font_styles is None: font_styles = ["楷"] * len(texts) if authors is None: authors = [None] * len(texts) for i, (text, font, author) in enumerate(zip(texts, font_styles, authors)): # Clean author name for filename author_name = author if author else "synthetic" if author and author in self.author_style: author_name = convert_to_pinyin(author) save_path = os.path.join( output_dir, f"{text}_{font}_{author_name}_{i}.png" ) result_img, cond_img = self.generate( text=text, font_style=font, author=author, save_path=save_path, **kwargs ) results.append((result_img, cond_img)) return results def get_available_authors(self) -> List[str]: """Get list of available author styles""" return list(self.author_style.keys()) def get_available_fonts(self) -> List[str]: """Get list of available font styles""" return list(self.font_style_des.keys()) # Hugging Face Pipeline wrapper class FluxCalligraphyPipeline: """Hugging Face compatible pipeline for calligraphy generation""" def __init__( self, model_name: str = "flux-dev", device: str = "cuda", checkpoint_path: Optional[str] = None, **kwargs ): """Initialize the pipeline""" self.generator = CalligraphyGenerator( model_name=model_name, device=device, checkpoint_path=checkpoint_path, **kwargs ) def __call__( self, text: Union[str, List[str]], font_style: Union[str, List[str]] = "楷", author: Union[str, List[str]] = None, num_inference_steps: int = 50, guidance_scale: float = 3.5, generator: Optional[torch.Generator] = None, **kwargs ) -> Union[Image.Image, List[Image.Image]]: """ Generate calligraphy images Args: text: text or list of texts to generate (1-7 characters each) font_style: font style(s) (楷/草/行) author: author name(s) from the style list num_inference_steps: number of denoising steps guidance_scale: guidance scale for generation generator: torch generator for reproducibility Returns: generated image(s) """ # Handle single text if isinstance(text, str): seed = None if generator is not None: seed = generator.initial_seed() result, _ = self.generator.generate( text=text, font_style=font_style, author=author, num_steps=num_inference_steps, guidance=guidance_scale, seed=seed, **kwargs ) return result # Handle batch else: if isinstance(font_style, str): font_style = [font_style] * len(text) if isinstance(author, str) or author is None: author = [author] * len(text) results = [] for t, f, a in zip(text, font_style, author): seed = None if generator is not None: seed = generator.initial_seed() result, _ = self.generator.generate( text=t, font_style=f, author=a, num_steps=num_inference_steps, guidance=guidance_scale, seed=seed, **kwargs ) results.append(result) return results if __name__ == "__main__": # Example usage import argparse parser = argparse.ArgumentParser(description="Generate Chinese calligraphy") parser.add_argument("--text", type=str, default="暴富且平安", help="Text to generate (1-7 characters)") parser.add_argument("--font", type=str, default="楷", help="Font style (楷/草/行)") parser.add_argument("--author", type=str, default=None, help="Author/calligrapher name") parser.add_argument("--steps", type=int, default=50, help="Number of inference steps") parser.add_argument("--seed", type=int, default=None, help="Random seed") parser.add_argument("--output", type=str, default="output.png", help="Output path") parser.add_argument("--device", type=str, default="cuda", help="Device to use") parser.add_argument("--checkpoint", type=str, default=None, help="Checkpoint path") parser.add_argument("--list-authors", action="store_true", help="List available authors") parser.add_argument("--list-fonts", action="store_true", help="List available font styles") parser.add_argument("--float8", action="store_true", help="Use Float8 quantization (torchao) for faster inference") parser.add_argument("--compile", action="store_true", help="Use torch.compile for optimized inference") parser.add_argument("--compile-mode", type=str, default="max-autotune", choices=["reduce-overhead", "max-autotune", "default"], help="torch.compile mode") args = parser.parse_args() # Initialize generator generator = CalligraphyGenerator( model_name="flux-dev", device=args.device, checkpoint_path=args.checkpoint, ) # Apply optimizations if requested (CLI mode) if args.float8 or args.compile: from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig import torch._inductor.config as inductor_config # Inductor configs from FLUX-Kontext-fp8 inductor_config.conv_1x1_as_mm = True inductor_config.coordinate_descent_tuning = True inductor_config.coordinate_descent_check_all_directions = True inductor_config.max_autotune = True if args.float8: print("Applying Float8 quantization...") quantize_(generator.model, Float8DynamicActivationFloat8WeightConfig()) print("✓ Float8 quantization complete!") if args.compile: print(f"Applying torch.compile (mode={args.compile_mode})...") generator.model = torch.compile( generator.model, mode=args.compile_mode, backend="inductor", dynamic=True, ) print("✓ torch.compile applied!") # List available options if args.list_authors: print("Available authors:") for author in generator.get_available_authors()[:20]: # Show first 20 print(f" - {author}") print(f" ... and {len(generator.get_available_authors()) - 20} more") exit(0) if args.list_fonts: print("Available font styles:") for font in generator.get_available_fonts(): print(f" - {font}: {generator.font_style_des[font]}") exit(0) # Validate text - must have 1-7 characters if len(args.text) < 1: print(f"Error: Text must have at least 1 character") exit(1) if len(args.text) > 7: print(f"Error: Text must have at most 7 characters, got {len(args.text)}") exit(1) # Generate result_img, cond_img = generator.generate( text=args.text, font_style=args.font, author=args.author, num_steps=args.steps, seed=args.seed, save_path=args.output ) print(f"Generation complete! Saved to {args.output}")