import math from typing import Dict, Optional import numpy as np import torch import torch.nn as nn from einops import rearrange, repeat import torch, math from torch import nn def prob_mask_like(shape, prob, device): if prob == 1: return torch.ones(shape, device = device, dtype = torch.bool) elif prob == 0: return torch.zeros(shape, device = device, dtype = torch.bool) else: return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob def attention(q, k, v, heads, mask=None): """Convenience wrapper around a basic attention operation""" b, _, dim_head = q.shape dim_head //= heads q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v)) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) return out.transpose(1, 2).reshape(b, -1, heads * dim_head) class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks""" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, dtype=None, device=None): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device) self.act = act_layer self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.fc2(x) return x def build_mlp(hidden_size, projector_dim, z_dim): return nn.Sequential( nn.Conv1d(in_channels=816, out_channels=416, kernel_size=1), nn.SiLU(), nn.Linear(hidden_size, projector_dim), nn.SiLU(), nn.Linear(projector_dim, projector_dim), nn.SiLU(), nn.Linear(projector_dim, z_dim), ) class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding""" def __init__( self, img_size: Optional[int] = 224, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, flatten: bool = True, bias: bool = True, strict_img_size: bool = True, dynamic_img_pad: bool = False, dtype=None, device=None, ): super().__init__() self.patch_size = patch_size if img_size is not None: self.img_size = img_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] else: self.img_size = None self.grid_size = None self.num_patches = None # flatten spatial dim and transpose to channels last, kept for bwd compat self.flatten = flatten self.strict_img_size = strict_img_size self.dynamic_img_pad = dynamic_img_pad self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device) def forward(self, x): B, C, H, W = x.shape x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # NCHW -> NLC return x def modulate(x, shift, scale): if shift is None: shift = torch.zeros_like(scale) return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) ################################################################################# # Sine/Cosine Positional Embedding Functions # ################################################################################# def get_2d_sincos_pos_embed(embed_dim, grid_size_1, grid_size_2, cls_token=False, extra_tokens=0, scaling_factor=None, offset=None): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_size_1, dtype=np.float32) grid_w = np.arange(grid_size_2, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) if scaling_factor is not None: grid = grid / scaling_factor if offset is not None: grid = grid - offset grid = grid.reshape([2, 1, grid_size_1, grid_size_2]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token and extra_tokens > 0: pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) ################################################################################# # Embedding Layers for Timesteps and Class Labels # ################################################################################# class TimestepEmbedder(nn.Module): """Embeds scalar timesteps into vector representations.""" def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) if torch.is_floating_point(t): embedding = embedding.to(dtype=t.dtype) return embedding def forward(self, t, dtype, **kwargs): t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype) t_emb = self.mlp(t_freq) return t_emb class VectorEmbedder(nn.Module): """Embeds a flat vector of dimension input_dim""" def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None): super().__init__() self.mlp = nn.Sequential( nn.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.mlp(x) ################################################################################# # Core DiT Model # ################################################################################# def split_qkv(qkv, head_dim): qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0) return qkv[0], qkv[1], qkv[2] def optimized_attention(qkv, num_heads): return attention(qkv[0], qkv[1], qkv[2], num_heads) class SelfAttention(nn.Module): ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug") def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, qk_scale: Optional[float] = None, attn_mode: str = "xformers", pre_only: bool = False, qk_norm: Optional[str] = None, rmsnorm: bool = False, dtype=None, device=None, ): super().__init__() self.num_heads = num_heads self.head_dim = dim // num_heads self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) if not pre_only: self.proj = nn.Linear(dim, dim, dtype=dtype, device=device) assert attn_mode in self.ATTENTION_MODES self.attn_mode = attn_mode self.pre_only = pre_only if qk_norm == "rms": self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device) self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device) elif qk_norm == "ln": self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device) self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device) elif qk_norm is None: self.ln_q = nn.Identity() self.ln_k = nn.Identity() else: raise ValueError(qk_norm) def pre_attention(self, x: torch.Tensor): B, L, C = x.shape qkv = self.qkv(x) q, k, v = split_qkv(qkv, self.head_dim) q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1) k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1) return (q, k, v) def post_attention(self, x: torch.Tensor) -> torch.Tensor: assert not self.pre_only x = self.proj(x) return x def forward(self, x: torch.Tensor) -> torch.Tensor: (q, k, v) = self.pre_attention(x) x = attention(q, k, v, self.num_heads) x = self.post_attention(x) return x class RMSNorm(torch.nn.Module): def __init__( self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None ): """ Initialize the RMSNorm normalization layer. Args: dim (int): The dimension of the input tensor. eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. Attributes: eps (float): A small value added to the denominator for numerical stability. weight (nn.Parameter): Learnable scaling parameter. """ super().__init__() self.eps = eps self.learnable_scale = elementwise_affine if self.learnable_scale: self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) else: self.register_parameter("weight", None) def _norm(self, x): """ Apply the RMSNorm normalization to the input tensor. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The normalized tensor. """ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): """ Forward pass through the RMSNorm layer. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The output tensor after applying RMSNorm. """ x = self._norm(x) if self.learnable_scale: return x * self.weight.to(device=x.device, dtype=x.dtype) else: return x class SwiGLUFeedForward(nn.Module): def __init__( self, dim: int, hidden_dim: int, multiple_of: int, ffn_dim_multiplier: Optional[float] = None, ): """ Initialize the FeedForward module. Args: dim (int): Input dimension. hidden_dim (int): Hidden dimension of the feedforward layer. multiple_of (int): Value to ensure hidden dimension is a multiple of this value. ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None. Attributes: w1 (ColumnParallelLinear): Linear transformation for the first layer. w2 (RowParallelLinear): Linear transformation for the second layer. w3 (ColumnParallelLinear): Linear transformation for the third layer. """ super().__init__() hidden_dim = int(2 * hidden_dim / 3) # custom dim factor multiplier if ffn_dim_multiplier is not None: hidden_dim = int(ffn_dim_multiplier * hidden_dim) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False) def forward(self, x): return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) class DismantledBlock(nn.Module): """A DiT block with gated adaptive layer norm (adaLN) conditioning.""" ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug") def __init__( self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, attn_mode: str = "xformers", qkv_bias: bool = False, pre_only: bool = False, rmsnorm: bool = False, scale_mod_only: bool = False, swiglu: bool = False, qk_norm: Optional[str] = None, dtype=None, device=None, **block_kwargs, ): super().__init__() assert attn_mode in self.ATTENTION_MODES if not rmsnorm: self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) else: self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, attn_mode=attn_mode, pre_only=pre_only, qk_norm=qk_norm, rmsnorm=rmsnorm, dtype=dtype, device=device) if not pre_only: if not rmsnorm: self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) else: self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) if not pre_only: if not swiglu: self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=nn.GELU(approximate="tanh"), dtype=dtype, device=device) else: self.mlp = SwiGLUFeedForward(dim=hidden_size, hidden_dim=mlp_hidden_dim, multiple_of=256) self.scale_mod_only = scale_mod_only if not scale_mod_only: n_mods = 6 if not pre_only else 2 else: n_mods = 4 if not pre_only else 1 self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device)) self.pre_only = pre_only def pre_attention(self, x: torch.Tensor, c: torch.Tensor): assert x is not None, "pre_attention called with None input" if not self.pre_only: if not self.scale_mod_only: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) else: shift_msa = None shift_mlp = None scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(4, dim=1) qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp) else: if not self.scale_mod_only: shift_msa, scale_msa = self.adaLN_modulation(c).chunk(2, dim=1) else: shift_msa = None scale_msa = self.adaLN_modulation(c) qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) return qkv, None def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp): assert not self.pre_only x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn) x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) return x def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: assert not self.pre_only (q, k, v), intermediates = self.pre_attention(x, c) attn = attention(q, k, v, self.attn.num_heads) return self.post_attention(attn, *intermediates) def block_mixing(context, x, context_block, x_block, c): assert context is not None, "block_mixing called with None context" context_qkv, context_intermediates = context_block.pre_attention(context, c) x_qkv, x_intermediates = x_block.pre_attention(x, c) o = [] for t in range(3): o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=1)) q, k, v = tuple(o) attn = attention(q, k, v, x_block.attn.num_heads) context_attn, x_attn = (attn[:, : context_qkv[0].shape[1]], attn[:, context_qkv[0].shape[1] :]) if not context_block.pre_only: context = context_block.post_attention(context_attn, *context_intermediates) else: context = None x = x_block.post_attention(x_attn, *x_intermediates) return context, x class JointBlock(nn.Module): """just a small wrapper to serve as a fsdp unit""" def __init__(self, *args, **kwargs): super().__init__() pre_only = kwargs.pop("pre_only") qk_norm = kwargs.pop("qk_norm", None) self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs) self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs) def forward(self, *args, **kwargs): return block_mixing(*args, context_block=self.context_block, x_block=self.x_block, **kwargs) class FinalLayer(nn.Module): """ The final layer of DiT. """ def __init__(self, hidden_size: int, patch_size, out_channels: int, total_out_channels: Optional[int] = None, dtype=None, device=None): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.linear = ( nn.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True, dtype=dtype, device=device) if (total_out_channels is None) else nn.Linear(hidden_size, total_out_channels, bias=True, dtype=dtype, device=device) ) self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)) def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return x class MMDiT(nn.Module): """Diffusion model with a Transformer backbone.""" def __init__( self, input_size=(204, 16), patch_size=(2, 2), in_channels: int = 8, depth: int = 12, mlp_ratio: float = 4.0, learn_sigma: bool = False, adm_in_channels: Optional[int] = None, context_embedder_config: Optional[Dict] = None, register_length: int = 0, attn_mode: str = "torch", rmsnorm: bool = False, scale_mod_only: bool = False, swiglu: bool = False, out_channels: Optional[int] = None, pos_embed_scaling_factor: Optional[float] = None, pos_embed_offset: Optional[float] = None, pos_embed_max_size: Optional[int] = None, num_patches = None, qk_norm: Optional[str] = None, qkv_bias: bool = True, dtype = None, device = None, encoder_depth = 4, z_dims=[768], projector_dim=2048, ): super().__init__() print(f"mmdit initializing with: {input_size=}, {patch_size=}, {in_channels=}, {depth=}, {mlp_ratio=}, {learn_sigma=}, {adm_in_channels=}, {context_embedder_config=}, {register_length=}, {attn_mode=}, {rmsnorm=}, {scale_mod_only=}, {swiglu=}, {out_channels=}, {pos_embed_scaling_factor=}, {pos_embed_offset=}, {pos_embed_max_size=}, {num_patches=}, {qk_norm=}, {qkv_bias=}, {dtype=}, {device=}") self.dtype = dtype self.learn_sigma = learn_sigma self.in_channels = in_channels default_out_channels = in_channels * 2 if learn_sigma else in_channels self.out_channels = out_channels if out_channels is not None else default_out_channels self.patch_size = patch_size self.pos_embed_scaling_factor = pos_embed_scaling_factor self.pos_embed_offset = pos_embed_offset self.pos_embed_max_size = pos_embed_max_size = (102, 8) # apply magic --> this defines a head_size of 64 hidden_size = 64 * depth # hidden_size = 32 * depth num_heads = depth self.num_heads = num_heads self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True, strict_img_size=self.pos_embed_max_size is None, dtype=dtype, device=device) self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device) if adm_in_channels is not None: assert isinstance(adm_in_channels, int) self.y_embedder = VectorEmbedder(adm_in_channels, hidden_size, dtype=dtype, device=device) else: self.y_embedder = None self.context_embedder = nn.Identity() # TODO: hand coded context_embedder_config = {"params": {"in_features": 512, "out_features": hidden_size}, "target": "torch.nn.Linear"} if context_embedder_config is not None: if context_embedder_config["target"] == "torch.nn.Linear": self.context_embedder = nn.Linear(**context_embedder_config["params"], dtype=dtype, device=device) self.register_length = register_length if self.register_length > 0: self.register = nn.Parameter(torch.randn(1, register_length, hidden_size, dtype=dtype, device=device)) num_patches = self.x_embedder.num_patches # Will use fixed sin-cos embedding: # just use a buffer already if num_patches is not None: self.register_buffer( "pos_embed", torch.zeros(1, num_patches, hidden_size, dtype=dtype, device=device), ) else: self.pos_embed = None self.joint_blocks = nn.ModuleList( [ JointBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, attn_mode=attn_mode, pre_only=i == depth - 1, rmsnorm=rmsnorm, scale_mod_only=scale_mod_only, swiglu=swiglu, qk_norm=qk_norm, dtype=dtype, device=device) for i in range(depth) ] ) self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, dtype=dtype, device=device) # REPA self.encoder_depth = encoder_depth self.projectors = nn.ModuleList([ build_mlp(hidden_size, projector_dim, z_dim) for z_dim in z_dims ]) # Initialize (and freeze) pos_embed by sin-cos embedding: grid_size_1 = 102 grid_size_2 = 8 pos_embed = get_2d_sincos_pos_embed( self.pos_embed.shape[-1], grid_size_1, grid_size_2 ) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) self.initialize_weights() def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize timestep embedding MLP: nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) nn.init.normal_(self.context_embedder.weight, std=0.02) if self.y_embedder is not None: nn.init.normal_(self.y_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.y_embedder.mlp[2].weight, std=0.02) # Zero-out adaLN modulation layers in DiT blocks: for block in self.joint_blocks: nn.init.constant_(block.context_block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.context_block.adaLN_modulation[-1].bias, 0) nn.init.constant_(block.x_block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.x_block.adaLN_modulation[-1].bias, 0) # Zero-out output layers: nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) def cropped_pos_embed(self, hw): assert self.pos_embed_max_size is not None p1, p2 = self.x_embedder.patch_size h, w = hw # patched size h = h // p1 w = w // p2 assert h <= self.pos_embed_max_size[0], (h, self.pos_embed_max_size[0]) assert w <= self.pos_embed_max_size[1], (w, self.pos_embed_max_size[1]) top = (self.pos_embed_max_size[0] - h) // 2 left = (self.pos_embed_max_size[1] - w) // 2 spatial_pos_embed = rearrange( self.pos_embed, "1 (h w) c -> 1 h w c", h=self.pos_embed_max_size[0], w=self.pos_embed_max_size[1], ) spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c") return spatial_pos_embed def unpatchify(self, x, hw=None): """ x: (N, T, patch_size**2 * C) imgs: (N, H, W, C) """ c = self.out_channels p1, p2 = self.x_embedder.patch_size h, w = hw # patched size h = h // p1 w = w // p2 assert h * w == x.shape[1] h_1, w_1 = self.x_embedder.img_size x = x.reshape(shape=(x.shape[0], h, w, p1, p2, c)) x = torch.einsum('nhwpqc->nchpwq', x) imgs = x.reshape(shape=(x.shape[0], c, h_1, w_1)) return imgs def forward_core_with_concat( self, x: torch.Tensor, c_mod: torch.Tensor, context: Optional[torch.Tensor] = None, detach: Optional[bool] = False) -> torch.Tensor: if self.register_length > 0: context = torch.cat((repeat(self.register, "1 ... -> b ...", b=x.shape[0]), context if context is not None else torch.Tensor([]).type_as(x)), 1) # context is B, L', D # x is B, L, D B, L, D = x.shape for i, block in enumerate(self.joint_blocks): context, x = block(context, x, c=c_mod) if (i + 1) == self.encoder_depth: zs = [projector(x) for projector in self.projectors] x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels) return x, zs def forward( self, x: torch.Tensor, t: torch.Tensor, y: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None, do_guidance=False, detach: Optional[bool] = False) -> torch.Tensor: """ Forward pass of DiT. x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) t: (N,) tensor of diffusion timesteps y: (N,) tensor of class labels """ hw = x.shape[-2:] x = self.x_embedder(x) + self.cropped_pos_embed(hw) c = self.t_embedder(t, dtype=x.dtype) # (N, D) context = self.context_embedder(context) if self.training and not do_guidance: cond_mask = prob_mask_like((context.shape[0],), prob = 1 - 0.1, device = context.device) # classifier free guidance cond_mask = cond_mask.to(context.dtype) context = cond_mask.view(-1, 1, 1) * context elif do_guidance: N = x.shape[0] half_bs = N // 2 cond_mask = torch.cat((torch.ones(half_bs), torch.zeros(N - half_bs))).to(context.device) cond_mask = cond_mask.to(context.dtype) context = cond_mask.view(-1, 1, 1) * context else: cond_mask = torch.ones(context.shape[0], device = context.device, dtype = torch.bool) if y is not None: y = self.y_embedder(y) y = cond_mask.view(-1, 1) * y c = c + y # (N, D) x, zs = self.forward_core_with_concat(x, c, context, detach) x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W) return x, zs