| | """ |
| | MLX implementation of CAM++ model - ModelScope architecture (Clean implementation) |
| | |
| | Based on analysis of iic/speech_campplus_sv_zh_en_16k-common_advanced: |
| | - Dense connections: each layer's output is concatenated with all previous outputs |
| | - TDNN layers use kernel_size=1 (no temporal context in main conv) |
| | - CAM layers provide the actual feature extraction |
| | - Architecture: Input β Dense Blocks (with CAM) β Transitions β Dense Layer |
| | """ |
| |
|
| | import mlx.core as mx |
| | import mlx.nn as nn |
| | from typing import Dict, List, Optional |
| | import json |
| |
|
| |
|
| | class EmbeddedCAM(nn.Module): |
| | """ |
| | Context-Aware Masking module embedded within TDNN layers |
| | |
| | Architecture (verified from ModelScope weights): |
| | - linear1: 1x1 Conv (in_channels β cam_channels//2) with bias |
| | - linear2: 1x1 Conv (cam_channels//2 β cam_channels//4) with bias |
| | - linear_local: 3x3 Conv (in_channels β cam_channels//4) without bias |
| | - Output: cam_channels//4 channels (e.g., 32 for cam_channels=128) |
| | """ |
| |
|
| | def __init__(self, in_channels: int, cam_channels: int = 128): |
| | super().__init__() |
| |
|
| | |
| | self.linear1 = nn.Conv1d( |
| | in_channels=in_channels, |
| | out_channels=cam_channels // 2, |
| | kernel_size=1, |
| | bias=True |
| | ) |
| |
|
| | self.linear2 = nn.Conv1d( |
| | in_channels=cam_channels // 2, |
| | out_channels=cam_channels // 4, |
| | kernel_size=1, |
| | bias=True |
| | ) |
| |
|
| | |
| | self.linear_local = nn.Conv1d( |
| | in_channels=in_channels, |
| | out_channels=cam_channels // 4, |
| | kernel_size=3, |
| | padding=1, |
| | bias=False |
| | ) |
| |
|
| | def __call__(self, x: mx.array) -> mx.array: |
| | """ |
| | Apply context-aware masking |
| | |
| | Args: |
| | x: Input (batch, length, in_channels) - channels_last format |
| | |
| | Returns: |
| | Output (batch, length, cam_channels//4) |
| | """ |
| | |
| | global_context = self.linear1(x) |
| | global_context = nn.relu(global_context) |
| | global_context = self.linear2(global_context) |
| |
|
| | |
| | local_context = self.linear_local(x) |
| |
|
| | |
| | mask = nn.sigmoid(global_context) |
| | output = local_context * mask |
| |
|
| | return output |
| |
|
| |
|
| | class TDNNLayerWithCAM(nn.Module): |
| | """ |
| | TDNN layer with embedded CAM (verified architecture) |
| | |
| | Flow: |
| | 1. Main conv: kernel_size=1 (channels projection) |
| | 2. BatchNorm |
| | 3. ReLU |
| | 4. CAM: extracts features and outputs cam_channels//4 |
| | |
| | Note: The main conv projects to a fixed channel size (e.g., 128), |
| | then CAM reduces to cam_channels//4 (e.g., 32) for dense connection. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | in_channels: int, |
| | out_channels: int = 128, |
| | cam_channels: int = 128 |
| | ): |
| | super().__init__() |
| |
|
| | |
| | self.conv = nn.Conv1d( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | kernel_size=1, |
| | padding=0, |
| | bias=False |
| | ) |
| |
|
| | |
| | self.bn = nn.BatchNorm(out_channels, affine=True) |
| |
|
| | |
| | self.activation = nn.ReLU() |
| |
|
| | |
| | self.cam = EmbeddedCAM( |
| | in_channels=out_channels, |
| | cam_channels=cam_channels |
| | ) |
| |
|
| | def __call__(self, x: mx.array) -> mx.array: |
| | """ |
| | Forward pass |
| | |
| | Args: |
| | x: Input (batch, length, in_channels) |
| | |
| | Returns: |
| | CAM output (batch, length, cam_channels//4) |
| | """ |
| | |
| | out = self.conv(x) |
| | out = self.bn(out) |
| | out = self.activation(out) |
| |
|
| | |
| | out = self.cam(out) |
| |
|
| | return out |
| |
|
| |
|
| | class TransitionLayer(nn.Module): |
| | """ |
| | Transition layer between dense blocks |
| | |
| | Reduces the accumulated channels back to base channel count. |
| | Architecture: BatchNorm β ReLU β 1x1 Conv |
| | """ |
| |
|
| | def __init__(self, in_channels: int, out_channels: int): |
| | super().__init__() |
| |
|
| | self.bn = nn.BatchNorm(in_channels, affine=True) |
| | self.activation = nn.ReLU() |
| | self.conv = nn.Conv1d( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | kernel_size=1, |
| | bias=False |
| | ) |
| |
|
| | def __call__(self, x: mx.array) -> mx.array: |
| | out = self.bn(x) |
| | out = self.activation(out) |
| | out = self.conv(out) |
| | return out |
| |
|
| |
|
| | class CAMPPModelScopeV2(nn.Module): |
| | """ |
| | Clean CAM++ implementation matching ModelScope architecture |
| | |
| | Key features: |
| | - Dense connections: each layer's output is concatenated |
| | - TDNN layers use kernel_size=1 |
| | - CAM provides feature extraction (outputs cam_channels//4 per layer) |
| | - Transitions reduce accumulated channels back to base |
| | |
| | Args: |
| | input_dim: Input feature dimension (e.g., 80 or 320) |
| | channels: Base channel count (e.g., 128 or 512) |
| | block_layers: Layers per block (e.g., [12, 24, 16]) |
| | embedding_dim: Output embedding dimension (e.g., 192) |
| | cam_channels: CAM channel count (e.g., 128) |
| | input_kernel_size: Input layer kernel size (e.g., 5) |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | input_dim: int = 80, |
| | channels: int = 512, |
| | block_layers: List[int] = None, |
| | embedding_dim: int = 192, |
| | cam_channels: int = 128, |
| | input_kernel_size: int = 5 |
| | ): |
| | super().__init__() |
| |
|
| | if block_layers is None: |
| | block_layers = [4, 9, 16] |
| |
|
| | self.input_dim = input_dim |
| | self.channels = channels |
| | self.block_layers = block_layers |
| | self.embedding_dim = embedding_dim |
| | self.cam_channels = cam_channels |
| | self.growth_rate = cam_channels // 4 |
| |
|
| | |
| | self.input_conv = nn.Conv1d( |
| | in_channels=input_dim, |
| | out_channels=channels, |
| | kernel_size=input_kernel_size, |
| | padding=input_kernel_size // 2, |
| | bias=False |
| | ) |
| | self.input_bn = nn.BatchNorm(channels, affine=True) |
| | self.input_activation = nn.ReLU() |
| |
|
| | |
| | for i in range(block_layers[0]): |
| | in_ch = channels + i * self.growth_rate |
| | layer = TDNNLayerWithCAM( |
| | in_channels=in_ch, |
| | out_channels=channels, |
| | cam_channels=cam_channels |
| | ) |
| | setattr(self, f'block0_{i}', layer) |
| | self._block0_size = block_layers[0] |
| |
|
| | |
| | transit1_in = channels + block_layers[0] * self.growth_rate |
| | transit1_out = channels * 2 |
| | self.transit1 = TransitionLayer(transit1_in, transit1_out) |
| |
|
| | |
| | for i in range(block_layers[1]): |
| | in_ch = transit1_out + i * self.growth_rate |
| | layer = TDNNLayerWithCAM( |
| | in_channels=in_ch, |
| | out_channels=channels, |
| | cam_channels=cam_channels |
| | ) |
| | setattr(self, f'block1_{i}', layer) |
| | self._block1_size = block_layers[1] |
| |
|
| | |
| | transit2_in = transit1_out + block_layers[1] * self.growth_rate |
| | transit2_out = transit1_out * 2 |
| | self.transit2 = TransitionLayer(transit2_in, transit2_out) |
| |
|
| | |
| | for i in range(block_layers[2]): |
| | in_ch = transit2_out + i * self.growth_rate |
| | layer = TDNNLayerWithCAM( |
| | in_channels=in_ch, |
| | out_channels=channels, |
| | cam_channels=cam_channels |
| | ) |
| | setattr(self, f'block2_{i}', layer) |
| | self._block2_size = block_layers[2] |
| |
|
| | |
| | dense_in = transit2_out + block_layers[2] * self.growth_rate |
| | self.dense = nn.Conv1d( |
| | in_channels=dense_in, |
| | out_channels=embedding_dim, |
| | kernel_size=1, |
| | bias=False |
| | ) |
| |
|
| | def __call__(self, x: mx.array) -> mx.array: |
| | """ |
| | Forward pass |
| | |
| | Args: |
| | x: Input (batch, length, in_channels) - channels_last format |
| | |
| | Returns: |
| | Embeddings (batch, length, embedding_dim) |
| | """ |
| | |
| | if x.ndim == 2: |
| | x = mx.expand_dims(x, axis=0) |
| |
|
| | |
| | if x.shape[2] != self.input_dim: |
| | x = mx.transpose(x, (0, 2, 1)) |
| |
|
| | |
| | out = self.input_conv(x) |
| | out = self.input_bn(out) |
| | out = self.input_activation(out) |
| |
|
| | |
| | for i in range(self._block0_size): |
| | layer = getattr(self, f'block0_{i}') |
| | layer_out = layer(out) |
| | out = mx.concatenate([out, layer_out], axis=2) |
| |
|
| | |
| | out = self.transit1(out) |
| |
|
| | |
| | for i in range(self._block1_size): |
| | layer = getattr(self, f'block1_{i}') |
| | layer_out = layer(out) |
| | out = mx.concatenate([out, layer_out], axis=2) |
| |
|
| | |
| | out = self.transit2(out) |
| |
|
| | |
| | for i in range(self._block2_size): |
| | layer = getattr(self, f'block2_{i}') |
| | layer_out = layer(out) |
| | out = mx.concatenate([out, layer_out], axis=2) |
| |
|
| | |
| | embeddings = self.dense(out) |
| |
|
| | return embeddings |
| |
|
| | def extract_embedding(self, x: mx.array, pooling: str = "mean") -> mx.array: |
| | """ |
| | Extract fixed-size speaker embedding |
| | |
| | Args: |
| | x: Input (batch, length, in_channels) |
| | pooling: "mean", "max", or "both" |
| | |
| | Returns: |
| | Embedding (batch, embedding_dim) |
| | """ |
| | frame_embeddings = self(x) |
| |
|
| | if pooling == "mean": |
| | embedding = mx.mean(frame_embeddings, axis=1) |
| | elif pooling == "max": |
| | embedding = mx.max(frame_embeddings, axis=1) |
| | elif pooling == "both": |
| | mean_pool = mx.mean(frame_embeddings, axis=1) |
| | max_pool = mx.max(frame_embeddings, axis=1) |
| | embedding = mx.concatenate([mean_pool, max_pool], axis=1) |
| | else: |
| | raise ValueError(f"Unknown pooling: {pooling}") |
| |
|
| | return embedding |
| |
|
| | def load_weights(self, file_or_weights, strict: bool = True): |
| | """ |
| | Override load_weights to handle quantized weights with dequantization |
| | |
| | Args: |
| | file_or_weights: Path to .npz file or list of (name, array) tuples |
| | strict: If True, all parameters must match exactly |
| | """ |
| | |
| | if isinstance(file_or_weights, str): |
| | loaded_weights = mx.load(file_or_weights) |
| | else: |
| | loaded_weights = dict(file_or_weights) |
| |
|
| | |
| | dequantized_weights = {} |
| | quantized_names = set() |
| |
|
| | for name, array in loaded_weights.items(): |
| | |
| | |
| | if ':qSCALES_GS' in name or ':qBIASES_GS' in name: |
| | |
| | continue |
| |
|
| | |
| | has_quantization = any(k.startswith(f"{name}:qSCALES_GS") for k in loaded_weights.keys()) |
| |
|
| | if has_quantization: |
| | |
| | scales_key = next(k for k in loaded_weights.keys() if k.startswith(f"{name}:qSCALES_GS")) |
| | |
| | import re |
| | match = re.search(r'GS(\d+)_B(\d+)', scales_key) |
| | if match: |
| | group_size = int(match.group(1)) |
| | bits = int(match.group(2)) |
| |
|
| | |
| | biases_key = f"{name}:qBIASES_GS{group_size}_B{bits}" |
| | scales = loaded_weights[scales_key] |
| | biases = loaded_weights[biases_key] |
| |
|
| | |
| | dequantized = mx.dequantize(array, scales, biases, group_size=group_size, bits=bits) |
| | dequantized_weights[name] = dequantized |
| | quantized_names.add(name) |
| | else: |
| | |
| | dequantized_weights[name] = array |
| | else: |
| | |
| | dequantized_weights[name] = array |
| |
|
| | |
| | super().load_weights(list(dequantized_weights.items()), strict=strict) |
| |
|
| |
|
| | def load_model(weights_path: str, config_path: Optional[str] = None) -> CAMPPModelScopeV2: |
| | """Load model from weights and config""" |
| | if config_path: |
| | with open(config_path, 'r') as f: |
| | config = json.load(f) |
| | else: |
| | config = { |
| | 'input_dim': 80, |
| | 'channels': 512, |
| | 'block_layers': [4, 9, 16], |
| | 'embedding_dim': 192, |
| | 'cam_channels': 128, |
| | 'input_kernel_size': 5 |
| | } |
| |
|
| | model = CAMPPModelScopeV2(**config) |
| | weights = mx.load(weights_path) |
| | model.load_weights(list(weights.items())) |
| |
|
| | return model |
| |
|