Shen Feiyu
add 1s
faadabf
import math
import torch
from typing import Optional
import torch.nn as nn
import torch.nn.functional as F
class MLP(torch.nn.Module):
def __init__(
self,
in_features:int,
hidden_features:Optional[int]=None,
out_features:Optional[int]=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.,
):
super().__init__()
hidden_features = hidden_features or in_features
out_features = out_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.drop1 = nn.Dropout(drop)
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop2 = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class Attention(torch.nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
head_dim: int = 64,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = nn.LayerNorm,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.inner_dim = num_heads * head_dim
self.scale = head_dim ** -0.5
self.to_q = nn.Linear(dim, self.inner_dim, bias=qkv_bias)
self.to_k = nn.Linear(dim, self.inner_dim, bias=qkv_bias)
self.to_v = nn.Linear(dim, self.inner_dim, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj_drop = nn.Dropout(proj_drop)
self.proj = nn.Linear(self.inner_dim, dim)
def to_heads(self, ts:torch.Tensor):
b, t, c = ts.shape
# (b, t, nh, c)
ts = ts.reshape(b, t, self.num_heads, c // self.num_heads)
ts = ts.transpose(1, 2)
return ts
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor:
"""Args:
x(torch.Tensor): shape (b, t, c)
attn_mask(torch.Tensor): shape (b, t, t)
"""
b, t, c = x.shape
q = self.to_q(x)
k = self.to_k(x)
v = self.to_v(x)
q = self.to_heads(q) # (b, nh, t, c)
k = self.to_heads(k)
v = self.to_heads(v)
q = self.q_norm(q)
k = self.k_norm(k)
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(1)
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p if self.training else 0.,
) # (b, nh, t, c)
x = x.transpose(1, 2).reshape(b, t, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
# from SinusoidalPosEmb
self.scale = 1000
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t * self.scale, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
# Convolution related
class Transpose(torch.nn.Module):
def __init__(self, dim0: int, dim1: int):
super().__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x: torch.Tensor):
x = torch.transpose(x, self.dim0, self.dim1)
return x
class CausalConv1d(torch.nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
) -> None:
super(CausalConv1d, self).__init__(in_channels, out_channels, kernel_size)
self.causal_padding = (kernel_size - 1, 0)
def forward(self, x: torch.Tensor):
x = F.pad(x, self.causal_padding)
x = super(CausalConv1d, self).forward(x)
return x
class CausalConvBlock(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.block = torch.nn.Sequential(
# norm
# conv1
Transpose(1, 2),
CausalConv1d(in_channels, out_channels, kernel_size),
Transpose(1, 2),
# norm & act
nn.LayerNorm(out_channels),
nn.Mish(),
# conv2
Transpose(1, 2),
CausalConv1d(out_channels, out_channels, kernel_size),
Transpose(1, 2),
)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
"""
Args:
x: shape (b, t, c)
mask: shape (b, t, 1)
"""
if mask is not None: x = x * mask
x = self.block(x)
if mask is not None: x = x * mask
return x
class DiTBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, head_dim, mlp_ratio=4.0, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, qk_norm=True, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = MLP(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.conv = CausalConvBlock(in_channels=hidden_size, out_channels=hidden_size, kernel_size=3)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 9 * hidden_size, bias=True)
)
def forward(self, x:torch.Tensor, c:torch.Tensor, attn_mask:torch.Tensor=None, conv_mask:torch.Tensor=None):
"""Args
x: shape (b, t, c)
c: shape (b, 1, c)
attn_mask: shape (b, t, t), bool type attention mask
conv_mask: shape (b, 1, t), bool type non-pad mask
"""
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_conv, scale_conv, gate_conv \
= self.adaLN_modulation(c).chunk(9, dim=-1)
# attention
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), attn_mask=attn_mask)
# conv
x = x + gate_conv * self.conv(modulate(self.norm3(x), shift_conv, scale_conv), mask=conv_mask)
# mlp
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(self, hidden_size, out_channels):
super().__init__()
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class DiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
mlp_ratio: float = 4.0,
depth: int = 28,
num_heads: int = 8,
head_dim: int = 64,
hidden_size: int = 256,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.t_embedder = TimestepEmbedder(hidden_size)
self.in_proj = nn.Linear(in_channels, hidden_size)
self.blocks = nn.ModuleList([
DiTBlock(hidden_size, num_heads, head_dim, mlp_ratio=mlp_ratio) for _ in range(depth)
])
self.final_layer = FinalLayer(hidden_size, self.out_channels)
self.initialize_weights()
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
"""For non-streaming inference.
"""
def forward(self, x:torch.Tensor, c:torch.Tensor, t:torch.Tensor, attn_mask:torch.Tensor=None, conv_mask:torch.Tensor=None):
"""
Args:
x: shape (b, c, t)
c: aux condition, shape (b, c, t)
t: shape (b,)
attn_mask: (b, t, t)
conv_mask: (b, 1, t)
Returns:
pred: shape (b, c, t)
"""
# time
t = self.t_embedder(t.view(-1)).unsqueeze(1) # (b, 1, c)
# CausalConvBlock mask is (b, t, 1)
conv_mask = conv_mask if conv_mask is None else conv_mask.transpose(1, 2)
x = torch.cat([x, c], dim=1)
# forward blocks
x = x.transpose(1, 2)
x = self.in_proj(x)
for block in self.blocks:
x = block(x, t, attn_mask=attn_mask, conv_mask=conv_mask)
x = self.final_layer(x, t)
x = x.transpose(1, 2)
return x