| """Model loading utilities for Z-Image components."""
|
|
|
| import json
|
| import os
|
| from pathlib import Path
|
| import sys
|
| from typing import Optional, Union
|
|
|
| from loguru import logger
|
| from safetensors.torch import load_file
|
| import torch
|
| from transformers import AutoModel, AutoTokenizer
|
|
|
| from config import (
|
| DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS,
|
| DEFAULT_SCHEDULER_SHIFT,
|
| DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING,
|
| DEFAULT_TRANSFORMER_CAP_FEAT_DIM,
|
| DEFAULT_TRANSFORMER_DIM,
|
| DEFAULT_TRANSFORMER_F_PATCH_SIZE,
|
| DEFAULT_TRANSFORMER_IN_CHANNELS,
|
| DEFAULT_TRANSFORMER_N_HEADS,
|
| DEFAULT_TRANSFORMER_N_KV_HEADS,
|
| DEFAULT_TRANSFORMER_N_LAYERS,
|
| DEFAULT_TRANSFORMER_N_REFINER_LAYERS,
|
| DEFAULT_TRANSFORMER_NORM_EPS,
|
| DEFAULT_TRANSFORMER_PATCH_SIZE,
|
| DEFAULT_TRANSFORMER_QK_NORM,
|
| DEFAULT_TRANSFORMER_T_SCALE,
|
| DEFAULT_VAE_IN_CHANNELS,
|
| DEFAULT_VAE_LATENT_CHANNELS,
|
| DEFAULT_VAE_NORM_NUM_GROUPS,
|
| DEFAULT_VAE_OUT_CHANNELS,
|
| DEFAULT_VAE_SCALING_FACTOR,
|
| ROPE_AXES_DIMS,
|
| ROPE_AXES_LENS,
|
| ROPE_THETA,
|
| )
|
| from zimage.autoencoder import AutoencoderKL as LocalAutoencoderKL
|
| from zimage.scheduler import FlowMatchEulerDiscreteScheduler
|
|
|
| DIFFUSERS_AVAILABLE = False
|
|
|
|
|
| def load_config(config_path: str) -> dict:
|
| with open(config_path, "r") as f:
|
| return json.load(f)
|
|
|
|
|
| def load_sharded_safetensors(weight_dir: Path, device: str = "cuda", dtype: Optional[torch.dtype] = None) -> dict:
|
| """Load sharded safetensors from a directory."""
|
| weight_dir = Path(weight_dir)
|
| index_files = list(weight_dir.glob("*.safetensors.index.json"))
|
|
|
| state_dict = {}
|
| if index_files:
|
|
|
| with open(index_files[0], "r") as f:
|
| index = json.load(f)
|
| weight_map = index.get("weight_map", {})
|
| shard_files = set(weight_map.values())
|
| for shard_file in shard_files:
|
| shard_path = weight_dir / shard_file
|
| shard_state = load_file(str(shard_path), device=str(device))
|
| state_dict.update(shard_state)
|
| else:
|
|
|
| safetensors_files = list(weight_dir.glob("*.safetensors"))
|
| if not safetensors_files:
|
| raise FileNotFoundError(f"No safetensors files found in {weight_dir}")
|
| state_dict = load_file(str(safetensors_files[0]), device=str(device))
|
|
|
|
|
| if dtype is not None:
|
| state_dict = {k: v.to(dtype) if v.dtype != dtype else v for k, v in state_dict.items()}
|
|
|
| return state_dict
|
|
|
|
|
| def load_from_local_dir(
|
| model_dir: Union[str, Path],
|
| device: str = "cuda",
|
| dtype: torch.dtype = torch.bfloat16,
|
| verbose: bool = False,
|
| compile: bool = False,
|
| ) -> dict:
|
| """
|
| Load all Z-Image components from local directory.
|
|
|
| Args:
|
| model_dir: Path to model directory
|
| device: Device to load models on
|
| dtype: Data type for model weights
|
| verbose: Whether to display loading logs
|
| compile: Whether to compile transformer and vae with torch.compile
|
|
|
| Returns:
|
| Dictionary containing transformer, vae, text_encoder, tokenizer, and scheduler
|
| """
|
| model_dir = Path(model_dir)
|
|
|
| sys.path.insert(0, str(model_dir.parent.parent / "Z-Image" / "src"))
|
| from zimage.transformer import ZImageTransformer2DModel
|
|
|
| if verbose:
|
| logger.info(f"Loading Z-Image from: {model_dir}")
|
|
|
|
|
| if verbose:
|
| logger.info("Loading DiT...")
|
| transformer_dir = model_dir / "transformer"
|
| config = load_config(str(transformer_dir / "config.json"))
|
|
|
| with torch.device("meta"):
|
| transformer = ZImageTransformer2DModel(
|
| all_patch_size=tuple(config.get("all_patch_size", DEFAULT_TRANSFORMER_PATCH_SIZE)),
|
| all_f_patch_size=tuple(config.get("all_f_patch_size", DEFAULT_TRANSFORMER_F_PATCH_SIZE)),
|
| in_channels=config.get("in_channels", DEFAULT_TRANSFORMER_IN_CHANNELS),
|
| dim=config.get("dim", DEFAULT_TRANSFORMER_DIM),
|
| n_layers=config.get("n_layers", DEFAULT_TRANSFORMER_N_LAYERS),
|
| n_refiner_layers=config.get("n_refiner_layers", DEFAULT_TRANSFORMER_N_REFINER_LAYERS),
|
| n_heads=config.get("n_heads", DEFAULT_TRANSFORMER_N_HEADS),
|
| n_kv_heads=config.get("n_kv_heads", DEFAULT_TRANSFORMER_N_KV_HEADS),
|
| norm_eps=config.get("norm_eps", DEFAULT_TRANSFORMER_NORM_EPS),
|
| qk_norm=config.get("qk_norm", DEFAULT_TRANSFORMER_QK_NORM),
|
| cap_feat_dim=config.get("cap_feat_dim", DEFAULT_TRANSFORMER_CAP_FEAT_DIM),
|
| rope_theta=config.get("rope_theta", ROPE_THETA),
|
| t_scale=config.get("t_scale", DEFAULT_TRANSFORMER_T_SCALE),
|
| axes_dims=config.get("axes_dims", ROPE_AXES_DIMS),
|
| axes_lens=config.get("axes_lens", ROPE_AXES_LENS),
|
| ).to(dtype)
|
|
|
|
|
| state_dict = load_sharded_safetensors(transformer_dir, device="cpu", dtype=dtype)
|
| transformer.load_state_dict(state_dict, strict=False, assign=True)
|
| del state_dict
|
|
|
| if verbose:
|
| logger.info("Moving DiT to GPU...")
|
| transformer = transformer.to(device)
|
| if torch.cuda.is_available():
|
| torch.cuda.empty_cache()
|
| transformer.eval()
|
|
|
|
|
| if verbose:
|
| logger.info("Loading VAE...")
|
| vae_dir = model_dir / "vae"
|
| vae_config = load_config(str(vae_dir / "config.json"))
|
|
|
| vae = LocalAutoencoderKL(
|
| in_channels=vae_config.get("in_channels", DEFAULT_VAE_IN_CHANNELS),
|
| out_channels=vae_config.get("out_channels", DEFAULT_VAE_OUT_CHANNELS),
|
| down_block_types=tuple(vae_config.get("down_block_types", ("DownEncoderBlock2D",))),
|
| up_block_types=tuple(vae_config.get("up_block_types", ("UpDecoderBlock2D",))),
|
| block_out_channels=tuple(vae_config.get("block_out_channels", (64,))),
|
| layers_per_block=vae_config.get("layers_per_block", 1),
|
| latent_channels=vae_config.get("latent_channels", DEFAULT_VAE_LATENT_CHANNELS),
|
| norm_num_groups=vae_config.get("norm_num_groups", DEFAULT_VAE_NORM_NUM_GROUPS),
|
| scaling_factor=vae_config.get("scaling_factor", DEFAULT_VAE_SCALING_FACTOR),
|
| shift_factor=vae_config.get("shift_factor", None),
|
| use_quant_conv=vae_config.get("use_quant_conv", True),
|
| use_post_quant_conv=vae_config.get("use_post_quant_conv", True),
|
| mid_block_add_attention=vae_config.get("mid_block_add_attention", True),
|
| )
|
|
|
|
|
| vae_state_dict = load_sharded_safetensors(vae_dir, device="cpu")
|
| vae.load_state_dict(vae_state_dict, strict=False)
|
| del vae_state_dict
|
| vae.to(device=device, dtype=torch.float32)
|
| vae.eval()
|
| torch.cuda.empty_cache()
|
|
|
|
|
| if verbose:
|
| logger.info("Loading Text Encoder...")
|
| text_encoder_dir = model_dir / "text_encoder"
|
| text_encoder = AutoModel.from_pretrained(
|
| str(text_encoder_dir),
|
|
|
| dtype=dtype,
|
| trust_remote_code=True,
|
| )
|
| text_encoder.to(device)
|
| text_encoder.eval()
|
|
|
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| if verbose:
|
| logger.info("Loading Tokenizer...")
|
| tokenizer_dir = model_dir / "tokenizer"
|
| tokenizer = AutoTokenizer.from_pretrained(
|
| str(tokenizer_dir) if tokenizer_dir.exists() else str(text_encoder_dir),
|
| trust_remote_code=True,
|
| )
|
|
|
|
|
| if verbose:
|
| logger.info("Loading Scheduler...")
|
| scheduler_dir = model_dir / "scheduler"
|
| scheduler_config = load_config(str(scheduler_dir / "scheduler_config.json"))
|
| scheduler = FlowMatchEulerDiscreteScheduler(
|
| num_train_timesteps=scheduler_config.get("num_train_timesteps", DEFAULT_SCHEDULER_NUM_TRAIN_TIMESTEPS),
|
| shift=scheduler_config.get("shift", DEFAULT_SCHEDULER_SHIFT),
|
| use_dynamic_shifting=scheduler_config.get("use_dynamic_shifting", DEFAULT_SCHEDULER_USE_DYNAMIC_SHIFTING),
|
| )
|
|
|
| if compile:
|
| if verbose:
|
| logger.info("Compiling DiT and VAE...")
|
| transformer = torch.compile(transformer)
|
| vae = torch.compile(vae)
|
|
|
| if verbose:
|
| logger.success("All components loaded successfully")
|
|
|
| return {
|
| "transformer": transformer,
|
| "vae": vae,
|
| "text_encoder": text_encoder,
|
| "tokenizer": tokenizer,
|
| "scheduler": scheduler,
|
| }
|
|
|