| | """Flow matching MLP with adaptive layer normalization. |
| | |
| | Adapted from pocket-tts, originally from: |
| | https://github.com/LTH14/mar/blob/fe470ac24afbee924668d8c5c83e9fec60af3a73/models/diffloss.py |
| | |
| | Reference: https://arxiv.org/abs/2406.11838 |
| | """ |
| |
|
| | import math |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: |
| | """Apply adaptive normalization modulation.""" |
| | return x * (1 + scale) + shift |
| |
|
| |
|
| | class RMSNorm(nn.Module): |
| | """Root Mean Square Layer Normalization.""" |
| |
|
| | def __init__(self, dim: int, eps: float = 1e-5): |
| | super().__init__() |
| | self.eps = eps |
| | self.alpha = nn.Parameter(torch.ones(dim)) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x_dtype = x.dtype |
| | var = self.eps + x.var(dim=-1, keepdim=True) |
| | return (x * (self.alpha.to(var) * torch.rsqrt(var))).to(x_dtype) |
| |
|
| |
|
| | class LayerNorm(nn.Module): |
| | """LayerNorm that supports JVP (for flow matching gradients).""" |
| |
|
| | def __init__(self, channels: int, eps: float = 1e-6, elementwise_affine: bool = True): |
| | super().__init__() |
| | self.eps = eps |
| | if elementwise_affine: |
| | self.weight = nn.Parameter(torch.ones(channels)) |
| | self.bias = nn.Parameter(torch.zeros(channels)) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | mean = x.mean(dim=-1, keepdim=True) |
| | var = x.var(dim=-1, unbiased=False, keepdim=True) |
| | x = (x - mean) / torch.sqrt(var + self.eps) |
| | if hasattr(self, "weight"): |
| | x = x * self.weight + self.bias |
| | return x |
| |
|
| |
|
| | class TimestepEmbedder(nn.Module): |
| | """Embeds scalar timesteps into vector representations.""" |
| |
|
| | def __init__( |
| | self, |
| | hidden_size: int, |
| | frequency_embedding_size: int = 256, |
| | max_period: int = 10000, |
| | ): |
| | super().__init__() |
| | self.mlp = nn.Sequential( |
| | nn.Linear(frequency_embedding_size, hidden_size, bias=True), |
| | nn.SiLU(), |
| | nn.Linear(hidden_size, hidden_size, bias=True), |
| | RMSNorm(hidden_size), |
| | ) |
| | self.frequency_embedding_size = frequency_embedding_size |
| | half = frequency_embedding_size // 2 |
| | freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half) / half) |
| | self.register_buffer("freqs", freqs) |
| |
|
| | def forward(self, t: torch.Tensor) -> torch.Tensor: |
| | args = t * self.freqs.to(t.dtype) |
| | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| | return self.mlp(embedding) |
| |
|
| |
|
| | class ResBlock(nn.Module): |
| | """Residual block with adaptive layer normalization.""" |
| |
|
| | def __init__(self, channels: int): |
| | super().__init__() |
| | self.channels = channels |
| | self.in_ln = LayerNorm(channels, eps=1e-6) |
| | self.mlp = nn.Sequential( |
| | nn.Linear(channels, channels, bias=True), |
| | nn.SiLU(), |
| | nn.Linear(channels, channels, bias=True), |
| | ) |
| | self.adaLN_modulation = nn.Sequential( |
| | nn.SiLU(), |
| | nn.Linear(channels, 3 * channels, bias=True), |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| | shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) |
| | h = modulate(self.in_ln(x), shift_mlp, scale_mlp) |
| | h = self.mlp(h) |
| | return x + gate_mlp * h |
| |
|
| |
|
| | class FinalLayer(nn.Module): |
| | """Final layer with adaptive normalization (DiT-style).""" |
| |
|
| | def __init__(self, model_channels: int, out_channels: int): |
| | super().__init__() |
| | self.norm_final = LayerNorm(model_channels, elementwise_affine=False, eps=1e-6) |
| | self.linear = nn.Linear(model_channels, out_channels, bias=True) |
| | self.adaLN_modulation = nn.Sequential( |
| | nn.SiLU(), |
| | nn.Linear(model_channels, 2 * model_channels, bias=True), |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: |
| | shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) |
| | x = modulate(self.norm_final(x), shift, scale) |
| | return self.linear(x) |
| |
|
| |
|
| | class SimpleMLPAdaLN(nn.Module): |
| | """MLP for flow matching with adaptive layer normalization. |
| | |
| | Takes conditioning from an AR transformer and predicts flow velocity. |
| | |
| | Args: |
| | in_channels: Input/output latent dimension (e.g., 256 for Mimi) |
| | model_channels: Hidden dimension of the MLP |
| | out_channels: Output dimension (same as in_channels for flow matching) |
| | cond_channels: Conditioning dimension from LLM |
| | num_res_blocks: Number of residual blocks |
| | num_time_conds: Number of time conditions (2 for start/end time in LSD) |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | in_channels: int, |
| | model_channels: int, |
| | out_channels: int, |
| | cond_channels: int, |
| | num_res_blocks: int, |
| | num_time_conds: int = 2, |
| | ): |
| | super().__init__() |
| |
|
| | self.in_channels = in_channels |
| | self.model_channels = model_channels |
| | self.out_channels = out_channels |
| | self.num_res_blocks = num_res_blocks |
| | self.num_time_conds = num_time_conds |
| |
|
| | assert num_time_conds == 2, "LSD requires exactly 2 time conditions (start, end)" |
| |
|
| | self.time_embed = nn.ModuleList( |
| | [TimestepEmbedder(model_channels) for _ in range(num_time_conds)] |
| | ) |
| | self.cond_embed = nn.Linear(cond_channels, model_channels) |
| | self.input_proj = nn.Linear(in_channels, model_channels) |
| |
|
| | self.res_blocks = nn.ModuleList([ResBlock(model_channels) for _ in range(num_res_blocks)]) |
| | self.final_layer = FinalLayer(model_channels, out_channels) |
| |
|
| | def forward( |
| | self, |
| | c: torch.Tensor, |
| | s: torch.Tensor, |
| | t: torch.Tensor, |
| | x: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """Predict flow velocity. |
| | |
| | Args: |
| | c: Conditioning from LLM, shape [N, cond_channels] |
| | s: Start time, shape [N, 1] |
| | t: Target time, shape [N, 1] |
| | x: Noisy latent, shape [N, in_channels] |
| | |
| | Returns: |
| | Predicted velocity, shape [N, out_channels] |
| | """ |
| | x = self.input_proj(x) |
| |
|
| | |
| | ts = [s, t] |
| | t_combined = sum(self.time_embed[i](ts[i]) for i in range(self.num_time_conds)) |
| | t_combined = t_combined / self.num_time_conds |
| |
|
| | |
| | c = self.cond_embed(c) |
| | y = t_combined + c |
| |
|
| | |
| | for block in self.res_blocks: |
| | x = block(x, y) |
| |
|
| | return self.final_layer(x, y) |
| |
|