Spaces:
Running on Zero
Running on Zero
| # -*- coding: utf-8 -*- | |
| """ | |
| Chinese Calligraphy Generation with Flux Model | |
| Author and font style controllable generation | |
| """ | |
| import os | |
| import json | |
| import torch | |
| from safetensors.torch import load_file as load_safetensors | |
| from optimum.quanto import quantize, freeze, qint4 | |
| from PIL import Image, ImageDraw, ImageFont | |
| from typing import Optional, List, Union, Dict, Any | |
| from einops import rearrange | |
| from pypinyin import lazy_pinyin | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| from src.flux.util import configs, load_ae, load_clip, load_t5 | |
| from src.flux.model import Flux | |
| from src.flux.xflux_pipeline import XFluxSampler | |
| # HuggingFace Hub model IDs | |
| HF_MODEL_ID = "TSXu/Unicalli_Pro" | |
| HF_CHECKPOINT_INDEX = "model.safetensors.index.json" # Sharded safetensors index | |
| HF_INTERNVL_ID = "OpenGVLab/InternVL3-1B" | |
| def download_sharded_safetensors( | |
| model_id: str = HF_MODEL_ID, | |
| local_dir: str = None, | |
| force_download: bool = False | |
| ) -> str: | |
| """ | |
| Download sharded safetensors model from HuggingFace Hub | |
| Args: | |
| model_id: HuggingFace model repository ID | |
| local_dir: Local directory to save files (optional) | |
| force_download: Whether to force re-download | |
| Returns: | |
| Path to the index.json file | |
| """ | |
| print(f"Downloading sharded safetensors from HuggingFace Hub ({model_id})...") | |
| # Get HF token from environment for private repos | |
| hf_token = os.environ.get("HF_TOKEN", None) | |
| try: | |
| # First download the index file | |
| index_path = hf_hub_download( | |
| repo_id=model_id, | |
| filename=HF_CHECKPOINT_INDEX, | |
| local_dir=local_dir, | |
| force_download=force_download, | |
| token=hf_token | |
| ) | |
| print(f"Index downloaded to: {index_path}") | |
| # Read index to get shard filenames | |
| with open(index_path, 'r') as f: | |
| index = json.load(f) | |
| # Get unique shard files | |
| shard_files = set(index['weight_map'].values()) | |
| print(f"Downloading {len(shard_files)} shard files...") | |
| # Download all shards | |
| for shard_file in sorted(shard_files): | |
| print(f" Downloading {shard_file}...") | |
| hf_hub_download( | |
| repo_id=model_id, | |
| filename=shard_file, | |
| local_dir=local_dir, | |
| force_download=force_download, | |
| token=hf_token | |
| ) | |
| print(f"All shards downloaded!") | |
| return index_path | |
| except Exception as e: | |
| print(f"Error downloading model: {e}") | |
| raise | |
| def is_huggingface_repo_id(path: str) -> bool: | |
| """ | |
| Check if a string looks like a HuggingFace repo ID (e.g., 'namespace/repo_name') | |
| NOT a local file path | |
| """ | |
| # HF repo IDs have format: namespace/repo_name (exactly one /) | |
| # Local paths typically have multiple / or start with / or . | |
| if path.startswith('/') or path.startswith('.') or path.startswith('~'): | |
| return False | |
| parts = path.split('/') | |
| # HF repo ID should have exactly 2 parts: namespace and repo_name | |
| if len(parts) == 2 and all(part and not part.startswith('.') for part in parts): | |
| return True | |
| return False | |
| def ensure_checkpoint_exists(checkpoint_path: str) -> str: | |
| """ | |
| Ensure checkpoint exists locally, download from HF Hub if not | |
| Args: | |
| checkpoint_path: Local path or HF model ID | |
| Returns: | |
| Path to the local checkpoint/index file | |
| """ | |
| # If it's a local path and exists, return it | |
| if os.path.exists(checkpoint_path): | |
| print(f"Using local checkpoint: {checkpoint_path}") | |
| return checkpoint_path | |
| # If it looks like a HuggingFace repo ID (e.g., "TSXu/Unicalli_Pro") | |
| if is_huggingface_repo_id(checkpoint_path): | |
| print(f"Downloading from HuggingFace Hub: {checkpoint_path}") | |
| return download_sharded_safetensors(model_id=checkpoint_path) | |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") | |
| def convert_to_pinyin(text): | |
| return ' '.join([item[0] if isinstance(item, list) else item for item in lazy_pinyin(text)]) | |
| class CalligraphyGenerator: | |
| """ | |
| Chinese Calligraphy Generator using Flux model | |
| Attributes: | |
| device: torch device for computation | |
| model_name: name of the flux model (flux-dev or flux-schnell) | |
| font_styles: available font styles for generation | |
| authors: available calligrapher authors | |
| """ | |
| def __init__( | |
| self, | |
| model_name: str = "flux-dev", | |
| device: str = "cuda", | |
| offload: bool = True, | |
| checkpoint_path: Optional[str] = None, | |
| intern_vlm_path: Optional[str] = None, | |
| ref_latent_path: Optional[str] = None, | |
| font_descriptions_path: str = "chirography.json", | |
| author_descriptions_path: str = "calligraphy_styles_en.json", | |
| use_deepspeed: bool = False, | |
| use_4bit_quantization: bool = False, | |
| use_float8_quantization: bool = False, | |
| use_torch_compile: bool = False, | |
| compile_mode: str = "reduce-overhead", | |
| deepspeed_config: Optional[str] = None, | |
| dtype: Optional[str] = None, | |
| preloaded_embedding: Optional[torch.nn.Module] = None, | |
| preloaded_tokenizer: Optional[Any] = None, | |
| ): | |
| """ | |
| Initialize the calligraphy generator | |
| Args: | |
| model_name: flux model name (flux-dev or flux-schnell) | |
| device: device for computation | |
| offload: whether to offload model to CPU when not in use | |
| checkpoint_path: path to model checkpoint if using fine-tuned model | |
| intern_vlm_path: path to InternVLM model for text embedding | |
| ref_latent_path: path to reference latents for recognition mode | |
| font_descriptions_path: path to font style descriptions JSON | |
| author_descriptions_path: path to author style descriptions JSON | |
| use_deepspeed: whether to use DeepSpeed ZeRO for memory optimization | |
| use_4bit_quantization: whether to use 4-bit quantization (quanto/bitsandbytes) | |
| use_float8_quantization: whether to use Float8 quantization (torchao) for faster inference | |
| use_torch_compile: whether to use torch.compile for optimized inference | |
| compile_mode: torch.compile mode - "reduce-overhead", "max-autotune", or "default" | |
| deepspeed_config: path to DeepSpeed config JSON file | |
| dtype: force specific dtype for inference: "fp16", "bf16", "fp32", or None for auto | |
| """ | |
| self.device = torch.device(device) | |
| self.model_name = model_name | |
| self.offload = offload | |
| self.is_schnell = model_name == "flux-schnell" | |
| self.use_deepspeed = use_deepspeed | |
| self.deepspeed_config = deepspeed_config | |
| self.use_4bit_quantization = use_4bit_quantization | |
| self.use_float8_quantization = use_float8_quantization | |
| self.use_torch_compile = use_torch_compile | |
| self.compile_mode = compile_mode | |
| self.forced_dtype = dtype # "fp16", "bf16", "fp32", or None for auto | |
| # Load font and author style descriptions | |
| if os.path.exists(font_descriptions_path): | |
| with open(font_descriptions_path, 'r', encoding='utf-8') as f: | |
| self.font_style_des = json.load(f) | |
| else: | |
| raise FileNotFoundError(f"Font descriptions file not found: {font_descriptions_path}") | |
| if os.path.exists(author_descriptions_path): | |
| with open(author_descriptions_path, 'r', encoding='utf-8') as f: | |
| self.author_style = json.load(f) | |
| else: | |
| raise FileNotFoundError(f"Author descriptions file not found: {author_descriptions_path}") | |
| # Load models | |
| print("Loading models...") | |
| # When using DeepSpeed, load text encoders on CPU first to save memory during initialization | |
| # They will be moved to GPU after DeepSpeed initializes the main model | |
| if self.use_deepspeed: | |
| text_encoder_device = "cpu" | |
| elif offload: | |
| text_encoder_device = "cpu" # Will be moved to GPU during inference | |
| else: | |
| text_encoder_device = self.device | |
| self.t5 = load_t5(text_encoder_device, max_length=256 if self.is_schnell else 512) | |
| self.clip = load_clip(text_encoder_device) | |
| self.clip.requires_grad_(False) | |
| # Ensure checkpoint exists (download from HF Hub if needed) | |
| if checkpoint_path: | |
| checkpoint_path = ensure_checkpoint_exists(checkpoint_path) | |
| print(f"Loading model from checkpoint: {checkpoint_path}") | |
| # When using DeepSpeed, don't move to GPU yet - let DeepSpeed handle it | |
| self.model = self._load_model_from_checkpoint( | |
| checkpoint_path, model_name, | |
| offload=offload, | |
| use_deepspeed=self.use_deepspeed | |
| ) | |
| # Initialize DeepSpeed if requested | |
| if self.use_deepspeed: | |
| self.model = self._init_deepspeed(self.model) | |
| else: | |
| # If no checkpoint path provided, download default from HF Hub | |
| print("No checkpoint path provided, downloading from HuggingFace Hub...") | |
| checkpoint_path = download_model_from_hf() | |
| print(f"Loading model from checkpoint: {checkpoint_path}") | |
| self.model = self._load_model_from_checkpoint( | |
| checkpoint_path, model_name, | |
| offload=offload, | |
| use_deepspeed=self.use_deepspeed | |
| ) | |
| if self.use_deepspeed: | |
| self.model = self._init_deepspeed(self.model) | |
| # Note: Float8 quantization and torch.compile optimizations | |
| # are applied externally (e.g., in app.py) for better control | |
| # over the optimization process with ZeroGPU AOT compilation. | |
| # Load VAE | |
| if self.use_deepspeed or offload: | |
| vae_device = "cpu" | |
| else: | |
| vae_device = self.device | |
| self.vae = load_ae(model_name, device=vae_device) | |
| # Move VAE to GPU only if offload (not DeepSpeed) | |
| if offload and not self.use_deepspeed: | |
| self.vae = self.vae.to(self.device) | |
| # After DeepSpeed init, move text encoders to GPU | |
| if self.use_deepspeed: | |
| print("Moving text encoders to GPU...") | |
| self.t5 = self.t5.to(self.device) | |
| self.clip = self.clip.to(self.device) | |
| self.vae = self.vae.to(self.device) | |
| # Load reference latents if provided | |
| self.ref_latent = None | |
| if ref_latent_path and os.path.exists(ref_latent_path): | |
| print(f"Loading reference latents from {ref_latent_path}") | |
| self.ref_latent = torch.load(ref_latent_path, map_location='cpu') | |
| # Create sampler (use preloaded embedding if available) | |
| self.sampler = XFluxSampler( | |
| clip=self.clip, | |
| t5=self.t5, | |
| ae=self.vae, | |
| ref_latent=self.ref_latent, | |
| model=self.model, | |
| device=self.device, | |
| intern_vlm_path=intern_vlm_path, | |
| preloaded_embedding=preloaded_embedding, | |
| preloaded_tokenizer=preloaded_tokenizer, | |
| ) | |
| # Font for generating condition images | |
| project_root = os.path.dirname(os.path.abspath(__file__)) | |
| local_font_path = os.path.join(project_root, "FangZhengKaiTiFanTi-1.ttf") | |
| self.font_path = self._ensure_font_exists(local_font_path) | |
| self.default_font_size = 102 # 128 * 0.8 | |
| def _ensure_font_exists(self, font_path: str) -> str: | |
| """ | |
| Ensure font file exists locally, download from HF Hub if not | |
| Args: | |
| font_path: Local path to font file | |
| Returns: | |
| Path to the local font file | |
| """ | |
| cached_font_path = os.environ.get("UNICALLI_FONT_PATH") | |
| if cached_font_path and os.path.exists(cached_font_path): | |
| return cached_font_path | |
| if os.path.exists(font_path): | |
| return font_path | |
| # Try to download from HF Hub | |
| print(f"Font file not found locally, downloading from HuggingFace Hub...") | |
| hf_token = os.environ.get("HF_TOKEN", None) | |
| try: | |
| font_path = hf_hub_download( | |
| repo_id=HF_MODEL_ID, | |
| filename="FangZhengKaiTiFanTi-1.ttf", | |
| token=hf_token | |
| ) | |
| print(f"Font downloaded to: {font_path}") | |
| return font_path | |
| except Exception as e: | |
| print(f"Warning: Could not download font: {e}") | |
| return font_path # Return original path, may fail later | |
| def _load_model_from_checkpoint(self, checkpoint_path: str, model_name: str, offload: bool, use_deepspeed: bool = False): | |
| """ | |
| Load model from checkpoint without loading flux pretrained weights. | |
| Supports both regular checkpoints and NF4 quantized checkpoints. | |
| Args: | |
| checkpoint_path: Path to your checkpoint file or NF4 model directory | |
| model_name: flux model name (for config) | |
| offload: whether to offload to CPU | |
| use_deepspeed: whether using DeepSpeed (keeps model on CPU) | |
| Returns: | |
| model with loaded checkpoint | |
| """ | |
| print(f"Creating empty flux model structure...") | |
| load_device = "cpu" | |
| # Create model structure without loading pretrained weights (using "meta" device) | |
| with torch.device("meta"): | |
| model = Flux(configs[model_name].params) | |
| # Initialize module embeddings (must be done before loading checkpoint) | |
| print("Initializing module embeddings...") | |
| model.init_module_embeddings(tokens_num=320, cond_txt_channel=896) | |
| # Move model to loading device | |
| print(f"Moving model to {load_device} for loading...") | |
| model = model.to_empty(device=load_device) | |
| # Check if this is an NF4 quantized model | |
| is_nf4 = self._is_nf4_checkpoint(checkpoint_path) | |
| # Load checkpoint | |
| print(f"Loading checkpoint from {checkpoint_path}") | |
| if is_nf4: | |
| print("Detected NF4 quantized model, dequantizing...") | |
| checkpoint = self._load_nf4_checkpoint(checkpoint_path) | |
| else: | |
| checkpoint = self._load_checkpoint_file(checkpoint_path) | |
| # Determine dtype from checkpoint - keep original dtype for efficiency | |
| first_tensor = next(iter(checkpoint.values())) | |
| checkpoint_dtype = first_tensor.dtype | |
| print(f"Checkpoint dtype: {checkpoint_dtype}") | |
| # Check if user forced a specific dtype | |
| forced_dtype = getattr(self, 'forced_dtype', None) | |
| if forced_dtype: | |
| dtype_map = { | |
| "fp16": torch.float16, | |
| "bf16": torch.bfloat16, | |
| "fp32": torch.float32, | |
| "fp8": torch.float8_e4m3fn, | |
| } | |
| if forced_dtype not in dtype_map: | |
| print(f"Warning: Unknown dtype '{forced_dtype}', using auto selection") | |
| forced_dtype = None | |
| else: | |
| target_dtype = dtype_map[forced_dtype] | |
| print(f"Using forced dtype: {target_dtype}") | |
| if checkpoint_dtype != target_dtype: | |
| print(f"Converting checkpoint from {checkpoint_dtype} to {target_dtype}...") | |
| checkpoint = {k: v.to(target_dtype) for k, v in checkpoint.items()} | |
| if not forced_dtype: | |
| # Note: We trust the original precision (like FP8) if it is provided that way | |
| target_dtype = checkpoint_dtype | |
| print(f"Using auto-detected checkpoint dtype: {target_dtype} for inference loading") | |
| # Load weights into model | |
| model.load_state_dict(checkpoint, strict=False, assign=True) | |
| print(f"Model dtype after loading: {next(model.parameters()).dtype}") | |
| # Store target dtype for inference | |
| self._model_dtype = target_dtype | |
| # Free checkpoint memory | |
| del checkpoint | |
| # Apply bitsandbytes 4-bit quantization if requested | |
| if hasattr(self, 'use_4bit_quantization') and self.use_4bit_quantization: | |
| try: | |
| import bitsandbytes as bnb | |
| print("Applying bitsandbytes NF4 quantization for 4-bit inference...") | |
| model = self._quantize_model_bnb(model) | |
| model._is_quantized = True | |
| print("bitsandbytes NF4 quantization complete!") | |
| except ImportError: | |
| print("bitsandbytes not available, using quanto quantization...") | |
| model = model.float() | |
| quantize(model, weights=qint4) | |
| freeze(model) | |
| model._is_quantized = True | |
| print("quanto 4-bit quantization complete!") | |
| # Move to GPU only if NOT using DeepSpeed | |
| if not use_deepspeed: | |
| if self.device.type != "cpu": | |
| print(f"Moving model to {self.device}...") | |
| model = model.to(self.device) | |
| # Enable optimized attention backends | |
| try: | |
| torch.backends.cuda.enable_flash_sdp(True) | |
| torch.backends.cuda.enable_mem_efficient_sdp(True) | |
| torch.backends.cuda.enable_math_sdp(False) | |
| print("Enabled FlashAttention / Memory-Efficient SDPA backends") | |
| except Exception as e: | |
| print(f"Could not configure SDPA backends: {e}") | |
| return model | |
| def _is_nf4_checkpoint(self, path: str) -> bool: | |
| """Check if path contains an NF4 quantized checkpoint""" | |
| if os.path.isdir(path): | |
| return os.path.exists(os.path.join(path, "quantization_config.json")) | |
| return False | |
| def _load_nf4_checkpoint(self, checkpoint_dir: str) -> dict: | |
| """ | |
| Load NF4 quantized checkpoint and dequantize to float tensors. | |
| Args: | |
| checkpoint_dir: Directory containing NF4 model files | |
| Returns: | |
| Dequantized state dict | |
| """ | |
| from safetensors.torch import load_file as load_safetensors | |
| # Load quantization config | |
| config_path = os.path.join(checkpoint_dir, "quantization_config.json") | |
| with open(config_path, 'r') as f: | |
| quant_config = json.load(f) | |
| block_size = quant_config.get("block_size", 64) | |
| quantized_keys = set(quant_config.get("quantized_keys", [])) | |
| # Load index | |
| index_path = os.path.join(checkpoint_dir, "model_nf4.safetensors.index.json") | |
| with open(index_path, 'r') as f: | |
| index = json.load(f) | |
| # Load all shards | |
| shard_files = sorted(set(index['weight_map'].values())) | |
| print(f"Loading {len(shard_files)} NF4 shards...") | |
| raw_state = {} | |
| for shard_file in shard_files: | |
| shard_path = os.path.join(checkpoint_dir, shard_file) | |
| print(f" Loading {shard_file}...") | |
| shard_data = load_safetensors(shard_path) | |
| raw_state.update(shard_data) | |
| # NF4 lookup table for dequantization | |
| nf4_values = torch.tensor([ | |
| -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, | |
| -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, | |
| 0.07958029955625534, 0.16093020141124725, 0.24611230850220, 0.33791524171829224, | |
| 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0 | |
| ], dtype=torch.float32) | |
| # Dequantize | |
| state_dict = {} | |
| dequant_count = 0 | |
| for key in list(raw_state.keys()): | |
| if key.endswith('.quant_data'): | |
| base_key = key.replace('.quant_data', '') | |
| if base_key in quantized_keys: | |
| # Dequantize this tensor | |
| quant_data = raw_state[f"{base_key}.quant_data"] | |
| scales = raw_state[f"{base_key}.scales"] | |
| shape = raw_state[f"{base_key}.shape"].tolist() | |
| pad_len = raw_state[f"{base_key}.pad_len"].item() | |
| # Unpack 4-bit values | |
| high = (quant_data >> 4) & 0x0F | |
| low = quant_data & 0x0F | |
| indices = torch.stack([high, low], dim=-1).flatten().long() | |
| # Lookup and reshape | |
| values = nf4_values[indices] | |
| # Apply scales | |
| num_blocks = len(scales) | |
| values = values[:num_blocks * block_size].reshape(num_blocks, block_size) | |
| values = values * scales.float().unsqueeze(1) | |
| values = values.flatten() | |
| # Remove padding and reshape | |
| if pad_len > 0: | |
| values = values[:-pad_len] | |
| state_dict[base_key] = values.reshape(shape) | |
| dequant_count += 1 | |
| elif not any(key.endswith(s) for s in ['.scales', '.shape', '.block_size', '.pad_len']): | |
| # Non-quantized tensor, keep as-is | |
| state_dict[key] = raw_state[key] | |
| print(f"Dequantized {dequant_count} tensors") | |
| return state_dict | |
| def _quantize_model_bnb(self, model): | |
| """ | |
| Quantize model using bitsandbytes NF4. | |
| Replaces Linear layers with Linear4bit for true 4-bit inference. | |
| """ | |
| import bitsandbytes as bnb | |
| import torch.nn as nn | |
| def replace_linear_with_4bit(module, name=''): | |
| for child_name, child in list(module.named_children()): | |
| full_name = f"{name}.{child_name}" if name else child_name | |
| if isinstance(child, nn.Linear): | |
| # Create 4-bit linear layer | |
| new_layer = bnb.nn.Linear4bit( | |
| child.in_features, | |
| child.out_features, | |
| bias=child.bias is not None, | |
| compute_dtype=torch.bfloat16, | |
| compress_statistics=True, | |
| quant_type='nf4' | |
| ) | |
| # Copy weights (will be quantized when moved to GPU) | |
| new_layer.weight = bnb.nn.Params4bit( | |
| child.weight.data, | |
| requires_grad=False, | |
| quant_type='nf4' | |
| ) | |
| if child.bias is not None: | |
| new_layer.bias = nn.Parameter(child.bias.data) | |
| setattr(module, child_name, new_layer) | |
| else: | |
| replace_linear_with_4bit(child, full_name) | |
| print("Replacing Linear layers with Linear4bit...") | |
| replace_linear_with_4bit(model) | |
| return model | |
| def _init_deepspeed(self, model): | |
| """ | |
| Initialize DeepSpeed for the model with ZeRO-3 inference optimization. | |
| Args: | |
| model: PyTorch model to wrap with DeepSpeed | |
| Returns: | |
| DeepSpeed inference engine | |
| """ | |
| try: | |
| import deepspeed | |
| except ImportError: | |
| raise ImportError("DeepSpeed is not installed. Install it with: pip install deepspeed") | |
| # Load DeepSpeed config | |
| if self.deepspeed_config is None: | |
| self.deepspeed_config = "ds_config_zero2.json" | |
| if not os.path.exists(self.deepspeed_config): | |
| raise FileNotFoundError(f"DeepSpeed config not found: {self.deepspeed_config}") | |
| print(f"Initializing DeepSpeed Inference with config: {self.deepspeed_config}") | |
| # Initialize distributed environment for single GPU if not already initialized | |
| if not torch.distributed.is_initialized(): | |
| import random | |
| # Set environment variables for single-process mode | |
| # Use a random port to avoid conflicts | |
| port = random.randint(29500, 29600) | |
| os.environ['MASTER_ADDR'] = 'localhost' | |
| os.environ['MASTER_PORT'] = str(port) | |
| os.environ['RANK'] = '0' | |
| os.environ['LOCAL_RANK'] = '0' | |
| os.environ['WORLD_SIZE'] = '1' | |
| # Initialize process group | |
| try: | |
| torch.distributed.init_process_group( | |
| backend='nccl', | |
| init_method='env://', | |
| world_size=1, | |
| rank=0 | |
| ) | |
| print(f"Initialized single-GPU distributed environment for DeepSpeed on port {port}") | |
| except RuntimeError as e: | |
| if "address already in use" in str(e): | |
| print(f"Port {port} in use, trying again...") | |
| # Try a different port | |
| port = random.randint(29600, 29700) | |
| os.environ['MASTER_PORT'] = str(port) | |
| torch.distributed.init_process_group( | |
| backend='nccl', | |
| init_method='env://', | |
| world_size=1, | |
| rank=0 | |
| ) | |
| print(f"Initialized single-GPU distributed environment for DeepSpeed on port {port}") | |
| else: | |
| raise | |
| # Use DeepSpeed inference API instead of initialize | |
| # This doesn't require an optimizer | |
| with open(self.deepspeed_config) as f: | |
| ds_config = json.load(f) | |
| model_engine = deepspeed.init_inference( | |
| model=model, | |
| mp_size=1, # model parallel size | |
| dtype=torch.float32, # Use float32 for compatibility | |
| replace_with_kernel_inject=False, # Don't replace with DeepSpeed kernels for custom models | |
| ) | |
| print("DeepSpeed Inference initialized successfully") | |
| return model_engine | |
| def _load_checkpoint_file(self, checkpoint_path: str) -> dict: | |
| """ | |
| Load checkpoint file and extract state dict. | |
| Supports: sharded safetensors, single safetensors, .bin/.pt files | |
| Args: | |
| checkpoint_path: Path to checkpoint file or index.json | |
| Returns: | |
| state_dict: model state dictionary | |
| """ | |
| # Check if it's a sharded safetensors (index.json file) | |
| if checkpoint_path.endswith('.index.json'): | |
| print(f"Loading sharded safetensors from index: {checkpoint_path}") | |
| with open(checkpoint_path, 'r') as f: | |
| index = json.load(f) | |
| # Get the directory containing the shards | |
| shard_dir = os.path.dirname(checkpoint_path) | |
| # Get unique shard files | |
| shard_files = sorted(set(index['weight_map'].values())) | |
| print(f"Loading {len(shard_files)} shard files in parallel...") | |
| # Load shards in parallel using ThreadPoolExecutor | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| def load_shard(shard_file): | |
| shard_path = os.path.join(shard_dir, shard_file) | |
| return shard_file, load_safetensors(shard_path) | |
| state_dict = {} | |
| with ThreadPoolExecutor(max_workers=len(shard_files)) as executor: | |
| futures = {executor.submit(load_shard, sf): sf for sf in shard_files} | |
| for future in as_completed(futures): | |
| shard_file, shard_dict = future.result() | |
| print(f" Loaded {shard_file}") | |
| state_dict.update(shard_dict) | |
| print(f"Loaded {len(state_dict)} tensors from sharded safetensors") | |
| return state_dict | |
| # Check if it's a single safetensors file | |
| if checkpoint_path.endswith('.safetensors'): | |
| print(f"Loading safetensors: {checkpoint_path}") | |
| state_dict = load_safetensors(checkpoint_path) | |
| return state_dict | |
| # Check if it's a directory containing checkpoint files | |
| if os.path.isdir(checkpoint_path): | |
| # Look for index.json first (sharded safetensors) | |
| index_path = os.path.join(checkpoint_path, 'model.safetensors.index.json') | |
| if os.path.exists(index_path): | |
| return self._load_checkpoint_file(index_path) | |
| # Look for common checkpoint filenames | |
| possible_files = [ | |
| 'model.safetensors', | |
| 'model.pt', 'model.pth', 'model.bin', | |
| 'checkpoint.pt', 'checkpoint.pth', | |
| 'pytorch_model.bin', 'model_state_dict.pt' | |
| ] | |
| checkpoint_file = None | |
| for filename in possible_files: | |
| full_path = os.path.join(checkpoint_path, filename) | |
| if os.path.exists(full_path): | |
| checkpoint_file = full_path | |
| print(f"Found checkpoint file: {filename}") | |
| break | |
| if checkpoint_file is None: | |
| import glob | |
| # Try safetensors first | |
| st_files = glob.glob(os.path.join(checkpoint_path, "*.safetensors")) | |
| if st_files: | |
| checkpoint_file = st_files[0] | |
| else: | |
| pt_files = glob.glob(os.path.join(checkpoint_path, "*.pt")) + \ | |
| glob.glob(os.path.join(checkpoint_path, "*.pth")) + \ | |
| glob.glob(os.path.join(checkpoint_path, "*.bin")) | |
| if pt_files: | |
| checkpoint_file = pt_files[0] | |
| else: | |
| raise ValueError(f"No checkpoint files found in directory: {checkpoint_path}") | |
| print(f"Found checkpoint file: {os.path.basename(checkpoint_file)}") | |
| checkpoint_path = checkpoint_file | |
| # Recursively call to handle the found file | |
| return self._load_checkpoint_file(checkpoint_path) | |
| # Load .bin or .pt checkpoint | |
| print(f"Loading checkpoint file: {checkpoint_path}") | |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
| # Handle different checkpoint formats | |
| if isinstance(checkpoint, dict): | |
| if 'model' in checkpoint: | |
| state_dict = checkpoint['model'] | |
| elif 'model_state_dict' in checkpoint: | |
| state_dict = checkpoint['model_state_dict'] | |
| elif 'state_dict' in checkpoint: | |
| state_dict = checkpoint['state_dict'] | |
| else: | |
| state_dict = checkpoint | |
| if 'epoch' in checkpoint: | |
| print(f"Checkpoint from epoch: {checkpoint['epoch']}") | |
| if 'global_step' in checkpoint: | |
| print(f"Checkpoint from step: {checkpoint['global_step']}") | |
| if 'loss' in checkpoint: | |
| print(f"Checkpoint loss: {checkpoint['loss']:.4f}") | |
| else: | |
| state_dict = checkpoint | |
| # Remove 'module.' prefix if present | |
| if any(key.startswith('module.') for key in state_dict.keys()): | |
| state_dict = {key.replace('module.', ''): value | |
| for key, value in state_dict.items()} | |
| print("Removed 'module.' prefix from state dict keys") | |
| return state_dict | |
| def text_to_cond_image( | |
| self, | |
| text: str, | |
| img_size: int = 128, | |
| font_scale: float = 0.8, | |
| font_path: Optional[str] = None, | |
| fixed_chars: int = 7 | |
| ) -> Image.Image: | |
| """ | |
| Convert text to condition image - always creates image for fixed_chars characters | |
| Text is arranged from top to bottom. | |
| Args: | |
| text: Chinese text to convert (must be <= fixed_chars characters) | |
| img_size: size of each character block (default 128) | |
| font_scale: scale of font relative to image size (default 0.8) | |
| font_path: path to font file | |
| fixed_chars: fixed number of character slots (default 7) | |
| Returns: | |
| PIL Image with text rendered (always fixed_chars * img_size height) | |
| """ | |
| if len(text) > fixed_chars: | |
| raise ValueError(f"Text must be at most {fixed_chars} characters, got {len(text)}") | |
| if font_path is None: | |
| font_path = self.font_path | |
| # Create font - font size is scaled down from img_size | |
| font_size_scaled = int(font_scale * img_size) | |
| font = ImageFont.truetype(font_path, font_size_scaled) | |
| # Calculate image dimensions - always fixed_chars height | |
| img_width = img_size | |
| img_height = img_size * fixed_chars # Fixed height for 7 characters | |
| # Create white background image | |
| cond_img = Image.new("RGB", (img_width, img_height), (255, 255, 255)) | |
| cond_draw = ImageDraw.Draw(cond_img) | |
| # Draw each character from top to bottom | |
| # Note: font_size for positioning should be img_size, not the scaled font size | |
| for i, char in enumerate(text): | |
| font_space = font_size_scaled * (1 - font_scale) // 2 | |
| # Position based on img_size blocks, not scaled font size | |
| font_position = (font_space, img_size * i + font_space) | |
| cond_draw.text(font_position, char, font=font, fill=(0, 0, 0)) | |
| return cond_img | |
| def build_prompt( | |
| self, | |
| font_style: str = "楷", | |
| author: str = None, | |
| is_traditional: bool = True, | |
| ) -> str: | |
| """ | |
| Build prompt for generation following dataset.py logic | |
| Args: | |
| font_style: font style (楷/草/行) | |
| author: author name (Chinese or None for synthetic) | |
| is_traditional: whether generating traditional calligraphy | |
| Returns: | |
| formatted prompt string | |
| """ | |
| # Validate font style | |
| if font_style not in self.font_style_des: | |
| raise ValueError(f"Font style must be one of: {list(self.font_style_des.keys())}") | |
| # Convert font style to pinyin | |
| font_style_pinyin = convert_to_pinyin(font_style) | |
| # Build prompt based on traditional or synthetic | |
| if is_traditional and author and author in self.author_style: | |
| # Traditional calligraphy with specific author | |
| prompt = f"Traditional Chinese calligraphy works, background: black, font: {font_style_pinyin}, " | |
| prompt += self.font_style_des[font_style] | |
| author_info = self.author_style[author] | |
| prompt += f" author: {author_info}" | |
| else: | |
| # Synthetic calligraphy | |
| prompt = f"Synthetic calligraphy data, background: black, font: {font_style_pinyin}, " | |
| prompt += self.font_style_des[font_style] | |
| return prompt | |
| def generate( | |
| self, | |
| text: str, | |
| font_style: str = "楷", | |
| author: str = None, | |
| width: int = 128, | |
| height: int = None, # Fixed to 7 characters height | |
| num_steps: int = 50, | |
| guidance: float = 3.5, | |
| seed: int = None, | |
| is_traditional: bool = None, | |
| save_path: Optional[str] = None | |
| ) -> tuple[Image.Image, Image.Image]: | |
| """ | |
| Generate calligraphy image from text | |
| Args: | |
| text: Chinese text to generate (1-7 characters) | |
| font_style: font style (楷/草/行) | |
| author: author/calligrapher name from the style list | |
| width: image width (default 128) | |
| height: image height (fixed to 7 * width) | |
| num_steps: number of denoising steps | |
| guidance: guidance scale | |
| seed: random seed for generation | |
| is_traditional: whether generating traditional calligraphy (auto-determined if None) | |
| save_path: optional path to save the generated image | |
| Returns: | |
| tuple of (generated_image, condition_image) | |
| """ | |
| # Fixed number of characters | |
| FIXED_CHARS = 7 | |
| # Validate text - must have 1-7 characters | |
| if len(text) < 1: | |
| raise ValueError(f"Text must have at least 1 character, got empty string") | |
| if len(text) > FIXED_CHARS: | |
| raise ValueError(f"Text must have at most {FIXED_CHARS} characters, got {len(text)}") | |
| if seed is None: | |
| seed = torch.randint(0, 2**32, (1,)).item() | |
| # Fixed height for 7 characters | |
| num_chars = len(text) | |
| height = width * FIXED_CHARS # Always 7 characters height | |
| # Auto-determine traditional vs synthetic | |
| if is_traditional is None: | |
| is_traditional = author is not None and author in self.author_style | |
| # Generate condition image (fixed size for 7 characters) | |
| cond_img = self.text_to_cond_image(text, img_size=width, fixed_chars=FIXED_CHARS) | |
| # Build prompt | |
| prompt = self.build_prompt( | |
| font_style=font_style, | |
| author=author, | |
| is_traditional=is_traditional, | |
| ) | |
| print(f"Generating with prompt: {prompt}") | |
| print(f"Text: {text} ({num_chars} chars), Seed: {seed}") | |
| # Generate image | |
| result_img, recognized_text = self.sampler( | |
| prompt=prompt, | |
| width=width, | |
| height=height, | |
| num_steps=num_steps, | |
| controlnet_image=cond_img, | |
| is_generation=True, | |
| cond_text=text, | |
| required_chars=FIXED_CHARS, # Always 7 characters | |
| seed=seed | |
| ) | |
| # Crop to actual text length if less than FIXED_CHARS | |
| if num_chars < FIXED_CHARS: | |
| actual_height = width * num_chars | |
| # Crop result image (top portion only) | |
| result_img = result_img.crop((0, 0, width, actual_height)) | |
| # Crop condition image as well | |
| cond_img = cond_img.crop((0, 0, width, actual_height)) | |
| # Save if path provided | |
| if save_path: | |
| os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) | |
| result_img.save(save_path) | |
| print(f"Image saved to {save_path}") | |
| return result_img, cond_img | |
| def batch_generate( | |
| self, | |
| texts: List[str], | |
| font_styles: Optional[List[str]] = None, | |
| authors: Optional[List[str]] = None, | |
| output_dir: str = "./outputs", | |
| **kwargs | |
| ) -> List[tuple[Image.Image, Image.Image]]: | |
| """ | |
| Batch generate calligraphy images | |
| Args: | |
| texts: list of texts to generate (1-7 characters each) | |
| font_styles: list of font styles (if None, use default) | |
| authors: list of authors (if None, use synthetic) | |
| output_dir: directory to save outputs | |
| **kwargs: additional arguments for generate() | |
| Returns: | |
| list of (generated_image, condition_image) tuples | |
| """ | |
| os.makedirs(output_dir, exist_ok=True) | |
| results = [] | |
| # Default styles and authors if not provided | |
| if font_styles is None: | |
| font_styles = ["楷"] * len(texts) | |
| if authors is None: | |
| authors = [None] * len(texts) | |
| for i, (text, font, author) in enumerate(zip(texts, font_styles, authors)): | |
| # Clean author name for filename | |
| author_name = author if author else "synthetic" | |
| if author and author in self.author_style: | |
| author_name = convert_to_pinyin(author) | |
| save_path = os.path.join( | |
| output_dir, | |
| f"{text}_{font}_{author_name}_{i}.png" | |
| ) | |
| result_img, cond_img = self.generate( | |
| text=text, | |
| font_style=font, | |
| author=author, | |
| save_path=save_path, | |
| **kwargs | |
| ) | |
| results.append((result_img, cond_img)) | |
| return results | |
| def get_available_authors(self) -> List[str]: | |
| """Get list of available author styles""" | |
| return list(self.author_style.keys()) | |
| def get_available_fonts(self) -> List[str]: | |
| """Get list of available font styles""" | |
| return list(self.font_style_des.keys()) | |
| # Hugging Face Pipeline wrapper | |
| class FluxCalligraphyPipeline: | |
| """Hugging Face compatible pipeline for calligraphy generation""" | |
| def __init__( | |
| self, | |
| model_name: str = "flux-dev", | |
| device: str = "cuda", | |
| checkpoint_path: Optional[str] = None, | |
| **kwargs | |
| ): | |
| """Initialize the pipeline""" | |
| self.generator = CalligraphyGenerator( | |
| model_name=model_name, | |
| device=device, | |
| checkpoint_path=checkpoint_path, | |
| **kwargs | |
| ) | |
| def __call__( | |
| self, | |
| text: Union[str, List[str]], | |
| font_style: Union[str, List[str]] = "楷", | |
| author: Union[str, List[str]] = None, | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 3.5, | |
| generator: Optional[torch.Generator] = None, | |
| **kwargs | |
| ) -> Union[Image.Image, List[Image.Image]]: | |
| """ | |
| Generate calligraphy images | |
| Args: | |
| text: text or list of texts to generate (1-7 characters each) | |
| font_style: font style(s) (楷/草/行) | |
| author: author name(s) from the style list | |
| num_inference_steps: number of denoising steps | |
| guidance_scale: guidance scale for generation | |
| generator: torch generator for reproducibility | |
| Returns: | |
| generated image(s) | |
| """ | |
| # Handle single text | |
| if isinstance(text, str): | |
| seed = None | |
| if generator is not None: | |
| seed = generator.initial_seed() | |
| result, _ = self.generator.generate( | |
| text=text, | |
| font_style=font_style, | |
| author=author, | |
| num_steps=num_inference_steps, | |
| guidance=guidance_scale, | |
| seed=seed, | |
| **kwargs | |
| ) | |
| return result | |
| # Handle batch | |
| else: | |
| if isinstance(font_style, str): | |
| font_style = [font_style] * len(text) | |
| if isinstance(author, str) or author is None: | |
| author = [author] * len(text) | |
| results = [] | |
| for t, f, a in zip(text, font_style, author): | |
| seed = None | |
| if generator is not None: | |
| seed = generator.initial_seed() | |
| result, _ = self.generator.generate( | |
| text=t, | |
| font_style=f, | |
| author=a, | |
| num_steps=num_inference_steps, | |
| guidance=guidance_scale, | |
| seed=seed, | |
| **kwargs | |
| ) | |
| results.append(result) | |
| return results | |
| if __name__ == "__main__": | |
| # Example usage | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Generate Chinese calligraphy") | |
| parser.add_argument("--text", type=str, default="暴富且平安", help="Text to generate (1-7 characters)") | |
| parser.add_argument("--font", type=str, default="楷", help="Font style (楷/草/行)") | |
| parser.add_argument("--author", type=str, default=None, help="Author/calligrapher name") | |
| parser.add_argument("--steps", type=int, default=50, help="Number of inference steps") | |
| parser.add_argument("--seed", type=int, default=None, help="Random seed") | |
| parser.add_argument("--output", type=str, default="output.png", help="Output path") | |
| parser.add_argument("--device", type=str, default="cuda", help="Device to use") | |
| parser.add_argument("--checkpoint", type=str, default=None, help="Checkpoint path") | |
| parser.add_argument("--list-authors", action="store_true", help="List available authors") | |
| parser.add_argument("--list-fonts", action="store_true", help="List available font styles") | |
| parser.add_argument("--float8", action="store_true", help="Use Float8 quantization (torchao) for faster inference") | |
| parser.add_argument("--compile", action="store_true", help="Use torch.compile for optimized inference") | |
| parser.add_argument("--compile-mode", type=str, default="max-autotune", | |
| choices=["reduce-overhead", "max-autotune", "default"], | |
| help="torch.compile mode") | |
| args = parser.parse_args() | |
| # Initialize generator | |
| generator = CalligraphyGenerator( | |
| model_name="flux-dev", | |
| device=args.device, | |
| checkpoint_path=args.checkpoint, | |
| ) | |
| # Apply optimizations if requested (CLI mode) | |
| if args.float8 or args.compile: | |
| from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig | |
| import torch._inductor.config as inductor_config | |
| # Inductor configs from FLUX-Kontext-fp8 | |
| inductor_config.conv_1x1_as_mm = True | |
| inductor_config.coordinate_descent_tuning = True | |
| inductor_config.coordinate_descent_check_all_directions = True | |
| inductor_config.max_autotune = True | |
| if args.float8: | |
| print("Applying Float8 quantization...") | |
| quantize_(generator.model, Float8DynamicActivationFloat8WeightConfig()) | |
| print("✓ Float8 quantization complete!") | |
| if args.compile: | |
| print(f"Applying torch.compile (mode={args.compile_mode})...") | |
| generator.model = torch.compile( | |
| generator.model, | |
| mode=args.compile_mode, | |
| backend="inductor", | |
| dynamic=True, | |
| ) | |
| print("✓ torch.compile applied!") | |
| # List available options | |
| if args.list_authors: | |
| print("Available authors:") | |
| for author in generator.get_available_authors()[:20]: # Show first 20 | |
| print(f" - {author}") | |
| print(f" ... and {len(generator.get_available_authors()) - 20} more") | |
| exit(0) | |
| if args.list_fonts: | |
| print("Available font styles:") | |
| for font in generator.get_available_fonts(): | |
| print(f" - {font}: {generator.font_style_des[font]}") | |
| exit(0) | |
| # Validate text - must have 1-7 characters | |
| if len(args.text) < 1: | |
| print(f"Error: Text must have at least 1 character") | |
| exit(1) | |
| if len(args.text) > 7: | |
| print(f"Error: Text must have at most 7 characters, got {len(args.text)}") | |
| exit(1) | |
| # Generate | |
| result_img, cond_img = generator.generate( | |
| text=args.text, | |
| font_style=args.font, | |
| author=args.author, | |
| num_steps=args.steps, | |
| seed=args.seed, | |
| save_path=args.output | |
| ) | |
| print(f"Generation complete! Saved to {args.output}") |