"""Model loading utilities for Z-Image components.""" import json import os from pathlib import Path import sys from typing import Optional, Union from loguru import logger from safetensors.torch import load_file import torch from transformers import AutoModel, AutoTokenizer from config import ( DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS, DEFAULT_SCHEDULER_SHIFT, DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING, DEFAULT_TRANSFORMER_CAP_FEAT_DIM, DEFAULT_TRANSFORMER_DIM, DEFAULT_TRANSFORMER_F_PATCH_SIZE, DEFAULT_TRANSFORMER_IN_CHANNELS, DEFAULT_TRANSFORMER_N_HEADS, DEFAULT_TRANSFORMER_N_KV_HEADS, DEFAULT_TRANSFORMER_N_LAYERS, DEFAULT_TRANSFORMER_N_REFINER_LAYERS, DEFAULT_TRANSFORMER_NORM_EPS, DEFAULT_TRANSFORMER_PATCH_SIZE, DEFAULT_TRANSFORMER_QK_NORM, DEFAULT_TRANSFORMER_T_SCALE, DEFAULT_VAE_IN_CHANNELS, DEFAULT_VAE_LATENT_CHANNELS, DEFAULT_VAE_NORM_NUM_GROUPS, DEFAULT_VAE_OUT_CHANNELS, DEFAULT_VAE_SCALING_FACTOR, ROPE_AXES_DIMS, ROPE_AXES_LENS, ROPE_THETA, ) from zimage.autoencoder import AutoencoderKL as LocalAutoencoderKL from zimage.scheduler import FlowMatchEulerDiscreteScheduler DIFFUSERS_AVAILABLE = False def load_config(config_path: str) -> dict: with open(config_path, "r") as f: return json.load(f) def load_sharded_safetensors(weight_dir: Path, device: str = "cuda", dtype: Optional[torch.dtype] = None) -> dict: """Load sharded safetensors from a directory.""" weight_dir = Path(weight_dir) index_files = list(weight_dir.glob("*.safetensors.index.json")) state_dict = {} if index_files: # Load sharded weights with open(index_files[0], "r") as f: index = json.load(f) weight_map = index.get("weight_map", {}) shard_files = set(weight_map.values()) for shard_file in shard_files: shard_path = weight_dir / shard_file shard_state = load_file(str(shard_path), device=str(device)) state_dict.update(shard_state) else: # Load single safetensors file safetensors_files = list(weight_dir.glob("*.safetensors")) if not safetensors_files: raise FileNotFoundError(f"No safetensors files found in {weight_dir}") state_dict = load_file(str(safetensors_files[0]), device=str(device)) # Cast to target dtype if specified if dtype is not None: state_dict = {k: v.to(dtype) if v.dtype != dtype else v for k, v in state_dict.items()} return state_dict def load_from_local_dir( model_dir: Union[str, Path], device: str = "cuda", dtype: torch.dtype = torch.bfloat16, verbose: bool = False, compile: bool = False, ) -> dict: """ Load all Z-Image components from local directory. Args: model_dir: Path to model directory device: Device to load models on dtype: Data type for model weights verbose: Whether to display loading logs compile: Whether to compile transformer and vae with torch.compile Returns: Dictionary containing transformer, vae, text_encoder, tokenizer, and scheduler """ model_dir = Path(model_dir) sys.path.insert(0, str(model_dir.parent.parent / "Z-Image" / "src")) from zimage.transformer import ZImageTransformer2DModel if verbose: logger.info(f"Loading Z-Image from: {model_dir}") # DiT if verbose: logger.info("Loading DiT...") transformer_dir = model_dir / "transformer" config = load_config(str(transformer_dir / "config.json")) with torch.device("meta"): transformer = ZImageTransformer2DModel( all_patch_size=tuple(config.get("all_patch_size", DEFAULT_TRANSFORMER_PATCH_SIZE)), all_f_patch_size=tuple(config.get("all_f_patch_size", DEFAULT_TRANSFORMER_F_PATCH_SIZE)), in_channels=config.get("in_channels", DEFAULT_TRANSFORMER_IN_CHANNELS), dim=config.get("dim", DEFAULT_TRANSFORMER_DIM), n_layers=config.get("n_layers", DEFAULT_TRANSFORMER_N_LAYERS), n_refiner_layers=config.get("n_refiner_layers", DEFAULT_TRANSFORMER_N_REFINER_LAYERS), n_heads=config.get("n_heads", DEFAULT_TRANSFORMER_N_HEADS), n_kv_heads=config.get("n_kv_heads", DEFAULT_TRANSFORMER_N_KV_HEADS), norm_eps=config.get("norm_eps", DEFAULT_TRANSFORMER_NORM_EPS), qk_norm=config.get("qk_norm", DEFAULT_TRANSFORMER_QK_NORM), cap_feat_dim=config.get("cap_feat_dim", DEFAULT_TRANSFORMER_CAP_FEAT_DIM), rope_theta=config.get("rope_theta", ROPE_THETA), t_scale=config.get("t_scale", DEFAULT_TRANSFORMER_T_SCALE), axes_dims=config.get("axes_dims", ROPE_AXES_DIMS), axes_lens=config.get("axes_lens", ROPE_AXES_LENS), ).to(dtype) # DiT (weights to CPU then move to GPU to optimize memory) state_dict = load_sharded_safetensors(transformer_dir, device="cpu", dtype=dtype) transformer.load_state_dict(state_dict, strict=False, assign=True) del state_dict if verbose: logger.info("Moving DiT to GPU...") transformer = transformer.to(device) if torch.cuda.is_available(): torch.cuda.empty_cache() transformer.eval() # VAE if verbose: logger.info("Loading VAE...") vae_dir = model_dir / "vae" vae_config = load_config(str(vae_dir / "config.json")) vae = LocalAutoencoderKL( in_channels=vae_config.get("in_channels", DEFAULT_VAE_IN_CHANNELS), out_channels=vae_config.get("out_channels", DEFAULT_VAE_OUT_CHANNELS), down_block_types=tuple(vae_config.get("down_block_types", ("DownEncoderBlock2D",))), up_block_types=tuple(vae_config.get("up_block_types", ("UpDecoderBlock2D",))), block_out_channels=tuple(vae_config.get("block_out_channels", (64,))), layers_per_block=vae_config.get("layers_per_block", 1), latent_channels=vae_config.get("latent_channels", DEFAULT_VAE_LATENT_CHANNELS), norm_num_groups=vae_config.get("norm_num_groups", DEFAULT_VAE_NORM_NUM_GROUPS), scaling_factor=vae_config.get("scaling_factor", DEFAULT_VAE_SCALING_FACTOR), shift_factor=vae_config.get("shift_factor", None), use_quant_conv=vae_config.get("use_quant_conv", True), use_post_quant_conv=vae_config.get("use_post_quant_conv", True), mid_block_add_attention=vae_config.get("mid_block_add_attention", True), ) # VAE (fp32 for better precision) vae_state_dict = load_sharded_safetensors(vae_dir, device="cpu") vae.load_state_dict(vae_state_dict, strict=False) del vae_state_dict vae.to(device=device, dtype=torch.float32) vae.eval() torch.cuda.empty_cache() # Text Encoder if verbose: logger.info("Loading Text Encoder...") text_encoder_dir = model_dir / "text_encoder" text_encoder = AutoModel.from_pretrained( str(text_encoder_dir), # torch_dtype=dtype, # some version use this dtype=dtype, trust_remote_code=True, ) text_encoder.to(device) text_encoder.eval() # Tokenizer os.environ["TOKENIZERS_PARALLELISM"] = "false" if verbose: logger.info("Loading Tokenizer...") tokenizer_dir = model_dir / "tokenizer" tokenizer = AutoTokenizer.from_pretrained( str(tokenizer_dir) if tokenizer_dir.exists() else str(text_encoder_dir), trust_remote_code=True, ) # Scheduler if verbose: logger.info("Loading Scheduler...") scheduler_dir = model_dir / "scheduler" scheduler_config = load_config(str(scheduler_dir / "scheduler_config.json")) scheduler = FlowMatchEulerDiscreteScheduler( num_train_timesteps=scheduler_config.get("num_train_timesteps", DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS), shift=scheduler_config.get("shift", DEFAULT_SCHEDULER_SHIFT), use_dynamic_shifting=scheduler_config.get("use_dynamic_shifting", DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING), ) if compile: if verbose: logger.info("Compiling DiT and VAE...") transformer = torch.compile(transformer) vae = torch.compile(vae) if verbose: logger.success("All components loaded successfully") return { "transformer": transformer, "vae": vae, "text_encoder": text_encoder, "tokenizer": tokenizer, "scheduler": scheduler, }