| | import argparse |
| | from functools import partial |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import functional as F |
| | from torch.utils.checkpoint import checkpoint |
| |
|
| | from .diff_head import DiffHead |
| | from .layers import TransformerBlock, get_2d_pos, precompute_freqs_cis_2d |
| | from .qae import VQModel |
| |
|
| | def get_model_args(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--model", type=str, choices=list(BitDance_models.keys()), default="BitDance-L" |
| | ) |
| | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) |
| | parser.add_argument("--down-size", type=int, default=16, choices=[16]) |
| | parser.add_argument("--patch-size", type=int, default=1, choices=[1, 2, 4]) |
| | parser.add_argument("--num-classes", type=int, default=1000) |
| | parser.add_argument("--cls-token-num", type=int, default=64) |
| | parser.add_argument("--latent-dim", type=int, default=16) |
| | parser.add_argument("--diff-batch-mul", type=int, default=4) |
| | parser.add_argument("--grad-checkpointing", action="store_true") |
| | parser.add_argument("--trained-vae", type=str, default="") |
| | parser.add_argument("--drop-rate", type=float, default=0.0) |
| | parser.add_argument("--perturb-schedule", type=str, default="constant") |
| | parser.add_argument("--perturb-rate", type=float, default=0.0) |
| | parser.add_argument("--perturb-rate-max", type=float, default=0.3) |
| | parser.add_argument("--time-schedule", type=str, default='logit_normal') |
| | parser.add_argument("--time-shift", type=float, default=1.) |
| | parser.add_argument("--P-std", type=float, default=1.) |
| | parser.add_argument("--P-mean", type=float, default=0.) |
| | return parser |
| |
|
| |
|
| | def create_model(args, device): |
| | model = BitDance_models[args.model]( |
| | resolution=args.image_size, |
| | down_size=args.down_size, |
| | patch_size=args.patch_size, |
| | latent_dim=args.latent_dim, |
| | diff_batch_mul=args.diff_batch_mul, |
| | cls_token_num=args.cls_token_num, |
| | num_classes=args.num_classes, |
| | grad_checkpointing=args.grad_checkpointing, |
| | trained_vae=args.trained_vae, |
| | drop_rate=args.drop_rate, |
| | perturb_schedule=args.perturb_schedule, |
| | perturb_rate=args.perturb_rate, |
| | perturb_rate_max=args.perturb_rate_max, |
| | time_schedule=args.time_schedule, |
| | time_shift=args.time_shift, |
| | P_std=args.P_std, |
| | P_mean=args.P_mean, |
| | ).to(device, memory_format=torch.channels_last) |
| | return model |
| |
|
| | class MLPConnector(nn.Module): |
| | def __init__(self, in_dim, dim, dropout_p=0.0): |
| | super().__init__() |
| | hidden_dim = int(dim * 1.5) |
| | self.w1 = nn.Linear(in_dim, hidden_dim * 2, bias=True) |
| | self.w2 = nn.Linear(hidden_dim, dim, bias=True) |
| | self.ffn_dropout = nn.Dropout(dropout_p) |
| |
|
| | def forward(self, x): |
| | h1, h2 = self.w1(x).chunk(2, dim=-1) |
| | return self.ffn_dropout(self.w2(F.silu(h1) * h2)) |
| |
|
| | def flip_tensor_elements_uniform_prob(tensor: torch.Tensor, p_max: float) -> torch.Tensor: |
| | if not 0.0 <= p_max <= 1.0: |
| | raise ValueError(f"p_max must be in [0.0, 1.0] range, but got: {p_max}") |
| | r1 = torch.rand_like(tensor) |
| | r2 = torch.rand_like(tensor) |
| | flip_mask = r1 < p_max * r2 |
| | multiplier = torch.where(flip_mask, -1.0, 1.0) |
| | multiplier = multiplier.to(tensor.dtype) |
| | flipped_tensor = tensor * multiplier |
| | return flipped_tensor |
| |
|
| | class BitDance(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | dim, |
| | n_layer, |
| | n_head, |
| | diff_layers, |
| | diff_dim, |
| | diff_adanln_layers, |
| | latent_dim, |
| | down_size, |
| | patch_size, |
| | resolution, |
| | diff_batch_mul, |
| | grad_checkpointing=False, |
| | cls_token_num=16, |
| | num_classes: int = 1000, |
| | class_dropout_prob: float = 0.1, |
| | trained_vae: str = "", |
| | drop_rate: float = 0.0, |
| | perturb_schedule: str = "constant", |
| | perturb_rate: float = 0.0, |
| | perturb_rate_max: float = 0.3, |
| | time_schedule: str = 'logit_normal', |
| | time_shift: float = 1., |
| | P_std: float = 1., |
| | P_mean: float = 0., |
| | ): |
| | super().__init__() |
| |
|
| | self.n_layer = n_layer |
| | self.resolution = resolution |
| | self.down_size = down_size |
| | self.patch_size = patch_size |
| | self.num_classes = num_classes |
| | self.cls_token_num = cls_token_num |
| | self.class_dropout_prob = class_dropout_prob |
| | self.latent_dim = latent_dim |
| | self.trained_vae = trained_vae |
| | self.perturb_schedule = perturb_schedule |
| | self.perturb_rate = perturb_rate |
| | self.perturb_rate_max = perturb_rate_max |
| |
|
| | |
| | ddconfig = { |
| | "double_z": False, |
| | "z_channels": latent_dim, |
| | "in_channels": 3, |
| | "out_ch": 3, |
| | "ch": 256, |
| | "ch_mult": [1,1,2,2,4], |
| | "num_res_blocks": 4 |
| | } |
| | num_codebooks = 4 |
| | |
| | self.vae = VQModel(ddconfig, num_codebooks) |
| | self.grad_checkpointing = grad_checkpointing |
| |
|
| | self.cls_embedding = nn.Embedding(num_classes + 1, dim * self.cls_token_num) |
| | self.proj_in = MLPConnector(latent_dim * self.patch_size * self.patch_size, dim, drop_rate) |
| | self.emb_norm = nn.RMSNorm(dim, eps=1e-6, elementwise_affine=True) |
| | self.h, self.w = resolution // (down_size * patch_size), resolution // (down_size * patch_size) |
| | self.total_tokens = self.h * self.w + self.cls_token_num |
| |
|
| | self.layers = torch.nn.ModuleList() |
| | for layer_id in range(n_layer): |
| | self.layers.append( |
| | TransformerBlock( |
| | dim, |
| | n_head, |
| | resid_dropout_p=drop_rate, |
| | causal=True, |
| | ) |
| | ) |
| |
|
| | self.norm = nn.RMSNorm(dim, eps=1e-6, elementwise_affine=True) |
| | self.pos_for_diff = nn.Embedding(self.h * self.w, dim) |
| | self.head = DiffHead( |
| | ch_target=latent_dim * self.patch_size * self.patch_size, |
| | ch_cond=dim, |
| | ch_latent=diff_dim, |
| | depth_latent=diff_layers, |
| | depth_adanln=diff_adanln_layers, |
| | grad_checkpointing=grad_checkpointing, |
| | time_shift=time_shift, |
| | time_schedule=time_schedule, |
| | P_std=P_std, |
| | P_mean=P_mean, |
| | ) |
| | self.diff_batch_mul = diff_batch_mul |
| |
|
| | patch_2d_pos = get_2d_pos(resolution, int(down_size * patch_size)) |
| |
|
| | self.register_buffer( |
| | "freqs_cis", |
| | precompute_freqs_cis_2d( |
| | patch_2d_pos, |
| | dim // n_head, |
| | 10000, |
| | cls_token_num=self.cls_token_num, |
| | )[:-1], |
| | persistent=False, |
| | ) |
| | self.freeze_vae() |
| |
|
| | self.initialize_weights() |
| |
|
| | def load_vae_weight(self): |
| | state = torch.load( |
| | self.trained_vae, |
| | map_location="cpu", |
| | ) |
| | missing_keys, unexpected_keys = self.vae.load_state_dict(state["state_dict"], strict=False) |
| | print(f"loading vae, missing_keys: {missing_keys}") |
| | del state |
| |
|
| | def non_decay_keys(self): |
| | return ["proj_in", "cls_embedding"] |
| |
|
| | def freeze_module(self, module: nn.Module): |
| | for param in module.parameters(): |
| | param.requires_grad = False |
| |
|
| | def freeze_vae(self): |
| | self.freeze_module(self.vae) |
| | self.vae.eval() |
| |
|
| | def initialize_weights(self): |
| | |
| | self.apply(self.__init_weights) |
| | self.head.initialize_weights() |
| | |
| |
|
| | def __init_weights(self, module): |
| | std = 0.02 |
| | if isinstance(module, nn.Linear): |
| | module.weight.data.normal_(mean=0.0, std=std) |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | elif isinstance(module, nn.Embedding): |
| | module.weight.data.normal_(mean=0.0, std=std) |
| |
|
| | def drop_label(self, class_id): |
| | if self.class_dropout_prob > 0.0 and self.training: |
| | is_drop = ( |
| | torch.rand(class_id.shape, device=class_id.device) |
| | < self.class_dropout_prob |
| | ) |
| | class_id = torch.where(is_drop, self.num_classes, class_id) |
| | return class_id |
| |
|
| | def patchify(self, x): |
| | bsz, c, h, w = x.shape |
| | p = self.patch_size |
| | h_, w_ = h // p, w // p |
| |
|
| | x = x.reshape(bsz, c, h_, p, w_, p) |
| | x = torch.einsum('nchpwq->nhwcpq', x) |
| | x = x.reshape(bsz, h_ * w_, c * p ** 2) |
| | return x |
| |
|
| | def unpatchify(self, x): |
| | bsz = x.shape[0] |
| | p = self.patch_size |
| | c = self.latent_dim |
| | h_, w_ = self.h, self.w |
| |
|
| | x = x.reshape(bsz, h_, w_, c, p, p) |
| | x = torch.einsum('nhwcpq->nchpwq', x) |
| | x = x.reshape(bsz, c, h_ * p, w_ * p) |
| | return x |
| |
|
| | def forward( |
| | self, |
| | images, |
| | class_id, |
| | cached=False |
| | ): |
| | if cached: |
| | vae_latent = images |
| | else: |
| | vae_latent, _, _, _ = self.vae.encode(images) |
| |
|
| | vae_latent = self.patchify(vae_latent) |
| | x = vae_latent.clone().detach() |
| | if self.training: |
| | if self.perturb_schedule =="constant": |
| | x = flip_tensor_elements_uniform_prob(x, self.perturb_rate) |
| | else: |
| | raise NotImplementedError(f"unknown perturb_schedule {self.perturb_schedule}") |
| | x = self.proj_in(x[:, :-1, :]) |
| | class_id = self.drop_label(class_id) |
| | bsz = x.shape[0] |
| | c = self.cls_embedding(class_id).view(bsz, self.cls_token_num, -1) |
| | x = torch.cat([c, x], dim=1) |
| | x = self.emb_norm(x) |
| |
|
| | if self.grad_checkpointing and self.training: |
| | for layer in self.layers: |
| | block = partial(layer.forward, freqs_cis=self.freqs_cis) |
| | x = checkpoint(block, x, use_reentrant=False) |
| | else: |
| | for layer in self.layers: |
| | x = layer(x, self.freqs_cis) |
| |
|
| | x = x[:, -self.h * self.w :, :] |
| | x = self.norm(x) |
| | x = x + self.pos_for_diff.weight |
| |
|
| | target = vae_latent.clone().detach() |
| | x = x.view(-1, x.shape[-1]) |
| | target = target.view(-1, target.shape[-1]) |
| |
|
| | x = x.repeat(self.diff_batch_mul, 1) |
| | target = target.repeat(self.diff_batch_mul, 1) |
| | loss = self.head(target, x) |
| |
|
| | return loss |
| |
|
| | def enable_kv_cache(self, bsz): |
| | for layer in self.layers: |
| | layer.attention.enable_kv_cache(bsz, self.total_tokens) |
| |
|
| | @torch.compile() |
| | def forward_model(self, x, start_pos, end_pos): |
| | x = self.emb_norm(x) |
| | for layer in self.layers: |
| | x = layer.forward_onestep( |
| | x, self.freqs_cis[start_pos:end_pos,], start_pos, end_pos |
| | ) |
| | x = self.norm(x) |
| | return x |
| | |
| | def head_sample(self, x, diff_pos, sample_steps, cfg_scale, cfg_schedule="linear"): |
| | x = x + self.pos_for_diff.weight[diff_pos : diff_pos + 1, :] |
| | x = x.view(-1, x.shape[-1]) |
| | seq_len = self.h * self.w |
| | if cfg_scale > 1.0: |
| | if cfg_schedule == "constant": |
| | cfg_iter = cfg_scale |
| | elif cfg_schedule == "linear": |
| | start = 1.0 |
| | cfg_iter = start + (cfg_scale - start) * diff_pos / seq_len |
| | else: |
| | raise NotImplementedError(f"unknown cfg_schedule {cfg_schedule}") |
| | else: |
| | cfg_iter = 1.0 |
| | pred = self.head.sample(x, num_sampling_steps=sample_steps, cfg=cfg_iter) |
| | pred = pred.view(-1, 1, pred.shape[-1]) |
| | |
| | pred = torch.sign(pred) |
| | return pred |
| |
|
| | @torch.no_grad() |
| | def sample(self, cond, sample_steps, cfg_scale=1.0, cfg_schedule="linear", chunk_size=0): |
| | self.eval() |
| | if cfg_scale > 1.0: |
| | cond_null = torch.ones_like(cond) * self.num_classes |
| | cond_combined = torch.cat([cond, cond_null]) |
| | else: |
| | cond_combined = cond |
| | bsz = cond_combined.shape[0] |
| | act_bsz = bsz // 2 if cfg_scale > 1.0 else bsz |
| | self.enable_kv_cache(bsz) |
| |
|
| | c = self.cls_embedding(cond_combined).view(bsz, self.cls_token_num, -1) |
| | last_pred = None |
| | all_preds = [] |
| | for i in range(self.h * self.w): |
| | if i == 0: |
| | x = self.forward_model(c, 0, self.cls_token_num) |
| | else: |
| | x = self.proj_in(last_pred) |
| | x = self.forward_model( |
| | x, i + self.cls_token_num - 1, i + self.cls_token_num |
| | ) |
| | last_pred = self.head_sample( |
| | x[:, -1:, :], |
| | i, |
| | sample_steps, |
| | cfg_scale, |
| | cfg_schedule, |
| | ) |
| | all_preds.append(last_pred) |
| |
|
| | x = torch.cat(all_preds, dim=-2)[:act_bsz] |
| | if x.dim() == 3: |
| | x = self.unpatchify(x) |
| | if chunk_size > 0: |
| | recon = self.decode_in_chunks(x, chunk_size) |
| | else: |
| | recon = self.vae.decode(x) |
| | return recon |
| |
|
| | def decode_in_chunks(self, latent_tensor, chunk_size=64): |
| | total_bsz = latent_tensor.shape[0] |
| | recon_chunks_on_cpu = [] |
| | with torch.no_grad(): |
| | for i in range(0, total_bsz, chunk_size): |
| | end_idx = min(i + chunk_size, total_bsz) |
| | latent_chunk = latent_tensor[i:end_idx] |
| | recon_chunk = self.vae.decode(latent_chunk) |
| | recon_chunks_on_cpu.append(recon_chunk.cpu()) |
| | return torch.cat(recon_chunks_on_cpu, dim=0) |
| |
|
| | def get_fsdp_wrap_module_list(self): |
| | return list(self.layers) |
| |
|
| | def BitDance_H(**kwargs): |
| | return BitDance( |
| | n_layer=40, |
| | n_head=20, |
| | dim=1280, |
| | diff_layers=12, |
| | diff_dim=1280, |
| | diff_adanln_layers=3, |
| | **kwargs, |
| | ) |
| |
|
| |
|
| | def BitDance_L(**kwargs): |
| | return BitDance( |
| | n_layer=32, |
| | n_head=16, |
| | dim=1024, |
| | diff_layers=8, |
| | diff_dim=1024, |
| | diff_adanln_layers=2, |
| | **kwargs, |
| | ) |
| |
|
| |
|
| | def BitDance_B(**kwargs): |
| | return BitDance( |
| | n_layer=24, |
| | n_head=12, |
| | dim=768, |
| | diff_layers=6, |
| | diff_dim=768, |
| | diff_adanln_layers=2, |
| | **kwargs, |
| | ) |
| |
|
| |
|
| | BitDance_models = { |
| | "BitDance-B": BitDance_B, |
| | "BitDance-L": BitDance_L, |
| | "BitDance-H": BitDance_H, |
| | } |
| |
|