| | from __future__ import annotations |
| |
|
| | import json |
| | from pathlib import Path |
| |
|
| | import torch |
| | from safetensors.torch import load_file as load_safetensors |
| |
|
| | from diffusers.configuration_utils import ConfigMixin, register_to_config |
| | from diffusers.models.modeling_utils import ModelMixin |
| |
|
| | |
| | |
| | |
| | if False: |
| | from .model import BitDance_B as _BD_B_STD |
| | from .model import BitDance_H as _BD_H_STD |
| | from .model import BitDance_L as _BD_L_STD |
| | from .model_parallel import BitDance_B as _BD_B_PAR |
| | from .model_parallel import BitDance_H as _BD_H_PAR |
| | from .model_parallel import BitDance_L as _BD_L_PAR |
| | from .diff_head import DiffHead as _DiffHead |
| | from .diff_head_parallel import DiffHead as _DiffHeadParallel |
| | from .layers import TransformerBlock as _TB |
| | from .layers_parallel import TransformerBlock as _TBP |
| | from .qae import VQModel as _VQ |
| | from .gfq import GFQ as _GFQ |
| | from .sampling import euler_maruyama as _EM |
| | from .sampling_parallel import euler_maruyama as _EMP |
| | from .utils import patchify_raster as _PR |
| |
|
| |
|
| | class BitDanceImageNetTransformer(ModelMixin, ConfigMixin): |
| | @register_to_config |
| | def __init__( |
| | self, |
| | architecture: str, |
| | parallel_num: int, |
| | resolution: int, |
| | down_size: int, |
| | latent_dim: int, |
| | num_classes: int, |
| | runtime_impl: str, |
| | parallel_mode: str = "patch", |
| | time_schedule: str = "logit_normal", |
| | time_shift: float = 1.0, |
| | p_std: float = 1.0, |
| | p_mean: float = 0.0, |
| | ): |
| | super().__init__() |
| |
|
| | kwargs = dict( |
| | resolution=resolution, |
| | down_size=down_size, |
| | patch_size=1, |
| | latent_dim=latent_dim, |
| | diff_batch_mul=4, |
| | cls_token_num=64, |
| | num_classes=num_classes, |
| | grad_checkpointing=False, |
| | trained_vae="", |
| | drop_rate=0.0, |
| | perturb_schedule="constant", |
| | perturb_rate=0.0, |
| | perturb_rate_max=0.3, |
| | time_schedule=time_schedule, |
| | time_shift=time_shift, |
| | P_std=p_std, |
| | P_mean=p_mean, |
| | ) |
| |
|
| | if runtime_impl == "model_parallel.py" or parallel_num > 1: |
| | from .model_parallel import BitDance_B, BitDance_H, BitDance_L |
| |
|
| | ctors = {"BitDance-B": BitDance_B, "BitDance-L": BitDance_L, "BitDance-H": BitDance_H} |
| | kwargs.update(parallel_num=parallel_num, parallel_mode=parallel_mode) |
| | else: |
| | from .model import BitDance_B, BitDance_H, BitDance_L |
| |
|
| | ctors = {"BitDance-B": BitDance_B, "BitDance-L": BitDance_L, "BitDance-H": BitDance_H} |
| |
|
| | self.runtime_model = ctors[architecture](**kwargs) |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): |
| | del args, kwargs |
| | model_dir = Path(pretrained_model_name_or_path) |
| | config = json.loads((model_dir / "config.json").read_text(encoding="utf-8")) |
| | model = cls( |
| | architecture=config["architecture"], |
| | parallel_num=int(config["parallel_num"]), |
| | resolution=int(config["resolution"]), |
| | down_size=int(config["down_size"]), |
| | latent_dim=int(config["latent_dim"]), |
| | num_classes=int(config["num_classes"]), |
| | runtime_impl=config["runtime_impl"], |
| | parallel_mode=config.get("parallel_mode", "patch"), |
| | time_schedule=config.get("time_schedule", "logit_normal"), |
| | time_shift=float(config.get("time_shift", 1.0)), |
| | p_std=float(config.get("p_std", 1.0)), |
| | p_mean=float(config.get("p_mean", 0.0)), |
| | ) |
| | state = load_safetensors(model_dir / "diffusion_pytorch_model.safetensors") |
| | model.runtime_model.load_state_dict(state, strict=True) |
| | model.eval() |
| | return model |
| |
|
| | @torch.no_grad() |
| | def sample( |
| | self, |
| | class_ids: torch.Tensor, |
| | sample_steps: int = 100, |
| | cfg_scale: float = 4.6, |
| | chunk_size: int = 0, |
| | ) -> torch.Tensor: |
| | return self.runtime_model.sample( |
| | cond=class_ids, |
| | sample_steps=sample_steps, |
| | cfg_scale=cfg_scale, |
| | chunk_size=chunk_size, |
| | ) |
| |
|
| | def forward(self, *args, **kwargs): |
| | return self.runtime_model(*args, **kwargs) |
| |
|