File size: 13,012 Bytes
7417a6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
"""
AudioTextHTDemucs v2 - Text-conditioned source separation.

Changes from v1:
- Custom trainable decoder that outputs 1 source (not 4)
- HTDemucs encoder kept (frozen)
- CLAP text encoder (frozen)
- Cross-attention conditioning at bottleneck
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Any
from fractions import Fraction
from einops import rearrange

from demucs.htdemucs import HTDemucs
from transformers import ClapModel, ClapTextModelWithProjection, RobertaTokenizerFast

class TextCrossAttention(nn.Module):
    """Cross-attention: audio features attend to text embeddings."""

    def __init__(self, feat_dim, text_dim, n_heads=8, dropout=0.0):
        super().__init__()
        self.q_proj = nn.Linear(feat_dim, feat_dim)
        self.k_proj = nn.Linear(text_dim, feat_dim)
        self.v_proj = nn.Linear(text_dim, feat_dim)
        self.attn = nn.MultiheadAttention(feat_dim, n_heads, batch_first=True, dropout=dropout)
        self.out_mlp = nn.Sequential(
            nn.Linear(feat_dim, feat_dim),
            nn.GELU(),
            nn.Linear(feat_dim, feat_dim),
        )
        self.norm_q = nn.LayerNorm(feat_dim)
        self.norm_out = nn.LayerNorm(feat_dim)

    def forward_attend(self, queries, text_emb):
        q = self.norm_q(queries)
        if text_emb.dim() == 2:
            text_emb = text_emb.unsqueeze(1)
        k = self.k_proj(text_emb)
        v = self.v_proj(text_emb)
        q_proj = self.q_proj(q)
        attn_out, _ = self.attn(query=q_proj, key=k, value=v)
        out = queries + attn_out
        out = out + self.out_mlp(out)
        return self.norm_out(out)

    def forward(self, x, xt, text_emb):
        B, C, F, T = x.shape
        x_seq = rearrange(x, "b c f t -> b (f t) c")
        xt_seq = rearrange(xt, "b c t -> b t c")
        x_seq = self.forward_attend(x_seq, text_emb)
        xt_seq = self.forward_attend(xt_seq, text_emb)
        x = rearrange(x_seq, "b (f t) c -> b c f t", f=F, t=T)
        xt = rearrange(xt_seq, "b t c -> b c t")
        return x, xt


class FreqDecoder(nn.Module):
    """Frequency-domain decoder: mirrors HTDemucs encoder structure but outputs 1 source."""

    def __init__(self, channels: List[int], kernel_size: int = 8, stride: int = 4):
        """
        channels: List of channel dims from bottleneck to output, e.g. [384, 192, 96, 48, 2]
        """
        super().__init__()
        self.layers = nn.ModuleList()

        for i in range(len(channels) - 1):
            in_ch = channels[i]
            out_ch = channels[i + 1]
            is_last = (i == len(channels) - 2)

            self.layers.append(nn.Sequential(
                nn.ConvTranspose2d(in_ch, out_ch, kernel_size=(kernel_size, 1), stride=(stride, 1), padding=(kernel_size//4, 0)),
                nn.GroupNorm(1, out_ch) if not is_last else nn.Identity(),
                nn.GELU() if not is_last else nn.Identity(),
            ))

    def forward(self, x, skips: List[torch.Tensor], target_lengths: List[int]):
        """
        x: (B, C, F, T) bottleneck features
        skips: encoder skip connections (reversed order)
        target_lengths: target frequency dimensions for each layer
        """
        for i, layer in enumerate(self.layers):
            x = layer(x)
            # Match target size
            if i < len(target_lengths):
                target_f = target_lengths[i]
                if x.shape[2] != target_f:
                    x = F.interpolate(x, size=(target_f, x.shape[3]), mode='bilinear', align_corners=False)
            # Add skip connection if available
            if i < len(skips):
                skip = skips[i]
                # Project skip to match channels if needed
                if skip.shape[1] != x.shape[1]:
                    skip = skip[:, :x.shape[1]]  # Simple channel truncation
                if skip.shape[2:] != x.shape[2:]:
                    skip = F.interpolate(skip, size=x.shape[2:], mode='bilinear', align_corners=False)
                x = x + skip * 0.1  # Scaled residual
        return x


class TimeDecoder(nn.Module):
    """Time-domain decoder: outputs 1 source waveform."""

    def __init__(self, channels: List[int], kernel_size: int = 8, stride: int = 4):
        super().__init__()
        self.layers = nn.ModuleList()

        for i in range(len(channels) - 1):
            in_ch = channels[i]
            out_ch = channels[i + 1]
            is_last = (i == len(channels) - 2)

            self.layers.append(nn.Sequential(
                nn.ConvTranspose1d(in_ch, out_ch, kernel_size, stride, padding=kernel_size//4),
                nn.GroupNorm(1, out_ch) if not is_last else nn.Identity(),
                nn.GELU() if not is_last else nn.Identity(),
            ))

    def forward(self, x, skips: List[torch.Tensor], target_lengths: List[int]):
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i < len(target_lengths):
                target_t = target_lengths[i]
                if x.shape[2] != target_t:
                    x = F.interpolate(x, size=target_t, mode='linear', align_corners=False)
            if i < len(skips):
                skip = skips[i]
                if skip.shape[1] != x.shape[1]:
                    skip = skip[:, :x.shape[1]]
                if skip.shape[2] != x.shape[2]:
                    skip = F.interpolate(skip, size=x.shape[2], mode='linear', align_corners=False)
                x = x + skip * 0.1
        return x


class AudioTextHTDemucs(nn.Module):
    """
    Text-conditioned source separation.
    - HTDemucs encoder (frozen): extracts multi-scale audio features
    - CLAP (frozen): text embeddings
    - Cross-attention: conditions audio on text at bottleneck
    - Custom decoder (trainable): outputs single source
    """

    def __init__(
        self,
        htdemucs_model: HTDemucs,
        clap_encoder: ClapModel | ClapTextModelWithProjection,
        clap_tokenizer: RobertaTokenizerFast,
        model_dim: int = 384,
        text_dim: int = 512,
        num_heads: int = 8,
        sample_rate: int = 44100,
        segment: float = 7.8,
    ):
        super().__init__()

        self.htdemucs = htdemucs_model
        self.clap = clap_encoder
        self.tokenizer = clap_tokenizer
        self.sample_rate = sample_rate
        self.segment = segment

        # Freeze HTDemucs encoder
        for param in self.htdemucs.parameters():
            param.requires_grad = False

        # Freeze CLAP
        for param in self.clap.parameters():
            param.requires_grad = False

        # Text cross-attention at bottleneck
        self.text_attn = TextCrossAttention(model_dim, text_dim, num_heads)

        # Custom decoders (trainable) - output 1 source with 2 channels (stereo)
        # Channel progression: 384 -> 192 -> 96 -> 48 -> 4 (will be reshaped to 2 channels)
        self.freq_decoder = FreqDecoder([384, 192, 96, 48, 4])
        self.time_decoder = TimeDecoder([384, 192, 96, 48, 4])

        # Final projection to stereo
        self.freq_out = nn.Conv2d(4, 2, 1)
        self.time_out = nn.Conv1d(4, 2, 1)

    def _encode(self, x, xt):
        """Run HTDemucs encoder, save skip connections."""
        saved = []
        saved_t = []
        lengths = []
        lengths_t = []

        for idx, encode in enumerate(self.htdemucs.encoder):
            lengths.append(x.shape[-1])
            inject = None

            if idx < len(self.htdemucs.tencoder):
                lengths_t.append(xt.shape[-1])
                tenc = self.htdemucs.tencoder[idx]
                xt = tenc(xt)
                if not tenc.empty:
                    saved_t.append(xt)
                else:
                    inject = xt

            x = encode(x, inject)

            if idx == 0 and self.htdemucs.freq_emb is not None:
                frs = torch.arange(x.shape[-2], device=x.device)
                emb = self.htdemucs.freq_emb(frs).t()[None, :, :, None].expand_as(x)
                x = x + self.htdemucs.freq_emb_scale * emb

            saved.append(x)

        # Cross-transformer at bottleneck
        if self.htdemucs.crosstransformer:
            if self.htdemucs.bottom_channels:
                b, c, f, t = x.shape
                x = rearrange(x, "b c f t -> b c (f t)")
                x = self.htdemucs.channel_upsampler(x)
                x = rearrange(x, "b c (f t) -> b c f t", f=f)
                xt = self.htdemucs.channel_upsampler_t(xt)

            x, xt = self.htdemucs.crosstransformer(x, xt)

            if self.htdemucs.bottom_channels:
                x = rearrange(x, "b c f t -> b c (f t)")
                x = self.htdemucs.channel_downsampler(x)
                x = rearrange(x, "b c (f t) -> b c f t", f=f)
                xt = self.htdemucs.channel_downsampler_t(xt)

        return x, xt, saved, saved_t, lengths, lengths_t

    def _get_clap_embeddings(self, text: List[str], device):
        inputs = self.tokenizer(text, padding=True, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        if isinstance(self.clap, ClapModel):
            # Use get_text_features for ClapModel
            with torch.no_grad():
                return self.clap.get_text_features(**inputs)
        else:
            # Use forward pass for ClapTextModelWithProjection
            with torch.no_grad():
                return self.clap.forward(**inputs).text_embeds

    def forward(self, wav, text):
        """
        wav: (B, 2, T) stereo mixture
        text: List[str] prompts
        Returns: (B, 2, T) separated stereo source
        """
        device = wav.device
        B = wav.shape[0]
        original_length = wav.shape[-1]

        # Compute spectrogram (ensure all on same device)
        z = self.htdemucs._spec(wav).to(device)
        mag = self.htdemucs._magnitude(z).to(device)
        x = mag

        B, C, Fq, T_spec = x.shape

        # Normalize
        mean = x.mean(dim=(1, 2, 3), keepdim=True)
        std = x.std(dim=(1, 2, 3), keepdim=True)
        x = (x - mean) / (1e-5 + std)

        xt = wav
        meant = xt.mean(dim=(1, 2), keepdim=True)
        stdt = xt.std(dim=(1, 2), keepdim=True)
        xt = (xt - meant) / (1e-5 + stdt)

        # Encode (frozen)
        with torch.no_grad():
            x_enc, xt_enc, saved, saved_t, lengths, lengths_t = self._encode(x, xt)

        # Text conditioning via cross-attention (trainable)
        text_emb = self._get_clap_embeddings(text, device)
        x_cond, xt_cond = self.text_attn(x_enc, xt_enc, text_emb)

        # Decode with custom decoder (trainable)
        # Reverse skips for decoder
        saved_rev = saved[::-1]
        saved_t_rev = saved_t[::-1]
        lengths_rev = lengths[::-1]
        lengths_t_rev = lengths_t[::-1]

        # Frequency decoder
        x_dec = self.freq_decoder(x_cond, saved_rev, lengths_rev)
        x_dec = self.freq_out(x_dec)  # (B, 2, F, T)

        # Interpolate to match original spectrogram size
        x_dec = F.interpolate(x_dec, size=(Fq, T_spec), mode='bilinear', align_corners=False)

        # Apply as mask and invert spectrogram
        mask = torch.sigmoid(x_dec)  # (B, 2, F, T) in [0, 1]

        # mag is (B, C, F, T) from htdemucs - take first 2 channels for stereo
        mag_stereo = mag[:, :2, :, :]  # (B, 2, F, T)
        masked_spec = mag_stereo * mask

        # z is complex (B, C, F, T) - take stereo channels
        z_stereo = z[:, :2, :, :]  # (B, 2, F, T)
        phase = z_stereo / (mag_stereo + 1e-8)  # Complex phase
        masked_z = masked_spec * phase  # Apply mask while preserving phase
        freq_wav = self.htdemucs._ispec(masked_z, original_length).to(device)

        # Time decoder
        xt_dec = self.time_decoder(xt_cond, saved_t_rev, lengths_t_rev)
        xt_dec = self.time_out(xt_dec)  # (B, 2, T)

        # Interpolate to original length
        if xt_dec.shape[-1] != original_length:
            xt_dec = F.interpolate(xt_dec, size=original_length, mode='linear', align_corners=False)

        # Denormalize time output
        xt_dec = xt_dec * stdt + meant

        # Combine frequency and time branches
        output = freq_wav + xt_dec

        return output


if __name__ == "__main__":
    from demucs import pretrained

    htdemucs = pretrained.get_model('htdemucs').models[0]
    clap = ClapModel.from_pretrained("laion/clap-htsat-unfused")
    tokenizer = __import__('transformers').AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")

    model = AudioTextHTDemucs(htdemucs, clap, tokenizer)

    # Count params
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total params: {total:,}")
    print(f"Trainable params: {trainable:,}")

    # Test forward
    wav = torch.randn(2, 2, 44100 * 3)
    prompts = ["drums", "bass"]
    out = model(wav, prompts)
    print(f"Input: {wav.shape} -> Output: {out.shape}")