Spaces:
Runtime error
Runtime error
File size: 4,525 Bytes
fcfea15 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | import os
import glob
import torch
import torch.distributed as dist
from modules.models.bucket import BucketGroup
from modules.models.mmdit.dit import Transformer3DModel
from modules.models.mmdit.text_encoder import load_text_encoder
from modules.models.mmdit.vae import WanxVAE
from modules.models.pipeline import Pipeline
from modules.models.scheduler import FlowMatchDiscreteScheduler
from modules.utils.fsdp_load import maybe_load_fsdp_model, pt_weights_iterator, safetensors_weights_iterator
from modules.utils.logging import get_logger
from modules.utils.constants import PRECISION_TO_TYPE
from modules.utils.utils import build_from_config
def load_pipeline(cfg, dit, device: torch.device):
# vae
factory_kwargs = {
'torch_dtype': PRECISION_TO_TYPE[cfg.vae_precision], "device": device}
vae = build_from_config(cfg.vae_arch_config, **factory_kwargs)
if getattr(cfg.vae_arch_config, "enable_feature_caching", False):
vae.enable_feature_caching()
# text_encoder
factory_kwargs = {
'torch_dtype': PRECISION_TO_TYPE[cfg.text_encoder_precision], "device": device}
tokenizer, text_encoder = build_from_config(
cfg.text_encoder_arch_config, **factory_kwargs)
# scheduler
scheduler = build_from_config(cfg.scheduler_arch_config)
pipeline = Pipeline(
vae=vae,
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=dit,
scheduler=scheduler,
args=cfg,
)
pipeline = pipeline.to(device)
return pipeline
def load_dit(cfg, device: torch.device) -> torch.nn.Module:
"""Load DiT model with FSDP support."""
logger = get_logger()
state_dict = None
if cfg.dit_ckpt is not None:
logger.info(
f"Loading model from: {cfg.dit_ckpt}, type: {cfg.dit_ckpt_type}")
if cfg.dit_ckpt_type == "safetensor":
# Find all safetensors files
safetensors_files = glob.glob(
os.path.join(str(cfg.dit_ckpt), "*.safetensors"))
if not safetensors_files:
raise ValueError(
f"No safetensors files found in {cfg.dit_ckpt}")
state_dict = dict(
safetensors_weights_iterator(safetensors_files))
elif cfg.dit_ckpt_type == "pt":
pt_files = [cfg.dit_ckpt]
state_dict = dict(pt_weights_iterator(pt_files))
if "model" in state_dict:
state_dict = state_dict["model"]
else:
raise ValueError(
f"Unknown dit_ckpt_type: {cfg.dit_ckpt_type}, must be 'safetensor' or 'pt'")
dtype = PRECISION_TO_TYPE[cfg.dit_precision]
model_kwargs = {'dtype': dtype, 'device': device, 'args': cfg}
model = build_from_config(cfg.dit_arch_config, **model_kwargs)
if not dist.is_initialized() or dist.get_world_size() == 1:
# Debug mode
model.to(device=device)
if state_dict is not None:
# filter unused params
load_state_dict = {}
for k, v in state_dict.items():
if k == "img_in.weight" and model.img_in.weight.shape != v.shape:
logger.info(
f"Inflate {k} from {v.shape} to {model.img_in.weight.shape}")
v_new = v.new_zeros(model.img_in.weight.shape)
v_new[:, :v.shape[1], :, :, :] = v
v = v_new
load_state_dict[k] = v
model.load_state_dict(load_state_dict, strict=True)
model = maybe_load_fsdp_model(
model=model,
hsdp_shard_dim=cfg.hsdp_shard_dim,
reshard_after_forward=cfg.reshard_after_forward,
param_dtype=dtype,
reduce_dtype=torch.float32,
output_dtype=None,
cpu_offload=cfg.cpu_offload,
fsdp_inference=cfg.use_fsdp_inference,
training_mode=cfg.training_mode,
pin_cpu_memory=cfg.pin_cpu_memory,
)
# Log model info
total_params = sum(p.numel() for p in model.parameters())
logger.info(f"Instantiate model with {total_params / 1e9:.2f}B parameters")
# Ensure consistent dtype
param_dtypes = {param.dtype for param in model.parameters()}
if len(param_dtypes) > 1:
logger.warning(
f"Model has mixed dtypes: {param_dtypes}. Converting to {dtype}")
model = model.to(dtype)
return model.eval()
__all__ = [
"BucketGroup",
"FlowMatchDiscreteScheduler",
"Pipeline",
"Transformer3DModel",
"WanxVAE",
"load_pipeline",
"load_text_encoder",
]
|