| | """Flow matching audio head for speech-to-speech. |
| | |
| | Generates audio from LLM hidden states via flow matching: |
| | LLM hidden -> llm_proj -> flow_net (LSD decode) -> Mimi latents -> Mimi decoder -> audio |
| | |
| | Supports two modes: |
| | 1. Training from scratch with 512-dim Mimi embeddings (latent_proj_in/out) |
| | 2. Using pretrained pocket-tts flow_net with 32-dim normalized latents |
| | """ |
| |
|
| | import logging |
| | from functools import partial |
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from .modules.mlp import SimpleMLPAdaLN |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def lsd_decode( |
| | v_t, |
| | x_0: torch.Tensor, |
| | num_steps: int = 1, |
| | ) -> torch.Tensor: |
| | """Lagrangian Self-Distillation decoding. |
| | |
| | Iteratively refines noise into latents using the flow velocity network. |
| | |
| | Args: |
| | v_t: Velocity function v(s, t, x) -> velocity |
| | x_0: Initial noise, shape [N, latent_dim] |
| | num_steps: Number of integration steps |
| | |
| | Returns: |
| | Decoded latents, shape [N, latent_dim] |
| | """ |
| | current = x_0 |
| | for i in range(num_steps): |
| | s = i / num_steps |
| | t = (i + 1) / num_steps |
| | s_tensor = torch.full_like(x_0[..., :1], s) |
| | t_tensor = torch.full_like(x_0[..., :1], t) |
| | flow_dir = v_t(s_tensor, t_tensor, current) |
| | current = current + flow_dir / num_steps |
| | return current |
| |
|
| |
|
| | class AudioHead(nn.Module): |
| | """Flow matching head: LLM hidden -> Mimi latents -> audio. |
| | |
| | Architecture: |
| | - llm_proj: Linear projection from LLM hidden dim to flow conditioning |
| | - latent_proj_in/out: Project between Mimi 512-dim and flow 32-dim |
| | - flow_net: SimpleMLPAdaLN that predicts flow velocity |
| | - Mimi decoder for latent -> audio |
| | |
| | Args: |
| | config: ASRConfig with: |
| | - llm_dim: LLM hidden dimension (default: 2048) |
| | - lsd_decode_steps: Number of LSD integration steps (default: 1) |
| | - flow_temperature: Sampling temperature for noise (default: 1.0) |
| | """ |
| |
|
| | |
| | COND_DIM = 1024 |
| | LATENT_DIM = 32 |
| | MIMI_DIM = 512 |
| | FLOW_DIM = 512 |
| | FLOW_DEPTH = 6 |
| |
|
| | def __init__(self, config, llm_dim: int = None): |
| | super().__init__() |
| | |
| | self.llm_dim = llm_dim or getattr(config, "llm_dim", None) or 2048 |
| | self.cond_dim = self.COND_DIM |
| | self.latent_dim = self.LATENT_DIM |
| | self.mimi_dim = self.MIMI_DIM |
| | self.lsd_steps = getattr(config, "lsd_decode_steps", 1) |
| | self.temp = getattr(config, "flow_temperature", 1.0) |
| |
|
| | |
| | self.llm_proj = nn.Linear(self.llm_dim, self.cond_dim, bias=False) |
| |
|
| | |
| | |
| | self.latent_proj_in = nn.Linear(self.mimi_dim, self.latent_dim, bias=False) |
| | self.latent_proj_out = nn.Linear(self.latent_dim, self.mimi_dim, bias=False) |
| |
|
| | |
| | self.flow_net = SimpleMLPAdaLN( |
| | in_channels=self.latent_dim, |
| | model_channels=self.FLOW_DIM, |
| | out_channels=self.latent_dim, |
| | cond_channels=self.cond_dim, |
| | num_res_blocks=self.FLOW_DEPTH, |
| | num_time_conds=2, |
| | ) |
| |
|
| | |
| | |
| | self.register_buffer("emb_mean", torch.zeros(self.latent_dim)) |
| | self.register_buffer("emb_std", torch.ones(self.latent_dim)) |
| | self._use_pretrained_normalization = False |
| |
|
| | |
| | self.mimi = None |
| |
|
| | def load_mimi_decoder(self, device: torch.device = None, dtype: torch.dtype = None): |
| | """Load Mimi model for decoding latents to audio.""" |
| | from transformers import MimiModel |
| |
|
| | self.mimi = MimiModel.from_pretrained("kyutai/mimi") |
| | self.mimi.requires_grad_(False) |
| | self.mimi.eval() |
| |
|
| | if device is not None: |
| | self.mimi = self.mimi.to(device) |
| | if dtype is not None: |
| | self.mimi = self.mimi.to(dtype) |
| |
|
| | logger.info("Loaded Mimi decoder from kyutai/mimi") |
| |
|
| | def load_pretrained_flow_net( |
| | self, |
| | weights_path: Optional[str] = None, |
| | freeze: bool = True, |
| | ): |
| | """Load pretrained pocket-tts flow_net weights. |
| | |
| | This enables using the pretrained flow matching network from pocket-tts, |
| | which operates in normalized 32-dim latent space. |
| | |
| | Args: |
| | weights_path: Path to safetensors file. If None, downloads from HuggingFace. |
| | freeze: Whether to freeze flow_net weights (default: True, only train llm_proj) |
| | """ |
| | import safetensors.torch |
| |
|
| | if weights_path is None: |
| | from huggingface_hub import hf_hub_download |
| |
|
| | weights_path = hf_hub_download( |
| | repo_id="kyutai/pocket-tts", filename="tts_b6369a24.safetensors" |
| | ) |
| |
|
| | state = safetensors.torch.load_file(weights_path) |
| |
|
| | |
| | flow_state = {} |
| | for k, v in state.items(): |
| | if k.startswith("flow_lm.flow_net."): |
| | new_key = k.replace("flow_lm.flow_net.", "") |
| | flow_state[new_key] = v |
| |
|
| | self.flow_net.load_state_dict(flow_state) |
| | logger.info(f"Loaded pretrained flow_net from {weights_path}") |
| |
|
| | |
| | if "flow_lm.emb_mean" in state: |
| | self.emb_mean.copy_(state["flow_lm.emb_mean"]) |
| | if "flow_lm.emb_std" in state: |
| | self.emb_std.copy_(state["flow_lm.emb_std"]) |
| | |
| | self._use_pretrained_normalization = True |
| | logger.info("Loaded emb_mean and emb_std for normalization") |
| |
|
| | if freeze: |
| | self.flow_net.requires_grad_(False) |
| | logger.info("Froze flow_net weights (only llm_proj will train)") |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | latent_targets: Optional[torch.Tensor] = None, |
| | latent_lengths: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | """Forward pass for training or inference. |
| | |
| | Args: |
| | hidden_states: LLM hidden states, shape [batch, seq_len, llm_dim] |
| | latent_targets: Target Mimi latents for training, shape [batch, seq_len, 512] |
| | latent_lengths: Actual lengths per sample, shape [batch] |
| | |
| | Returns: |
| | Training: scalar flow matching loss |
| | Inference: generated Mimi latents, shape [batch, seq_len, 512] |
| | """ |
| | |
| | cond = self.llm_proj(hidden_states) |
| |
|
| | if latent_targets is not None: |
| | return self._compute_loss(cond, latent_targets, latent_lengths) |
| | return self._generate(cond) |
| |
|
| | def _compute_loss( |
| | self, |
| | cond: torch.Tensor, |
| | targets: torch.Tensor, |
| | lengths: Optional[torch.Tensor], |
| | ) -> torch.Tensor: |
| | """Compute flow matching loss with reconstruction term. |
| | |
| | The loss has two components: |
| | 1. Flow matching loss: MSE between predicted and target velocities in 32-dim space |
| | 2. Reconstruction loss: MSE between reconstructed and original 512-dim embeddings |
| | (this ensures latent_proj_out is trained) |
| | |
| | Args: |
| | cond: Conditioning from LLM, shape [batch, cond_seq_len, cond_dim] |
| | targets: Mimi embeddings, shape [batch, target_seq_len, 512] |
| | lengths: Optional lengths for masking |
| | """ |
| | |
| | if torch.isnan(cond).any() or torch.isinf(cond).any(): |
| | logger.warning( |
| | f"NaN/Inf in cond! shape={cond.shape}, nan={torch.isnan(cond).sum()}, inf={torch.isinf(cond).sum()}" |
| | ) |
| | if torch.isnan(targets).any() or torch.isinf(targets).any(): |
| | logger.warning(f"NaN/Inf in targets! shape={targets.shape}") |
| |
|
| | batch, cond_seq_len, _ = cond.shape |
| | target_seq_len = targets.shape[1] |
| | device = cond.device |
| | dtype = cond.dtype |
| |
|
| | |
| | if cond_seq_len == 0 or target_seq_len == 0: |
| | return torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True) |
| |
|
| | |
| | targets_proj = self.latent_proj_in(targets) |
| |
|
| | |
| | |
| | targets_reconstructed = self.latent_proj_out(targets_proj) |
| |
|
| | |
| | targets_for_interp = targets |
| | if target_seq_len != cond_seq_len: |
| | targets_proj = targets_proj.transpose(1, 2) |
| | targets_proj = torch.nn.functional.interpolate( |
| | targets_proj, size=cond_seq_len, mode="linear", align_corners=False |
| | ) |
| | targets_proj = targets_proj.transpose(1, 2).contiguous() |
| |
|
| | |
| | targets_for_interp = targets.transpose(1, 2) |
| | targets_for_interp = torch.nn.functional.interpolate( |
| | targets_for_interp, size=cond_seq_len, mode="linear", align_corners=False |
| | ) |
| | targets_for_interp = targets_for_interp.transpose(1, 2).contiguous() |
| |
|
| | |
| | targets_reconstructed = targets_reconstructed.transpose(1, 2) |
| | targets_reconstructed = torch.nn.functional.interpolate( |
| | targets_reconstructed, size=cond_seq_len, mode="linear", align_corners=False |
| | ) |
| | targets_reconstructed = targets_reconstructed.transpose(1, 2).contiguous() |
| |
|
| | if lengths is not None: |
| | scale = cond_seq_len / target_seq_len |
| | lengths = (lengths.float() * scale).long() |
| |
|
| | seq_len = cond_seq_len |
| | x_1 = targets_proj |
| |
|
| | |
| | t = torch.rand(batch, seq_len, 1, device=device, dtype=dtype) |
| |
|
| | |
| | x_0 = torch.randn_like(x_1) |
| |
|
| | |
| | x_t = (1 - t) * x_0 + t * x_1 |
| |
|
| | |
| | v_target = x_1 - x_0 |
| |
|
| | |
| | cond_flat = cond.view(-1, self.cond_dim) |
| | t_flat = t.view(-1, 1) |
| | x_t_flat = x_t.view(-1, self.latent_dim) |
| |
|
| | |
| | v_pred = self.flow_net(cond_flat, t_flat, t_flat, x_t_flat) |
| | v_pred = v_pred.view(batch, seq_len, self.latent_dim) |
| |
|
| | |
| | if lengths is not None: |
| | positions = torch.arange(seq_len, device=device).unsqueeze(0) |
| | mask = positions < lengths.unsqueeze(1) |
| |
|
| | |
| | if not mask.any(): |
| | return torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True) |
| |
|
| | flow_mask = mask.unsqueeze(-1).expand_as(v_pred) |
| | recon_mask = mask.unsqueeze(-1).expand_as(targets_reconstructed) |
| |
|
| | flow_loss = ((v_pred - v_target) ** 2)[flow_mask].mean() |
| | recon_loss = ((targets_reconstructed - targets_for_interp) ** 2)[recon_mask].mean() |
| | else: |
| | flow_loss = ((v_pred - v_target) ** 2).mean() |
| | recon_loss = ((targets_reconstructed - targets_for_interp) ** 2).mean() |
| |
|
| | |
| | return flow_loss + 0.1 * recon_loss |
| |
|
| | def _generate(self, cond: torch.Tensor) -> torch.Tensor: |
| | """Generate Mimi embeddings via LSD decoding. |
| | |
| | Args: |
| | cond: Conditioning from LLM, shape [batch, seq_len, cond_dim] |
| | |
| | Returns: |
| | Generated Mimi embeddings, shape [batch, seq_len, 512] |
| | """ |
| | batch, seq_len, _ = cond.shape |
| | device = cond.device |
| | dtype = cond.dtype |
| |
|
| | |
| | if seq_len == 0: |
| | return torch.empty(batch, 0, self.mimi_dim, device=device, dtype=dtype) |
| |
|
| | |
| | temp = max(0.0, self.temp) |
| |
|
| | latents = [] |
| | for t in range(seq_len): |
| | cond_t = cond[:, t] |
| |
|
| | |
| | noise = torch.randn(batch, self.latent_dim, device=device, dtype=dtype) |
| | noise = noise * (temp**0.5) |
| |
|
| | def velocity_fn(cond_fixed, s, t, x): |
| | return self.flow_net(cond_fixed, s, t, x) |
| |
|
| | conditioned_flow = partial(velocity_fn, cond_t) |
| | latent = lsd_decode(conditioned_flow, noise, self.lsd_steps) |
| | latents.append(latent) |
| |
|
| | latents = torch.stack(latents, dim=1) |
| |
|
| | |
| | if self._use_pretrained_normalization: |
| | latents = latents * self.emb_std + self.emb_mean |
| |
|
| | |
| | return self.latent_proj_out(latents) |
| |
|
| | def decode_to_audio(self, latents: torch.Tensor) -> torch.Tensor: |
| | """Decode Mimi latents to audio waveform. |
| | |
| | Note: HuggingFace MimiModel.decode() expects discrete codes, not continuous |
| | embeddings. We bypass the quantizer and call upsample → decoder_transformer |
| | → decoder directly to decode from continuous latents. |
| | |
| | Args: |
| | latents: Mimi latents, shape [batch, seq_len, 512] |
| | |
| | Returns: |
| | Audio waveform, shape [batch, samples] |
| | """ |
| | if self.mimi is None: |
| | raise RuntimeError("Mimi decoder not loaded. Call load_mimi_decoder() first.") |
| |
|
| | |
| | latents = latents.transpose(1, 2) |
| |
|
| | with torch.no_grad(): |
| | |
| | emb = self.mimi.upsample(latents) |
| |
|
| | |
| | emb = emb.transpose(1, 2) |
| | decoder_out = self.mimi.decoder_transformer(emb) |
| | emb = getattr(decoder_out, "last_hidden_state", decoder_out[0]) |
| |
|
| | |
| | emb = emb.transpose(1, 2) |
| | audio = self.mimi.decoder(emb) |
| |
|
| | return audio.squeeze(1) |
| |
|
| | def get_output_length(self, input_length: int) -> int: |
| | """Estimate output audio frames from input hidden state length. |
| | |
| | For Mimi at 12.5 Hz frame rate with 24kHz audio: |
| | Each latent frame = 24000 / 12.5 = 1920 audio samples |
| | """ |
| | return input_length * 1920 |
| |
|