from __future__ import annotations import json from pathlib import Path import torch from safetensors.torch import load_file as load_safetensors from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin # NOTE: Diffusers dynamic module loader only copies directly-referenced relative imports. # These guarded imports are intentionally never executed, but they force dependent files # (and their siblings) to be copied into the dynamic module cache. if False: # pragma: no cover from .model import BitDance_B as _BD_B_STD from .model import BitDance_H as _BD_H_STD from .model import BitDance_L as _BD_L_STD from .model_parallel import BitDance_B as _BD_B_PAR from .model_parallel import BitDance_H as _BD_H_PAR from .model_parallel import BitDance_L as _BD_L_PAR from .diff_head import DiffHead as _DiffHead from .diff_head_parallel import DiffHead as _DiffHeadParallel from .layers import TransformerBlock as _TB from .layers_parallel import TransformerBlock as _TBP from .qae import VQModel as _VQ from .gfq import GFQ as _GFQ from .sampling import euler_maruyama as _EM from .sampling_parallel import euler_maruyama as _EMP from .utils import patchify_raster as _PR class BitDanceImageNetTransformer(ModelMixin, ConfigMixin): @register_to_config def __init__( self, architecture: str, parallel_num: int, resolution: int, down_size: int, latent_dim: int, num_classes: int, runtime_impl: str, parallel_mode: str = "patch", time_schedule: str = "logit_normal", time_shift: float = 1.0, p_std: float = 1.0, p_mean: float = 0.0, ): super().__init__() kwargs = dict( resolution=resolution, down_size=down_size, patch_size=1, latent_dim=latent_dim, diff_batch_mul=4, cls_token_num=64, num_classes=num_classes, grad_checkpointing=False, trained_vae="", drop_rate=0.0, perturb_schedule="constant", perturb_rate=0.0, perturb_rate_max=0.3, time_schedule=time_schedule, time_shift=time_shift, P_std=p_std, P_mean=p_mean, ) if runtime_impl == "model_parallel.py" or parallel_num > 1: from .model_parallel import BitDance_B, BitDance_H, BitDance_L ctors = {"BitDance-B": BitDance_B, "BitDance-L": BitDance_L, "BitDance-H": BitDance_H} kwargs.update(parallel_num=parallel_num, parallel_mode=parallel_mode) else: from .model import BitDance_B, BitDance_H, BitDance_L ctors = {"BitDance-B": BitDance_B, "BitDance-L": BitDance_L, "BitDance-H": BitDance_H} self.runtime_model = ctors[architecture](**kwargs) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): del args, kwargs model_dir = Path(pretrained_model_name_or_path) config = json.loads((model_dir / "config.json").read_text(encoding="utf-8")) model = cls( architecture=config["architecture"], parallel_num=int(config["parallel_num"]), resolution=int(config["resolution"]), down_size=int(config["down_size"]), latent_dim=int(config["latent_dim"]), num_classes=int(config["num_classes"]), runtime_impl=config["runtime_impl"], parallel_mode=config.get("parallel_mode", "patch"), time_schedule=config.get("time_schedule", "logit_normal"), time_shift=float(config.get("time_shift", 1.0)), p_std=float(config.get("p_std", 1.0)), p_mean=float(config.get("p_mean", 0.0)), ) state = load_safetensors(model_dir / "diffusion_pytorch_model.safetensors") model.runtime_model.load_state_dict(state, strict=True) model.eval() return model @torch.no_grad() def sample( self, class_ids: torch.Tensor, sample_steps: int = 100, cfg_scale: float = 4.6, chunk_size: int = 0, ) -> torch.Tensor: return self.runtime_model.sample( cond=class_ids, sample_steps=sample_steps, cfg_scale=cfg_scale, chunk_size=chunk_size, ) def forward(self, *args, **kwargs): return self.runtime_model(*args, **kwargs)