File size: 4,584 Bytes
f66af66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from __future__ import annotations

import json
from pathlib import Path

import torch
from safetensors.torch import load_file as load_safetensors

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin

# NOTE: Diffusers dynamic module loader only copies directly-referenced relative imports.
# These guarded imports are intentionally never executed, but they force dependent files
# (and their siblings) to be copied into the dynamic module cache.
if False:  # pragma: no cover
    from .model import BitDance_B as _BD_B_STD
    from .model import BitDance_H as _BD_H_STD
    from .model import BitDance_L as _BD_L_STD
    from .model_parallel import BitDance_B as _BD_B_PAR
    from .model_parallel import BitDance_H as _BD_H_PAR
    from .model_parallel import BitDance_L as _BD_L_PAR
    from .diff_head import DiffHead as _DiffHead
    from .diff_head_parallel import DiffHead as _DiffHeadParallel
    from .layers import TransformerBlock as _TB
    from .layers_parallel import TransformerBlock as _TBP
    from .qae import VQModel as _VQ
    from .gfq import GFQ as _GFQ
    from .sampling import euler_maruyama as _EM
    from .sampling_parallel import euler_maruyama as _EMP
    from .utils import patchify_raster as _PR


class BitDanceImageNetTransformer(ModelMixin, ConfigMixin):
    @register_to_config
    def __init__(
        self,
        architecture: str,
        parallel_num: int,
        resolution: int,
        down_size: int,
        latent_dim: int,
        num_classes: int,
        runtime_impl: str,
        parallel_mode: str = "patch",
        time_schedule: str = "logit_normal",
        time_shift: float = 1.0,
        p_std: float = 1.0,
        p_mean: float = 0.0,
    ):
        super().__init__()

        kwargs = dict(
            resolution=resolution,
            down_size=down_size,
            patch_size=1,
            latent_dim=latent_dim,
            diff_batch_mul=4,
            cls_token_num=64,
            num_classes=num_classes,
            grad_checkpointing=False,
            trained_vae="",
            drop_rate=0.0,
            perturb_schedule="constant",
            perturb_rate=0.0,
            perturb_rate_max=0.3,
            time_schedule=time_schedule,
            time_shift=time_shift,
            P_std=p_std,
            P_mean=p_mean,
        )

        if runtime_impl == "model_parallel.py" or parallel_num > 1:
            from .model_parallel import BitDance_B, BitDance_H, BitDance_L

            ctors = {"BitDance-B": BitDance_B, "BitDance-L": BitDance_L, "BitDance-H": BitDance_H}
            kwargs.update(parallel_num=parallel_num, parallel_mode=parallel_mode)
        else:
            from .model import BitDance_B, BitDance_H, BitDance_L

            ctors = {"BitDance-B": BitDance_B, "BitDance-L": BitDance_L, "BitDance-H": BitDance_H}

        self.runtime_model = ctors[architecture](**kwargs)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
        del args, kwargs
        model_dir = Path(pretrained_model_name_or_path)
        config = json.loads((model_dir / "config.json").read_text(encoding="utf-8"))
        model = cls(
            architecture=config["architecture"],
            parallel_num=int(config["parallel_num"]),
            resolution=int(config["resolution"]),
            down_size=int(config["down_size"]),
            latent_dim=int(config["latent_dim"]),
            num_classes=int(config["num_classes"]),
            runtime_impl=config["runtime_impl"],
            parallel_mode=config.get("parallel_mode", "patch"),
            time_schedule=config.get("time_schedule", "logit_normal"),
            time_shift=float(config.get("time_shift", 1.0)),
            p_std=float(config.get("p_std", 1.0)),
            p_mean=float(config.get("p_mean", 0.0)),
        )
        state = load_safetensors(model_dir / "diffusion_pytorch_model.safetensors")
        model.runtime_model.load_state_dict(state, strict=True)
        model.eval()
        return model

    @torch.no_grad()
    def sample(
        self,
        class_ids: torch.Tensor,
        sample_steps: int = 100,
        cfg_scale: float = 4.6,
        chunk_size: int = 0,
    ) -> torch.Tensor:
        return self.runtime_model.sample(
            cond=class_ids,
            sample_steps=sample_steps,
            cfg_scale=cfg_scale,
            chunk_size=chunk_size,
        )

    def forward(self, *args, **kwargs):
        return self.runtime_model(*args, **kwargs)