|
|
from __future__ import annotations |
|
|
|
|
|
import json |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
from safetensors.torch import load_file as load_safetensors |
|
|
|
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
|
from diffusers.models.modeling_utils import ModelMixin |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if False: |
|
|
from .model import BitDance_B as _BD_B_STD |
|
|
from .model import BitDance_H as _BD_H_STD |
|
|
from .model import BitDance_L as _BD_L_STD |
|
|
from .model_parallel import BitDance_B as _BD_B_PAR |
|
|
from .model_parallel import BitDance_H as _BD_H_PAR |
|
|
from .model_parallel import BitDance_L as _BD_L_PAR |
|
|
from .diff_head import DiffHead as _DiffHead |
|
|
from .diff_head_parallel import DiffHead as _DiffHeadParallel |
|
|
from .layers import TransformerBlock as _TB |
|
|
from .layers_parallel import TransformerBlock as _TBP |
|
|
from .qae import VQModel as _VQ |
|
|
from .gfq import GFQ as _GFQ |
|
|
from .sampling import euler_maruyama as _EM |
|
|
from .sampling_parallel import euler_maruyama as _EMP |
|
|
from .utils import patchify_raster as _PR |
|
|
|
|
|
|
|
|
class BitDanceImageNetTransformer(ModelMixin, ConfigMixin): |
|
|
@register_to_config |
|
|
def __init__( |
|
|
self, |
|
|
architecture: str, |
|
|
parallel_num: int, |
|
|
resolution: int, |
|
|
down_size: int, |
|
|
latent_dim: int, |
|
|
num_classes: int, |
|
|
runtime_impl: str, |
|
|
parallel_mode: str = "patch", |
|
|
time_schedule: str = "logit_normal", |
|
|
time_shift: float = 1.0, |
|
|
p_std: float = 1.0, |
|
|
p_mean: float = 0.0, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
kwargs = dict( |
|
|
resolution=resolution, |
|
|
down_size=down_size, |
|
|
patch_size=1, |
|
|
latent_dim=latent_dim, |
|
|
diff_batch_mul=4, |
|
|
cls_token_num=64, |
|
|
num_classes=num_classes, |
|
|
grad_checkpointing=False, |
|
|
trained_vae="", |
|
|
drop_rate=0.0, |
|
|
perturb_schedule="constant", |
|
|
perturb_rate=0.0, |
|
|
perturb_rate_max=0.3, |
|
|
time_schedule=time_schedule, |
|
|
time_shift=time_shift, |
|
|
P_std=p_std, |
|
|
P_mean=p_mean, |
|
|
) |
|
|
|
|
|
if runtime_impl == "model_parallel.py" or parallel_num > 1: |
|
|
from .model_parallel import BitDance_B, BitDance_H, BitDance_L |
|
|
|
|
|
ctors = {"BitDance-B": BitDance_B, "BitDance-L": BitDance_L, "BitDance-H": BitDance_H} |
|
|
kwargs.update(parallel_num=parallel_num, parallel_mode=parallel_mode) |
|
|
else: |
|
|
from .model import BitDance_B, BitDance_H, BitDance_L |
|
|
|
|
|
ctors = {"BitDance-B": BitDance_B, "BitDance-L": BitDance_L, "BitDance-H": BitDance_H} |
|
|
|
|
|
self.runtime_model = ctors[architecture](**kwargs) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): |
|
|
del args, kwargs |
|
|
model_dir = Path(pretrained_model_name_or_path) |
|
|
config = json.loads((model_dir / "config.json").read_text(encoding="utf-8")) |
|
|
model = cls( |
|
|
architecture=config["architecture"], |
|
|
parallel_num=int(config["parallel_num"]), |
|
|
resolution=int(config["resolution"]), |
|
|
down_size=int(config["down_size"]), |
|
|
latent_dim=int(config["latent_dim"]), |
|
|
num_classes=int(config["num_classes"]), |
|
|
runtime_impl=config["runtime_impl"], |
|
|
parallel_mode=config.get("parallel_mode", "patch"), |
|
|
time_schedule=config.get("time_schedule", "logit_normal"), |
|
|
time_shift=float(config.get("time_shift", 1.0)), |
|
|
p_std=float(config.get("p_std", 1.0)), |
|
|
p_mean=float(config.get("p_mean", 0.0)), |
|
|
) |
|
|
state = load_safetensors(model_dir / "diffusion_pytorch_model.safetensors") |
|
|
model.runtime_model.load_state_dict(state, strict=True) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
@torch.no_grad() |
|
|
def sample( |
|
|
self, |
|
|
class_ids: torch.Tensor, |
|
|
sample_steps: int = 100, |
|
|
cfg_scale: float = 4.6, |
|
|
chunk_size: int = 0, |
|
|
) -> torch.Tensor: |
|
|
return self.runtime_model.sample( |
|
|
cond=class_ids, |
|
|
sample_steps=sample_steps, |
|
|
cfg_scale=cfg_scale, |
|
|
chunk_size=chunk_size, |
|
|
) |
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
return self.runtime_model(*args, **kwargs) |
|
|
|