test_s2s_output / audio_head.py
mazesmazes's picture
Model save
ae41cb4 verified
"""Flow matching audio head for speech-to-speech.
Generates audio from LLM hidden states via flow matching:
LLM hidden -> llm_proj -> flow_net (LSD decode) -> Mimi latents -> Mimi decoder -> audio
Supports two modes:
1. Training from scratch with 512-dim Mimi embeddings (latent_proj_in/out)
2. Using pretrained pocket-tts flow_net with 32-dim normalized latents
"""
import logging
from functools import partial
from typing import Optional
import torch
import torch.nn as nn
from .modules.mlp import SimpleMLPAdaLN
logger = logging.getLogger(__name__)
def lsd_decode(
v_t,
x_0: torch.Tensor,
num_steps: int = 1,
) -> torch.Tensor:
"""Lagrangian Self-Distillation decoding.
Iteratively refines noise into latents using the flow velocity network.
Args:
v_t: Velocity function v(s, t, x) -> velocity
x_0: Initial noise, shape [N, latent_dim]
num_steps: Number of integration steps
Returns:
Decoded latents, shape [N, latent_dim]
"""
current = x_0
for i in range(num_steps):
s = i / num_steps
t = (i + 1) / num_steps
s_tensor = torch.full_like(x_0[..., :1], s)
t_tensor = torch.full_like(x_0[..., :1], t)
flow_dir = v_t(s_tensor, t_tensor, current)
current = current + flow_dir / num_steps
return current
class AudioHead(nn.Module):
"""Flow matching head: LLM hidden -> Mimi latents -> audio.
Architecture:
- llm_proj: Linear projection from LLM hidden dim to flow conditioning
- latent_proj_in/out: Project between Mimi 512-dim and flow 32-dim
- flow_net: SimpleMLPAdaLN that predicts flow velocity
- Mimi decoder for latent -> audio
Args:
config: ASRConfig with:
- llm_dim: LLM hidden dimension (default: 2048)
- lsd_decode_steps: Number of LSD integration steps (default: 1)
- flow_temperature: Sampling temperature for noise (default: 1.0)
"""
# Architecture dimensions
COND_DIM = 1024 # Conditioning dimension
LATENT_DIM = 32 # Flow latent dimension (matches Mimi's 32 codebooks)
MIMI_DIM = 512 # Mimi encoder output dimension
FLOW_DIM = 512 # Flow network hidden dimension
FLOW_DEPTH = 6 # Number of residual blocks
def __init__(self, config, llm_dim: int = None):
super().__init__()
# llm_dim can be passed directly or from config
self.llm_dim = llm_dim or getattr(config, "llm_dim", None) or 2048
self.cond_dim = self.COND_DIM
self.latent_dim = self.LATENT_DIM
self.mimi_dim = self.MIMI_DIM
self.lsd_steps = getattr(config, "lsd_decode_steps", 1)
self.temp = getattr(config, "flow_temperature", 1.0)
# LLM -> conditioning projection
self.llm_proj = nn.Linear(self.llm_dim, self.cond_dim, bias=False)
# Mimi embedding projections
# Projects 512-dim Mimi embeddings to 32-dim flow latents and back
self.latent_proj_in = nn.Linear(self.mimi_dim, self.latent_dim, bias=False)
self.latent_proj_out = nn.Linear(self.latent_dim, self.mimi_dim, bias=False)
# Flow network
self.flow_net = SimpleMLPAdaLN(
in_channels=self.latent_dim,
model_channels=self.FLOW_DIM,
out_channels=self.latent_dim,
cond_channels=self.cond_dim,
num_res_blocks=self.FLOW_DEPTH,
num_time_conds=2,
)
# Normalization buffers for pretrained pocket-tts flow_net
# When using pretrained weights, the flow operates in normalized 32-dim space
self.register_buffer("emb_mean", torch.zeros(self.latent_dim))
self.register_buffer("emb_std", torch.ones(self.latent_dim))
self._use_pretrained_normalization = False
# Mimi decoder components (loaded separately via load_mimi_decoder)
self.mimi = None
def load_mimi_decoder(self, device: torch.device = None, dtype: torch.dtype = None):
"""Load Mimi model for decoding latents to audio."""
from transformers import MimiModel
self.mimi = MimiModel.from_pretrained("kyutai/mimi")
self.mimi.requires_grad_(False)
self.mimi.eval()
if device is not None:
self.mimi = self.mimi.to(device)
if dtype is not None:
self.mimi = self.mimi.to(dtype)
logger.info("Loaded Mimi decoder from kyutai/mimi")
def load_pretrained_flow_net(
self,
weights_path: Optional[str] = None,
freeze: bool = True,
):
"""Load pretrained pocket-tts flow_net weights.
This enables using the pretrained flow matching network from pocket-tts,
which operates in normalized 32-dim latent space.
Args:
weights_path: Path to safetensors file. If None, downloads from HuggingFace.
freeze: Whether to freeze flow_net weights (default: True, only train llm_proj)
"""
import safetensors.torch
if weights_path is None:
from huggingface_hub import hf_hub_download
weights_path = hf_hub_download(
repo_id="kyutai/pocket-tts", filename="tts_b6369a24.safetensors"
)
state = safetensors.torch.load_file(weights_path)
# Extract flow_net weights
flow_state = {}
for k, v in state.items():
if k.startswith("flow_lm.flow_net."):
new_key = k.replace("flow_lm.flow_net.", "")
flow_state[new_key] = v
self.flow_net.load_state_dict(flow_state)
logger.info(f"Loaded pretrained flow_net from {weights_path}")
# Load normalization buffers
if "flow_lm.emb_mean" in state:
self.emb_mean.copy_(state["flow_lm.emb_mean"])
if "flow_lm.emb_std" in state:
self.emb_std.copy_(state["flow_lm.emb_std"])
# Enable normalization for generate
self._use_pretrained_normalization = True
logger.info("Loaded emb_mean and emb_std for normalization")
if freeze:
self.flow_net.requires_grad_(False)
logger.info("Froze flow_net weights (only llm_proj will train)")
def forward(
self,
hidden_states: torch.Tensor,
latent_targets: Optional[torch.Tensor] = None,
latent_lengths: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass for training or inference.
Args:
hidden_states: LLM hidden states, shape [batch, seq_len, llm_dim]
latent_targets: Target Mimi latents for training, shape [batch, seq_len, 512]
latent_lengths: Actual lengths per sample, shape [batch]
Returns:
Training: scalar flow matching loss
Inference: generated Mimi latents, shape [batch, seq_len, 512]
"""
# Project LLM hidden states to conditioning
cond = self.llm_proj(hidden_states)
if latent_targets is not None:
return self._compute_loss(cond, latent_targets, latent_lengths)
return self._generate(cond)
def _compute_loss(
self,
cond: torch.Tensor,
targets: torch.Tensor,
lengths: Optional[torch.Tensor],
) -> torch.Tensor:
"""Compute flow matching loss with reconstruction term.
The loss has two components:
1. Flow matching loss: MSE between predicted and target velocities in 32-dim space
2. Reconstruction loss: MSE between reconstructed and original 512-dim embeddings
(this ensures latent_proj_out is trained)
Args:
cond: Conditioning from LLM, shape [batch, cond_seq_len, cond_dim]
targets: Mimi embeddings, shape [batch, target_seq_len, 512]
lengths: Optional lengths for masking
"""
# Debug: check inputs for NaN/Inf
if torch.isnan(cond).any() or torch.isinf(cond).any():
logger.warning(
f"NaN/Inf in cond! shape={cond.shape}, nan={torch.isnan(cond).sum()}, inf={torch.isinf(cond).sum()}"
)
if torch.isnan(targets).any() or torch.isinf(targets).any():
logger.warning(f"NaN/Inf in targets! shape={targets.shape}")
batch, cond_seq_len, _ = cond.shape
target_seq_len = targets.shape[1]
device = cond.device
dtype = cond.dtype
# Handle empty sequences
if cond_seq_len == 0 or target_seq_len == 0:
return torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True)
# Project 512-dim Mimi embeddings to 32-dim flow latents
targets_proj = self.latent_proj_in(targets)
# Compute reconstruction loss to train latent_proj_out
# This ensures the projection learns a good inverse mapping
targets_reconstructed = self.latent_proj_out(targets_proj)
# Interpolate targets to match conditioning sequence length
targets_for_interp = targets
if target_seq_len != cond_seq_len:
targets_proj = targets_proj.transpose(1, 2)
targets_proj = torch.nn.functional.interpolate(
targets_proj, size=cond_seq_len, mode="linear", align_corners=False
)
targets_proj = targets_proj.transpose(1, 2).contiguous()
# Also interpolate original targets for reconstruction loss
targets_for_interp = targets.transpose(1, 2)
targets_for_interp = torch.nn.functional.interpolate(
targets_for_interp, size=cond_seq_len, mode="linear", align_corners=False
)
targets_for_interp = targets_for_interp.transpose(1, 2).contiguous()
# Interpolate reconstructed targets to match
targets_reconstructed = targets_reconstructed.transpose(1, 2)
targets_reconstructed = torch.nn.functional.interpolate(
targets_reconstructed, size=cond_seq_len, mode="linear", align_corners=False
)
targets_reconstructed = targets_reconstructed.transpose(1, 2).contiguous()
if lengths is not None:
scale = cond_seq_len / target_seq_len
lengths = (lengths.float() * scale).long()
seq_len = cond_seq_len
x_1 = targets_proj
# Random timesteps for each sample/position (match input dtype)
t = torch.rand(batch, seq_len, 1, device=device, dtype=dtype)
# Sample noise
x_0 = torch.randn_like(x_1)
# Linear interpolation: x_t = (1-t) * x_0 + t * x_1
x_t = (1 - t) * x_0 + t * x_1
# Target velocity: dx/dt = x_1 - x_0
v_target = x_1 - x_0
# Flatten for flow_net: [batch * seq_len, dim]
cond_flat = cond.view(-1, self.cond_dim)
t_flat = t.view(-1, 1)
x_t_flat = x_t.view(-1, self.latent_dim)
# Predict velocity
v_pred = self.flow_net(cond_flat, t_flat, t_flat, x_t_flat)
v_pred = v_pred.view(batch, seq_len, self.latent_dim)
# Compute masked losses
if lengths is not None:
positions = torch.arange(seq_len, device=device).unsqueeze(0)
mask = positions < lengths.unsqueeze(1)
# Check if mask is all False (no valid positions)
if not mask.any():
return torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True)
flow_mask = mask.unsqueeze(-1).expand_as(v_pred)
recon_mask = mask.unsqueeze(-1).expand_as(targets_reconstructed)
flow_loss = ((v_pred - v_target) ** 2)[flow_mask].mean()
recon_loss = ((targets_reconstructed - targets_for_interp) ** 2)[recon_mask].mean()
else:
flow_loss = ((v_pred - v_target) ** 2).mean()
recon_loss = ((targets_reconstructed - targets_for_interp) ** 2).mean()
# Combined loss (reconstruction loss weighted at 0.1 to not dominate)
return flow_loss + 0.1 * recon_loss
def _generate(self, cond: torch.Tensor) -> torch.Tensor:
"""Generate Mimi embeddings via LSD decoding.
Args:
cond: Conditioning from LLM, shape [batch, seq_len, cond_dim]
Returns:
Generated Mimi embeddings, shape [batch, seq_len, 512]
"""
batch, seq_len, _ = cond.shape
device = cond.device
dtype = cond.dtype
# Handle empty sequences
if seq_len == 0:
return torch.empty(batch, 0, self.mimi_dim, device=device, dtype=dtype)
# Clamp temperature to non-negative to avoid complex numbers from sqrt
temp = max(0.0, self.temp)
latents = []
for t in range(seq_len):
cond_t = cond[:, t]
# Sample initial noise in 32-dim flow space
noise = torch.randn(batch, self.latent_dim, device=device, dtype=dtype)
noise = noise * (temp**0.5)
def velocity_fn(cond_fixed, s, t, x):
return self.flow_net(cond_fixed, s, t, x)
conditioned_flow = partial(velocity_fn, cond_t)
latent = lsd_decode(conditioned_flow, noise, self.lsd_steps)
latents.append(latent)
latents = torch.stack(latents, dim=1)
# Denormalize if using pretrained pocket-tts normalization
if self._use_pretrained_normalization:
latents = latents * self.emb_std + self.emb_mean
# Project back to 512-dim Mimi embedding space
return self.latent_proj_out(latents)
def decode_to_audio(self, latents: torch.Tensor) -> torch.Tensor:
"""Decode Mimi latents to audio waveform.
Note: HuggingFace MimiModel.decode() expects discrete codes, not continuous
embeddings. We bypass the quantizer and call upsample → decoder_transformer
→ decoder directly to decode from continuous latents.
Args:
latents: Mimi latents, shape [batch, seq_len, 512]
Returns:
Audio waveform, shape [batch, samples]
"""
if self.mimi is None:
raise RuntimeError("Mimi decoder not loaded. Call load_mimi_decoder() first.")
# [batch, seq, 512] → [batch, 512, seq]
latents = latents.transpose(1, 2)
with torch.no_grad():
# Upsample latents (2x temporal upsampling)
emb = self.mimi.upsample(latents)
# Decoder transformer expects [batch, seq, dim]
emb = emb.transpose(1, 2)
decoder_out = self.mimi.decoder_transformer(emb)
emb = getattr(decoder_out, "last_hidden_state", decoder_out[0])
# Final decoder expects [batch, dim, seq]
emb = emb.transpose(1, 2)
audio = self.mimi.decoder(emb)
return audio.squeeze(1)
def get_output_length(self, input_length: int) -> int:
"""Estimate output audio frames from input hidden state length.
For Mimi at 12.5 Hz frame rate with 24kHz audio:
Each latent frame = 24000 / 12.5 = 1920 audio samples
"""
return input_length * 1920