Spaces:
Running on Zero
Running on Zero
| import math | |
| import torch | |
| from typing import Optional | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class MLP(torch.nn.Module): | |
| def __init__( | |
| self, | |
| in_features:int, | |
| hidden_features:Optional[int]=None, | |
| out_features:Optional[int]=None, | |
| act_layer=nn.GELU, | |
| norm_layer=None, | |
| bias=True, | |
| drop=0., | |
| ): | |
| super().__init__() | |
| hidden_features = hidden_features or in_features | |
| out_features = out_features or in_features | |
| self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) | |
| self.act = act_layer() | |
| self.drop1 = nn.Dropout(drop) | |
| self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() | |
| self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) | |
| self.drop2 = nn.Dropout(drop) | |
| def forward(self, x): | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| x = self.drop1(x) | |
| x = self.norm(x) | |
| x = self.fc2(x) | |
| x = self.drop2(x) | |
| return x | |
| class Attention(torch.nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int = 8, | |
| head_dim: int = 64, | |
| qkv_bias: bool = False, | |
| qk_norm: bool = False, | |
| attn_drop: float = 0., | |
| proj_drop: float = 0., | |
| norm_layer: nn.Module = nn.LayerNorm, | |
| ) -> None: | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.head_dim = head_dim | |
| self.inner_dim = num_heads * head_dim | |
| self.scale = head_dim ** -0.5 | |
| self.to_q = nn.Linear(dim, self.inner_dim, bias=qkv_bias) | |
| self.to_k = nn.Linear(dim, self.inner_dim, bias=qkv_bias) | |
| self.to_v = nn.Linear(dim, self.inner_dim, bias=qkv_bias) | |
| self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() | |
| self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| self.proj = nn.Linear(self.inner_dim, dim) | |
| def to_heads(self, ts:torch.Tensor): | |
| b, t, c = ts.shape | |
| # (b, t, nh, c) | |
| ts = ts.reshape(b, t, self.num_heads, c // self.num_heads) | |
| ts = ts.transpose(1, 2) | |
| return ts | |
| def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: | |
| """Args: | |
| x(torch.Tensor): shape (b, t, c) | |
| attn_mask(torch.Tensor): shape (b, t, t) | |
| """ | |
| b, t, c = x.shape | |
| q = self.to_q(x) | |
| k = self.to_k(x) | |
| v = self.to_v(x) | |
| q = self.to_heads(q) # (b, nh, t, c) | |
| k = self.to_heads(k) | |
| v = self.to_heads(v) | |
| q = self.q_norm(q) | |
| k = self.k_norm(k) | |
| if attn_mask is not None: | |
| attn_mask = attn_mask.unsqueeze(1) | |
| x = F.scaled_dot_product_attention( | |
| q, k, v, | |
| attn_mask=attn_mask, | |
| dropout_p=self.attn_drop.p if self.training else 0., | |
| ) # (b, nh, t, c) | |
| x = x.transpose(1, 2).reshape(b, t, -1) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| def modulate(x, shift, scale): | |
| return x * (1 + scale) + shift | |
| class TimestepEmbedder(nn.Module): | |
| """ | |
| Embeds scalar timesteps into vector representations. | |
| """ | |
| def __init__(self, hidden_size, frequency_embedding_size=256): | |
| super().__init__() | |
| self.mlp = nn.Sequential( | |
| nn.Linear(frequency_embedding_size, hidden_size, bias=True), | |
| nn.SiLU(), | |
| nn.Linear(hidden_size, hidden_size, bias=True), | |
| ) | |
| self.frequency_embedding_size = frequency_embedding_size | |
| # from SinusoidalPosEmb | |
| self.scale = 1000 | |
| def timestep_embedding(t, dim, max_period=10000): | |
| """ | |
| Create sinusoidal timestep embeddings. | |
| :param t: a 1-D Tensor of N indices, one per batch element. | |
| These may be fractional. | |
| :param dim: the dimension of the output. | |
| :param max_period: controls the minimum frequency of the embeddings. | |
| :return: an (N, D) Tensor of positional embeddings. | |
| """ | |
| # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py | |
| half = dim // 2 | |
| freqs = torch.exp( | |
| -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half | |
| ).to(device=t.device) | |
| args = t[:, None].float() * freqs[None] | |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| if dim % 2: | |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) | |
| return embedding | |
| def forward(self, t): | |
| t_freq = self.timestep_embedding(t * self.scale, self.frequency_embedding_size) | |
| t_emb = self.mlp(t_freq) | |
| return t_emb | |
| # Convolution related | |
| class Transpose(torch.nn.Module): | |
| def __init__(self, dim0: int, dim1: int): | |
| super().__init__() | |
| self.dim0 = dim0 | |
| self.dim1 = dim1 | |
| def forward(self, x: torch.Tensor): | |
| x = torch.transpose(x, self.dim0, self.dim1) | |
| return x | |
| class CausalConv1d(torch.nn.Conv1d): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| ) -> None: | |
| super(CausalConv1d, self).__init__(in_channels, out_channels, kernel_size) | |
| self.causal_padding = (kernel_size - 1, 0) | |
| def forward(self, x: torch.Tensor): | |
| x = F.pad(x, self.causal_padding) | |
| x = super(CausalConv1d, self).forward(x) | |
| return x | |
| class CausalConvBlock(nn.Module): | |
| def __init__(self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int = 3, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.kernel_size = kernel_size | |
| self.block = torch.nn.Sequential( | |
| # norm | |
| # conv1 | |
| Transpose(1, 2), | |
| CausalConv1d(in_channels, out_channels, kernel_size), | |
| Transpose(1, 2), | |
| # norm & act | |
| nn.LayerNorm(out_channels), | |
| nn.Mish(), | |
| # conv2 | |
| Transpose(1, 2), | |
| CausalConv1d(out_channels, out_channels, kernel_size), | |
| Transpose(1, 2), | |
| ) | |
| def forward(self, x: torch.Tensor, mask: torch.Tensor = None): | |
| """ | |
| Args: | |
| x: shape (b, t, c) | |
| mask: shape (b, t, 1) | |
| """ | |
| if mask is not None: x = x * mask | |
| x = self.block(x) | |
| if mask is not None: x = x * mask | |
| return x | |
| class DiTBlock(nn.Module): | |
| """ | |
| A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. | |
| """ | |
| def __init__(self, hidden_size, num_heads, head_dim, mlp_ratio=4.0, **block_kwargs): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, qk_norm=True, **block_kwargs) | |
| self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| mlp_hidden_dim = int(hidden_size * mlp_ratio) | |
| approx_gelu = lambda: nn.GELU(approximate="tanh") | |
| self.mlp = MLP(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) | |
| self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| self.conv = CausalConvBlock(in_channels=hidden_size, out_channels=hidden_size, kernel_size=3) | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(hidden_size, 9 * hidden_size, bias=True) | |
| ) | |
| def forward(self, x:torch.Tensor, c:torch.Tensor, attn_mask:torch.Tensor=None, conv_mask:torch.Tensor=None): | |
| """Args | |
| x: shape (b, t, c) | |
| c: shape (b, 1, c) | |
| attn_mask: shape (b, t, t), bool type attention mask | |
| conv_mask: shape (b, 1, t), bool type non-pad mask | |
| """ | |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_conv, scale_conv, gate_conv \ | |
| = self.adaLN_modulation(c).chunk(9, dim=-1) | |
| # attention | |
| x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), attn_mask=attn_mask) | |
| # conv | |
| x = x + gate_conv * self.conv(modulate(self.norm3(x), shift_conv, scale_conv), mask=conv_mask) | |
| # mlp | |
| x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) | |
| return x | |
| class FinalLayer(nn.Module): | |
| """ | |
| The final layer of DiT. | |
| """ | |
| def __init__(self, hidden_size, out_channels): | |
| super().__init__() | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(hidden_size, 2 * hidden_size, bias=True) | |
| ) | |
| self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| self.linear = nn.Linear(hidden_size, out_channels, bias=True) | |
| def forward(self, x, c): | |
| shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) | |
| x = modulate(self.norm_final(x), shift, scale) | |
| x = self.linear(x) | |
| return x | |
| class DiT(nn.Module): | |
| """ | |
| Diffusion model with a Transformer backbone. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| mlp_ratio: float = 4.0, | |
| depth: int = 28, | |
| num_heads: int = 8, | |
| head_dim: int = 64, | |
| hidden_size: int = 256, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.t_embedder = TimestepEmbedder(hidden_size) | |
| self.in_proj = nn.Linear(in_channels, hidden_size) | |
| self.blocks = nn.ModuleList([ | |
| DiTBlock(hidden_size, num_heads, head_dim, mlp_ratio=mlp_ratio) for _ in range(depth) | |
| ]) | |
| self.final_layer = FinalLayer(hidden_size, self.out_channels) | |
| self.initialize_weights() | |
| def initialize_weights(self): | |
| # Initialize transformer layers: | |
| def _basic_init(module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias, 0) | |
| self.apply(_basic_init) | |
| # Initialize timestep embedding MLP: | |
| nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) | |
| nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) | |
| # Zero-out adaLN modulation layers in DiT blocks: | |
| for block in self.blocks: | |
| nn.init.constant_(block.adaLN_modulation[-1].weight, 0) | |
| nn.init.constant_(block.adaLN_modulation[-1].bias, 0) | |
| # Zero-out output layers: | |
| nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) | |
| nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) | |
| nn.init.constant_(self.final_layer.linear.weight, 0) | |
| nn.init.constant_(self.final_layer.linear.bias, 0) | |
| """For non-streaming inference. | |
| """ | |
| def forward(self, x:torch.Tensor, c:torch.Tensor, t:torch.Tensor, attn_mask:torch.Tensor=None, conv_mask:torch.Tensor=None): | |
| """ | |
| Args: | |
| x: shape (b, c, t) | |
| c: aux condition, shape (b, c, t) | |
| t: shape (b,) | |
| attn_mask: (b, t, t) | |
| conv_mask: (b, 1, t) | |
| Returns: | |
| pred: shape (b, c, t) | |
| """ | |
| # time | |
| t = self.t_embedder(t.view(-1)).unsqueeze(1) # (b, 1, c) | |
| # CausalConvBlock mask is (b, t, 1) | |
| conv_mask = conv_mask if conv_mask is None else conv_mask.transpose(1, 2) | |
| x = torch.cat([x, c], dim=1) | |
| # forward blocks | |
| x = x.transpose(1, 2) | |
| x = self.in_proj(x) | |
| for block in self.blocks: | |
| x = block(x, t, attn_mask=attn_mask, conv_mask=conv_mask) | |
| x = self.final_layer(x, t) | |
| x = x.transpose(1, 2) | |
| return x | |