|
|
""" |
|
|
MLX implementation of CAM++ model - ModelScope architecture (Clean implementation) |
|
|
|
|
|
Based on analysis of iic/speech_campplus_sv_zh_en_16k-common_advanced: |
|
|
- Dense connections: each layer's output is concatenated with all previous outputs |
|
|
- TDNN layers use kernel_size=1 (no temporal context in main conv) |
|
|
- CAM layers provide the actual feature extraction |
|
|
- Architecture: Input β Dense Blocks (with CAM) β Transitions β Dense Layer |
|
|
""" |
|
|
|
|
|
import mlx.core as mx |
|
|
import mlx.nn as nn |
|
|
from typing import Dict, List, Optional |
|
|
import json |
|
|
|
|
|
|
|
|
class EmbeddedCAM(nn.Module): |
|
|
""" |
|
|
Context-Aware Masking module embedded within TDNN layers |
|
|
|
|
|
Architecture (verified from ModelScope weights): |
|
|
- linear1: 1x1 Conv (in_channels β cam_channels//2) with bias |
|
|
- linear2: 1x1 Conv (cam_channels//2 β cam_channels//4) with bias |
|
|
- linear_local: 3x3 Conv (in_channels β cam_channels//4) without bias |
|
|
- Output: cam_channels//4 channels (e.g., 32 for cam_channels=128) |
|
|
""" |
|
|
|
|
|
def __init__(self, in_channels: int, cam_channels: int = 128): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.linear1 = nn.Conv1d( |
|
|
in_channels=in_channels, |
|
|
out_channels=cam_channels // 2, |
|
|
kernel_size=1, |
|
|
bias=True |
|
|
) |
|
|
|
|
|
self.linear2 = nn.Conv1d( |
|
|
in_channels=cam_channels // 2, |
|
|
out_channels=cam_channels // 4, |
|
|
kernel_size=1, |
|
|
bias=True |
|
|
) |
|
|
|
|
|
|
|
|
self.linear_local = nn.Conv1d( |
|
|
in_channels=in_channels, |
|
|
out_channels=cam_channels // 4, |
|
|
kernel_size=3, |
|
|
padding=1, |
|
|
bias=False |
|
|
) |
|
|
|
|
|
def __call__(self, x: mx.array) -> mx.array: |
|
|
""" |
|
|
Apply context-aware masking |
|
|
|
|
|
Args: |
|
|
x: Input (batch, length, in_channels) - channels_last format |
|
|
|
|
|
Returns: |
|
|
Output (batch, length, cam_channels//4) |
|
|
""" |
|
|
|
|
|
global_context = self.linear1(x) |
|
|
global_context = nn.relu(global_context) |
|
|
global_context = self.linear2(global_context) |
|
|
|
|
|
|
|
|
local_context = self.linear_local(x) |
|
|
|
|
|
|
|
|
mask = nn.sigmoid(global_context) |
|
|
output = local_context * mask |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class TDNNLayerWithCAM(nn.Module): |
|
|
""" |
|
|
TDNN layer with embedded CAM (verified architecture) |
|
|
|
|
|
Flow: |
|
|
1. Main conv: kernel_size=1 (channels projection) |
|
|
2. BatchNorm |
|
|
3. ReLU |
|
|
4. CAM: extracts features and outputs cam_channels//4 |
|
|
|
|
|
Note: The main conv projects to a fixed channel size (e.g., 128), |
|
|
then CAM reduces to cam_channels//4 (e.g., 32) for dense connection. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int, |
|
|
out_channels: int = 128, |
|
|
cam_channels: int = 128 |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.conv = nn.Conv1d( |
|
|
in_channels=in_channels, |
|
|
out_channels=out_channels, |
|
|
kernel_size=1, |
|
|
padding=0, |
|
|
bias=False |
|
|
) |
|
|
|
|
|
|
|
|
self.bn = nn.BatchNorm(out_channels, affine=True) |
|
|
|
|
|
|
|
|
self.activation = nn.ReLU() |
|
|
|
|
|
|
|
|
self.cam = EmbeddedCAM( |
|
|
in_channels=out_channels, |
|
|
cam_channels=cam_channels |
|
|
) |
|
|
|
|
|
def __call__(self, x: mx.array) -> mx.array: |
|
|
""" |
|
|
Forward pass |
|
|
|
|
|
Args: |
|
|
x: Input (batch, length, in_channels) |
|
|
|
|
|
Returns: |
|
|
CAM output (batch, length, cam_channels//4) |
|
|
""" |
|
|
|
|
|
out = self.conv(x) |
|
|
out = self.bn(out) |
|
|
out = self.activation(out) |
|
|
|
|
|
|
|
|
out = self.cam(out) |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
class TransitionLayer(nn.Module): |
|
|
""" |
|
|
Transition layer between dense blocks |
|
|
|
|
|
Reduces the accumulated channels back to base channel count. |
|
|
Architecture: BatchNorm β ReLU β 1x1 Conv |
|
|
""" |
|
|
|
|
|
def __init__(self, in_channels: int, out_channels: int): |
|
|
super().__init__() |
|
|
|
|
|
self.bn = nn.BatchNorm(in_channels, affine=True) |
|
|
self.activation = nn.ReLU() |
|
|
self.conv = nn.Conv1d( |
|
|
in_channels=in_channels, |
|
|
out_channels=out_channels, |
|
|
kernel_size=1, |
|
|
bias=False |
|
|
) |
|
|
|
|
|
def __call__(self, x: mx.array) -> mx.array: |
|
|
out = self.bn(x) |
|
|
out = self.activation(out) |
|
|
out = self.conv(out) |
|
|
return out |
|
|
|
|
|
|
|
|
class CAMPPModelScopeV2(nn.Module): |
|
|
""" |
|
|
Clean CAM++ implementation matching ModelScope architecture |
|
|
|
|
|
Key features: |
|
|
- Dense connections: each layer's output is concatenated |
|
|
- TDNN layers use kernel_size=1 |
|
|
- CAM provides feature extraction (outputs cam_channels//4 per layer) |
|
|
- Transitions reduce accumulated channels back to base |
|
|
|
|
|
Args: |
|
|
input_dim: Input feature dimension (e.g., 80 or 320) |
|
|
channels: Base channel count (e.g., 128 or 512) |
|
|
block_layers: Layers per block (e.g., [12, 24, 16]) |
|
|
embedding_dim: Output embedding dimension (e.g., 192) |
|
|
cam_channels: CAM channel count (e.g., 128) |
|
|
input_kernel_size: Input layer kernel size (e.g., 5) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_dim: int = 80, |
|
|
channels: int = 512, |
|
|
block_layers: List[int] = None, |
|
|
embedding_dim: int = 192, |
|
|
cam_channels: int = 128, |
|
|
input_kernel_size: int = 5 |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
if block_layers is None: |
|
|
block_layers = [4, 9, 16] |
|
|
|
|
|
self.input_dim = input_dim |
|
|
self.channels = channels |
|
|
self.block_layers = block_layers |
|
|
self.embedding_dim = embedding_dim |
|
|
self.cam_channels = cam_channels |
|
|
self.growth_rate = cam_channels // 4 |
|
|
|
|
|
|
|
|
self.input_conv = nn.Conv1d( |
|
|
in_channels=input_dim, |
|
|
out_channels=channels, |
|
|
kernel_size=input_kernel_size, |
|
|
padding=input_kernel_size // 2, |
|
|
bias=False |
|
|
) |
|
|
self.input_bn = nn.BatchNorm(channels, affine=True) |
|
|
self.input_activation = nn.ReLU() |
|
|
|
|
|
|
|
|
for i in range(block_layers[0]): |
|
|
in_ch = channels + i * self.growth_rate |
|
|
layer = TDNNLayerWithCAM( |
|
|
in_channels=in_ch, |
|
|
out_channels=channels, |
|
|
cam_channels=cam_channels |
|
|
) |
|
|
setattr(self, f'block0_{i}', layer) |
|
|
self._block0_size = block_layers[0] |
|
|
|
|
|
|
|
|
transit1_in = channels + block_layers[0] * self.growth_rate |
|
|
transit1_out = channels * 2 |
|
|
self.transit1 = TransitionLayer(transit1_in, transit1_out) |
|
|
|
|
|
|
|
|
for i in range(block_layers[1]): |
|
|
in_ch = transit1_out + i * self.growth_rate |
|
|
layer = TDNNLayerWithCAM( |
|
|
in_channels=in_ch, |
|
|
out_channels=channels, |
|
|
cam_channels=cam_channels |
|
|
) |
|
|
setattr(self, f'block1_{i}', layer) |
|
|
self._block1_size = block_layers[1] |
|
|
|
|
|
|
|
|
transit2_in = transit1_out + block_layers[1] * self.growth_rate |
|
|
transit2_out = transit1_out * 2 |
|
|
self.transit2 = TransitionLayer(transit2_in, transit2_out) |
|
|
|
|
|
|
|
|
for i in range(block_layers[2]): |
|
|
in_ch = transit2_out + i * self.growth_rate |
|
|
layer = TDNNLayerWithCAM( |
|
|
in_channels=in_ch, |
|
|
out_channels=channels, |
|
|
cam_channels=cam_channels |
|
|
) |
|
|
setattr(self, f'block2_{i}', layer) |
|
|
self._block2_size = block_layers[2] |
|
|
|
|
|
|
|
|
dense_in = transit2_out + block_layers[2] * self.growth_rate |
|
|
self.dense = nn.Conv1d( |
|
|
in_channels=dense_in, |
|
|
out_channels=embedding_dim, |
|
|
kernel_size=1, |
|
|
bias=False |
|
|
) |
|
|
|
|
|
def __call__(self, x: mx.array) -> mx.array: |
|
|
""" |
|
|
Forward pass |
|
|
|
|
|
Args: |
|
|
x: Input (batch, length, in_channels) - channels_last format |
|
|
|
|
|
Returns: |
|
|
Embeddings (batch, length, embedding_dim) |
|
|
""" |
|
|
|
|
|
if x.ndim == 2: |
|
|
x = mx.expand_dims(x, axis=0) |
|
|
|
|
|
|
|
|
if x.shape[2] != self.input_dim: |
|
|
x = mx.transpose(x, (0, 2, 1)) |
|
|
|
|
|
|
|
|
out = self.input_conv(x) |
|
|
out = self.input_bn(out) |
|
|
out = self.input_activation(out) |
|
|
|
|
|
|
|
|
for i in range(self._block0_size): |
|
|
layer = getattr(self, f'block0_{i}') |
|
|
layer_out = layer(out) |
|
|
out = mx.concatenate([out, layer_out], axis=2) |
|
|
|
|
|
|
|
|
out = self.transit1(out) |
|
|
|
|
|
|
|
|
for i in range(self._block1_size): |
|
|
layer = getattr(self, f'block1_{i}') |
|
|
layer_out = layer(out) |
|
|
out = mx.concatenate([out, layer_out], axis=2) |
|
|
|
|
|
|
|
|
out = self.transit2(out) |
|
|
|
|
|
|
|
|
for i in range(self._block2_size): |
|
|
layer = getattr(self, f'block2_{i}') |
|
|
layer_out = layer(out) |
|
|
out = mx.concatenate([out, layer_out], axis=2) |
|
|
|
|
|
|
|
|
embeddings = self.dense(out) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def extract_embedding(self, x: mx.array, pooling: str = "mean") -> mx.array: |
|
|
""" |
|
|
Extract fixed-size speaker embedding |
|
|
|
|
|
Args: |
|
|
x: Input (batch, length, in_channels) |
|
|
pooling: "mean", "max", or "both" |
|
|
|
|
|
Returns: |
|
|
Embedding (batch, embedding_dim) |
|
|
""" |
|
|
frame_embeddings = self(x) |
|
|
|
|
|
if pooling == "mean": |
|
|
embedding = mx.mean(frame_embeddings, axis=1) |
|
|
elif pooling == "max": |
|
|
embedding = mx.max(frame_embeddings, axis=1) |
|
|
elif pooling == "both": |
|
|
mean_pool = mx.mean(frame_embeddings, axis=1) |
|
|
max_pool = mx.max(frame_embeddings, axis=1) |
|
|
embedding = mx.concatenate([mean_pool, max_pool], axis=1) |
|
|
else: |
|
|
raise ValueError(f"Unknown pooling: {pooling}") |
|
|
|
|
|
return embedding |
|
|
|
|
|
def load_weights(self, file_or_weights, strict: bool = True): |
|
|
""" |
|
|
Override load_weights to handle quantized weights with dequantization |
|
|
|
|
|
Args: |
|
|
file_or_weights: Path to .npz file or list of (name, array) tuples |
|
|
strict: If True, all parameters must match exactly |
|
|
""" |
|
|
|
|
|
if isinstance(file_or_weights, str): |
|
|
loaded_weights = mx.load(file_or_weights) |
|
|
else: |
|
|
loaded_weights = dict(file_or_weights) |
|
|
|
|
|
|
|
|
dequantized_weights = {} |
|
|
quantized_names = set() |
|
|
|
|
|
for name, array in loaded_weights.items(): |
|
|
|
|
|
|
|
|
if ':qSCALES_GS' in name or ':qBIASES_GS' in name: |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
has_quantization = any(k.startswith(f"{name}:qSCALES_GS") for k in loaded_weights.keys()) |
|
|
|
|
|
if has_quantization: |
|
|
|
|
|
scales_key = next(k for k in loaded_weights.keys() if k.startswith(f"{name}:qSCALES_GS")) |
|
|
|
|
|
import re |
|
|
match = re.search(r'GS(\d+)_B(\d+)', scales_key) |
|
|
if match: |
|
|
group_size = int(match.group(1)) |
|
|
bits = int(match.group(2)) |
|
|
|
|
|
|
|
|
biases_key = f"{name}:qBIASES_GS{group_size}_B{bits}" |
|
|
scales = loaded_weights[scales_key] |
|
|
biases = loaded_weights[biases_key] |
|
|
|
|
|
|
|
|
dequantized = mx.dequantize(array, scales, biases, group_size=group_size, bits=bits) |
|
|
dequantized_weights[name] = dequantized |
|
|
quantized_names.add(name) |
|
|
else: |
|
|
|
|
|
dequantized_weights[name] = array |
|
|
else: |
|
|
|
|
|
dequantized_weights[name] = array |
|
|
|
|
|
|
|
|
super().load_weights(list(dequantized_weights.items()), strict=strict) |
|
|
|
|
|
|
|
|
def load_model(weights_path: str, config_path: Optional[str] = None) -> CAMPPModelScopeV2: |
|
|
"""Load model from weights and config""" |
|
|
if config_path: |
|
|
with open(config_path, 'r') as f: |
|
|
config = json.load(f) |
|
|
else: |
|
|
config = { |
|
|
'input_dim': 80, |
|
|
'channels': 512, |
|
|
'block_layers': [4, 9, 16], |
|
|
'embedding_dim': 192, |
|
|
'cam_channels': 128, |
|
|
'input_kernel_size': 5 |
|
|
} |
|
|
|
|
|
model = CAMPPModelScopeV2(**config) |
|
|
weights = mx.load(weights_path) |
|
|
model.load_weights(list(weights.items())) |
|
|
|
|
|
return model |
|
|
|