Sample-Audio / sam_audio /model /transformer.py
ray-006's picture
Upload 43 files
fc605f9 verified
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
import math
from functools import partial
from typing import List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .config import TransformerConfig
from .patcher import Patcher
from .rope import RotaryEmbedding
def gate(x, gate):
return x * gate
def modulate(x, shift, scale):
return x * (1 + scale) + shift
def get_nonlinearity(kind: str):
return {
"relu": F.relu,
"gelu": F.gelu,
"swiglu": None,
"approx_gelu": partial(F.gelu, approximate="tanh"),
"srelu": lambda x: F.relu(x) ** 2,
"silu": F.silu,
}[kind]
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float())
return (output * self.weight).type_as(x)
class ProjectionLayer(torch.nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
non_linearity: str,
dropout: float,
fc_bias: bool = False,
):
super().__init__()
self.swiglu = non_linearity == "swiglu"
self.dropout = dropout
self.w1 = torch.nn.Linear(in_dim, out_dim, bias=fc_bias)
self.w2 = torch.nn.Linear(out_dim, out_dim, bias=fc_bias)
if self.swiglu:
self.w3 = torch.nn.Linear(in_dim, out_dim, bias=fc_bias)
# non-linearity
self.non_linearity = get_nonlinearity(non_linearity)
def forward(self, x):
hidden1 = self.w1(x)
if self.swiglu:
hidden3 = self.w3(x)
hidden = F.silu(hidden1) * hidden3
else:
hidden = self.non_linearity(hidden1)
hidden = F.dropout(hidden, p=self.dropout, training=self.training)
return self.w2(hidden)
class Attention(nn.Module):
def __init__(
self,
dim: int,
head_dim: int,
n_heads: int,
n_kv_heads: int,
norm_eps: float = 1e-5,
use_qk_norm: bool = False,
fc_bias: bool = False,
):
super().__init__()
assert n_heads % n_kv_heads == 0
self.head_dim = head_dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.use_qk_norm = use_qk_norm
self.wq = torch.nn.Linear(dim, n_heads * head_dim, bias=fc_bias)
self.wk, self.wv = [
torch.nn.Linear(
dim,
n_kv_heads * head_dim,
bias=fc_bias,
)
for _ in range(2)
]
self.wo = torch.nn.Linear(
n_heads * head_dim,
dim,
bias=fc_bias,
)
if self.use_qk_norm is True:
self.q_norm = RMSNorm(head_dim, eps=norm_eps)
self.k_norm = RMSNorm(head_dim, eps=norm_eps)
def reshape_heads(self, x: torch.Tensor, heads: int) -> torch.Tensor:
B, T, C = x.shape
# B x T x C -> B x T x C/H x H
x = x.reshape(B, T, C // heads, heads)
# B x T x C/H x H -> B x H x T x C/H
return x.permute(0, 3, 1, 2)
def forward(
self,
x: torch.Tensor,
cross_x: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
rope: Optional[RotaryEmbedding] = None,
):
# x: B, T, E
xq = self.wq(x)
if cross_x is not None:
xk, xv = self.wk(cross_x), self.wv(cross_x)
else:
xk, xv = self.wk(x), self.wv(x)
xk = self.reshape_heads(xk, self.n_kv_heads)
xv = self.reshape_heads(xv, self.n_kv_heads)
xq = self.reshape_heads(xq, self.n_heads)
if self.use_qk_norm:
xq = self.q_norm(xq)
xk = self.k_norm(xk)
if rope is not None:
xq = rope(xq, bhle=True)
xk = rope(xk, bhle=True)
attn_mask = None
if key_padding_mask is not None:
attn_mask = key_padding_mask[:, None, None, :]
output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=attn_mask)
output = rearrange(output, "b h n d -> b n (h d)")
return self.wo(output)
class FeedForward(torch.nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
ffn_dim_multiplier: float,
multiple_of: int,
dropout: float,
non_linearity: str = "swiglu",
fc_bias: bool = False,
):
super().__init__()
self.dropout = dropout
self.swiglu = non_linearity == "swiglu"
# swiglu hidden dim factor multiplier (same #params as relu / gelu)
if self.swiglu:
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
# round hidden dimension to `multiple_of`
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
# layers
self.w1 = torch.nn.Linear(dim, hidden_dim, bias=fc_bias)
self.w2 = torch.nn.Linear(hidden_dim, dim, bias=fc_bias)
if self.swiglu:
self.w3 = torch.nn.Linear(dim, hidden_dim, bias=fc_bias)
# non-linearity
self.non_linearity = get_nonlinearity(non_linearity)
def forward(
self,
x,
):
hidden1 = self.w1(x)
if self.swiglu:
hidden3 = self.w3(x)
hidden = F.silu(hidden1) * hidden3
else:
hidden = self.non_linearity(hidden1)
hidden = F.dropout(hidden, p=self.dropout, training=self.training)
return self.w2(hidden)
class TimestepEmbedder(torch.nn.Module):
def __init__(
self,
dim: int,
frequency_embedding_dim: int,
non_linearity: str,
dropout: float,
fc_bias: bool,
max_period: int = 10000,
):
super().__init__()
self.frequency_embedding_size = frequency_embedding_dim
self.projection = ProjectionLayer(
in_dim=frequency_embedding_dim,
out_dim=dim,
non_linearity=non_linearity,
dropout=dropout,
fc_bias=fc_bias,
)
half = frequency_embedding_dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
)
self.register_buffer("freqs", freqs, persistent=False)
def timestep_embedding(self, t, dim):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
self.freqs = self.freqs.to(device=t.device)
args = t[:, None].float() * self.freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding.to(t)
def forward(self, t):
x = self.timestep_embedding(t, self.frequency_embedding_size)
return self.projection(x)
class ContextEmbedder(torch.nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
non_linearity: str,
dropout: float,
fc_bias: bool,
norm_eps: float = 1e-5,
context_norm: bool = False,
):
super().__init__()
self.context_norm = context_norm
if context_norm:
self.norm = RMSNorm(in_dim, norm_eps)
self.projection = ProjectionLayer(
in_dim=in_dim,
out_dim=out_dim,
non_linearity=non_linearity,
dropout=dropout,
fc_bias=fc_bias,
)
def forward(self, x):
if self.context_norm:
x = self.norm(x)
h = self.projection(x)
return h
class DiTBlock(torch.nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: Optional[int] = None,
dropout: float = 0.0,
norm_eps: float = 1e-5,
qk_norm: bool = False,
fc_bias: bool = False,
ffn_exp: int = 1,
ffn_dim_multiplier: int = 4,
multiple_of: int = 64,
non_linearity: str = "silu",
no_cross_attention: bool = False,
):
super().__init__()
assert dim % n_heads == 0
self.n_heads = n_heads
self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
self.dim = dim
self.dropout = dropout
self.head_dim = dim // n_heads
assert self.n_heads % self.n_kv_heads == 0
self.attention = Attention(
dim=dim,
head_dim=self.head_dim,
n_heads=self.n_heads,
n_kv_heads=self.n_kv_heads,
norm_eps=norm_eps,
use_qk_norm=qk_norm,
fc_bias=fc_bias,
)
self.feed_forward = FeedForward(
dim=dim,
hidden_dim=int(ffn_exp * dim),
ffn_dim_multiplier=ffn_dim_multiplier,
multiple_of=multiple_of,
dropout=dropout,
non_linearity=non_linearity,
fc_bias=fc_bias,
)
self.attention_norm, self.ffn_norm = [RMSNorm(dim, norm_eps) for _ in range(2)]
self.cross_attention = None
if not no_cross_attention:
self.cross_attention = Attention(
dim=dim,
head_dim=self.head_dim,
n_heads=self.n_heads,
n_kv_heads=self.n_heads,
norm_eps=norm_eps,
use_qk_norm=qk_norm,
fc_bias=fc_bias,
)
self.scale_shift_table = nn.Parameter(
torch.randn(6, self.dim) / self.dim**0.5,
)
def forward(
self,
x: torch.Tensor,
cross_x: Optional[torch.Tensor],
t: torch.Tensor,
padding_mask: Optional[torch.Tensor],
memory_padding_mask: Optional[torch.Tensor],
rope: Optional[RotaryEmbedding] = None,
):
biases = self.scale_shift_table[None] + t.reshape(x.size(0), 6, -1)
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
) = biases.chunk(6, dim=1)
assert self.attention is not None and self.attention_norm is not None
h_attn = self.attention(
modulate(self.attention_norm(x), shift_msa, scale_msa),
key_padding_mask=padding_mask,
rope=rope,
)
h = x + gate(h_attn, gate_msa)
if self.cross_attention is not None:
h_cross = self.cross_attention(
x=h,
cross_x=cross_x,
key_padding_mask=memory_padding_mask,
)
h = h + h_cross # residual
h_ff = self.feed_forward(modulate(self.ffn_norm(h), shift_mlp, scale_mlp))
out = h + gate(h_ff, gate_mlp)
return out
class DiT(torch.nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.dropout = config.dropout
if config.in_channels is not None:
self.data_proj = torch.nn.Linear(config.in_channels, config.dim)
# embeddings
self.rope_embeddings = None
# rotary embeddings
if config.use_rope:
self.rope_embeddings = RotaryEmbedding(
theta=max(10000, 2 * config.max_positions),
head_dim=config.dim // config.n_heads,
max_seqlen=config.max_positions,
)
self.rope_embeddings.reset_parameters()
# transformer blocks
self.layers = nn.ModuleList()
for _ in range(config.n_layers):
self.layers.append(
DiTBlock(
dim=config.dim,
n_heads=config.n_heads,
dropout=config.dropout,
norm_eps=config.norm_eps,
qk_norm=config.qk_norm,
fc_bias=config.fc_bias,
ffn_exp=config.ffn_exp,
ffn_dim_multiplier=config.ffn_dim_multiplier,
multiple_of=config.multiple_of,
non_linearity=config.non_linearity,
)
)
self.norm = RMSNorm(config.dim, config.norm_eps)
# output layer
self.output = torch.nn.Linear(
config.dim, config.out_channels, bias=config.fc_bias
)
self.x_embedder = Patcher(
in_channels=config.dim,
out_channels=config.dim,
patch_size=1,
)
self.y_embedder = ContextEmbedder(
in_dim=config.context_dim,
out_dim=config.dim,
non_linearity=config.context_non_linearity,
dropout=config.context_embedder_dropout,
fc_bias=config.fc_bias,
norm_eps=config.norm_eps,
context_norm=config.context_norm,
)
self.t_embedder = TimestepEmbedder(
config.dim,
config.frequency_embedding_dim,
non_linearity=config.timestep_non_linearity,
dropout=config.dropout,
fc_bias=config.fc_bias,
max_period=10000,
)
self.t_block_non_linearity = get_nonlinearity(config.t_block_non_linearity)
self.t_block = torch.nn.Linear(
config.dim,
config.dim * 6,
bias=config.t_block_bias,
)
self.final_layer_scale_shift_table = nn.Parameter(
torch.randn(2, config.dim) / config.dim**0.5,
)
def forward(
self,
x: torch.Tensor,
time: torch.Tensor,
*,
padding_mask: Optional[torch.Tensor] = None,
memory: Optional[torch.Tensor] = None,
memory_padding_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
x = rearrange(x, "b l c-> b c l")
h = self.x_embedder(x)
h = rearrange(h, "b c l -> b l c")
original_N = h.shape[1]
N = h.shape[1]
h = F.dropout(h, p=self.dropout, training=self.training)
t = self.t_embedder(time) # B -> B D
t0 = self.t_block_non_linearity(t)
t0 = self.t_block(t0) # B D -> B 6D
y = self.y_embedder(memory)
for layer in self.layers:
h = layer(
x=h,
cross_x=y,
t=t0,
padding_mask=padding_mask,
memory_padding_mask=memory_padding_mask,
rope=self.rope_embeddings,
)
shift, scale = (self.final_layer_scale_shift_table[None] + t[:, None]).chunk(
2, dim=1
)
# output layer
if self.norm is not None:
h = self.norm(h)
h = modulate(h, shift, scale)
h = F.dropout(h, p=self.dropout, training=self.training)
output = self.output(h)
N = output.shape[1]
if original_N != N:
output = output[:, -original_N:]
return output