""" MLX implementation of CAM++ model - ModelScope architecture (Clean implementation) Based on analysis of iic/speech_campplus_sv_zh_en_16k-common_advanced: - Dense connections: each layer's output is concatenated with all previous outputs - TDNN layers use kernel_size=1 (no temporal context in main conv) - CAM layers provide the actual feature extraction - Architecture: Input → Dense Blocks (with CAM) → Transitions → Dense Layer """ import mlx.core as mx import mlx.nn as nn from typing import Dict, List, Optional import json class EmbeddedCAM(nn.Module): """ Context-Aware Masking module embedded within TDNN layers Architecture (verified from ModelScope weights): - linear1: 1x1 Conv (in_channels → cam_channels//2) with bias - linear2: 1x1 Conv (cam_channels//2 → cam_channels//4) with bias - linear_local: 3x3 Conv (in_channels → cam_channels//4) without bias - Output: cam_channels//4 channels (e.g., 32 for cam_channels=128) """ def __init__(self, in_channels: int, cam_channels: int = 128): super().__init__() # Global context path: 1x1 → 1x1 self.linear1 = nn.Conv1d( in_channels=in_channels, out_channels=cam_channels // 2, # 128 → 64 kernel_size=1, bias=True ) self.linear2 = nn.Conv1d( in_channels=cam_channels // 2, # 64 out_channels=cam_channels // 4, # 64 → 32 kernel_size=1, bias=True ) # Local context path: 3x3 conv self.linear_local = nn.Conv1d( in_channels=in_channels, out_channels=cam_channels // 4, # 128 → 32 kernel_size=3, padding=1, bias=False ) def __call__(self, x: mx.array) -> mx.array: """ Apply context-aware masking Args: x: Input (batch, length, in_channels) - channels_last format Returns: Output (batch, length, cam_channels//4) """ # Global context: 1x1 → relu → 1x1 global_context = self.linear1(x) global_context = nn.relu(global_context) global_context = self.linear2(global_context) # Local context: 3x3 conv local_context = self.linear_local(x) # Apply sigmoid mask mask = nn.sigmoid(global_context) output = local_context * mask return output class TDNNLayerWithCAM(nn.Module): """ TDNN layer with embedded CAM (verified architecture) Flow: 1. Main conv: kernel_size=1 (channels projection) 2. BatchNorm 3. ReLU 4. CAM: extracts features and outputs cam_channels//4 Note: The main conv projects to a fixed channel size (e.g., 128), then CAM reduces to cam_channels//4 (e.g., 32) for dense connection. """ def __init__( self, in_channels: int, out_channels: int = 128, cam_channels: int = 128 ): super().__init__() # Main TDNN: 1x1 conv (no temporal context) self.conv = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding=0, bias=False ) # BatchNorm on the conv output self.bn = nn.BatchNorm(out_channels, affine=True) # ReLU activation self.activation = nn.ReLU() # Embedded CAM (takes conv output, produces cam_channels//4) self.cam = EmbeddedCAM( in_channels=out_channels, cam_channels=cam_channels ) def __call__(self, x: mx.array) -> mx.array: """ Forward pass Args: x: Input (batch, length, in_channels) Returns: CAM output (batch, length, cam_channels//4) """ # Main conv + bn + relu out = self.conv(x) out = self.bn(out) out = self.activation(out) # CAM feature extraction out = self.cam(out) return out class TransitionLayer(nn.Module): """ Transition layer between dense blocks Reduces the accumulated channels back to base channel count. Architecture: BatchNorm → ReLU → 1x1 Conv """ def __init__(self, in_channels: int, out_channels: int): super().__init__() self.bn = nn.BatchNorm(in_channels, affine=True) self.activation = nn.ReLU() self.conv = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False ) def __call__(self, x: mx.array) -> mx.array: out = self.bn(x) out = self.activation(out) out = self.conv(out) return out class CAMPPModelScopeV2(nn.Module): """ Clean CAM++ implementation matching ModelScope architecture Key features: - Dense connections: each layer's output is concatenated - TDNN layers use kernel_size=1 - CAM provides feature extraction (outputs cam_channels//4 per layer) - Transitions reduce accumulated channels back to base Args: input_dim: Input feature dimension (e.g., 80 or 320) channels: Base channel count (e.g., 128 or 512) block_layers: Layers per block (e.g., [12, 24, 16]) embedding_dim: Output embedding dimension (e.g., 192) cam_channels: CAM channel count (e.g., 128) input_kernel_size: Input layer kernel size (e.g., 5) """ def __init__( self, input_dim: int = 80, channels: int = 512, block_layers: List[int] = None, embedding_dim: int = 192, cam_channels: int = 128, input_kernel_size: int = 5 ): super().__init__() if block_layers is None: block_layers = [4, 9, 16] self.input_dim = input_dim self.channels = channels self.block_layers = block_layers self.embedding_dim = embedding_dim self.cam_channels = cam_channels self.growth_rate = cam_channels // 4 # Each layer adds this many channels # Input layer self.input_conv = nn.Conv1d( in_channels=input_dim, out_channels=channels, kernel_size=input_kernel_size, padding=input_kernel_size // 2, bias=False ) self.input_bn = nn.BatchNorm(channels, affine=True) self.input_activation = nn.ReLU() # Dense Block 0 for i in range(block_layers[0]): in_ch = channels + i * self.growth_rate layer = TDNNLayerWithCAM( in_channels=in_ch, out_channels=channels, cam_channels=cam_channels ) setattr(self, f'block0_{i}', layer) self._block0_size = block_layers[0] # Transition 1 - doubles channel count transit1_in = channels + block_layers[0] * self.growth_rate transit1_out = channels * 2 self.transit1 = TransitionLayer(transit1_in, transit1_out) # Dense Block 1 - starts with doubled channels for i in range(block_layers[1]): in_ch = transit1_out + i * self.growth_rate layer = TDNNLayerWithCAM( in_channels=in_ch, out_channels=channels, cam_channels=cam_channels ) setattr(self, f'block1_{i}', layer) self._block1_size = block_layers[1] # Transition 2 - doubles channel count again transit2_in = transit1_out + block_layers[1] * self.growth_rate transit2_out = transit1_out * 2 # 4x original channels self.transit2 = TransitionLayer(transit2_in, transit2_out) # Dense Block 2 - starts with quadrupled channels for i in range(block_layers[2]): in_ch = transit2_out + i * self.growth_rate layer = TDNNLayerWithCAM( in_channels=in_ch, out_channels=channels, cam_channels=cam_channels ) setattr(self, f'block2_{i}', layer) self._block2_size = block_layers[2] # Final dense layer dense_in = transit2_out + block_layers[2] * self.growth_rate self.dense = nn.Conv1d( in_channels=dense_in, out_channels=embedding_dim, kernel_size=1, bias=False ) def __call__(self, x: mx.array) -> mx.array: """ Forward pass Args: x: Input (batch, length, in_channels) - channels_last format Returns: Embeddings (batch, length, embedding_dim) """ # Handle input format if x.ndim == 2: x = mx.expand_dims(x, axis=0) # MLX Conv1d expects (batch, length, in_channels) if x.shape[2] != self.input_dim: x = mx.transpose(x, (0, 2, 1)) # Input layer out = self.input_conv(x) out = self.input_bn(out) out = self.input_activation(out) # Dense Block 0 (with concatenation) for i in range(self._block0_size): layer = getattr(self, f'block0_{i}') layer_out = layer(out) out = mx.concatenate([out, layer_out], axis=2) # Transition 1 out = self.transit1(out) # Dense Block 1 for i in range(self._block1_size): layer = getattr(self, f'block1_{i}') layer_out = layer(out) out = mx.concatenate([out, layer_out], axis=2) # Transition 2 out = self.transit2(out) # Dense Block 2 for i in range(self._block2_size): layer = getattr(self, f'block2_{i}') layer_out = layer(out) out = mx.concatenate([out, layer_out], axis=2) # Final dense layer embeddings = self.dense(out) return embeddings def extract_embedding(self, x: mx.array, pooling: str = "mean") -> mx.array: """ Extract fixed-size speaker embedding Args: x: Input (batch, length, in_channels) pooling: "mean", "max", or "both" Returns: Embedding (batch, embedding_dim) """ frame_embeddings = self(x) # (batch, length, embedding_dim) if pooling == "mean": embedding = mx.mean(frame_embeddings, axis=1) elif pooling == "max": embedding = mx.max(frame_embeddings, axis=1) elif pooling == "both": mean_pool = mx.mean(frame_embeddings, axis=1) max_pool = mx.max(frame_embeddings, axis=1) embedding = mx.concatenate([mean_pool, max_pool], axis=1) else: raise ValueError(f"Unknown pooling: {pooling}") return embedding def load_weights(self, file_or_weights, strict: bool = True): """ Override load_weights to handle quantized weights with dequantization Args: file_or_weights: Path to .npz file or list of (name, array) tuples strict: If True, all parameters must match exactly """ # Load weights from file if needed if isinstance(file_or_weights, str): loaded_weights = mx.load(file_or_weights) else: loaded_weights = dict(file_or_weights) # Dequantize weights that have scales and biases dequantized_weights = {} quantized_names = set() for name, array in loaded_weights.items(): # Check if this is a quantized weight by looking for scales/biases with metadata # Format: name:qSCALES_GS64_B4 or name:qBIASES_GS64_B4 if ':qSCALES_GS' in name or ':qBIASES_GS' in name: # Skip, will be processed when we see the main weight continue # Check if this weight has quantization metadata has_quantization = any(k.startswith(f"{name}:qSCALES_GS") for k in loaded_weights.keys()) if has_quantization: # Find the scales key to extract group_size and bits scales_key = next(k for k in loaded_weights.keys() if k.startswith(f"{name}:qSCALES_GS")) # Parse: name:qSCALES_GS64_B4 -> extract GS64 and B4 import re match = re.search(r'GS(\d+)_B(\d+)', scales_key) if match: group_size = int(match.group(1)) bits = int(match.group(2)) # Get scales and biases biases_key = f"{name}:qBIASES_GS{group_size}_B{bits}" scales = loaded_weights[scales_key] biases = loaded_weights[biases_key] # Dequantize the weight dequantized = mx.dequantize(array, scales, biases, group_size=group_size, bits=bits) dequantized_weights[name] = dequantized quantized_names.add(name) else: # Fallback: couldn't parse, keep original dequantized_weights[name] = array else: # Regular weight (not quantized) dequantized_weights[name] = array # Use the parent class load_weights with dequantized weights super().load_weights(list(dequantized_weights.items()), strict=strict) def load_model(weights_path: str, config_path: Optional[str] = None) -> CAMPPModelScopeV2: """Load model from weights and config""" if config_path: with open(config_path, 'r') as f: config = json.load(f) else: config = { 'input_dim': 80, 'channels': 512, 'block_layers': [4, 9, 16], 'embedding_dim': 192, 'cam_channels': 128, 'input_kernel_size': 5 } model = CAMPPModelScopeV2(**config) weights = mx.load(weights_path) model.load_weights(list(weights.items())) return model