LocalSong / model.py
Localsong's picture
Upload 15 files
d0831da verified
raw
history blame
17.9 kB
from typing import Tuple
import torch
import torch.nn as nn
import math
from einops import rearrange
from torch.nn.functional import scaled_dot_product_attention
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class PatchEmbed(nn.Module):
def __init__(
self,
in_channels=8,
embed_dim=1152,
bias=True,
patch_size=1,
):
super().__init__()
self.patch_h, self.patch_w = patch_size
self.patch_size = patch_size
self.proj = nn.Linear(in_channels * self.patch_h * self.patch_w, embed_dim, bias=bias)
self.in_channels = in_channels
self.embed_dim = embed_dim
def forward(self, latent):
x = rearrange(latent, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1=self.patch_h, p2=self.patch_w)
x = self.proj(x)
return x
class FinalLayer(nn.Module):
"""Final layer with configurable patch_size support"""
def __init__(self, hidden_size, out_channels=8, patch_size=1):
super().__init__()
self.patch_h, self.patch_w = patch_size
self.linear = nn.Linear(hidden_size, out_channels * self.patch_h * self.patch_w, bias=True)
self.out_channels = out_channels
self.patch_size = patch_size
def forward(self, x, target_height, target_width):
x = self.linear(x)
x = rearrange(x, 'b (h w) (c p1 p2) -> b c (h p1) (w p2)',
h=target_height, w=target_width,
p1=self.patch_h, p2=self.patch_w, c=self.out_channels)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
def precompute_freqs_cis_2d(dim: int, height: int, width: int, theta: float = 10000.0, scale=1.0):
if isinstance(scale, float):
scale = (scale, scale)
x_pos = torch.linspace(0, width * scale[0], width)
y_pos = torch.linspace(0, height * scale[1], height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
x_freqs = torch.outer(x_pos, freqs).float()
y_freqs = torch.outer(y_pos, freqs).float()
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1)
freqs_cis = freqs_cis.reshape(height * width, -1)
return freqs_cis
@torch.compiler.disable
def apply_rotary_emb_2d(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, None, :, :]
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = self.q_norm(q.contiguous())
k = self.k_norm(k.contiguous())
q, k = apply_rotary_emb_2d(q, k, freqs_cis=pos)
q = q.view(B, self.num_heads, -1, C // self.num_heads)
k = k.view(B, self.num_heads, -1, C // self.num_heads).contiguous()
v = v.view(B, self.num_heads, -1, C // self.num_heads).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.attn_drop.p if self.training else 0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class CrossAttention(nn.Module):
def __init__(
self,
dim: int,
context_dim: int,
num_heads: int,
qkv_bias: bool = False,
proj_drop: float = 0.0,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.kv_proj = nn.Linear(context_dim, dim * 2, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, context: torch.Tensor, context_mask: torch.Tensor = None) -> torch.Tensor:
B, N, C = x.shape
B_ctx, M, C_ctx = context.shape
q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
kv = self.kv_proj(context).reshape(B_ctx, M, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn_mask = None
if context_mask is not None:
attn_mask = torch.zeros(B, 1, 1, M, dtype=q.dtype, device=q.device)
attn_mask.masked_fill_(~context_mask.unsqueeze(1).unsqueeze(2), float('-inf'))
attn = scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=self.proj_drop.p if self.training else 0.0)
x = attn.permute(0, 2, 1, 3).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class DDTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, context_dim=None, is_encoder_block=False):
super().__init__()
self.hidden_size = hidden_size
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm_cross = RMSNorm(hidden_size, eps=1e-6) if context_dim else nn.Identity()
self.cross_attn = CrossAttention(hidden_size, context_dim, groups) if context_dim else None
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.is_encoder_block = is_encoder_block
if not is_encoder_block:
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None, context=None, context_mask=None, shared_adaLN=None):
if self.is_encoder_block:
adaLN_output = shared_adaLN(c)
else:
adaLN_output = self.adaLN_modulation(c)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = adaLN_output.chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
if self.cross_attn is not None and context is not None:
x = x + self.cross_attn(self.norm_cross(x), context=context, context_mask=context_mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class LocalSongModel(nn.Module):
def __init__(
self,
in_channels=8,
num_groups=16,
hidden_size=1024,
decoder_hidden_size=2048,
num_blocks=36,
patch_size=(16,1),
num_classes=2304,
max_tags=8,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.decoder_hidden_size = decoder_hidden_size
self.num_groups = num_groups
self.num_groups = num_groups
self.num_blocks = num_blocks
self.patch_size = patch_size
self.num_classes = num_classes
self.max_tags = max_tags
self.patch_h, self.patch_w = patch_size
self.x_embedder = PatchEmbed(
in_channels=in_channels,
embed_dim=decoder_hidden_size,
bias=True,
patch_size=patch_size
)
self.s_embedder = PatchEmbed(
in_channels=in_channels,
embed_dim=decoder_hidden_size,
bias=True,
patch_size=patch_size
)
self.encoder_to_decoder = nn.Linear(hidden_size, decoder_hidden_size, bias=False)
self.a_to_b_proj = nn.Linear(decoder_hidden_size, hidden_size, bias=False)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = nn.Embedding(num_classes + 1, hidden_size, padding_idx=0)
self.final_layer = FinalLayer(
decoder_hidden_size,
out_channels=in_channels,
patch_size=patch_size
)
self.shared_encoder_adaLN = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.shared_decoder_adaLN = nn.Sequential(
nn.Linear(hidden_size, 6 * decoder_hidden_size, bias=True)
)
self.blocks = nn.ModuleList()
for i in range(self.num_blocks):
is_encoder = i < self.num_blocks
if is_encoder:
if i < 1:
block_hidden_size = decoder_hidden_size
num_heads = self.num_groups
elif i >= self.num_blocks - 3:
block_hidden_size = decoder_hidden_size
num_heads = self.num_groups
else:
block_hidden_size = hidden_size
num_heads = self.num_groups
else:
block_hidden_size = decoder_hidden_size
num_heads = self.num_groups
context_dim = hidden_size if i % 2 == 0 and is_encoder else None
self.blocks.append(
DDTBlock(
block_hidden_size,
num_heads,
context_dim=context_dim,
is_encoder_block=is_encoder
)
)
self.bc_projection = nn.Linear(decoder_hidden_size + hidden_size, decoder_hidden_size, bias=False)
self.initialize_weights()
self.precompute_encoder_pos = dict()
self.precompute_decoder_pos = dict()
from functools import lru_cache
@lru_cache
def fetch_encoder_pos(self, height, width, device):
key = (height, width)
if key in self.precompute_encoder_pos:
return self.precompute_encoder_pos[key].to(device)
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_encoder_pos[key] = pos
return pos
@lru_cache
def fetch_decoder_pos(self, height, width, device):
key = (height, width)
if key in self.precompute_decoder_pos:
return self.precompute_decoder_pos[key].to(device)
else:
pos = precompute_freqs_cis_2d(self.decoder_hidden_size // self.num_groups, height, width).to(device)
self.precompute_decoder_pos[key] = pos
return pos
def initialize_weights(self):
for embedder in [self.x_embedder, self.s_embedder]:
nn.init.xavier_uniform_(embedder.proj.weight)
if embedder.proj.bias is not None:
nn.init.constant_(embedder.proj.bias, 0)
nn.init.xavier_uniform_(self.encoder_to_decoder.weight)
nn.init.xavier_uniform_(self.a_to_b_proj.weight)
nn.init.normal_(self.y_embedder.weight, std=0.02)
with torch.no_grad():
self.y_embedder.weight[0].fill_(0)
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.constant_(self.shared_encoder_adaLN[-1].weight, 0)
nn.init.constant_(self.shared_encoder_adaLN[-1].bias, 0)
nn.init.constant_(self.shared_decoder_adaLN[-1].weight, 0)
nn.init.constant_(self.shared_decoder_adaLN[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
nn.init.xavier_uniform_(self.bc_projection.weight)
def embed_condition(self, cond):
device = self.y_embedder.weight.device
max_len = self.max_tags
batch_size = len(cond)
padded_tags = torch.zeros(batch_size, max_len, dtype=torch.long, device=device)
for i, tags in enumerate(cond):
truncated_tags = tags[:max_len]
padded_tags[i, :len(truncated_tags)] = torch.tensor(truncated_tags, dtype=torch.long, device=device)
padding_mask = (padded_tags != 0)
embedded = self.y_embedder(padded_tags)
return embedded, padding_mask
def forward(self, x, t, y):
y_emb, padding_mask = self.embed_condition(y)
return self.forward_emb(x, t, y_emb, padding_mask)
@torch.compile()
def forward_emb(self, x, t, y_emb, padding_mask=None):
B, _, H, W = x.shape
h_patches = H // self.patch_h
w_patches = W // self.patch_w
encoder_pos = self.fetch_encoder_pos(h_patches, w_patches, x.device)
decoder_pos = self.fetch_decoder_pos(h_patches, w_patches, x.device)
t_emb = self.t_embedder(t.view(-1)).view(B, 1, self.hidden_size)
t_cond = nn.functional.silu(t_emb)
s = self.s_embedder(x)
s_section_a = s
for i in range(min(1, self.num_blocks)):
block_context = y_emb if i % 2 == 0 else None
s_section_a = self.blocks[i](s_section_a, t_cond, decoder_pos, None, context=block_context, context_mask=padding_mask, shared_adaLN=self.shared_decoder_adaLN)
s_section_a_projected = self.a_to_b_proj(s_section_a)
s_section_b = s_section_a_projected
for i in range(1, self.num_blocks - 3):
block_context = y_emb if i % 2 == 0 else None
s_section_b = self.blocks[i](s_section_b, t_cond, encoder_pos, None, context=block_context, context_mask=padding_mask, shared_adaLN=self.shared_encoder_adaLN)
s_concat = torch.cat([s_section_a, s_section_b], dim=-1)
s = self.bc_projection(s_concat)
for i in range(max(1, self.num_blocks - 3), self.num_blocks):
block_context = y_emb if i % 2 == 0 else None
s = self.blocks[i](s, t_cond, decoder_pos, None, context=block_context, context_mask=padding_mask, shared_adaLN=self.shared_decoder_adaLN)
s = self.final_layer(s, H // self.patch_h, W // self.patch_w)
return s