File size: 11,338 Bytes
4698bfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
"""
FlowMatchingTTS – adapts CosyVoice's Conditional Flow Matching pipeline
to the semacs-tts dataset (VQ codes β†’ mel spectrogram).

Architecture mirrors MaskedDiffWithXvec from cosyvoice/flow/flow.py:
  codes  β†’ Embedding β†’ causal Transformer β†’ Linear β†’ InterpolateRegulator
         β†˜ ConditionalCFM (ConditionalDecoder UNet1D) ← speaker emb
         β†’ mel spectrogram loss
"""
import math
import random
import sys
import os

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as ckpt
from omegaconf import DictConfig

from flow_matching.utils.cfm import ConditionalCFM
from flow_matching.utils.decoder import ConditionalDecoder
from flow_matching.utils.length_regulator import InterpolateRegulator
from flow_matching.utils.mask import make_pad_mask


class SinusoidalPE(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.0, max_len: int = 8192):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).float().unsqueeze(1)
        div = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))   # (1, max_len, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.dropout(x + self.pe[:, :x.size(1)])


def _forward_layer(layer, x, attn_mask, pad_mask):
    return layer(x, src_mask=attn_mask, src_key_padding_mask=pad_mask)


class CodeEncoder(nn.Module):
    """Causal Transformer encoder that operates on pre-embedded VQ codes."""

    def __init__(
        self,
        hidden_dim: int,
        num_heads: int,
        num_layers: int,
        ffn_dim: int,
        dropout: float,
        causal: bool,
        grad_checkpoint: bool,
    ):
        super().__init__()
        self.causal = causal
        self.grad_checkpoint = grad_checkpoint
        self.pos_enc = SinusoidalPE(hidden_dim, dropout)
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=hidden_dim,
                nhead=num_heads,
                dim_feedforward=ffn_dim,
                dropout=dropout,
                batch_first=True,
                norm_first=True,    # pre-LN for training stability
            )
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(hidden_dim)
        self._hidden_dim = hidden_dim

    def output_size(self) -> int:
        return self._hidden_dim

    def forward(self, x: torch.Tensor, lengths: torch.Tensor):
        """
        x:       (B, T, hidden_dim) – embedded codes
        lengths: (B,)               – valid code lengths per sample
        Returns: (B, T, hidden_dim), lengths
        """
        B, T, _ = x.shape
        x = self.pos_enc(x)

        pad_mask = make_pad_mask(lengths, T)       # (B, T), True = padded
        attn_mask = None
        if self.causal:
            attn_mask = torch.triu(
                torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1
            )

        for layer in self.layers:
            if self.grad_checkpoint and self.training:
                x = ckpt.checkpoint(
                    _forward_layer, layer, x, attn_mask, pad_mask,
                    use_reentrant=False,
                )
            else:
                x = layer(x, src_mask=attn_mask, src_key_padding_mask=pad_mask)

        return self.norm(x), lengths


class FlowMatchingTTS(nn.Module):
    """
    Full TTS flow matching model.

    Training input:
        codes     (B, num_q, T_codes)  VQ code indices
        code_lens (B,)                 valid code lengths
        mel       (B, T_mel, n_mels)   target log-mel spectrogram
        mel_lens  (B,)                 valid mel frame lengths
        spk_emb   (B, 192)             CAM++ speaker embedding (pre-extracted)

    Output:
        {'loss': scalar}               conditional flow matching loss
    """

    def __init__(self, cfg):
        super().__init__()
        m    = cfg.model
        data = cfg.data
        cfm  = cfg.cfm

        self.n_mels    = data.n_mels       # 100
        hidden_dim     = m.hidden_dim      # 768
        spk_emb_dim    = m.spk_emb_dim     # 192

        # ── code embedding: sum across quantizers ────────────────────────
        self.code_embedding = nn.Embedding(m.codebook_size, hidden_dim)

        # ── causal transformer encoder ───────────────────────────────────
        self.encoder = CodeEncoder(
            hidden_dim=hidden_dim,
            num_heads=m.num_heads,
            num_layers=m.num_layers,
            ffn_dim=m.ffn_dim,
            dropout=m.dropout,
            causal=m.causal,
            grad_checkpoint=m.grad_checkpoint,
        )

        # ── project encoder output to mel dimension ──────────────────────
        self.encoder_proj = nn.Linear(hidden_dim, self.n_mels)

        # ── speaker embedding: 192 β†’ n_mels ─────────────────────────────
        self.spk_embed_affine = nn.Linear(spk_emb_dim, self.n_mels)

        # ── length regulator: upsample codes to mel frame rate ───────────
        self.length_regulator = InterpolateRegulator(
            channels=self.n_mels,
            sampling_ratios=(),
        )

        # ── conditional flow matching decoder ────────────────────────────
        # ConditionalDecoder input = concat(x, mu, spks_t, cond) = 4 Γ— n_mels
        cfm_params = DictConfig({
            'sigma_min':          cfm.sigma_min,
            'solver':             'euler',
            't_scheduler':        cfm.t_scheduler,
            'training_cfg_rate':  cfm.training_cfg_rate,
            'inference_cfg_rate': cfm.inference_cfg_rate,
            'reg_loss_type':      'l1',
        })
        estimator = ConditionalDecoder(
            in_channels=4 * self.n_mels,   # x + mu + spks_expanded + cond
            out_channels=self.n_mels,
            channels=(256, 256),
            dropout=0.05,
            attention_head_dim=64,
            n_blocks=4,
            num_mid_blocks=12,
            num_heads=8,
            act_fn='gelu',
        )
        self.decoder = ConditionalCFM(
            in_channels=self.n_mels,
            cfm_params=cfm_params,
            n_spks=1,
            spk_emb_dim=self.n_mels,   # already projected to n_mels
            estimator=estimator,
        )

    # ── forward (training) ───────────────────────────────────────────────────

    def forward(self, batch: dict, device) -> dict:
        """
        Same interface as cosyvoice MaskedDiffWithXvec:  model(batch, device).

        Batch keys (added by Executor before this call):
            codes      (B, num_q, T_codes)
            code_lens  (B,)
            mel        (B, T_mel, n_mels)
            mel_lens   (B,)
            embedding  (B, 192)  L2-normalised CAM++ speaker embedding
        """
        codes     = batch['codes'].to(device)
        code_lens = batch['code_lens'].to(device)
        mel       = batch['mel'].to(device)
        mel_lens  = batch['mel_lens'].to(device)
        embedding = batch['embedding'].to(device)      # (B, 192)

        # Speaker projection
        spk = F.normalize(embedding, dim=-1)
        spk = self.spk_embed_affine(spk)               # (B, n_mels)

        # Code embedding: sum over quantizer axis
        x = self.code_embedding(codes)                 # (B, num_q, T, hidden_dim)
        x = x.sum(dim=1)                               # (B, T_codes, hidden_dim)

        # Encode
        h, _ = self.encoder(x, code_lens)              # (B, T_codes, hidden_dim)
        h    = self.encoder_proj(h)                    # (B, T_codes, n_mels)

        # Upsample to mel frame rate
        h, _ = self.length_regulator(h, mel_lens)      # (B, T_mel, n_mels)

        # Build conditioning: random-length mel prefix (50 % chance per sample)
        conds = torch.zeros_like(mel)
        for i, j in enumerate(mel_lens.tolist()):
            if random.random() < 0.5:
                continue
            idx = random.randint(0, int(0.8 * j))
            conds[i, :idx] = mel[i, :idx]
        conds = conds.transpose(1, 2)                  # (B, n_mels, T_mel)

        # Transpose to (B, n_mels, T) for the CFM decoder
        mel_t = mel.transpose(1, 2).contiguous()       # (B, n_mels, T_mel)
        h_t   = h.transpose(1, 2).contiguous()         # (B, n_mels, T_mel)

        # Safety alignment (no-op when length_regulator works correctly)
        if mel_t.shape[-1] != h_t.shape[-1]:
            mel_t = F.interpolate(mel_t, size=h_t.shape[-1], mode='nearest')
            conds = F.interpolate(conds, size=h_t.shape[-1], mode='nearest')

        mask = (~make_pad_mask(mel_lens)).to(h)        # (B, T_mel)

        loss, _ = self.decoder.compute_loss(
            mel_t,
            mask.unsqueeze(1),
            h_t,
            spk,
            cond=conds,
        )
        return {'loss': loss}

    # ── inference ────────────────────────────────────────────────────────────

    @torch.inference_mode()
    def inference(
        self,
        codes:        torch.Tensor,   # (1, num_q, T_codes)
        code_lens:    torch.Tensor,   # (1,)
        prompt_mel:   torch.Tensor,   # (1, T_prompt, n_mels) or empty
        target_len:   int,            # desired output mel frames
        spk_emb:      torch.Tensor,   # (1, 192)
        n_timesteps:  int = 10,
        temperature:  float = 1.0,
    ) -> torch.Tensor:
        """Returns generated mel: (1, n_mels, T_target)"""
        spk = F.normalize(spk_emb, dim=-1)
        spk = self.spk_embed_affine(spk)               # (1, n_mels)

        x = self.code_embedding(codes).sum(dim=1)      # (1, T_codes, hidden_dim)
        h, _ = self.encoder(x, code_lens)
        h    = self.encoder_proj(h)                    # (1, T_codes, n_mels)

        out_lens = torch.tensor([target_len], device=codes.device)
        h, _     = self.length_regulator(h, out_lens)  # (1, target_len, n_mels)
        h_t      = h.transpose(1, 2).contiguous()      # (1, n_mels, target_len)

        conds = torch.zeros_like(h_t)
        if prompt_mel.shape[1] > 0:
            p = min(prompt_mel.shape[1], target_len)
            conds[:, :, :p] = prompt_mel[:, :p].transpose(1, 2)

        mask = torch.ones(1, 1, target_len, device=codes.device)

        mel_out = self.decoder(
            mu=h_t,
            mask=mask,
            spks=spk,
            cond=conds,
            n_timesteps=n_timesteps,
            temperature=temperature,
        )
        return mel_out   # (1, n_mels, target_len)