Spaces:
Runtime error
Runtime error
| import os | |
| import glob | |
| import torch | |
| import torch.distributed as dist | |
| from modules.models.bucket import BucketGroup | |
| from modules.models.mmdit.dit import Transformer3DModel | |
| from modules.models.mmdit.text_encoder import load_text_encoder | |
| from modules.models.mmdit.vae import WanxVAE | |
| from modules.models.pipeline import Pipeline | |
| from modules.models.scheduler import FlowMatchDiscreteScheduler | |
| from modules.utils.fsdp_load import maybe_load_fsdp_model, pt_weights_iterator, safetensors_weights_iterator | |
| from modules.utils.logging import get_logger | |
| from modules.utils.constants import PRECISION_TO_TYPE | |
| from modules.utils.utils import build_from_config | |
| def load_pipeline(cfg, dit, device: torch.device): | |
| # vae | |
| factory_kwargs = { | |
| 'torch_dtype': PRECISION_TO_TYPE[cfg.vae_precision], "device": device} | |
| vae = build_from_config(cfg.vae_arch_config, **factory_kwargs) | |
| if getattr(cfg.vae_arch_config, "enable_feature_caching", False): | |
| vae.enable_feature_caching() | |
| # text_encoder | |
| factory_kwargs = { | |
| 'torch_dtype': PRECISION_TO_TYPE[cfg.text_encoder_precision], "device": device} | |
| tokenizer, text_encoder = build_from_config( | |
| cfg.text_encoder_arch_config, **factory_kwargs) | |
| # scheduler | |
| scheduler = build_from_config(cfg.scheduler_arch_config) | |
| pipeline = Pipeline( | |
| vae=vae, | |
| tokenizer=tokenizer, | |
| text_encoder=text_encoder, | |
| transformer=dit, | |
| scheduler=scheduler, | |
| args=cfg, | |
| ) | |
| pipeline = pipeline.to(device) | |
| return pipeline | |
| def load_dit(cfg, device: torch.device) -> torch.nn.Module: | |
| """Load DiT model with FSDP support.""" | |
| logger = get_logger() | |
| state_dict = None | |
| if cfg.dit_ckpt is not None: | |
| logger.info( | |
| f"Loading model from: {cfg.dit_ckpt}, type: {cfg.dit_ckpt_type}") | |
| if cfg.dit_ckpt_type == "safetensor": | |
| # Find all safetensors files | |
| safetensors_files = glob.glob( | |
| os.path.join(str(cfg.dit_ckpt), "*.safetensors")) | |
| if not safetensors_files: | |
| raise ValueError( | |
| f"No safetensors files found in {cfg.dit_ckpt}") | |
| state_dict = dict( | |
| safetensors_weights_iterator(safetensors_files)) | |
| elif cfg.dit_ckpt_type == "pt": | |
| pt_files = [cfg.dit_ckpt] | |
| state_dict = dict(pt_weights_iterator(pt_files)) | |
| if "model" in state_dict: | |
| state_dict = state_dict["model"] | |
| else: | |
| raise ValueError( | |
| f"Unknown dit_ckpt_type: {cfg.dit_ckpt_type}, must be 'safetensor' or 'pt'") | |
| dtype = PRECISION_TO_TYPE[cfg.dit_precision] | |
| model_kwargs = {'dtype': dtype, 'device': device, 'args': cfg} | |
| model = build_from_config(cfg.dit_arch_config, **model_kwargs) | |
| if not dist.is_initialized() or dist.get_world_size() == 1: | |
| # Debug mode | |
| model.to(device=device) | |
| if state_dict is not None: | |
| # filter unused params | |
| load_state_dict = {} | |
| for k, v in state_dict.items(): | |
| if k == "img_in.weight" and model.img_in.weight.shape != v.shape: | |
| logger.info( | |
| f"Inflate {k} from {v.shape} to {model.img_in.weight.shape}") | |
| v_new = v.new_zeros(model.img_in.weight.shape) | |
| v_new[:, :v.shape[1], :, :, :] = v | |
| v = v_new | |
| load_state_dict[k] = v | |
| model.load_state_dict(load_state_dict, strict=True) | |
| model = maybe_load_fsdp_model( | |
| model=model, | |
| hsdp_shard_dim=cfg.hsdp_shard_dim, | |
| reshard_after_forward=cfg.reshard_after_forward, | |
| param_dtype=dtype, | |
| reduce_dtype=torch.float32, | |
| output_dtype=None, | |
| cpu_offload=cfg.cpu_offload, | |
| fsdp_inference=cfg.use_fsdp_inference, | |
| training_mode=cfg.training_mode, | |
| pin_cpu_memory=cfg.pin_cpu_memory, | |
| ) | |
| # Log model info | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| logger.info(f"Instantiate model with {total_params / 1e9:.2f}B parameters") | |
| # Ensure consistent dtype | |
| param_dtypes = {param.dtype for param in model.parameters()} | |
| if len(param_dtypes) > 1: | |
| logger.warning( | |
| f"Model has mixed dtypes: {param_dtypes}. Converting to {dtype}") | |
| model = model.to(dtype) | |
| return model.eval() | |
| __all__ = [ | |
| "BucketGroup", | |
| "FlowMatchDiscreteScheduler", | |
| "Pipeline", | |
| "Transformer3DModel", | |
| "WanxVAE", | |
| "load_pipeline", | |
| "load_text_encoder", | |
| ] | |