campp-mlx-converter / conversion_utils.py
BMP's picture
feat: Add batch conversion scripts for CAM++ models
656e7f6
import mlx.core as mx
import numpy as np
import torch
from typing import Dict, Any, Tuple, Optional, List, Set
from datetime import datetime
import re
import logging
# Set up logging
logger = logging.getLogger(__name__)
# Constants for conversion thresholds
MIN_QUANTIZATION_SIZE = 1000 # Don't quantize tensors smaller than this
MIN_VERIFICATION_RATE = 95.0 # Minimum acceptable verification rate (%)
MAX_VERIFICATION_FAILURES = 2 # Maximum allowed verification failures
BATCHNORM_EPS = 1e-5
BATCHNORM_MOMENTUM = 0.1
class ConversionUtils:
"""Utilities for converting PyTorch CAM++ models to MLX format"""
def __init__(self, use_modelscope_architecture: bool = True):
"""
Initialize conversion utilities
Args:
use_modelscope_architecture: If True, use ModelScope architecture with embedded CAM
If False, use original architecture with shared CAM
"""
self.use_modelscope_architecture = use_modelscope_architecture
self.layer_mapping = {
'conv1d': self._convert_conv1d,
'linear': self._convert_linear,
'batchnorm': self._convert_batchnorm,
'embedding': self._convert_embedding
}
def convert_weights_to_mlx(self, pytorch_weights: Dict[str, torch.Tensor]) -> Tuple[Dict[str, mx.array], Dict[str, Any]]:
"""
Convert PyTorch weights to MLX format
Args:
pytorch_weights: Dictionary of PyTorch tensors
Returns:
Tuple of (mlx_weights, model_config)
"""
mlx_weights = {}
model_config = self._analyze_model_structure(pytorch_weights)
# Filter out unnecessary parameters (BatchNorm running stats, etc.)
filtered_weights = self._filter_weights(pytorch_weights)
# Map parameter names from PyTorch to MLX format
mapped_weights = self._map_parameter_names(filtered_weights)
# Add default values for missing MLX parameters
mapped_weights = self._add_missing_parameters(mapped_weights, model_config)
# Convert each weight tensor
for name, tensor in mapped_weights.items():
if isinstance(tensor, torch.Tensor):
converted = self._convert_tensor(name, tensor)
# Skip None values (e.g., num_batches_tracked)
if converted is not None:
mlx_weights[name] = converted
else:
# Handle non-tensor values (e.g., integers, strings)
continue
return mlx_weights, model_config
def _analyze_model_structure(self, pytorch_weights: Dict[str, torch.Tensor]) -> Dict[str, Any]:
"""
Analyze the PyTorch model structure to infer configuration
Args:
pytorch_weights: PyTorch weights dictionary
Returns:
Model configuration dictionary
"""
config = {
'input_dim': 80, # Default mel spectrogram features
'embedding_dim': 192, # Default embedding dimension for ModelScope
'channels': 512, # Default number of channels
'cam_channels': 128, # Default CAM channels
}
# Detect block structure for ModelScope architecture
if self.use_modelscope_architecture:
blocks = {1: set(), 2: set(), 3: set()}
for name in pytorch_weights.keys():
if 'xvector.block1.tdnnd' in name:
layer_num = name.split('tdnnd')[1].split('.')[0]
blocks[1].add(int(layer_num))
elif 'xvector.block2.tdnnd' in name:
layer_num = name.split('tdnnd')[1].split('.')[0]
blocks[2].add(int(layer_num))
elif 'xvector.block3.tdnnd' in name:
layer_num = name.split('tdnnd')[1].split('.')[0]
blocks[3].add(int(layer_num))
# Set block_layers configuration
if any(blocks.values()):
config['block_layers'] = [
len(blocks[1]) if blocks[1] else 4, # Default to 4 if not found
len(blocks[2]) if blocks[2] else 9, # Default to 9 if not found
len(blocks[3]) if blocks[3] else 16 # Default to 16 if not found
]
logger.info(f"Detected block structure: {config['block_layers']}")
# Try to infer input dimension and kernel size from first conv layer
for name, tensor in pytorch_weights.items():
if 'xvector.tdnn.linear.weight' in name:
if tensor.ndim == 3: # Conv1d weight: (out_channels, in_channels, kernel_size)
config['input_dim'] = tensor.shape[1] # in_channels
config['channels'] = tensor.shape[0] # out_channels
config['input_kernel_size'] = tensor.shape[2] # kernel_size
logger.info(f"Detected input layer: dim={config['input_dim']}, channels={config['channels']}, kernel_size={config['input_kernel_size']}")
break
# Try to infer embedding dimension from dense layer
for name, tensor in pytorch_weights.items():
if 'xvector.dense.linear.weight' in name:
if tensor.ndim == 3: # Conv1d with kernel_size=1
config['embedding_dim'] = tensor.shape[0] # out_channels
break
# Count total parameters for estimation
total_params = sum(tensor.numel() for tensor in pytorch_weights.values())
config['total_params'] = total_params
return config
def _map_parameter_names(self, pytorch_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Map PyTorch parameter names to MLX parameter names
Args:
pytorch_weights: PyTorch weights with original names
Returns:
Weights with MLX-compatible parameter names
"""
mapped_weights = {}
for name, tensor in pytorch_weights.items():
# Choose mapping function based on architecture
if self.use_modelscope_architecture:
mlx_name = self._xvector_to_mlx_modelscope_name(name)
else:
mlx_name = self._xvector_to_mlx_name(name)
if mlx_name: # Only keep parameters that have MLX equivalents
mapped_weights[mlx_name] = tensor
return mapped_weights
def _add_missing_parameters(self, mapped_weights: Dict[str, torch.Tensor], model_config: Dict) -> Dict[str, torch.Tensor]:
"""
Add default values for MLX parameters that don't have PyTorch equivalents
Args:
mapped_weights: Already mapped weights
model_config: Model configuration
Returns:
Weights with missing parameters added
Note: This method intentionally does NOT add fake/random parameters.
Adding untrained random weights will degrade model accuracy significantly.
The conversion should only include weights that are actually mapped from
the source model. Better to fail explicitly when a layer is missing than
to add random weights that produce nonsensical outputs.
"""
# Return mapped weights as-is without adding arbitrary fake parameters
return mapped_weights
def get_missing_mlx_parameters(self, pytorch_weights: Dict[str, torch.Tensor], mlx_weights: Dict[str, mx.array]) -> Dict[str, str]:
"""
Get list of MLX parameters that don't have source PyTorch equivalents
Args:
pytorch_weights: Original PyTorch weights
mlx_weights: Converted MLX weights
Returns:
Dictionary mapping MLX parameter names to their source parameter names (or "NOT FOUND")
"""
missing_params = {}
# Define expected MLX model parameters
expected_mlx_params = {
# Input layer
'input_conv.weight', 'input_bn.weight', 'input_bn.bias',
'input_bn.running_mean', 'input_bn.running_var',
# Dense blocks (0-2)
'dense_blocks.0.layers.0.conv.weight', 'dense_blocks.0.layers.0.bn.weight', 'dense_blocks.0.layers.0.bn.bias',
'dense_blocks.0.layers.0.bn.running_mean', 'dense_blocks.0.layers.0.bn.running_var',
'dense_blocks.0.layers.1.conv.weight', 'dense_blocks.0.layers.1.bn.weight', 'dense_blocks.0.layers.1.bn.bias',
'dense_blocks.0.layers.2.conv.weight', 'dense_blocks.0.layers.2.bn.weight', 'dense_blocks.0.layers.2.bn.bias',
'dense_blocks.0.layers.3.conv.weight', 'dense_blocks.0.layers.3.bn.weight', 'dense_blocks.0.layers.3.bn.bias',
# Transitions
'transitions.0.layers.0.weight', 'transitions.0.layers.0.bias',
'transitions.0.layers.0.running_mean', 'transitions.0.layers.0.running_var',
'transitions.0.layers.2.weight',
'transitions.1.layers.0.weight', 'transitions.1.layers.0.bias',
'transitions.1.layers.0.running_mean', 'transitions.1.layers.0.running_var',
'transitions.1.layers.2.weight',
# CAM layer
'cam.context_conv1.weight', 'cam.context_conv1.bias',
'cam.context_conv3.weight', 'cam.context_conv3.bias',
'cam.context_conv5.weight', 'cam.context_conv5.bias',
'cam.mask_conv.weight', 'cam.mask_conv.bias',
'cam.bn.weight', 'cam.bn.bias', 'cam.bn.running_mean', 'cam.bn.running_var',
# Channel gating
'channel_gating.fc.layers.0.weight', 'channel_gating.fc.layers.0.bias',
'channel_gating.fc.layers.1.weight', 'channel_gating.fc.layers.1.bias',
'channel_gating.fc.layers.2.weight', 'channel_gating.fc.layers.2.bias',
# Pooling
'pooling.attention_weights.weight', 'pooling.attention_weights.bias',
'pooling.projection.weight', 'pooling.projection.bias',
# Final layer
'final_bn.weight', 'final_bn.bias', 'final_bn.running_mean', 'final_bn.running_var',
}
# Check which expected parameters are missing from converted weights
for param in expected_mlx_params:
if param not in mlx_weights:
missing_params[param] = "NOT FOUND"
return missing_params
def _xvector_to_mlx_modelscope_name(self, xvector_name: str) -> Optional[str]:
"""
Convert xvector parameter name to MLX ModelScope architecture parameter name
This mapping is for ModelScope CAM++ models where CAM is embedded in each TDNN layer.
Architecture:
- Input layer (TDNN)
- Block 1: 4 TDNN layers with embedded CAM
- Transit 1
- Block 2: 9 TDNN layers with embedded CAM
- Transit 2
- Block 3: 16 TDNN layers with embedded CAM
- Dense layer (Conv1d kernel_size=1)
Args:
xvector_name: Original xvector parameter name from PyTorch model
Returns:
MLX-compatible parameter name, or None if parameter should be skipped
"""
# ========== INPUT LAYER ==========
if xvector_name == 'xvector.tdnn.linear.weight':
return 'input_conv.weight'
elif 'xvector.tdnn.nonlinear.batchnorm' in xvector_name:
param_type = xvector_name.split('.')[-1] # bias, weight, running_mean, running_var
# Skip num_batches_tracked (PyTorch tracking statistic, not needed)
if param_type == 'num_batches_tracked':
return None
return f'input_bn.{param_type}'
# ========== DENSE BLOCKS WITH EMBEDDED CAM ==========
# Extract block number and layer number
import re
block_match = re.match(r'xvector\.block(\d+)\.tdnnd(\d+)\.(.*)', xvector_name)
if block_match:
block_num = int(block_match.group(1)) # 1, 2, or 3
layer_num = int(block_match.group(2)) # 1-indexed
param_path = block_match.group(3)
# Map to MLX block index (0, 1, 2)
mlx_block_idx = block_num - 1
# Map to MLX layer index (0-indexed)
mlx_layer_idx = layer_num - 1
# Main TDNN layer parameters
if param_path.startswith('linear1.'):
param_type = param_path.split('.')[-1]
return f'block{mlx_block_idx}_{mlx_layer_idx}.conv.{param_type}'
# PyTorch has TWO batch norms per layer:
# - nonlinear1.batchnorm: sized for INPUT channels (applied before conv)
# - nonlinear2.batchnorm: sized for OUTPUT channels (applied after conv)
# MLX model only has one BN (after conv), so map nonlinear2 to bn
elif param_path.startswith('nonlinear1.batchnorm.'):
# Skip nonlinear1 batch norm - it's sized for input channels
return None
elif param_path.startswith('nonlinear2.batchnorm.'):
param_type = param_path.split('.')[-1]
# Skip num_batches_tracked
if param_type == 'num_batches_tracked':
return None
return f'block{mlx_block_idx}_{mlx_layer_idx}.bn.{param_type}'
# Embedded CAM layer parameters
elif param_path.startswith('cam_layer.linear1.'):
param_type = param_path.split('.')[-1]
return f'block{mlx_block_idx}_{mlx_layer_idx}.cam.linear1.{param_type}'
elif param_path.startswith('cam_layer.linear2.'):
param_type = param_path.split('.')[-1]
return f'block{mlx_block_idx}_{mlx_layer_idx}.cam.linear2.{param_type}'
elif param_path.startswith('cam_layer.linear_local.'):
param_type = param_path.split('.')[-1]
return f'block{mlx_block_idx}_{mlx_layer_idx}.cam.linear_local.{param_type}'
# ========== TRANSITION LAYERS ==========
if 'xvector.transit1.' in xvector_name:
if '.linear.weight' in xvector_name:
return 'transit1.conv.weight'
elif 'nonlinear.batchnorm' in xvector_name:
param_type = xvector_name.split('.')[-1]
# Skip num_batches_tracked
if param_type == 'num_batches_tracked':
return None
return f'transit1.bn.{param_type}'
if 'xvector.transit2.' in xvector_name:
if '.linear.weight' in xvector_name:
return 'transit2.conv.weight'
elif 'nonlinear.batchnorm' in xvector_name:
param_type = xvector_name.split('.')[-1]
# Skip num_batches_tracked
if param_type == 'num_batches_tracked':
return None
return f'transit2.bn.{param_type}'
# ========== DENSE LAYER ==========
if 'xvector.dense.linear.' in xvector_name:
param_type = xvector_name.split('.')[-1]
return f'dense.{param_type}'
# ========== SKIP UNMAPPED PARAMETERS ==========
# These don't exist in ModelScope architecture
if any(x in xvector_name for x in ['head.', 'output.', 'pool', 'final_bn']):
logger.debug(f"Skipping parameter not in ModelScope architecture: {xvector_name}")
return None
# Log unexpected parameters
if xvector_name.startswith('xvector.'):
logger.debug(f"Skipping unmapped parameter: {xvector_name}")
return None
def _xvector_to_mlx_name(self, xvector_name: str) -> Optional[str]:
"""
Convert xvector parameter name to MLX parameter name with comprehensive mapping
This method maps PyTorch CAM++ xvector parameters to MLX CAMPPModel parameters.
It handles:
- Input layer (TDNN)
- Dense blocks (3 blocks with 4, 6, 8 layers respectively)
- Transition layers between blocks
- Context-Aware Masking (CAM) layer
- Channel gating mechanism
- Multi-granularity pooling
- Final batch normalization
Args:
xvector_name: Original xvector parameter name from PyTorch model
Returns:
MLX-compatible parameter name, or None if parameter should be skipped
"""
# ========== INPUT LAYER MAPPING ==========
if xvector_name == 'xvector.tdnn.linear.weight':
return 'input_conv.weight'
elif xvector_name == 'xvector.tdnn.nonlinear.batchnorm.bias':
return 'input_bn.bias'
elif xvector_name == 'xvector.tdnn.nonlinear.batchnorm.weight':
return 'input_bn.weight'
elif xvector_name == 'xvector.tdnn.nonlinear.batchnorm.running_mean':
return 'input_bn.running_mean'
elif xvector_name == 'xvector.tdnn.nonlinear.batchnorm.running_var':
return 'input_bn.running_var'
# ========== DENSE BLOCKS MAPPING ==========
# MLX architecture: block 0 (4 layers), block 1 (6 layers), block 2 (8 layers)
# Map PyTorch block1/block2/block3 to MLX dense_blocks.0/1/2
# Block 0: Map first 4 layers of PyTorch block1
for i in range(1, 13): # Handle up to 12 layers (generous for real models)
# Block 0 - first 4 layers
if i <= 4 and f'xvector.block1.tdnnd{i}.' in xvector_name:
layer_idx = i - 1
if '.linear1.weight' in xvector_name:
return f'dense_blocks.0.layers.{layer_idx}.conv.weight'
elif '.nonlinear1.batchnorm.bias' in xvector_name:
return f'dense_blocks.0.layers.{layer_idx}.bn.bias'
elif '.nonlinear1.batchnorm.weight' in xvector_name:
return f'dense_blocks.0.layers.{layer_idx}.bn.weight'
elif '.nonlinear1.batchnorm.running_mean' in xvector_name:
return f'dense_blocks.0.layers.{layer_idx}.bn.running_mean'
elif '.nonlinear1.batchnorm.running_var' in xvector_name:
return f'dense_blocks.0.layers.{layer_idx}.bn.running_var'
# Block 1 - first 6 layers of PyTorch block2
# Skip block2.tdnnd1 and block2.tdnnd2 as they may be used for transition
if i >= 3 and i <= 8 and f'xvector.block2.tdnnd{i}.' in xvector_name:
layer_idx = i - 3 # Map block2.tdnnd3 -> layer 0, etc.
if layer_idx < 6: # Only map first 6 layers
if '.linear1.weight' in xvector_name:
return f'dense_blocks.1.layers.{layer_idx}.conv.weight'
elif '.nonlinear1.batchnorm.bias' in xvector_name:
return f'dense_blocks.1.layers.{layer_idx}.bn.bias'
elif '.nonlinear1.batchnorm.weight' in xvector_name:
return f'dense_blocks.1.layers.{layer_idx}.bn.weight'
elif '.nonlinear1.batchnorm.running_mean' in xvector_name:
return f'dense_blocks.1.layers.{layer_idx}.bn.running_mean'
elif '.nonlinear1.batchnorm.running_var' in xvector_name:
return f'dense_blocks.1.layers.{layer_idx}.bn.running_var'
# Block 2 - first 8 layers of PyTorch block3
if i <= 8 and f'xvector.block3.tdnnd{i}.' in xvector_name:
layer_idx = i - 1
if '.linear1.weight' in xvector_name:
return f'dense_blocks.2.layers.{layer_idx}.conv.weight'
elif '.nonlinear1.batchnorm.bias' in xvector_name:
return f'dense_blocks.2.layers.{layer_idx}.bn.bias'
elif '.nonlinear1.batchnorm.weight' in xvector_name:
return f'dense_blocks.2.layers.{layer_idx}.bn.weight'
elif '.nonlinear1.batchnorm.running_mean' in xvector_name:
return f'dense_blocks.2.layers.{layer_idx}.bn.running_mean'
elif '.nonlinear1.batchnorm.running_var' in xvector_name:
return f'dense_blocks.2.layers.{layer_idx}.bn.running_var'
# ========== TRANSITION LAYERS MAPPING ==========
# Transition 0: After block 0
if 'xvector.transit1.' in xvector_name:
if '.linear.weight' in xvector_name:
return 'transitions.0.layers.2.weight'
elif '.nonlinear.batchnorm.bias' in xvector_name:
return 'transitions.0.layers.0.bias'
elif '.nonlinear.batchnorm.weight' in xvector_name:
return 'transitions.0.layers.0.weight'
elif '.nonlinear.batchnorm.running_mean' in xvector_name:
return 'transitions.0.layers.0.running_mean'
elif '.nonlinear.batchnorm.running_var' in xvector_name:
return 'transitions.0.layers.0.running_var'
# Transition 1: Use block2.tdnnd1 and tdnnd2 (before dense block 1)
if 'xvector.transit2.' in xvector_name or 'xvector.block2.tdnnd1.' in xvector_name:
# Map transit2 or beginning of block2 to transition 1
if '.linear.weight' in xvector_name or 'xvector.block2.tdnnd2.linear1.weight' in xvector_name:
return 'transitions.1.layers.2.weight'
elif '.nonlinear.batchnorm.bias' in xvector_name or 'xvector.block2.tdnnd1.nonlinear1.batchnorm.bias' in xvector_name:
return 'transitions.1.layers.0.bias'
elif '.nonlinear.batchnorm.weight' in xvector_name or 'xvector.block2.tdnnd1.nonlinear1.batchnorm.weight' in xvector_name:
return 'transitions.1.layers.0.weight'
elif '.nonlinear.batchnorm.running_mean' in xvector_name or 'xvector.block2.tdnnd1.nonlinear1.batchnorm.running_mean' in xvector_name:
return 'transitions.1.layers.0.running_mean'
elif '.nonlinear.batchnorm.running_var' in xvector_name or 'xvector.block2.tdnnd1.nonlinear1.batchnorm.running_var' in xvector_name:
return 'transitions.1.layers.0.running_var'
# ========== CAM LAYER MAPPING ==========
# Context-aware masking with multi-scale convolutions
# NOTE: Real ModelScope models have CAM embedded in EACH TDNN layer,
# but MLX model has ONE shared CAM layer. We map only the first occurrence
# from block1.tdnnd1.cam_layer and skip all others.
if 'cam_layer' in xvector_name or 'cam.' in xvector_name:
# Only map CAM from the first block's first layer
# Skip CAM from all other layers to avoid conflicts
is_first_cam = 'block1.tdnnd1.cam_layer' in xvector_name
if not is_first_cam:
logger.debug(f"Skipping embedded CAM layer (only using first occurrence): {xvector_name}")
return None
# Map first CAM layer to MLX shared CAM
# ModelScope structure: linear1 (1x1 conv), linear2 (1x1 conv), linear_local (3x3 conv)
# MLX structure: context_conv1 (1x1), context_conv3 (3x3), context_conv5 (5x5)
if 'cam_layer.linear1.weight' in xvector_name:
return 'cam.context_conv1.weight'
elif 'cam_layer.linear1.bias' in xvector_name:
logger.debug(f"Skipping CAM context_conv1 bias (MLX uses bias=False): {xvector_name}")
return None # MLX context_conv1 has bias=False
elif 'cam_layer.linear2.weight' in xvector_name:
# Map linear2 (1x1) to context_conv3 - note: this is a compromise
# Real model has 1x1 conv here, MLX expects 3x3
logger.warning(f"Mapping 1x1 conv to context_conv3 (shape mismatch possible): {xvector_name}")
return 'cam.context_conv3.weight'
elif 'cam_layer.linear2.bias' in xvector_name:
logger.debug(f"Skipping CAM context_conv3 bias (MLX uses bias=False): {xvector_name}")
return None # MLX context_conv3 has bias=False
elif 'cam_layer.linear_local.weight' in xvector_name:
# Map linear_local (3x3) to mask_conv
return 'cam.mask_conv.weight'
elif 'cam_layer.linear_local.bias' in xvector_name:
# linear_local typically has no bias in ModelScope models
logger.debug(f"Skipping CAM mask_conv bias: {xvector_name}")
return None
# Handle standalone cam. parameters (if model has separate CAM layer)
elif 'context1.weight' in xvector_name or (not 'cam_layer' in xvector_name and 'linear1.weight' in xvector_name):
return 'cam.context_conv1.weight'
elif 'context3.weight' in xvector_name or (not 'cam_layer' in xvector_name and 'linear2.weight' in xvector_name):
return 'cam.context_conv3.weight'
elif 'context5.weight' in xvector_name or (not 'cam_layer' in xvector_name and 'linear3.weight' in xvector_name):
return 'cam.context_conv5.weight'
elif 'mask_conv.weight' in xvector_name:
return 'cam.mask_conv.weight'
elif 'fusion.weight' in xvector_name:
return 'cam.fusion.weight'
# Batch normalization
elif 'batchnorm.weight' in xvector_name:
return 'cam.bn.weight'
elif 'batchnorm.bias' in xvector_name:
return 'cam.bn.bias'
elif 'running_mean' in xvector_name:
return 'cam.bn.running_mean'
elif 'running_var' in xvector_name:
return 'cam.bn.running_var'
# ========== CHANNEL GATING MAPPING ==========
# Channel-wise context gating (squeeze-excitation style)
# NOTE: Real ModelScope models only have xvector.dense.linear (single layer)
# MLX model expects 3-layer FC, but real model has only 1 layer
if 'xvector.dense.' in xvector_name:
if '.linear.weight' in xvector_name or 'xvector.dense.linear.weight' == xvector_name:
# Map to first layer - this is the only dense layer in real model
return 'channel_gating.fc.layers.0.weight'
elif '.linear.bias' in xvector_name or 'xvector.dense.linear.bias' == xvector_name:
# Check if bias exists (some models use Conv1d without bias)
logger.debug(f"Mapping dense bias (may not exist in Conv1d): {xvector_name}")
return 'channel_gating.fc.layers.0.bias'
# The following layers don't exist in real ModelScope models
elif 'linear_mid.weight' in xvector_name:
logger.warning(f"Found linear_mid layer (unexpected in ModelScope model): {xvector_name}")
return 'channel_gating.fc.layers.1.weight'
elif 'linear_mid.bias' in xvector_name:
return 'channel_gating.fc.layers.1.bias'
elif 'linear_out.weight' in xvector_name:
logger.warning(f"Found linear_out layer (unexpected in ModelScope model): {xvector_name}")
return 'channel_gating.fc.layers.2.weight'
elif 'linear_out.bias' in xvector_name:
return 'channel_gating.fc.layers.2.bias'
# ========== POOLING LAYER MAPPING ==========
# Multi-granularity statistical pooling
# NOTE: Real ModelScope models typically DON'T have xvector.output or pooling layers
# These models are feature extractors that end at xvector.dense
if 'xvector.output.' in xvector_name or 'xvector.pool' in xvector_name:
logger.warning(f"Found pooling/output layer (rare in ModelScope models): {xvector_name}")
if 'xvector.output.linear.weight' == xvector_name:
return 'pooling.attention_weights.weight'
elif 'xvector.output.linear.bias' == xvector_name:
return 'pooling.attention_weights.bias'
elif 'pool_output.linear.weight' in xvector_name or 'pooling.linear.weight' in xvector_name:
return 'pooling.projection.weight'
elif 'pool_output.linear.bias' in xvector_name or 'pooling.linear.bias' in xvector_name:
return 'pooling.projection.bias'
# ========== FINAL BATCH NORMALIZATION ==========
if 'xvector.out_nonlinear.batchnorm.' in xvector_name or 'xvector.final_bn.' in xvector_name:
if '.bias' in xvector_name:
return 'final_bn.bias'
elif '.weight' in xvector_name:
return 'final_bn.weight'
elif 'running_mean' in xvector_name:
return 'final_bn.running_mean'
elif 'running_var' in xvector_name:
return 'final_bn.running_var'
# ========== SKIP UNMAPPED PARAMETERS ==========
# Log skipped parameters for debugging
if xvector_name.startswith('xvector.'):
logger.debug(f"Skipping unmapped parameter: {xvector_name}")
return None
def _filter_weights(self, pytorch_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Filter out unnecessary parameters that shouldn't be converted to MLX
Args:
pytorch_weights: Original PyTorch weights dict
Returns:
Filtered weights dict
"""
filtered_weights = {}
skipped_params = []
for name, tensor in pytorch_weights.items():
# Skip classification head parameters (not needed for inference)
if name.startswith('head.'):
skipped_params.append(name)
continue
# Keep all other parameters including BatchNorm running statistics
# The mapping function will filter out parameters that don't have MLX equivalents
filtered_weights[name] = tensor
if skipped_params:
print(f"Filtered out {len(skipped_params)} unnecessary parameters: {skipped_params[:5]}{'...' if len(skipped_params) > 5 else ''}")
return filtered_weights
def _convert_tensor(self, name: str, tensor: torch.Tensor) -> Optional[mx.array]:
"""
Convert individual tensor based on layer type and shape
Args:
name: Parameter name
tensor: PyTorch tensor to convert
Returns:
MLX array, or None if parameter should be skipped (e.g., num_batches_tracked)
"""
# Convert to numpy first
numpy_tensor = tensor.detach().cpu().numpy()
# Biases don't need any conversion, just pass through
if name.endswith('.bias'):
return mx.array(numpy_tensor)
# Determine layer type from name AND shape
layer_type = self._identify_layer_type(name)
# Override layer type based on actual tensor shape
# This handles cases where Conv1d(kernel_size=1) is used but named like Linear
if numpy_tensor.ndim == 3:
# 3D tensor must be Conv1d, regardless of name
layer_type = 'conv1d'
elif numpy_tensor.ndim == 2 and layer_type == 'conv1d':
# 2D tensor can't be Conv1d, must be Linear
layer_type = 'linear'
elif numpy_tensor.ndim == 1:
# 1D tensor is likely BatchNorm or bias
if 'bn' in name.lower() or 'batchnorm' in name.lower() or 'running' in name.lower():
layer_type = 'batchnorm'
# Apply layer-specific transformations
if layer_type in self.layer_mapping:
numpy_tensor = self.layer_mapping[layer_type](name, numpy_tensor)
# Handle None returns (e.g., num_batches_tracked)
if numpy_tensor is None:
return None
# Convert to MLX array
return mx.array(numpy_tensor)
def _identify_layer_type(self, name: str) -> str:
"""Identify layer type from parameter name"""
name_lower = name.lower()
# BatchNorm check first (more specific)
if 'bn' in name_lower or 'batchnorm' in name_lower or 'batch_norm' in name_lower:
return 'batchnorm'
# Conv1d check (including 'conv' in name)
elif 'conv1d' in name_lower or 'conv' in name_lower:
return 'conv1d'
# Linear/FC check
elif 'linear' in name_lower or 'fc' in name_lower or 'dense' in name_lower:
return 'linear'
# Embedding check
elif 'embed' in name_lower:
return 'embedding'
else:
return 'default'
def _convert_conv1d(self, name: str, weight: np.ndarray) -> np.ndarray:
"""
Convert Conv1d weights from PyTorch to MLX format
PyTorch Conv1d: (out_channels, in_channels, kernel_size)
MLX Conv1d: (out_channels, kernel_size, in_channels) - DIFFERENT format!
Special case: Conv1d with kernel_size=1 can be used as Linear layer
Args:
name: Parameter name (for error reporting)
weight: Weight tensor as numpy array
Returns:
Converted weight tensor
Raises:
ValueError: If weight shape is invalid for Conv1d
"""
# Validate Conv1d weight shape
if weight.ndim != 3:
raise ValueError(f"Conv1d weight {name} must be 3D, got shape {weight.shape}")
out_channels, in_channels, kernel_size = weight.shape
# Validate kernel size is reasonable (1, 3, 5 are common)
if kernel_size > 11:
logger.warning(f"Unusual kernel size {kernel_size} for Conv1d {name}")
# MLX Conv1d uses (out_channels, kernel_size, in_channels) format
# Transpose from PyTorch's (out_channels, in_channels, kernel_size)
# This applies to ALL kernel sizes, including kernel_size=1
mlx_weight = weight.transpose(0, 2, 1)
logger.debug(f"Transposed Conv1d weight {name}: {weight.shape} -> {mlx_weight.shape}")
return mlx_weight
def _convert_linear(self, name: str, weight: np.ndarray) -> np.ndarray:
"""
Convert Linear layer weights
PyTorch Linear: (out_features, in_features)
MLX Linear: (out_features, in_features) - same format
Args:
name: Parameter name (for error reporting)
weight: Weight tensor as numpy array
Returns:
Converted weight tensor
Raises:
ValueError: If weight shape is invalid for Linear
"""
if weight.ndim != 2:
raise ValueError(f"Linear weight {name} must be 2D, got shape {weight.shape}")
return weight # No change needed for linear layers
def _convert_batchnorm(self, name: str, weight: np.ndarray) -> Optional[np.ndarray]:
"""
Convert BatchNorm parameters
Args:
name: Parameter name (for error reporting)
weight: Weight/bias/running_mean/running_var tensor
Returns:
Converted tensor, or None if parameter should be skipped (e.g., num_batches_tracked)
Raises:
ValueError: If tensor shape is invalid for BatchNorm
"""
# Skip num_batches_tracked (it's a scalar tracking statistic, not needed in MLX)
if 'num_batches_tracked' in name:
logger.debug(f"Skipping num_batches_tracked (not needed in MLX): {name}")
return None # Will be filtered out
# BatchNorm parameters should be 1D vectors
if weight.ndim != 1:
raise ValueError(f"BatchNorm parameter {name} must be 1D, got shape {weight.shape}")
# Check for NaN/Inf in running statistics
if 'running_mean' in name or 'running_var' in name:
if np.isnan(weight).any():
logger.warning(f"BatchNorm {name} contains NaN values - may indicate untrained model")
if np.isinf(weight).any():
logger.warning(f"BatchNorm {name} contains Inf values - may indicate numerical instability")
return weight
def _convert_embedding(self, name: str, weight: np.ndarray) -> np.ndarray:
"""
Convert Embedding layer weights
Args:
name: Parameter name (unused but kept for API consistency)
weight: Embedding weight tensor
Returns:
Converted weight tensor
"""
return weight # No change needed for embeddings
def quantize_weights(self, weights: Dict[str, mx.array],
bits: int = 4, group_size: int = 64) -> Dict[str, mx.array]:
"""
Quantize weights to reduce model size using MLX's built-in quantization
Note: This creates a copy of the weights dictionary, so callers don't need to copy before calling.
Args:
weights: MLX weights dictionary
bits: Number of bits for quantization (2, 4, or 8)
group_size: Group size for quantization (32, 64, or 128)
Returns:
Quantized weights dictionary (new copy)
"""
# Create a new dictionary to avoid modifying the original
quantized_weights = {}
skipped_count = 0
quantized_count = 0
logger.info(f"Starting {bits}-bit quantization with group_size={group_size}...")
for name, weight in weights.items():
if self._should_quantize(name, weight):
try:
# MLX quantization requires:
# 1. At least 2D tensors
# 2. Last dimension divisible by group_size
if len(weight.shape) < 2:
logger.debug(f"Skipping {name}: 1D tensor")
quantized_weights[name] = weight
skipped_count += 1
continue
if weight.shape[-1] % group_size != 0:
logger.debug(f"Skipping {name}: last dim {weight.shape[-1]} not divisible by {group_size}")
quantized_weights[name] = weight
skipped_count += 1
continue
# Quantize using MLX's affine quantization
w_q, scales, biases = mx.quantize(weight, group_size=group_size, bits=bits)
# Store quantized weights with special naming for scales and biases
# Format: name:qSCALES_GS64_B4 (scales for group_size=64, bits=4)
# This reduces the number of keys compared to separate metadata arrays
quantized_weights[name] = w_q
quantized_weights[f"{name}:qSCALES_GS{group_size}_B{bits}"] = scales
quantized_weights[f"{name}:qBIASES_GS{group_size}_B{bits}"] = biases
quantized_count += 1
# Log size reduction
original_size = weight.size * 4 # float32 = 4 bytes
# Quantized size = packed weights + scales + biases
quantized_size = w_q.nbytes + scales.nbytes + biases.nbytes
reduction = (1 - quantized_size / original_size) * 100
logger.debug(f"Quantized {name}: {reduction:.1f}% size reduction ({original_size//1024}KB β†’ {quantized_size//1024}KB)")
except Exception as e:
# If quantization fails for this weight, keep original
logger.warning(f"Failed to quantize {name}: {e}, keeping original")
quantized_weights[name] = weight
skipped_count += 1
else:
# Keep small weights in full precision
quantized_weights[name] = weight
skipped_count += 1
logger.info(f"Quantization complete: {quantized_count} weights quantized, {skipped_count} kept in full precision")
return quantized_weights
def _quantize_to_int8(self, weight: mx.array) -> mx.array:
"""
Quantize a weight tensor to 8-bit precision
Args:
weight: Weight tensor to quantize
Returns:
Quantized weight tensor
"""
# Simple symmetric quantization to int8 range
# Find scale factor
abs_max = mx.max(mx.abs(weight))
scale = abs_max / 127.0
if scale == 0:
return weight
# Quantize and dequantize
quantized = mx.round(weight / scale)
quantized = mx.clip(quantized, -127, 127)
dequantized = quantized * scale
return dequantized.astype(mx.float32)
def _should_quantize(self, name: str, weight: mx.array) -> bool:
"""Determine if a weight should be quantized"""
# Don't quantize very small tensors or bias terms
if weight.size < MIN_QUANTIZATION_SIZE:
return False
# Don't quantize bias terms
if 'bias' in name.lower():
return False
# Don't quantize batchnorm parameters (weight, bias, running_mean, running_var)
if any(bn_key in name.lower() for bn_key in ['bn', 'batchnorm', 'batch_norm', 'running_mean', 'running_var']):
return False
# Quantize large weight matrices (Conv, Linear)
return True
def verify_conversion(self, pytorch_weights: Dict[str, torch.Tensor],
mlx_weights: Dict[str, mx.array]) -> Dict[str, bool]:
"""
Verify that conversion was successful by comparing shapes and values
Args:
pytorch_weights: Original PyTorch weights
mlx_weights: Converted MLX weights
Returns:
Dictionary of verification results
"""
results = {}
for name in pytorch_weights.keys():
if name in mlx_weights:
pytorch_tensor = pytorch_weights[name]
mlx_array = mlx_weights[name]
# Compare basic properties
pytorch_shape = pytorch_tensor.shape
mlx_shape = mlx_array.shape
# All layers should have matching shapes (no transpose)
results[name] = pytorch_shape == mlx_shape
# Additional verification: check if values are reasonable
if results[name]:
pytorch_values = pytorch_tensor.detach().cpu().numpy()
mlx_values = np.array(mlx_array)
# Check if the values are approximately equal
value_check = np.allclose(pytorch_values, mlx_values, rtol=1e-5, atol=1e-6)
results[name] = results[name] and value_check
else:
results[name] = False
return results
def check_conversion_status(self, pytorch_weights: Dict[str, torch.Tensor],
mlx_weights: Dict[str, mx.array],
verification_results: Dict[str, bool]) -> Dict[str, Any]:
"""
Check comprehensive status of conversion to ensure it's safe to deploy
Args:
pytorch_weights: Original PyTorch weights
mlx_weights: Converted MLX weights
verification_results: Results from verify_conversion
Returns:
Status dictionary with detailed report
"""
status = {
'is_perfect': False,
'total_source_weights': len(pytorch_weights),
'total_converted_weights': len(mlx_weights),
'verification_passed': sum(1 for v in verification_results.values() if v),
'verification_failed': sum(1 for v in verification_results.values() if not v),
'verification_rate': 0.0,
'errors': [],
'warnings': [],
'safe_to_deploy': False,
}
# Calculate verification rate
total_verified = len(verification_results)
if total_verified > 0:
status['verification_rate'] = (status['verification_passed'] / total_verified) * 100
# Check for critical issues
if len(mlx_weights) == 0:
status['errors'].append("No weights were converted - conversion failed completely")
if status['verification_failed'] > 0:
failed_weights = [name for name, result in verification_results.items() if not result]
status['errors'].append(
f"{status['verification_failed']} weight(s) failed verification: {failed_weights[:3]}"
f"{'...' if len(failed_weights) > 3 else ''}"
)
if len(mlx_weights) < len(pytorch_weights) * 0.5:
status['warnings'].append(
f"Only {len(mlx_weights)}/{len(pytorch_weights)} weights were converted "
f"({(len(mlx_weights)/len(pytorch_weights)*100):.1f}%) - possible mapping issues"
)
# Check data type consistency
dtype_set = set()
for weight in mlx_weights.values():
dtype_set.add(str(weight.dtype))
if len(dtype_set) > 1:
status['warnings'].append(f"Mixed data types detected in converted weights: {dtype_set}")
# Check for NaN or Inf values
nan_inf_weights = []
for name, weight in mlx_weights.items():
weight_np = np.array(weight)
if np.isnan(weight_np).any():
nan_inf_weights.append(f"{name} (NaN)")
elif np.isinf(weight_np).any():
nan_inf_weights.append(f"{name} (Inf)")
if nan_inf_weights:
status['errors'].append(f"Weights contain NaN/Inf: {nan_inf_weights[:3]}")
# Determine if safe to deploy
status['is_perfect'] = (
len(status['errors']) == 0 and
status['verification_rate'] == 100.0 and
len(mlx_weights) > 0
)
# Conservative approach: only deploy if perfect
status['safe_to_deploy'] = status['is_perfect']
if not status['safe_to_deploy'] and len(status['errors']) == 0:
status['safe_to_deploy'] = (
status['verification_rate'] >= MIN_VERIFICATION_RATE and
status['verification_failed'] <= MAX_VERIFICATION_FAILURES and
len(nan_inf_weights) == 0
)
return status
def print_status_report(self, status: Dict[str, Any]) -> None:
"""Print a formatted status report"""
print("\n" + "="*70)
print("CONVERSION STATUS REPORT")
print("="*70)
print(f"\nπŸ“Š Conversion Statistics:")
print(f" Total source weights: {status['total_source_weights']}")
print(f" Total converted weights: {status['total_converted_weights']}")
print(f" Verification passed: {status['verification_passed']}/{status['verification_passed'] + status['verification_failed']}")
print(f" Verification rate: {status['verification_rate']:.1f}%")
if status['errors']:
print(f"\n❌ Errors ({len(status['errors'])}):")
for error in status['errors']:
print(f" β€’ {error}")
if status['warnings']:
print(f"\n⚠️ Warnings ({len(status['warnings'])}):")
for warning in status['warnings']:
print(f" β€’ {warning}")
print(f"\nπŸ” Deployment Decision:")
if status['is_perfect']:
print(f" Status: βœ… PERFECT - All checks passed")
else:
print(f" Status: {'βœ… ACCEPTABLE' if status['safe_to_deploy'] else '❌ NOT SAFE'}")
print(f" Safe to deploy: {'βœ… YES' if status['safe_to_deploy'] else '❌ NO'}")
print("\n" + "="*70)
def create_model_metadata(self, original_repo: str, config: Dict[str, Any]) -> Dict[str, Any]:
"""Create metadata for the converted model"""
return {
"converted_from": original_repo,
"conversion_date": self.get_current_date(),
"framework": "mlx",
"model_type": "campp",
"architecture": "d-tdnn",
"license": "apache-2.0",
"tags": ["speaker-recognition", "audio", "mlx", "apple-silicon"],
"task": "speaker-verification",
"library_name": "mlx",
"datasets": ["voxceleb", "cnceleb"],
"metrics": {
"voxceleb1_eer": "0.65%",
"parameters": "7.2M",
"inference_speed": "optimized_for_apple_silicon"
},
**config
}
def get_current_date(self) -> str:
"""Get current date in ISO format"""
return datetime.now().isoformat()
def estimate_model_performance(self, weights: Dict[str, mx.array]) -> Dict[str, Any]:
"""Estimate model performance characteristics"""
total_params = sum(w.size for w in weights.values())
# Estimate memory usage (rough approximation)
total_bytes = total_params * 4 # Assuming fp32
memory_mb = total_bytes / (1024 * 1024)
# Estimate model complexity
conv_layers = sum(1 for name in weights.keys() if 'conv' in name.lower())
linear_layers = sum(1 for name in weights.keys() if any(x in name.lower() for x in ['linear', 'fc']))
return {
"total_parameters": total_params,
"estimated_memory_mb": memory_mb,
"conv_layers": conv_layers,
"linear_layers": linear_layers,
"model_complexity": "efficient" if total_params < 10e6 else "standard"
}
def optimize_for_inference(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Apply MLX-specific optimizations for inference"""
optimized_weights = {}
for name, weight in weights.items():
# Ensure weights are in optimal format for MLX
optimized_weight = mx.array(weight)
# MLX-specific optimizations could go here
# For now, just ensure proper data type
if optimized_weight.dtype != mx.float32:
optimized_weight = optimized_weight.astype(mx.float32)
optimized_weights[name] = optimized_weight
return optimized_weights
def test_conversion():
"""Test the conversion utilities with comprehensive status checking"""
utils = ConversionUtils()
# Create dummy PyTorch xvector weights (proper source format)
dummy_weights = {
# Input layer
'xvector.tdnn.linear.weight': torch.randn(64, 80, 3),
'xvector.tdnn.nonlinear.batchnorm.weight': torch.randn(64),
'xvector.tdnn.nonlinear.batchnorm.bias': torch.randn(64),
'xvector.tdnn.nonlinear.batchnorm.running_mean': torch.randn(64),
'xvector.tdnn.nonlinear.batchnorm.running_var': torch.randn(64),
# Dense block 0
'xvector.block1.tdnnd1.linear1.weight': torch.randn(32, 64, 3),
'xvector.block1.tdnnd1.nonlinear1.batchnorm.weight': torch.randn(32),
'xvector.block1.tdnnd1.nonlinear1.batchnorm.bias': torch.randn(32),
# Transition layer
'xvector.transit1.linear.weight': torch.randn(256, 96, 1),
'xvector.transit1.nonlinear.batchnorm.weight': torch.randn(256),
'xvector.transit1.nonlinear.batchnorm.bias': torch.randn(256),
# Final layer
'xvector.out_nonlinear.batchnorm.weight': torch.randn(512),
'xvector.out_nonlinear.batchnorm.bias': torch.randn(512),
}
# Convert
mlx_weights, config = utils.convert_weights_to_mlx(dummy_weights)
# Verify conversion
verification = {}
print("Conversion test results:")
# Get the mapping for each source weight
for name, tensor in dummy_weights.items():
mlx_name = utils._xvector_to_mlx_name(name)
if mlx_name and mlx_name in mlx_weights:
pytorch_shape = tensor.shape
mlx_shape = mlx_weights[mlx_name].shape
matches = pytorch_shape == mlx_shape
verification[name] = matches
status = "βœ…" if matches else "❌"
print(f" {status} {name} -> {mlx_name} | Shape: {pytorch_shape} -> {mlx_shape}")
else:
verification[name] = False
status = "❌"
print(f" {status} {name} (no mapping)")
print(f"\nTotal weights converted: {len(mlx_weights)}")
print(f"Inferred config: {config}")
# Check conversion status
status_report = utils.check_conversion_status(dummy_weights, mlx_weights, verification)
utils.print_status_report(status_report)
# Only return success if status is perfect and tests pass
tests_passed = all(verification.values())
return tests_passed and status_report['is_perfect']
if __name__ == "__main__":
test_passed = test_conversion()
print(f"\n{'βœ…' if test_passed else '❌'} Conversion utilities test {'passed' if test_passed else 'failed'}")