File size: 6,644 Bytes
ae41cb4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | """Flow matching MLP with adaptive layer normalization.
Adapted from pocket-tts, originally from:
https://github.com/LTH14/mar/blob/fe470ac24afbee924668d8c5c83e9fec60af3a73/models/diffloss.py
Reference: https://arxiv.org/abs/2406.11838
"""
import math
import torch
import torch.nn as nn
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""Apply adaptive normalization modulation."""
return x * (1 + scale) + shift
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.alpha = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_dtype = x.dtype
var = self.eps + x.var(dim=-1, keepdim=True)
return (x * (self.alpha.to(var) * torch.rsqrt(var))).to(x_dtype)
class LayerNorm(nn.Module):
"""LayerNorm that supports JVP (for flow matching gradients)."""
def __init__(self, channels: int, eps: float = 1e-6, elementwise_affine: bool = True):
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(channels))
self.bias = nn.Parameter(torch.zeros(channels))
def forward(self, x: torch.Tensor) -> torch.Tensor:
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, unbiased=False, keepdim=True)
x = (x - mean) / torch.sqrt(var + self.eps)
if hasattr(self, "weight"):
x = x * self.weight + self.bias
return x
class TimestepEmbedder(nn.Module):
"""Embeds scalar timesteps into vector representations."""
def __init__(
self,
hidden_size: int,
frequency_embedding_size: int = 256,
max_period: int = 10000,
):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
RMSNorm(hidden_size),
)
self.frequency_embedding_size = frequency_embedding_size
half = frequency_embedding_size // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half) / half)
self.register_buffer("freqs", freqs)
def forward(self, t: torch.Tensor) -> torch.Tensor:
args = t * self.freqs.to(t.dtype)
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
return self.mlp(embedding)
class ResBlock(nn.Module):
"""Residual block with adaptive layer normalization."""
def __init__(self, channels: int):
super().__init__()
self.channels = channels
self.in_ln = LayerNorm(channels, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(channels, channels, bias=True),
nn.SiLU(),
nn.Linear(channels, channels, bias=True),
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(channels, 3 * channels, bias=True),
)
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
h = self.mlp(h)
return x + gate_mlp * h
class FinalLayer(nn.Module):
"""Final layer with adaptive normalization (DiT-style)."""
def __init__(self, model_channels: int, out_channels: int):
super().__init__()
self.norm_final = LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(model_channels, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(model_channels, 2 * model_channels, bias=True),
)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
return self.linear(x)
class SimpleMLPAdaLN(nn.Module):
"""MLP for flow matching with adaptive layer normalization.
Takes conditioning from an AR transformer and predicts flow velocity.
Args:
in_channels: Input/output latent dimension (e.g., 256 for Mimi)
model_channels: Hidden dimension of the MLP
out_channels: Output dimension (same as in_channels for flow matching)
cond_channels: Conditioning dimension from LLM
num_res_blocks: Number of residual blocks
num_time_conds: Number of time conditions (2 for start/end time in LSD)
"""
def __init__(
self,
in_channels: int,
model_channels: int,
out_channels: int,
cond_channels: int,
num_res_blocks: int,
num_time_conds: int = 2,
):
super().__init__()
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.num_time_conds = num_time_conds
assert num_time_conds == 2, "LSD requires exactly 2 time conditions (start, end)"
self.time_embed = nn.ModuleList(
[TimestepEmbedder(model_channels) for _ in range(num_time_conds)]
)
self.cond_embed = nn.Linear(cond_channels, model_channels)
self.input_proj = nn.Linear(in_channels, model_channels)
self.res_blocks = nn.ModuleList([ResBlock(model_channels) for _ in range(num_res_blocks)])
self.final_layer = FinalLayer(model_channels, out_channels)
def forward(
self,
c: torch.Tensor,
s: torch.Tensor,
t: torch.Tensor,
x: torch.Tensor,
) -> torch.Tensor:
"""Predict flow velocity.
Args:
c: Conditioning from LLM, shape [N, cond_channels]
s: Start time, shape [N, 1]
t: Target time, shape [N, 1]
x: Noisy latent, shape [N, in_channels]
Returns:
Predicted velocity, shape [N, out_channels]
"""
x = self.input_proj(x)
# Combine time embeddings (average of start and end time embeddings)
ts = [s, t]
t_combined = sum(self.time_embed[i](ts[i]) for i in range(self.num_time_conds))
t_combined = t_combined / self.num_time_conds
# Add conditioning
c = self.cond_embed(c)
y = t_combined + c
# Residual blocks
for block in self.res_blocks:
x = block(x, y)
return self.final_layer(x, y)
|