File size: 13,965 Bytes
2e82ca2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1ecf71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e82ca2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
"""
MLX implementation of CAM++ model - ModelScope architecture (Clean implementation)

Based on analysis of iic/speech_campplus_sv_zh_en_16k-common_advanced:
- Dense connections: each layer's output is concatenated with all previous outputs
- TDNN layers use kernel_size=1 (no temporal context in main conv)
- CAM layers provide the actual feature extraction
- Architecture: Input β†’ Dense Blocks (with CAM) β†’ Transitions β†’ Dense Layer
"""

import mlx.core as mx
import mlx.nn as nn
from typing import Dict, List, Optional
import json


class EmbeddedCAM(nn.Module):
    """
    Context-Aware Masking module embedded within TDNN layers

    Architecture (verified from ModelScope weights):
    - linear1: 1x1 Conv (in_channels β†’ cam_channels//2) with bias
    - linear2: 1x1 Conv (cam_channels//2 β†’ cam_channels//4) with bias
    - linear_local: 3x3 Conv (in_channels β†’ cam_channels//4) without bias
    - Output: cam_channels//4 channels (e.g., 32 for cam_channels=128)
    """

    def __init__(self, in_channels: int, cam_channels: int = 128):
        super().__init__()

        # Global context path: 1x1 β†’ 1x1
        self.linear1 = nn.Conv1d(
            in_channels=in_channels,
            out_channels=cam_channels // 2,  # 128 β†’ 64
            kernel_size=1,
            bias=True
        )

        self.linear2 = nn.Conv1d(
            in_channels=cam_channels // 2,  # 64
            out_channels=cam_channels // 4,  # 64 β†’ 32
            kernel_size=1,
            bias=True
        )

        # Local context path: 3x3 conv
        self.linear_local = nn.Conv1d(
            in_channels=in_channels,
            out_channels=cam_channels // 4,  # 128 β†’ 32
            kernel_size=3,
            padding=1,
            bias=False
        )

    def __call__(self, x: mx.array) -> mx.array:
        """
        Apply context-aware masking

        Args:
            x: Input (batch, length, in_channels) - channels_last format

        Returns:
            Output (batch, length, cam_channels//4)
        """
        # Global context: 1x1 β†’ relu β†’ 1x1
        global_context = self.linear1(x)
        global_context = nn.relu(global_context)
        global_context = self.linear2(global_context)

        # Local context: 3x3 conv
        local_context = self.linear_local(x)

        # Apply sigmoid mask
        mask = nn.sigmoid(global_context)
        output = local_context * mask

        return output


class TDNNLayerWithCAM(nn.Module):
    """
    TDNN layer with embedded CAM (verified architecture)

    Flow:
    1. Main conv: kernel_size=1 (channels projection)
    2. BatchNorm
    3. ReLU
    4. CAM: extracts features and outputs cam_channels//4

    Note: The main conv projects to a fixed channel size (e.g., 128),
    then CAM reduces to cam_channels//4 (e.g., 32) for dense connection.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int = 128,
        cam_channels: int = 128
    ):
        super().__init__()

        # Main TDNN: 1x1 conv (no temporal context)
        self.conv = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=1,
            padding=0,
            bias=False
        )

        # BatchNorm on the conv output
        self.bn = nn.BatchNorm(out_channels, affine=True)

        # ReLU activation
        self.activation = nn.ReLU()

        # Embedded CAM (takes conv output, produces cam_channels//4)
        self.cam = EmbeddedCAM(
            in_channels=out_channels,
            cam_channels=cam_channels
        )

    def __call__(self, x: mx.array) -> mx.array:
        """
        Forward pass

        Args:
            x: Input (batch, length, in_channels)

        Returns:
            CAM output (batch, length, cam_channels//4)
        """
        # Main conv + bn + relu
        out = self.conv(x)
        out = self.bn(out)
        out = self.activation(out)

        # CAM feature extraction
        out = self.cam(out)

        return out


class TransitionLayer(nn.Module):
    """
    Transition layer between dense blocks

    Reduces the accumulated channels back to base channel count.
    Architecture: BatchNorm β†’ ReLU β†’ 1x1 Conv
    """

    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()

        self.bn = nn.BatchNorm(in_channels, affine=True)
        self.activation = nn.ReLU()
        self.conv = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=1,
            bias=False
        )

    def __call__(self, x: mx.array) -> mx.array:
        out = self.bn(x)
        out = self.activation(out)
        out = self.conv(out)
        return out


class CAMPPModelScopeV2(nn.Module):
    """
    Clean CAM++ implementation matching ModelScope architecture

    Key features:
    - Dense connections: each layer's output is concatenated
    - TDNN layers use kernel_size=1
    - CAM provides feature extraction (outputs cam_channels//4 per layer)
    - Transitions reduce accumulated channels back to base

    Args:
        input_dim: Input feature dimension (e.g., 80 or 320)
        channels: Base channel count (e.g., 128 or 512)
        block_layers: Layers per block (e.g., [12, 24, 16])
        embedding_dim: Output embedding dimension (e.g., 192)
        cam_channels: CAM channel count (e.g., 128)
        input_kernel_size: Input layer kernel size (e.g., 5)
    """

    def __init__(
        self,
        input_dim: int = 80,
        channels: int = 512,
        block_layers: List[int] = None,
        embedding_dim: int = 192,
        cam_channels: int = 128,
        input_kernel_size: int = 5
    ):
        super().__init__()

        if block_layers is None:
            block_layers = [4, 9, 16]

        self.input_dim = input_dim
        self.channels = channels
        self.block_layers = block_layers
        self.embedding_dim = embedding_dim
        self.cam_channels = cam_channels
        self.growth_rate = cam_channels // 4  # Each layer adds this many channels

        # Input layer
        self.input_conv = nn.Conv1d(
            in_channels=input_dim,
            out_channels=channels,
            kernel_size=input_kernel_size,
            padding=input_kernel_size // 2,
            bias=False
        )
        self.input_bn = nn.BatchNorm(channels, affine=True)
        self.input_activation = nn.ReLU()

        # Dense Block 0
        for i in range(block_layers[0]):
            in_ch = channels + i * self.growth_rate
            layer = TDNNLayerWithCAM(
                in_channels=in_ch,
                out_channels=channels,
                cam_channels=cam_channels
            )
            setattr(self, f'block0_{i}', layer)
        self._block0_size = block_layers[0]

        # Transition 1 - doubles channel count
        transit1_in = channels + block_layers[0] * self.growth_rate
        transit1_out = channels * 2
        self.transit1 = TransitionLayer(transit1_in, transit1_out)

        # Dense Block 1 - starts with doubled channels
        for i in range(block_layers[1]):
            in_ch = transit1_out + i * self.growth_rate
            layer = TDNNLayerWithCAM(
                in_channels=in_ch,
                out_channels=channels,
                cam_channels=cam_channels
            )
            setattr(self, f'block1_{i}', layer)
        self._block1_size = block_layers[1]

        # Transition 2 - doubles channel count again
        transit2_in = transit1_out + block_layers[1] * self.growth_rate
        transit2_out = transit1_out * 2  # 4x original channels
        self.transit2 = TransitionLayer(transit2_in, transit2_out)

        # Dense Block 2 - starts with quadrupled channels
        for i in range(block_layers[2]):
            in_ch = transit2_out + i * self.growth_rate
            layer = TDNNLayerWithCAM(
                in_channels=in_ch,
                out_channels=channels,
                cam_channels=cam_channels
            )
            setattr(self, f'block2_{i}', layer)
        self._block2_size = block_layers[2]

        # Final dense layer
        dense_in = transit2_out + block_layers[2] * self.growth_rate
        self.dense = nn.Conv1d(
            in_channels=dense_in,
            out_channels=embedding_dim,
            kernel_size=1,
            bias=False
        )

    def __call__(self, x: mx.array) -> mx.array:
        """
        Forward pass

        Args:
            x: Input (batch, length, in_channels) - channels_last format

        Returns:
            Embeddings (batch, length, embedding_dim)
        """
        # Handle input format
        if x.ndim == 2:
            x = mx.expand_dims(x, axis=0)

        # MLX Conv1d expects (batch, length, in_channels)
        if x.shape[2] != self.input_dim:
            x = mx.transpose(x, (0, 2, 1))

        # Input layer
        out = self.input_conv(x)
        out = self.input_bn(out)
        out = self.input_activation(out)

        # Dense Block 0 (with concatenation)
        for i in range(self._block0_size):
            layer = getattr(self, f'block0_{i}')
            layer_out = layer(out)
            out = mx.concatenate([out, layer_out], axis=2)

        # Transition 1
        out = self.transit1(out)

        # Dense Block 1
        for i in range(self._block1_size):
            layer = getattr(self, f'block1_{i}')
            layer_out = layer(out)
            out = mx.concatenate([out, layer_out], axis=2)

        # Transition 2
        out = self.transit2(out)

        # Dense Block 2
        for i in range(self._block2_size):
            layer = getattr(self, f'block2_{i}')
            layer_out = layer(out)
            out = mx.concatenate([out, layer_out], axis=2)

        # Final dense layer
        embeddings = self.dense(out)

        return embeddings

    def extract_embedding(self, x: mx.array, pooling: str = "mean") -> mx.array:
        """
        Extract fixed-size speaker embedding

        Args:
            x: Input (batch, length, in_channels)
            pooling: "mean", "max", or "both"

        Returns:
            Embedding (batch, embedding_dim)
        """
        frame_embeddings = self(x)  # (batch, length, embedding_dim)

        if pooling == "mean":
            embedding = mx.mean(frame_embeddings, axis=1)
        elif pooling == "max":
            embedding = mx.max(frame_embeddings, axis=1)
        elif pooling == "both":
            mean_pool = mx.mean(frame_embeddings, axis=1)
            max_pool = mx.max(frame_embeddings, axis=1)
            embedding = mx.concatenate([mean_pool, max_pool], axis=1)
        else:
            raise ValueError(f"Unknown pooling: {pooling}")

        return embedding

    def load_weights(self, file_or_weights, strict: bool = True):
        """
        Override load_weights to handle quantized weights with dequantization

        Args:
            file_or_weights: Path to .npz file or list of (name, array) tuples
            strict: If True, all parameters must match exactly
        """
        # Load weights from file if needed
        if isinstance(file_or_weights, str):
            loaded_weights = mx.load(file_or_weights)
        else:
            loaded_weights = dict(file_or_weights)

        # Dequantize weights that have scales and biases
        dequantized_weights = {}
        quantized_names = set()

        for name, array in loaded_weights.items():
            # Check if this is a quantized weight by looking for scales/biases with metadata
            # Format: name:qSCALES_GS64_B4 or name:qBIASES_GS64_B4
            if ':qSCALES_GS' in name or ':qBIASES_GS' in name:
                # Skip, will be processed when we see the main weight
                continue

            # Check if this weight has quantization metadata
            has_quantization = any(k.startswith(f"{name}:qSCALES_GS") for k in loaded_weights.keys())

            if has_quantization:
                # Find the scales key to extract group_size and bits
                scales_key = next(k for k in loaded_weights.keys() if k.startswith(f"{name}:qSCALES_GS"))
                # Parse: name:qSCALES_GS64_B4 -> extract GS64 and B4
                import re
                match = re.search(r'GS(\d+)_B(\d+)', scales_key)
                if match:
                    group_size = int(match.group(1))
                    bits = int(match.group(2))

                    # Get scales and biases
                    biases_key = f"{name}:qBIASES_GS{group_size}_B{bits}"
                    scales = loaded_weights[scales_key]
                    biases = loaded_weights[biases_key]

                    # Dequantize the weight
                    dequantized = mx.dequantize(array, scales, biases, group_size=group_size, bits=bits)
                    dequantized_weights[name] = dequantized
                    quantized_names.add(name)
                else:
                    # Fallback: couldn't parse, keep original
                    dequantized_weights[name] = array
            else:
                # Regular weight (not quantized)
                dequantized_weights[name] = array

        # Use the parent class load_weights with dequantized weights
        super().load_weights(list(dequantized_weights.items()), strict=strict)


def load_model(weights_path: str, config_path: Optional[str] = None) -> CAMPPModelScopeV2:
    """Load model from weights and config"""
    if config_path:
        with open(config_path, 'r') as f:
            config = json.load(f)
    else:
        config = {
            'input_dim': 80,
            'channels': 512,
            'block_layers': [4, 9, 16],
            'embedding_dim': 192,
            'cam_channels': 128,
            'input_kernel_size': 5
        }

    model = CAMPPModelScopeV2(**config)
    weights = mx.load(weights_path)
    model.load_weights(list(weights.items()))

    return model