BMP's picture
Convert iic/speech_campplus_sv_zh_en_16k-common_advanced to MLX format
e1ecf71 verified
"""
MLX implementation of CAM++ model - ModelScope architecture (Clean implementation)
Based on analysis of iic/speech_campplus_sv_zh_en_16k-common_advanced:
- Dense connections: each layer's output is concatenated with all previous outputs
- TDNN layers use kernel_size=1 (no temporal context in main conv)
- CAM layers provide the actual feature extraction
- Architecture: Input β†’ Dense Blocks (with CAM) β†’ Transitions β†’ Dense Layer
"""
import mlx.core as mx
import mlx.nn as nn
from typing import Dict, List, Optional
import json
class EmbeddedCAM(nn.Module):
"""
Context-Aware Masking module embedded within TDNN layers
Architecture (verified from ModelScope weights):
- linear1: 1x1 Conv (in_channels β†’ cam_channels//2) with bias
- linear2: 1x1 Conv (cam_channels//2 β†’ cam_channels//4) with bias
- linear_local: 3x3 Conv (in_channels β†’ cam_channels//4) without bias
- Output: cam_channels//4 channels (e.g., 32 for cam_channels=128)
"""
def __init__(self, in_channels: int, cam_channels: int = 128):
super().__init__()
# Global context path: 1x1 β†’ 1x1
self.linear1 = nn.Conv1d(
in_channels=in_channels,
out_channels=cam_channels // 2, # 128 β†’ 64
kernel_size=1,
bias=True
)
self.linear2 = nn.Conv1d(
in_channels=cam_channels // 2, # 64
out_channels=cam_channels // 4, # 64 β†’ 32
kernel_size=1,
bias=True
)
# Local context path: 3x3 conv
self.linear_local = nn.Conv1d(
in_channels=in_channels,
out_channels=cam_channels // 4, # 128 β†’ 32
kernel_size=3,
padding=1,
bias=False
)
def __call__(self, x: mx.array) -> mx.array:
"""
Apply context-aware masking
Args:
x: Input (batch, length, in_channels) - channels_last format
Returns:
Output (batch, length, cam_channels//4)
"""
# Global context: 1x1 β†’ relu β†’ 1x1
global_context = self.linear1(x)
global_context = nn.relu(global_context)
global_context = self.linear2(global_context)
# Local context: 3x3 conv
local_context = self.linear_local(x)
# Apply sigmoid mask
mask = nn.sigmoid(global_context)
output = local_context * mask
return output
class TDNNLayerWithCAM(nn.Module):
"""
TDNN layer with embedded CAM (verified architecture)
Flow:
1. Main conv: kernel_size=1 (channels projection)
2. BatchNorm
3. ReLU
4. CAM: extracts features and outputs cam_channels//4
Note: The main conv projects to a fixed channel size (e.g., 128),
then CAM reduces to cam_channels//4 (e.g., 32) for dense connection.
"""
def __init__(
self,
in_channels: int,
out_channels: int = 128,
cam_channels: int = 128
):
super().__init__()
# Main TDNN: 1x1 conv (no temporal context)
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
padding=0,
bias=False
)
# BatchNorm on the conv output
self.bn = nn.BatchNorm(out_channels, affine=True)
# ReLU activation
self.activation = nn.ReLU()
# Embedded CAM (takes conv output, produces cam_channels//4)
self.cam = EmbeddedCAM(
in_channels=out_channels,
cam_channels=cam_channels
)
def __call__(self, x: mx.array) -> mx.array:
"""
Forward pass
Args:
x: Input (batch, length, in_channels)
Returns:
CAM output (batch, length, cam_channels//4)
"""
# Main conv + bn + relu
out = self.conv(x)
out = self.bn(out)
out = self.activation(out)
# CAM feature extraction
out = self.cam(out)
return out
class TransitionLayer(nn.Module):
"""
Transition layer between dense blocks
Reduces the accumulated channels back to base channel count.
Architecture: BatchNorm β†’ ReLU β†’ 1x1 Conv
"""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.bn = nn.BatchNorm(in_channels, affine=True)
self.activation = nn.ReLU()
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
bias=False
)
def __call__(self, x: mx.array) -> mx.array:
out = self.bn(x)
out = self.activation(out)
out = self.conv(out)
return out
class CAMPPModelScopeV2(nn.Module):
"""
Clean CAM++ implementation matching ModelScope architecture
Key features:
- Dense connections: each layer's output is concatenated
- TDNN layers use kernel_size=1
- CAM provides feature extraction (outputs cam_channels//4 per layer)
- Transitions reduce accumulated channels back to base
Args:
input_dim: Input feature dimension (e.g., 80 or 320)
channels: Base channel count (e.g., 128 or 512)
block_layers: Layers per block (e.g., [12, 24, 16])
embedding_dim: Output embedding dimension (e.g., 192)
cam_channels: CAM channel count (e.g., 128)
input_kernel_size: Input layer kernel size (e.g., 5)
"""
def __init__(
self,
input_dim: int = 80,
channels: int = 512,
block_layers: List[int] = None,
embedding_dim: int = 192,
cam_channels: int = 128,
input_kernel_size: int = 5
):
super().__init__()
if block_layers is None:
block_layers = [4, 9, 16]
self.input_dim = input_dim
self.channels = channels
self.block_layers = block_layers
self.embedding_dim = embedding_dim
self.cam_channels = cam_channels
self.growth_rate = cam_channels // 4 # Each layer adds this many channels
# Input layer
self.input_conv = nn.Conv1d(
in_channels=input_dim,
out_channels=channels,
kernel_size=input_kernel_size,
padding=input_kernel_size // 2,
bias=False
)
self.input_bn = nn.BatchNorm(channels, affine=True)
self.input_activation = nn.ReLU()
# Dense Block 0
for i in range(block_layers[0]):
in_ch = channels + i * self.growth_rate
layer = TDNNLayerWithCAM(
in_channels=in_ch,
out_channels=channels,
cam_channels=cam_channels
)
setattr(self, f'block0_{i}', layer)
self._block0_size = block_layers[0]
# Transition 1 - doubles channel count
transit1_in = channels + block_layers[0] * self.growth_rate
transit1_out = channels * 2
self.transit1 = TransitionLayer(transit1_in, transit1_out)
# Dense Block 1 - starts with doubled channels
for i in range(block_layers[1]):
in_ch = transit1_out + i * self.growth_rate
layer = TDNNLayerWithCAM(
in_channels=in_ch,
out_channels=channels,
cam_channels=cam_channels
)
setattr(self, f'block1_{i}', layer)
self._block1_size = block_layers[1]
# Transition 2 - doubles channel count again
transit2_in = transit1_out + block_layers[1] * self.growth_rate
transit2_out = transit1_out * 2 # 4x original channels
self.transit2 = TransitionLayer(transit2_in, transit2_out)
# Dense Block 2 - starts with quadrupled channels
for i in range(block_layers[2]):
in_ch = transit2_out + i * self.growth_rate
layer = TDNNLayerWithCAM(
in_channels=in_ch,
out_channels=channels,
cam_channels=cam_channels
)
setattr(self, f'block2_{i}', layer)
self._block2_size = block_layers[2]
# Final dense layer
dense_in = transit2_out + block_layers[2] * self.growth_rate
self.dense = nn.Conv1d(
in_channels=dense_in,
out_channels=embedding_dim,
kernel_size=1,
bias=False
)
def __call__(self, x: mx.array) -> mx.array:
"""
Forward pass
Args:
x: Input (batch, length, in_channels) - channels_last format
Returns:
Embeddings (batch, length, embedding_dim)
"""
# Handle input format
if x.ndim == 2:
x = mx.expand_dims(x, axis=0)
# MLX Conv1d expects (batch, length, in_channels)
if x.shape[2] != self.input_dim:
x = mx.transpose(x, (0, 2, 1))
# Input layer
out = self.input_conv(x)
out = self.input_bn(out)
out = self.input_activation(out)
# Dense Block 0 (with concatenation)
for i in range(self._block0_size):
layer = getattr(self, f'block0_{i}')
layer_out = layer(out)
out = mx.concatenate([out, layer_out], axis=2)
# Transition 1
out = self.transit1(out)
# Dense Block 1
for i in range(self._block1_size):
layer = getattr(self, f'block1_{i}')
layer_out = layer(out)
out = mx.concatenate([out, layer_out], axis=2)
# Transition 2
out = self.transit2(out)
# Dense Block 2
for i in range(self._block2_size):
layer = getattr(self, f'block2_{i}')
layer_out = layer(out)
out = mx.concatenate([out, layer_out], axis=2)
# Final dense layer
embeddings = self.dense(out)
return embeddings
def extract_embedding(self, x: mx.array, pooling: str = "mean") -> mx.array:
"""
Extract fixed-size speaker embedding
Args:
x: Input (batch, length, in_channels)
pooling: "mean", "max", or "both"
Returns:
Embedding (batch, embedding_dim)
"""
frame_embeddings = self(x) # (batch, length, embedding_dim)
if pooling == "mean":
embedding = mx.mean(frame_embeddings, axis=1)
elif pooling == "max":
embedding = mx.max(frame_embeddings, axis=1)
elif pooling == "both":
mean_pool = mx.mean(frame_embeddings, axis=1)
max_pool = mx.max(frame_embeddings, axis=1)
embedding = mx.concatenate([mean_pool, max_pool], axis=1)
else:
raise ValueError(f"Unknown pooling: {pooling}")
return embedding
def load_weights(self, file_or_weights, strict: bool = True):
"""
Override load_weights to handle quantized weights with dequantization
Args:
file_or_weights: Path to .npz file or list of (name, array) tuples
strict: If True, all parameters must match exactly
"""
# Load weights from file if needed
if isinstance(file_or_weights, str):
loaded_weights = mx.load(file_or_weights)
else:
loaded_weights = dict(file_or_weights)
# Dequantize weights that have scales and biases
dequantized_weights = {}
quantized_names = set()
for name, array in loaded_weights.items():
# Check if this is a quantized weight by looking for scales/biases with metadata
# Format: name:qSCALES_GS64_B4 or name:qBIASES_GS64_B4
if ':qSCALES_GS' in name or ':qBIASES_GS' in name:
# Skip, will be processed when we see the main weight
continue
# Check if this weight has quantization metadata
has_quantization = any(k.startswith(f"{name}:qSCALES_GS") for k in loaded_weights.keys())
if has_quantization:
# Find the scales key to extract group_size and bits
scales_key = next(k for k in loaded_weights.keys() if k.startswith(f"{name}:qSCALES_GS"))
# Parse: name:qSCALES_GS64_B4 -> extract GS64 and B4
import re
match = re.search(r'GS(\d+)_B(\d+)', scales_key)
if match:
group_size = int(match.group(1))
bits = int(match.group(2))
# Get scales and biases
biases_key = f"{name}:qBIASES_GS{group_size}_B{bits}"
scales = loaded_weights[scales_key]
biases = loaded_weights[biases_key]
# Dequantize the weight
dequantized = mx.dequantize(array, scales, biases, group_size=group_size, bits=bits)
dequantized_weights[name] = dequantized
quantized_names.add(name)
else:
# Fallback: couldn't parse, keep original
dequantized_weights[name] = array
else:
# Regular weight (not quantized)
dequantized_weights[name] = array
# Use the parent class load_weights with dequantized weights
super().load_weights(list(dequantized_weights.items()), strict=strict)
def load_model(weights_path: str, config_path: Optional[str] = None) -> CAMPPModelScopeV2:
"""Load model from weights and config"""
if config_path:
with open(config_path, 'r') as f:
config = json.load(f)
else:
config = {
'input_dim': 80,
'channels': 512,
'block_layers': [4, 9, 16],
'embedding_dim': 192,
'cam_channels': 128,
'input_kernel_size': 5
}
model = CAMPPModelScopeV2(**config)
weights = mx.load(weights_path)
model.load_weights(list(weights.items()))
return model