stevengrove's picture
Initial commit with Xet-tracked image assets
fcfea15
import os
import glob
import torch
import torch.distributed as dist
from modules.models.bucket import BucketGroup
from modules.models.mmdit.dit import Transformer3DModel
from modules.models.mmdit.text_encoder import load_text_encoder
from modules.models.mmdit.vae import WanxVAE
from modules.models.pipeline import Pipeline
from modules.models.scheduler import FlowMatchDiscreteScheduler
from modules.utils.fsdp_load import maybe_load_fsdp_model, pt_weights_iterator, safetensors_weights_iterator
from modules.utils.logging import get_logger
from modules.utils.constants import PRECISION_TO_TYPE
from modules.utils.utils import build_from_config
def load_pipeline(cfg, dit, device: torch.device):
# vae
factory_kwargs = {
'torch_dtype': PRECISION_TO_TYPE[cfg.vae_precision], "device": device}
vae = build_from_config(cfg.vae_arch_config, **factory_kwargs)
if getattr(cfg.vae_arch_config, "enable_feature_caching", False):
vae.enable_feature_caching()
# text_encoder
factory_kwargs = {
'torch_dtype': PRECISION_TO_TYPE[cfg.text_encoder_precision], "device": device}
tokenizer, text_encoder = build_from_config(
cfg.text_encoder_arch_config, **factory_kwargs)
# scheduler
scheduler = build_from_config(cfg.scheduler_arch_config)
pipeline = Pipeline(
vae=vae,
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=dit,
scheduler=scheduler,
args=cfg,
)
pipeline = pipeline.to(device)
return pipeline
def load_dit(cfg, device: torch.device) -> torch.nn.Module:
"""Load DiT model with FSDP support."""
logger = get_logger()
state_dict = None
if cfg.dit_ckpt is not None:
logger.info(
f"Loading model from: {cfg.dit_ckpt}, type: {cfg.dit_ckpt_type}")
if cfg.dit_ckpt_type == "safetensor":
# Find all safetensors files
safetensors_files = glob.glob(
os.path.join(str(cfg.dit_ckpt), "*.safetensors"))
if not safetensors_files:
raise ValueError(
f"No safetensors files found in {cfg.dit_ckpt}")
state_dict = dict(
safetensors_weights_iterator(safetensors_files))
elif cfg.dit_ckpt_type == "pt":
pt_files = [cfg.dit_ckpt]
state_dict = dict(pt_weights_iterator(pt_files))
if "model" in state_dict:
state_dict = state_dict["model"]
else:
raise ValueError(
f"Unknown dit_ckpt_type: {cfg.dit_ckpt_type}, must be 'safetensor' or 'pt'")
dtype = PRECISION_TO_TYPE[cfg.dit_precision]
model_kwargs = {'dtype': dtype, 'device': device, 'args': cfg}
model = build_from_config(cfg.dit_arch_config, **model_kwargs)
if not dist.is_initialized() or dist.get_world_size() == 1:
# Debug mode
model.to(device=device)
if state_dict is not None:
# filter unused params
load_state_dict = {}
for k, v in state_dict.items():
if k == "img_in.weight" and model.img_in.weight.shape != v.shape:
logger.info(
f"Inflate {k} from {v.shape} to {model.img_in.weight.shape}")
v_new = v.new_zeros(model.img_in.weight.shape)
v_new[:, :v.shape[1], :, :, :] = v
v = v_new
load_state_dict[k] = v
model.load_state_dict(load_state_dict, strict=True)
model = maybe_load_fsdp_model(
model=model,
hsdp_shard_dim=cfg.hsdp_shard_dim,
reshard_after_forward=cfg.reshard_after_forward,
param_dtype=dtype,
reduce_dtype=torch.float32,
output_dtype=None,
cpu_offload=cfg.cpu_offload,
fsdp_inference=cfg.use_fsdp_inference,
training_mode=cfg.training_mode,
pin_cpu_memory=cfg.pin_cpu_memory,
)
# Log model info
total_params = sum(p.numel() for p in model.parameters())
logger.info(f"Instantiate model with {total_params / 1e9:.2f}B parameters")
# Ensure consistent dtype
param_dtypes = {param.dtype for param in model.parameters()}
if len(param_dtypes) > 1:
logger.warning(
f"Model has mixed dtypes: {param_dtypes}. Converting to {dtype}")
model = model.to(dtype)
return model.eval()
__all__ = [
"BucketGroup",
"FlowMatchDiscreteScheduler",
"Pipeline",
"Transformer3DModel",
"WanxVAE",
"load_pipeline",
"load_text_encoder",
]