File size: 15,061 Bytes
ae41cb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
"""Flow matching audio head for speech-to-speech.

Generates audio from LLM hidden states via flow matching:
  LLM hidden -> llm_proj -> flow_net (LSD decode) -> Mimi latents -> Mimi decoder -> audio

Supports two modes:
1. Training from scratch with 512-dim Mimi embeddings (latent_proj_in/out)
2. Using pretrained pocket-tts flow_net with 32-dim normalized latents
"""

import logging
from functools import partial
from typing import Optional

import torch
import torch.nn as nn

from .modules.mlp import SimpleMLPAdaLN

logger = logging.getLogger(__name__)


def lsd_decode(
    v_t,
    x_0: torch.Tensor,
    num_steps: int = 1,
) -> torch.Tensor:
    """Lagrangian Self-Distillation decoding.

    Iteratively refines noise into latents using the flow velocity network.

    Args:
        v_t: Velocity function v(s, t, x) -> velocity
        x_0: Initial noise, shape [N, latent_dim]
        num_steps: Number of integration steps

    Returns:
        Decoded latents, shape [N, latent_dim]
    """
    current = x_0
    for i in range(num_steps):
        s = i / num_steps
        t = (i + 1) / num_steps
        s_tensor = torch.full_like(x_0[..., :1], s)
        t_tensor = torch.full_like(x_0[..., :1], t)
        flow_dir = v_t(s_tensor, t_tensor, current)
        current = current + flow_dir / num_steps
    return current


class AudioHead(nn.Module):
    """Flow matching head: LLM hidden -> Mimi latents -> audio.

    Architecture:
        - llm_proj: Linear projection from LLM hidden dim to flow conditioning
        - latent_proj_in/out: Project between Mimi 512-dim and flow 32-dim
        - flow_net: SimpleMLPAdaLN that predicts flow velocity
        - Mimi decoder for latent -> audio

    Args:
        config: ASRConfig with:
            - llm_dim: LLM hidden dimension (default: 2048)
            - lsd_decode_steps: Number of LSD integration steps (default: 1)
            - flow_temperature: Sampling temperature for noise (default: 1.0)
    """

    # Architecture dimensions
    COND_DIM = 1024  # Conditioning dimension
    LATENT_DIM = 32  # Flow latent dimension (matches Mimi's 32 codebooks)
    MIMI_DIM = 512  # Mimi encoder output dimension
    FLOW_DIM = 512  # Flow network hidden dimension
    FLOW_DEPTH = 6  # Number of residual blocks

    def __init__(self, config, llm_dim: int = None):
        super().__init__()
        # llm_dim can be passed directly or from config
        self.llm_dim = llm_dim or getattr(config, "llm_dim", None) or 2048
        self.cond_dim = self.COND_DIM
        self.latent_dim = self.LATENT_DIM
        self.mimi_dim = self.MIMI_DIM
        self.lsd_steps = getattr(config, "lsd_decode_steps", 1)
        self.temp = getattr(config, "flow_temperature", 1.0)

        # LLM -> conditioning projection
        self.llm_proj = nn.Linear(self.llm_dim, self.cond_dim, bias=False)

        # Mimi embedding projections
        # Projects 512-dim Mimi embeddings to 32-dim flow latents and back
        self.latent_proj_in = nn.Linear(self.mimi_dim, self.latent_dim, bias=False)
        self.latent_proj_out = nn.Linear(self.latent_dim, self.mimi_dim, bias=False)

        # Flow network
        self.flow_net = SimpleMLPAdaLN(
            in_channels=self.latent_dim,
            model_channels=self.FLOW_DIM,
            out_channels=self.latent_dim,
            cond_channels=self.cond_dim,
            num_res_blocks=self.FLOW_DEPTH,
            num_time_conds=2,
        )

        # Normalization buffers for pretrained pocket-tts flow_net
        # When using pretrained weights, the flow operates in normalized 32-dim space
        self.register_buffer("emb_mean", torch.zeros(self.latent_dim))
        self.register_buffer("emb_std", torch.ones(self.latent_dim))
        self._use_pretrained_normalization = False

        # Mimi decoder components (loaded separately via load_mimi_decoder)
        self.mimi = None

    def load_mimi_decoder(self, device: torch.device = None, dtype: torch.dtype = None):
        """Load Mimi model for decoding latents to audio."""
        from transformers import MimiModel

        self.mimi = MimiModel.from_pretrained("kyutai/mimi")
        self.mimi.requires_grad_(False)
        self.mimi.eval()

        if device is not None:
            self.mimi = self.mimi.to(device)
        if dtype is not None:
            self.mimi = self.mimi.to(dtype)

        logger.info("Loaded Mimi decoder from kyutai/mimi")

    def load_pretrained_flow_net(
        self,
        weights_path: Optional[str] = None,
        freeze: bool = True,
    ):
        """Load pretrained pocket-tts flow_net weights.

        This enables using the pretrained flow matching network from pocket-tts,
        which operates in normalized 32-dim latent space.

        Args:
            weights_path: Path to safetensors file. If None, downloads from HuggingFace.
            freeze: Whether to freeze flow_net weights (default: True, only train llm_proj)
        """
        import safetensors.torch

        if weights_path is None:
            from huggingface_hub import hf_hub_download

            weights_path = hf_hub_download(
                repo_id="kyutai/pocket-tts", filename="tts_b6369a24.safetensors"
            )

        state = safetensors.torch.load_file(weights_path)

        # Extract flow_net weights
        flow_state = {}
        for k, v in state.items():
            if k.startswith("flow_lm.flow_net."):
                new_key = k.replace("flow_lm.flow_net.", "")
                flow_state[new_key] = v

        self.flow_net.load_state_dict(flow_state)
        logger.info(f"Loaded pretrained flow_net from {weights_path}")

        # Load normalization buffers
        if "flow_lm.emb_mean" in state:
            self.emb_mean.copy_(state["flow_lm.emb_mean"])
        if "flow_lm.emb_std" in state:
            self.emb_std.copy_(state["flow_lm.emb_std"])
            # Enable normalization for generate
            self._use_pretrained_normalization = True
            logger.info("Loaded emb_mean and emb_std for normalization")

        if freeze:
            self.flow_net.requires_grad_(False)
            logger.info("Froze flow_net weights (only llm_proj will train)")

    def forward(
        self,
        hidden_states: torch.Tensor,
        latent_targets: Optional[torch.Tensor] = None,
        latent_lengths: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Forward pass for training or inference.

        Args:
            hidden_states: LLM hidden states, shape [batch, seq_len, llm_dim]
            latent_targets: Target Mimi latents for training, shape [batch, seq_len, 512]
            latent_lengths: Actual lengths per sample, shape [batch]

        Returns:
            Training: scalar flow matching loss
            Inference: generated Mimi latents, shape [batch, seq_len, 512]
        """
        # Project LLM hidden states to conditioning
        cond = self.llm_proj(hidden_states)

        if latent_targets is not None:
            return self._compute_loss(cond, latent_targets, latent_lengths)
        return self._generate(cond)

    def _compute_loss(
        self,
        cond: torch.Tensor,
        targets: torch.Tensor,
        lengths: Optional[torch.Tensor],
    ) -> torch.Tensor:
        """Compute flow matching loss with reconstruction term.

        The loss has two components:
        1. Flow matching loss: MSE between predicted and target velocities in 32-dim space
        2. Reconstruction loss: MSE between reconstructed and original 512-dim embeddings
           (this ensures latent_proj_out is trained)

        Args:
            cond: Conditioning from LLM, shape [batch, cond_seq_len, cond_dim]
            targets: Mimi embeddings, shape [batch, target_seq_len, 512]
            lengths: Optional lengths for masking
        """
        # Debug: check inputs for NaN/Inf
        if torch.isnan(cond).any() or torch.isinf(cond).any():
            logger.warning(
                f"NaN/Inf in cond! shape={cond.shape}, nan={torch.isnan(cond).sum()}, inf={torch.isinf(cond).sum()}"
            )
        if torch.isnan(targets).any() or torch.isinf(targets).any():
            logger.warning(f"NaN/Inf in targets! shape={targets.shape}")

        batch, cond_seq_len, _ = cond.shape
        target_seq_len = targets.shape[1]
        device = cond.device
        dtype = cond.dtype

        # Handle empty sequences
        if cond_seq_len == 0 or target_seq_len == 0:
            return torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True)

        # Project 512-dim Mimi embeddings to 32-dim flow latents
        targets_proj = self.latent_proj_in(targets)

        # Compute reconstruction loss to train latent_proj_out
        # This ensures the projection learns a good inverse mapping
        targets_reconstructed = self.latent_proj_out(targets_proj)

        # Interpolate targets to match conditioning sequence length
        targets_for_interp = targets
        if target_seq_len != cond_seq_len:
            targets_proj = targets_proj.transpose(1, 2)
            targets_proj = torch.nn.functional.interpolate(
                targets_proj, size=cond_seq_len, mode="linear", align_corners=False
            )
            targets_proj = targets_proj.transpose(1, 2).contiguous()

            # Also interpolate original targets for reconstruction loss
            targets_for_interp = targets.transpose(1, 2)
            targets_for_interp = torch.nn.functional.interpolate(
                targets_for_interp, size=cond_seq_len, mode="linear", align_corners=False
            )
            targets_for_interp = targets_for_interp.transpose(1, 2).contiguous()

            # Interpolate reconstructed targets to match
            targets_reconstructed = targets_reconstructed.transpose(1, 2)
            targets_reconstructed = torch.nn.functional.interpolate(
                targets_reconstructed, size=cond_seq_len, mode="linear", align_corners=False
            )
            targets_reconstructed = targets_reconstructed.transpose(1, 2).contiguous()

            if lengths is not None:
                scale = cond_seq_len / target_seq_len
                lengths = (lengths.float() * scale).long()

        seq_len = cond_seq_len
        x_1 = targets_proj

        # Random timesteps for each sample/position (match input dtype)
        t = torch.rand(batch, seq_len, 1, device=device, dtype=dtype)

        # Sample noise
        x_0 = torch.randn_like(x_1)

        # Linear interpolation: x_t = (1-t) * x_0 + t * x_1
        x_t = (1 - t) * x_0 + t * x_1

        # Target velocity: dx/dt = x_1 - x_0
        v_target = x_1 - x_0

        # Flatten for flow_net: [batch * seq_len, dim]
        cond_flat = cond.view(-1, self.cond_dim)
        t_flat = t.view(-1, 1)
        x_t_flat = x_t.view(-1, self.latent_dim)

        # Predict velocity
        v_pred = self.flow_net(cond_flat, t_flat, t_flat, x_t_flat)
        v_pred = v_pred.view(batch, seq_len, self.latent_dim)

        # Compute masked losses
        if lengths is not None:
            positions = torch.arange(seq_len, device=device).unsqueeze(0)
            mask = positions < lengths.unsqueeze(1)

            # Check if mask is all False (no valid positions)
            if not mask.any():
                return torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True)

            flow_mask = mask.unsqueeze(-1).expand_as(v_pred)
            recon_mask = mask.unsqueeze(-1).expand_as(targets_reconstructed)

            flow_loss = ((v_pred - v_target) ** 2)[flow_mask].mean()
            recon_loss = ((targets_reconstructed - targets_for_interp) ** 2)[recon_mask].mean()
        else:
            flow_loss = ((v_pred - v_target) ** 2).mean()
            recon_loss = ((targets_reconstructed - targets_for_interp) ** 2).mean()

        # Combined loss (reconstruction loss weighted at 0.1 to not dominate)
        return flow_loss + 0.1 * recon_loss

    def _generate(self, cond: torch.Tensor) -> torch.Tensor:
        """Generate Mimi embeddings via LSD decoding.

        Args:
            cond: Conditioning from LLM, shape [batch, seq_len, cond_dim]

        Returns:
            Generated Mimi embeddings, shape [batch, seq_len, 512]
        """
        batch, seq_len, _ = cond.shape
        device = cond.device
        dtype = cond.dtype

        # Handle empty sequences
        if seq_len == 0:
            return torch.empty(batch, 0, self.mimi_dim, device=device, dtype=dtype)

        # Clamp temperature to non-negative to avoid complex numbers from sqrt
        temp = max(0.0, self.temp)

        latents = []
        for t in range(seq_len):
            cond_t = cond[:, t]

            # Sample initial noise in 32-dim flow space
            noise = torch.randn(batch, self.latent_dim, device=device, dtype=dtype)
            noise = noise * (temp**0.5)

            def velocity_fn(cond_fixed, s, t, x):
                return self.flow_net(cond_fixed, s, t, x)

            conditioned_flow = partial(velocity_fn, cond_t)
            latent = lsd_decode(conditioned_flow, noise, self.lsd_steps)
            latents.append(latent)

        latents = torch.stack(latents, dim=1)

        # Denormalize if using pretrained pocket-tts normalization
        if self._use_pretrained_normalization:
            latents = latents * self.emb_std + self.emb_mean

        # Project back to 512-dim Mimi embedding space
        return self.latent_proj_out(latents)

    def decode_to_audio(self, latents: torch.Tensor) -> torch.Tensor:
        """Decode Mimi latents to audio waveform.

        Note: HuggingFace MimiModel.decode() expects discrete codes, not continuous
        embeddings. We bypass the quantizer and call upsample → decoder_transformer
        → decoder directly to decode from continuous latents.

        Args:
            latents: Mimi latents, shape [batch, seq_len, 512]

        Returns:
            Audio waveform, shape [batch, samples]
        """
        if self.mimi is None:
            raise RuntimeError("Mimi decoder not loaded. Call load_mimi_decoder() first.")

        # [batch, seq, 512] → [batch, 512, seq]
        latents = latents.transpose(1, 2)

        with torch.no_grad():
            # Upsample latents (2x temporal upsampling)
            emb = self.mimi.upsample(latents)

            # Decoder transformer expects [batch, seq, dim]
            emb = emb.transpose(1, 2)
            decoder_out = self.mimi.decoder_transformer(emb)
            emb = getattr(decoder_out, "last_hidden_state", decoder_out[0])

            # Final decoder expects [batch, dim, seq]
            emb = emb.transpose(1, 2)
            audio = self.mimi.decoder(emb)

        return audio.squeeze(1)

    def get_output_length(self, input_length: int) -> int:
        """Estimate output audio frames from input hidden state length.

        For Mimi at 12.5 Hz frame rate with 24kHz audio:
        Each latent frame = 24000 / 12.5 = 1920 audio samples
        """
        return input_length * 1920