import mlx.core as mx import numpy as np import torch from typing import Dict, Any, Tuple, Optional, List, Set from datetime import datetime import re import logging # Set up logging logger = logging.getLogger(__name__) # Constants for conversion thresholds MIN_QUANTIZATION_SIZE = 1000 # Don't quantize tensors smaller than this MIN_VERIFICATION_RATE = 95.0 # Minimum acceptable verification rate (%) MAX_VERIFICATION_FAILURES = 2 # Maximum allowed verification failures BATCHNORM_EPS = 1e-5 BATCHNORM_MOMENTUM = 0.1 class ConversionUtils: """Utilities for converting PyTorch CAM++ models to MLX format""" def __init__(self, use_modelscope_architecture: bool = True): """ Initialize conversion utilities Args: use_modelscope_architecture: If True, use ModelScope architecture with embedded CAM If False, use original architecture with shared CAM """ self.use_modelscope_architecture = use_modelscope_architecture self.layer_mapping = { 'conv1d': self._convert_conv1d, 'linear': self._convert_linear, 'batchnorm': self._convert_batchnorm, 'embedding': self._convert_embedding } def convert_weights_to_mlx(self, pytorch_weights: Dict[str, torch.Tensor]) -> Tuple[Dict[str, mx.array], Dict[str, Any]]: """ Convert PyTorch weights to MLX format Args: pytorch_weights: Dictionary of PyTorch tensors Returns: Tuple of (mlx_weights, model_config) """ mlx_weights = {} model_config = self._analyze_model_structure(pytorch_weights) # Filter out unnecessary parameters (BatchNorm running stats, etc.) filtered_weights = self._filter_weights(pytorch_weights) # Map parameter names from PyTorch to MLX format mapped_weights = self._map_parameter_names(filtered_weights) # Add default values for missing MLX parameters mapped_weights = self._add_missing_parameters(mapped_weights, model_config) # Convert each weight tensor for name, tensor in mapped_weights.items(): if isinstance(tensor, torch.Tensor): converted = self._convert_tensor(name, tensor) # Skip None values (e.g., num_batches_tracked) if converted is not None: mlx_weights[name] = converted else: # Handle non-tensor values (e.g., integers, strings) continue return mlx_weights, model_config def _analyze_model_structure(self, pytorch_weights: Dict[str, torch.Tensor]) -> Dict[str, Any]: """ Analyze the PyTorch model structure to infer configuration Args: pytorch_weights: PyTorch weights dictionary Returns: Model configuration dictionary """ config = { 'input_dim': 80, # Default mel spectrogram features 'embedding_dim': 192, # Default embedding dimension for ModelScope 'channels': 512, # Default number of channels 'cam_channels': 128, # Default CAM channels } # Detect block structure for ModelScope architecture if self.use_modelscope_architecture: blocks = {1: set(), 2: set(), 3: set()} for name in pytorch_weights.keys(): if 'xvector.block1.tdnnd' in name: layer_num = name.split('tdnnd')[1].split('.')[0] blocks[1].add(int(layer_num)) elif 'xvector.block2.tdnnd' in name: layer_num = name.split('tdnnd')[1].split('.')[0] blocks[2].add(int(layer_num)) elif 'xvector.block3.tdnnd' in name: layer_num = name.split('tdnnd')[1].split('.')[0] blocks[3].add(int(layer_num)) # Set block_layers configuration if any(blocks.values()): config['block_layers'] = [ len(blocks[1]) if blocks[1] else 4, # Default to 4 if not found len(blocks[2]) if blocks[2] else 9, # Default to 9 if not found len(blocks[3]) if blocks[3] else 16 # Default to 16 if not found ] logger.info(f"Detected block structure: {config['block_layers']}") # Try to infer input dimension and kernel size from first conv layer for name, tensor in pytorch_weights.items(): if 'xvector.tdnn.linear.weight' in name: if tensor.ndim == 3: # Conv1d weight: (out_channels, in_channels, kernel_size) config['input_dim'] = tensor.shape[1] # in_channels config['channels'] = tensor.shape[0] # out_channels config['input_kernel_size'] = tensor.shape[2] # kernel_size logger.info(f"Detected input layer: dim={config['input_dim']}, channels={config['channels']}, kernel_size={config['input_kernel_size']}") break # Try to infer embedding dimension from dense layer for name, tensor in pytorch_weights.items(): if 'xvector.dense.linear.weight' in name: if tensor.ndim == 3: # Conv1d with kernel_size=1 config['embedding_dim'] = tensor.shape[0] # out_channels break # Count total parameters for estimation total_params = sum(tensor.numel() for tensor in pytorch_weights.values()) config['total_params'] = total_params return config def _map_parameter_names(self, pytorch_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Map PyTorch parameter names to MLX parameter names Args: pytorch_weights: PyTorch weights with original names Returns: Weights with MLX-compatible parameter names """ mapped_weights = {} for name, tensor in pytorch_weights.items(): # Choose mapping function based on architecture if self.use_modelscope_architecture: mlx_name = self._xvector_to_mlx_modelscope_name(name) else: mlx_name = self._xvector_to_mlx_name(name) if mlx_name: # Only keep parameters that have MLX equivalents mapped_weights[mlx_name] = tensor return mapped_weights def _add_missing_parameters(self, mapped_weights: Dict[str, torch.Tensor], model_config: Dict) -> Dict[str, torch.Tensor]: """ Add default values for MLX parameters that don't have PyTorch equivalents Args: mapped_weights: Already mapped weights model_config: Model configuration Returns: Weights with missing parameters added Note: This method intentionally does NOT add fake/random parameters. Adding untrained random weights will degrade model accuracy significantly. The conversion should only include weights that are actually mapped from the source model. Better to fail explicitly when a layer is missing than to add random weights that produce nonsensical outputs. """ # Return mapped weights as-is without adding arbitrary fake parameters return mapped_weights def get_missing_mlx_parameters(self, pytorch_weights: Dict[str, torch.Tensor], mlx_weights: Dict[str, mx.array]) -> Dict[str, str]: """ Get list of MLX parameters that don't have source PyTorch equivalents Args: pytorch_weights: Original PyTorch weights mlx_weights: Converted MLX weights Returns: Dictionary mapping MLX parameter names to their source parameter names (or "NOT FOUND") """ missing_params = {} # Define expected MLX model parameters expected_mlx_params = { # Input layer 'input_conv.weight', 'input_bn.weight', 'input_bn.bias', 'input_bn.running_mean', 'input_bn.running_var', # Dense blocks (0-2) 'dense_blocks.0.layers.0.conv.weight', 'dense_blocks.0.layers.0.bn.weight', 'dense_blocks.0.layers.0.bn.bias', 'dense_blocks.0.layers.0.bn.running_mean', 'dense_blocks.0.layers.0.bn.running_var', 'dense_blocks.0.layers.1.conv.weight', 'dense_blocks.0.layers.1.bn.weight', 'dense_blocks.0.layers.1.bn.bias', 'dense_blocks.0.layers.2.conv.weight', 'dense_blocks.0.layers.2.bn.weight', 'dense_blocks.0.layers.2.bn.bias', 'dense_blocks.0.layers.3.conv.weight', 'dense_blocks.0.layers.3.bn.weight', 'dense_blocks.0.layers.3.bn.bias', # Transitions 'transitions.0.layers.0.weight', 'transitions.0.layers.0.bias', 'transitions.0.layers.0.running_mean', 'transitions.0.layers.0.running_var', 'transitions.0.layers.2.weight', 'transitions.1.layers.0.weight', 'transitions.1.layers.0.bias', 'transitions.1.layers.0.running_mean', 'transitions.1.layers.0.running_var', 'transitions.1.layers.2.weight', # CAM layer 'cam.context_conv1.weight', 'cam.context_conv1.bias', 'cam.context_conv3.weight', 'cam.context_conv3.bias', 'cam.context_conv5.weight', 'cam.context_conv5.bias', 'cam.mask_conv.weight', 'cam.mask_conv.bias', 'cam.bn.weight', 'cam.bn.bias', 'cam.bn.running_mean', 'cam.bn.running_var', # Channel gating 'channel_gating.fc.layers.0.weight', 'channel_gating.fc.layers.0.bias', 'channel_gating.fc.layers.1.weight', 'channel_gating.fc.layers.1.bias', 'channel_gating.fc.layers.2.weight', 'channel_gating.fc.layers.2.bias', # Pooling 'pooling.attention_weights.weight', 'pooling.attention_weights.bias', 'pooling.projection.weight', 'pooling.projection.bias', # Final layer 'final_bn.weight', 'final_bn.bias', 'final_bn.running_mean', 'final_bn.running_var', } # Check which expected parameters are missing from converted weights for param in expected_mlx_params: if param not in mlx_weights: missing_params[param] = "NOT FOUND" return missing_params def _xvector_to_mlx_modelscope_name(self, xvector_name: str) -> Optional[str]: """ Convert xvector parameter name to MLX ModelScope architecture parameter name This mapping is for ModelScope CAM++ models where CAM is embedded in each TDNN layer. Architecture: - Input layer (TDNN) - Block 1: 4 TDNN layers with embedded CAM - Transit 1 - Block 2: 9 TDNN layers with embedded CAM - Transit 2 - Block 3: 16 TDNN layers with embedded CAM - Dense layer (Conv1d kernel_size=1) Args: xvector_name: Original xvector parameter name from PyTorch model Returns: MLX-compatible parameter name, or None if parameter should be skipped """ # ========== INPUT LAYER ========== if xvector_name == 'xvector.tdnn.linear.weight': return 'input_conv.weight' elif 'xvector.tdnn.nonlinear.batchnorm' in xvector_name: param_type = xvector_name.split('.')[-1] # bias, weight, running_mean, running_var # Skip num_batches_tracked (PyTorch tracking statistic, not needed) if param_type == 'num_batches_tracked': return None return f'input_bn.{param_type}' # ========== DENSE BLOCKS WITH EMBEDDED CAM ========== # Extract block number and layer number import re block_match = re.match(r'xvector\.block(\d+)\.tdnnd(\d+)\.(.*)', xvector_name) if block_match: block_num = int(block_match.group(1)) # 1, 2, or 3 layer_num = int(block_match.group(2)) # 1-indexed param_path = block_match.group(3) # Map to MLX block index (0, 1, 2) mlx_block_idx = block_num - 1 # Map to MLX layer index (0-indexed) mlx_layer_idx = layer_num - 1 # Main TDNN layer parameters if param_path.startswith('linear1.'): param_type = param_path.split('.')[-1] return f'block{mlx_block_idx}_{mlx_layer_idx}.conv.{param_type}' # PyTorch has TWO batch norms per layer: # - nonlinear1.batchnorm: sized for INPUT channels (applied before conv) # - nonlinear2.batchnorm: sized for OUTPUT channels (applied after conv) # MLX model only has one BN (after conv), so map nonlinear2 to bn elif param_path.startswith('nonlinear1.batchnorm.'): # Skip nonlinear1 batch norm - it's sized for input channels return None elif param_path.startswith('nonlinear2.batchnorm.'): param_type = param_path.split('.')[-1] # Skip num_batches_tracked if param_type == 'num_batches_tracked': return None return f'block{mlx_block_idx}_{mlx_layer_idx}.bn.{param_type}' # Embedded CAM layer parameters elif param_path.startswith('cam_layer.linear1.'): param_type = param_path.split('.')[-1] return f'block{mlx_block_idx}_{mlx_layer_idx}.cam.linear1.{param_type}' elif param_path.startswith('cam_layer.linear2.'): param_type = param_path.split('.')[-1] return f'block{mlx_block_idx}_{mlx_layer_idx}.cam.linear2.{param_type}' elif param_path.startswith('cam_layer.linear_local.'): param_type = param_path.split('.')[-1] return f'block{mlx_block_idx}_{mlx_layer_idx}.cam.linear_local.{param_type}' # ========== TRANSITION LAYERS ========== if 'xvector.transit1.' in xvector_name: if '.linear.weight' in xvector_name: return 'transit1.conv.weight' elif 'nonlinear.batchnorm' in xvector_name: param_type = xvector_name.split('.')[-1] # Skip num_batches_tracked if param_type == 'num_batches_tracked': return None return f'transit1.bn.{param_type}' if 'xvector.transit2.' in xvector_name: if '.linear.weight' in xvector_name: return 'transit2.conv.weight' elif 'nonlinear.batchnorm' in xvector_name: param_type = xvector_name.split('.')[-1] # Skip num_batches_tracked if param_type == 'num_batches_tracked': return None return f'transit2.bn.{param_type}' # ========== DENSE LAYER ========== if 'xvector.dense.linear.' in xvector_name: param_type = xvector_name.split('.')[-1] return f'dense.{param_type}' # ========== SKIP UNMAPPED PARAMETERS ========== # These don't exist in ModelScope architecture if any(x in xvector_name for x in ['head.', 'output.', 'pool', 'final_bn']): logger.debug(f"Skipping parameter not in ModelScope architecture: {xvector_name}") return None # Log unexpected parameters if xvector_name.startswith('xvector.'): logger.debug(f"Skipping unmapped parameter: {xvector_name}") return None def _xvector_to_mlx_name(self, xvector_name: str) -> Optional[str]: """ Convert xvector parameter name to MLX parameter name with comprehensive mapping This method maps PyTorch CAM++ xvector parameters to MLX CAMPPModel parameters. It handles: - Input layer (TDNN) - Dense blocks (3 blocks with 4, 6, 8 layers respectively) - Transition layers between blocks - Context-Aware Masking (CAM) layer - Channel gating mechanism - Multi-granularity pooling - Final batch normalization Args: xvector_name: Original xvector parameter name from PyTorch model Returns: MLX-compatible parameter name, or None if parameter should be skipped """ # ========== INPUT LAYER MAPPING ========== if xvector_name == 'xvector.tdnn.linear.weight': return 'input_conv.weight' elif xvector_name == 'xvector.tdnn.nonlinear.batchnorm.bias': return 'input_bn.bias' elif xvector_name == 'xvector.tdnn.nonlinear.batchnorm.weight': return 'input_bn.weight' elif xvector_name == 'xvector.tdnn.nonlinear.batchnorm.running_mean': return 'input_bn.running_mean' elif xvector_name == 'xvector.tdnn.nonlinear.batchnorm.running_var': return 'input_bn.running_var' # ========== DENSE BLOCKS MAPPING ========== # MLX architecture: block 0 (4 layers), block 1 (6 layers), block 2 (8 layers) # Map PyTorch block1/block2/block3 to MLX dense_blocks.0/1/2 # Block 0: Map first 4 layers of PyTorch block1 for i in range(1, 13): # Handle up to 12 layers (generous for real models) # Block 0 - first 4 layers if i <= 4 and f'xvector.block1.tdnnd{i}.' in xvector_name: layer_idx = i - 1 if '.linear1.weight' in xvector_name: return f'dense_blocks.0.layers.{layer_idx}.conv.weight' elif '.nonlinear1.batchnorm.bias' in xvector_name: return f'dense_blocks.0.layers.{layer_idx}.bn.bias' elif '.nonlinear1.batchnorm.weight' in xvector_name: return f'dense_blocks.0.layers.{layer_idx}.bn.weight' elif '.nonlinear1.batchnorm.running_mean' in xvector_name: return f'dense_blocks.0.layers.{layer_idx}.bn.running_mean' elif '.nonlinear1.batchnorm.running_var' in xvector_name: return f'dense_blocks.0.layers.{layer_idx}.bn.running_var' # Block 1 - first 6 layers of PyTorch block2 # Skip block2.tdnnd1 and block2.tdnnd2 as they may be used for transition if i >= 3 and i <= 8 and f'xvector.block2.tdnnd{i}.' in xvector_name: layer_idx = i - 3 # Map block2.tdnnd3 -> layer 0, etc. if layer_idx < 6: # Only map first 6 layers if '.linear1.weight' in xvector_name: return f'dense_blocks.1.layers.{layer_idx}.conv.weight' elif '.nonlinear1.batchnorm.bias' in xvector_name: return f'dense_blocks.1.layers.{layer_idx}.bn.bias' elif '.nonlinear1.batchnorm.weight' in xvector_name: return f'dense_blocks.1.layers.{layer_idx}.bn.weight' elif '.nonlinear1.batchnorm.running_mean' in xvector_name: return f'dense_blocks.1.layers.{layer_idx}.bn.running_mean' elif '.nonlinear1.batchnorm.running_var' in xvector_name: return f'dense_blocks.1.layers.{layer_idx}.bn.running_var' # Block 2 - first 8 layers of PyTorch block3 if i <= 8 and f'xvector.block3.tdnnd{i}.' in xvector_name: layer_idx = i - 1 if '.linear1.weight' in xvector_name: return f'dense_blocks.2.layers.{layer_idx}.conv.weight' elif '.nonlinear1.batchnorm.bias' in xvector_name: return f'dense_blocks.2.layers.{layer_idx}.bn.bias' elif '.nonlinear1.batchnorm.weight' in xvector_name: return f'dense_blocks.2.layers.{layer_idx}.bn.weight' elif '.nonlinear1.batchnorm.running_mean' in xvector_name: return f'dense_blocks.2.layers.{layer_idx}.bn.running_mean' elif '.nonlinear1.batchnorm.running_var' in xvector_name: return f'dense_blocks.2.layers.{layer_idx}.bn.running_var' # ========== TRANSITION LAYERS MAPPING ========== # Transition 0: After block 0 if 'xvector.transit1.' in xvector_name: if '.linear.weight' in xvector_name: return 'transitions.0.layers.2.weight' elif '.nonlinear.batchnorm.bias' in xvector_name: return 'transitions.0.layers.0.bias' elif '.nonlinear.batchnorm.weight' in xvector_name: return 'transitions.0.layers.0.weight' elif '.nonlinear.batchnorm.running_mean' in xvector_name: return 'transitions.0.layers.0.running_mean' elif '.nonlinear.batchnorm.running_var' in xvector_name: return 'transitions.0.layers.0.running_var' # Transition 1: Use block2.tdnnd1 and tdnnd2 (before dense block 1) if 'xvector.transit2.' in xvector_name or 'xvector.block2.tdnnd1.' in xvector_name: # Map transit2 or beginning of block2 to transition 1 if '.linear.weight' in xvector_name or 'xvector.block2.tdnnd2.linear1.weight' in xvector_name: return 'transitions.1.layers.2.weight' elif '.nonlinear.batchnorm.bias' in xvector_name or 'xvector.block2.tdnnd1.nonlinear1.batchnorm.bias' in xvector_name: return 'transitions.1.layers.0.bias' elif '.nonlinear.batchnorm.weight' in xvector_name or 'xvector.block2.tdnnd1.nonlinear1.batchnorm.weight' in xvector_name: return 'transitions.1.layers.0.weight' elif '.nonlinear.batchnorm.running_mean' in xvector_name or 'xvector.block2.tdnnd1.nonlinear1.batchnorm.running_mean' in xvector_name: return 'transitions.1.layers.0.running_mean' elif '.nonlinear.batchnorm.running_var' in xvector_name or 'xvector.block2.tdnnd1.nonlinear1.batchnorm.running_var' in xvector_name: return 'transitions.1.layers.0.running_var' # ========== CAM LAYER MAPPING ========== # Context-aware masking with multi-scale convolutions # NOTE: Real ModelScope models have CAM embedded in EACH TDNN layer, # but MLX model has ONE shared CAM layer. We map only the first occurrence # from block1.tdnnd1.cam_layer and skip all others. if 'cam_layer' in xvector_name or 'cam.' in xvector_name: # Only map CAM from the first block's first layer # Skip CAM from all other layers to avoid conflicts is_first_cam = 'block1.tdnnd1.cam_layer' in xvector_name if not is_first_cam: logger.debug(f"Skipping embedded CAM layer (only using first occurrence): {xvector_name}") return None # Map first CAM layer to MLX shared CAM # ModelScope structure: linear1 (1x1 conv), linear2 (1x1 conv), linear_local (3x3 conv) # MLX structure: context_conv1 (1x1), context_conv3 (3x3), context_conv5 (5x5) if 'cam_layer.linear1.weight' in xvector_name: return 'cam.context_conv1.weight' elif 'cam_layer.linear1.bias' in xvector_name: logger.debug(f"Skipping CAM context_conv1 bias (MLX uses bias=False): {xvector_name}") return None # MLX context_conv1 has bias=False elif 'cam_layer.linear2.weight' in xvector_name: # Map linear2 (1x1) to context_conv3 - note: this is a compromise # Real model has 1x1 conv here, MLX expects 3x3 logger.warning(f"Mapping 1x1 conv to context_conv3 (shape mismatch possible): {xvector_name}") return 'cam.context_conv3.weight' elif 'cam_layer.linear2.bias' in xvector_name: logger.debug(f"Skipping CAM context_conv3 bias (MLX uses bias=False): {xvector_name}") return None # MLX context_conv3 has bias=False elif 'cam_layer.linear_local.weight' in xvector_name: # Map linear_local (3x3) to mask_conv return 'cam.mask_conv.weight' elif 'cam_layer.linear_local.bias' in xvector_name: # linear_local typically has no bias in ModelScope models logger.debug(f"Skipping CAM mask_conv bias: {xvector_name}") return None # Handle standalone cam. parameters (if model has separate CAM layer) elif 'context1.weight' in xvector_name or (not 'cam_layer' in xvector_name and 'linear1.weight' in xvector_name): return 'cam.context_conv1.weight' elif 'context3.weight' in xvector_name or (not 'cam_layer' in xvector_name and 'linear2.weight' in xvector_name): return 'cam.context_conv3.weight' elif 'context5.weight' in xvector_name or (not 'cam_layer' in xvector_name and 'linear3.weight' in xvector_name): return 'cam.context_conv5.weight' elif 'mask_conv.weight' in xvector_name: return 'cam.mask_conv.weight' elif 'fusion.weight' in xvector_name: return 'cam.fusion.weight' # Batch normalization elif 'batchnorm.weight' in xvector_name: return 'cam.bn.weight' elif 'batchnorm.bias' in xvector_name: return 'cam.bn.bias' elif 'running_mean' in xvector_name: return 'cam.bn.running_mean' elif 'running_var' in xvector_name: return 'cam.bn.running_var' # ========== CHANNEL GATING MAPPING ========== # Channel-wise context gating (squeeze-excitation style) # NOTE: Real ModelScope models only have xvector.dense.linear (single layer) # MLX model expects 3-layer FC, but real model has only 1 layer if 'xvector.dense.' in xvector_name: if '.linear.weight' in xvector_name or 'xvector.dense.linear.weight' == xvector_name: # Map to first layer - this is the only dense layer in real model return 'channel_gating.fc.layers.0.weight' elif '.linear.bias' in xvector_name or 'xvector.dense.linear.bias' == xvector_name: # Check if bias exists (some models use Conv1d without bias) logger.debug(f"Mapping dense bias (may not exist in Conv1d): {xvector_name}") return 'channel_gating.fc.layers.0.bias' # The following layers don't exist in real ModelScope models elif 'linear_mid.weight' in xvector_name: logger.warning(f"Found linear_mid layer (unexpected in ModelScope model): {xvector_name}") return 'channel_gating.fc.layers.1.weight' elif 'linear_mid.bias' in xvector_name: return 'channel_gating.fc.layers.1.bias' elif 'linear_out.weight' in xvector_name: logger.warning(f"Found linear_out layer (unexpected in ModelScope model): {xvector_name}") return 'channel_gating.fc.layers.2.weight' elif 'linear_out.bias' in xvector_name: return 'channel_gating.fc.layers.2.bias' # ========== POOLING LAYER MAPPING ========== # Multi-granularity statistical pooling # NOTE: Real ModelScope models typically DON'T have xvector.output or pooling layers # These models are feature extractors that end at xvector.dense if 'xvector.output.' in xvector_name or 'xvector.pool' in xvector_name: logger.warning(f"Found pooling/output layer (rare in ModelScope models): {xvector_name}") if 'xvector.output.linear.weight' == xvector_name: return 'pooling.attention_weights.weight' elif 'xvector.output.linear.bias' == xvector_name: return 'pooling.attention_weights.bias' elif 'pool_output.linear.weight' in xvector_name or 'pooling.linear.weight' in xvector_name: return 'pooling.projection.weight' elif 'pool_output.linear.bias' in xvector_name or 'pooling.linear.bias' in xvector_name: return 'pooling.projection.bias' # ========== FINAL BATCH NORMALIZATION ========== if 'xvector.out_nonlinear.batchnorm.' in xvector_name or 'xvector.final_bn.' in xvector_name: if '.bias' in xvector_name: return 'final_bn.bias' elif '.weight' in xvector_name: return 'final_bn.weight' elif 'running_mean' in xvector_name: return 'final_bn.running_mean' elif 'running_var' in xvector_name: return 'final_bn.running_var' # ========== SKIP UNMAPPED PARAMETERS ========== # Log skipped parameters for debugging if xvector_name.startswith('xvector.'): logger.debug(f"Skipping unmapped parameter: {xvector_name}") return None def _filter_weights(self, pytorch_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Filter out unnecessary parameters that shouldn't be converted to MLX Args: pytorch_weights: Original PyTorch weights dict Returns: Filtered weights dict """ filtered_weights = {} skipped_params = [] for name, tensor in pytorch_weights.items(): # Skip classification head parameters (not needed for inference) if name.startswith('head.'): skipped_params.append(name) continue # Keep all other parameters including BatchNorm running statistics # The mapping function will filter out parameters that don't have MLX equivalents filtered_weights[name] = tensor if skipped_params: print(f"Filtered out {len(skipped_params)} unnecessary parameters: {skipped_params[:5]}{'...' if len(skipped_params) > 5 else ''}") return filtered_weights def _convert_tensor(self, name: str, tensor: torch.Tensor) -> Optional[mx.array]: """ Convert individual tensor based on layer type and shape Args: name: Parameter name tensor: PyTorch tensor to convert Returns: MLX array, or None if parameter should be skipped (e.g., num_batches_tracked) """ # Convert to numpy first numpy_tensor = tensor.detach().cpu().numpy() # Biases don't need any conversion, just pass through if name.endswith('.bias'): return mx.array(numpy_tensor) # Determine layer type from name AND shape layer_type = self._identify_layer_type(name) # Override layer type based on actual tensor shape # This handles cases where Conv1d(kernel_size=1) is used but named like Linear if numpy_tensor.ndim == 3: # 3D tensor must be Conv1d, regardless of name layer_type = 'conv1d' elif numpy_tensor.ndim == 2 and layer_type == 'conv1d': # 2D tensor can't be Conv1d, must be Linear layer_type = 'linear' elif numpy_tensor.ndim == 1: # 1D tensor is likely BatchNorm or bias if 'bn' in name.lower() or 'batchnorm' in name.lower() or 'running' in name.lower(): layer_type = 'batchnorm' # Apply layer-specific transformations if layer_type in self.layer_mapping: numpy_tensor = self.layer_mapping[layer_type](name, numpy_tensor) # Handle None returns (e.g., num_batches_tracked) if numpy_tensor is None: return None # Convert to MLX array return mx.array(numpy_tensor) def _identify_layer_type(self, name: str) -> str: """Identify layer type from parameter name""" name_lower = name.lower() # BatchNorm check first (more specific) if 'bn' in name_lower or 'batchnorm' in name_lower or 'batch_norm' in name_lower: return 'batchnorm' # Conv1d check (including 'conv' in name) elif 'conv1d' in name_lower or 'conv' in name_lower: return 'conv1d' # Linear/FC check elif 'linear' in name_lower or 'fc' in name_lower or 'dense' in name_lower: return 'linear' # Embedding check elif 'embed' in name_lower: return 'embedding' else: return 'default' def _convert_conv1d(self, name: str, weight: np.ndarray) -> np.ndarray: """ Convert Conv1d weights from PyTorch to MLX format PyTorch Conv1d: (out_channels, in_channels, kernel_size) MLX Conv1d: (out_channels, kernel_size, in_channels) - DIFFERENT format! Special case: Conv1d with kernel_size=1 can be used as Linear layer Args: name: Parameter name (for error reporting) weight: Weight tensor as numpy array Returns: Converted weight tensor Raises: ValueError: If weight shape is invalid for Conv1d """ # Validate Conv1d weight shape if weight.ndim != 3: raise ValueError(f"Conv1d weight {name} must be 3D, got shape {weight.shape}") out_channels, in_channels, kernel_size = weight.shape # Validate kernel size is reasonable (1, 3, 5 are common) if kernel_size > 11: logger.warning(f"Unusual kernel size {kernel_size} for Conv1d {name}") # MLX Conv1d uses (out_channels, kernel_size, in_channels) format # Transpose from PyTorch's (out_channels, in_channels, kernel_size) # This applies to ALL kernel sizes, including kernel_size=1 mlx_weight = weight.transpose(0, 2, 1) logger.debug(f"Transposed Conv1d weight {name}: {weight.shape} -> {mlx_weight.shape}") return mlx_weight def _convert_linear(self, name: str, weight: np.ndarray) -> np.ndarray: """ Convert Linear layer weights PyTorch Linear: (out_features, in_features) MLX Linear: (out_features, in_features) - same format Args: name: Parameter name (for error reporting) weight: Weight tensor as numpy array Returns: Converted weight tensor Raises: ValueError: If weight shape is invalid for Linear """ if weight.ndim != 2: raise ValueError(f"Linear weight {name} must be 2D, got shape {weight.shape}") return weight # No change needed for linear layers def _convert_batchnorm(self, name: str, weight: np.ndarray) -> Optional[np.ndarray]: """ Convert BatchNorm parameters Args: name: Parameter name (for error reporting) weight: Weight/bias/running_mean/running_var tensor Returns: Converted tensor, or None if parameter should be skipped (e.g., num_batches_tracked) Raises: ValueError: If tensor shape is invalid for BatchNorm """ # Skip num_batches_tracked (it's a scalar tracking statistic, not needed in MLX) if 'num_batches_tracked' in name: logger.debug(f"Skipping num_batches_tracked (not needed in MLX): {name}") return None # Will be filtered out # BatchNorm parameters should be 1D vectors if weight.ndim != 1: raise ValueError(f"BatchNorm parameter {name} must be 1D, got shape {weight.shape}") # Check for NaN/Inf in running statistics if 'running_mean' in name or 'running_var' in name: if np.isnan(weight).any(): logger.warning(f"BatchNorm {name} contains NaN values - may indicate untrained model") if np.isinf(weight).any(): logger.warning(f"BatchNorm {name} contains Inf values - may indicate numerical instability") return weight def _convert_embedding(self, name: str, weight: np.ndarray) -> np.ndarray: """ Convert Embedding layer weights Args: name: Parameter name (unused but kept for API consistency) weight: Embedding weight tensor Returns: Converted weight tensor """ return weight # No change needed for embeddings def quantize_weights(self, weights: Dict[str, mx.array], bits: int = 4, group_size: int = 64) -> Dict[str, mx.array]: """ Quantize weights to reduce model size using MLX's built-in quantization Note: This creates a copy of the weights dictionary, so callers don't need to copy before calling. Args: weights: MLX weights dictionary bits: Number of bits for quantization (2, 4, or 8) group_size: Group size for quantization (32, 64, or 128) Returns: Quantized weights dictionary (new copy) """ # Create a new dictionary to avoid modifying the original quantized_weights = {} skipped_count = 0 quantized_count = 0 logger.info(f"Starting {bits}-bit quantization with group_size={group_size}...") for name, weight in weights.items(): if self._should_quantize(name, weight): try: # MLX quantization requires: # 1. At least 2D tensors # 2. Last dimension divisible by group_size if len(weight.shape) < 2: logger.debug(f"Skipping {name}: 1D tensor") quantized_weights[name] = weight skipped_count += 1 continue if weight.shape[-1] % group_size != 0: logger.debug(f"Skipping {name}: last dim {weight.shape[-1]} not divisible by {group_size}") quantized_weights[name] = weight skipped_count += 1 continue # Quantize using MLX's affine quantization w_q, scales, biases = mx.quantize(weight, group_size=group_size, bits=bits) # Store quantized weights with special naming for scales and biases # Format: name:qSCALES_GS64_B4 (scales for group_size=64, bits=4) # This reduces the number of keys compared to separate metadata arrays quantized_weights[name] = w_q quantized_weights[f"{name}:qSCALES_GS{group_size}_B{bits}"] = scales quantized_weights[f"{name}:qBIASES_GS{group_size}_B{bits}"] = biases quantized_count += 1 # Log size reduction original_size = weight.size * 4 # float32 = 4 bytes # Quantized size = packed weights + scales + biases quantized_size = w_q.nbytes + scales.nbytes + biases.nbytes reduction = (1 - quantized_size / original_size) * 100 logger.debug(f"Quantized {name}: {reduction:.1f}% size reduction ({original_size//1024}KB → {quantized_size//1024}KB)") except Exception as e: # If quantization fails for this weight, keep original logger.warning(f"Failed to quantize {name}: {e}, keeping original") quantized_weights[name] = weight skipped_count += 1 else: # Keep small weights in full precision quantized_weights[name] = weight skipped_count += 1 logger.info(f"Quantization complete: {quantized_count} weights quantized, {skipped_count} kept in full precision") return quantized_weights def _quantize_to_int8(self, weight: mx.array) -> mx.array: """ Quantize a weight tensor to 8-bit precision Args: weight: Weight tensor to quantize Returns: Quantized weight tensor """ # Simple symmetric quantization to int8 range # Find scale factor abs_max = mx.max(mx.abs(weight)) scale = abs_max / 127.0 if scale == 0: return weight # Quantize and dequantize quantized = mx.round(weight / scale) quantized = mx.clip(quantized, -127, 127) dequantized = quantized * scale return dequantized.astype(mx.float32) def _should_quantize(self, name: str, weight: mx.array) -> bool: """Determine if a weight should be quantized""" # Don't quantize very small tensors or bias terms if weight.size < MIN_QUANTIZATION_SIZE: return False # Don't quantize bias terms if 'bias' in name.lower(): return False # Don't quantize batchnorm parameters (weight, bias, running_mean, running_var) if any(bn_key in name.lower() for bn_key in ['bn', 'batchnorm', 'batch_norm', 'running_mean', 'running_var']): return False # Quantize large weight matrices (Conv, Linear) return True def verify_conversion(self, pytorch_weights: Dict[str, torch.Tensor], mlx_weights: Dict[str, mx.array]) -> Dict[str, bool]: """ Verify that conversion was successful by comparing shapes and values Args: pytorch_weights: Original PyTorch weights mlx_weights: Converted MLX weights Returns: Dictionary of verification results """ results = {} for name in pytorch_weights.keys(): if name in mlx_weights: pytorch_tensor = pytorch_weights[name] mlx_array = mlx_weights[name] # Compare basic properties pytorch_shape = pytorch_tensor.shape mlx_shape = mlx_array.shape # All layers should have matching shapes (no transpose) results[name] = pytorch_shape == mlx_shape # Additional verification: check if values are reasonable if results[name]: pytorch_values = pytorch_tensor.detach().cpu().numpy() mlx_values = np.array(mlx_array) # Check if the values are approximately equal value_check = np.allclose(pytorch_values, mlx_values, rtol=1e-5, atol=1e-6) results[name] = results[name] and value_check else: results[name] = False return results def check_conversion_status(self, pytorch_weights: Dict[str, torch.Tensor], mlx_weights: Dict[str, mx.array], verification_results: Dict[str, bool]) -> Dict[str, Any]: """ Check comprehensive status of conversion to ensure it's safe to deploy Args: pytorch_weights: Original PyTorch weights mlx_weights: Converted MLX weights verification_results: Results from verify_conversion Returns: Status dictionary with detailed report """ status = { 'is_perfect': False, 'total_source_weights': len(pytorch_weights), 'total_converted_weights': len(mlx_weights), 'verification_passed': sum(1 for v in verification_results.values() if v), 'verification_failed': sum(1 for v in verification_results.values() if not v), 'verification_rate': 0.0, 'errors': [], 'warnings': [], 'safe_to_deploy': False, } # Calculate verification rate total_verified = len(verification_results) if total_verified > 0: status['verification_rate'] = (status['verification_passed'] / total_verified) * 100 # Check for critical issues if len(mlx_weights) == 0: status['errors'].append("No weights were converted - conversion failed completely") if status['verification_failed'] > 0: failed_weights = [name for name, result in verification_results.items() if not result] status['errors'].append( f"{status['verification_failed']} weight(s) failed verification: {failed_weights[:3]}" f"{'...' if len(failed_weights) > 3 else ''}" ) if len(mlx_weights) < len(pytorch_weights) * 0.5: status['warnings'].append( f"Only {len(mlx_weights)}/{len(pytorch_weights)} weights were converted " f"({(len(mlx_weights)/len(pytorch_weights)*100):.1f}%) - possible mapping issues" ) # Check data type consistency dtype_set = set() for weight in mlx_weights.values(): dtype_set.add(str(weight.dtype)) if len(dtype_set) > 1: status['warnings'].append(f"Mixed data types detected in converted weights: {dtype_set}") # Check for NaN or Inf values nan_inf_weights = [] for name, weight in mlx_weights.items(): weight_np = np.array(weight) if np.isnan(weight_np).any(): nan_inf_weights.append(f"{name} (NaN)") elif np.isinf(weight_np).any(): nan_inf_weights.append(f"{name} (Inf)") if nan_inf_weights: status['errors'].append(f"Weights contain NaN/Inf: {nan_inf_weights[:3]}") # Determine if safe to deploy status['is_perfect'] = ( len(status['errors']) == 0 and status['verification_rate'] == 100.0 and len(mlx_weights) > 0 ) # Conservative approach: only deploy if perfect status['safe_to_deploy'] = status['is_perfect'] if not status['safe_to_deploy'] and len(status['errors']) == 0: status['safe_to_deploy'] = ( status['verification_rate'] >= MIN_VERIFICATION_RATE and status['verification_failed'] <= MAX_VERIFICATION_FAILURES and len(nan_inf_weights) == 0 ) return status def print_status_report(self, status: Dict[str, Any]) -> None: """Print a formatted status report""" print("\n" + "="*70) print("CONVERSION STATUS REPORT") print("="*70) print(f"\nšŸ“Š Conversion Statistics:") print(f" Total source weights: {status['total_source_weights']}") print(f" Total converted weights: {status['total_converted_weights']}") print(f" Verification passed: {status['verification_passed']}/{status['verification_passed'] + status['verification_failed']}") print(f" Verification rate: {status['verification_rate']:.1f}%") if status['errors']: print(f"\nāŒ Errors ({len(status['errors'])}):") for error in status['errors']: print(f" • {error}") if status['warnings']: print(f"\nāš ļø Warnings ({len(status['warnings'])}):") for warning in status['warnings']: print(f" • {warning}") print(f"\nšŸ” Deployment Decision:") if status['is_perfect']: print(f" Status: āœ… PERFECT - All checks passed") else: print(f" Status: {'āœ… ACCEPTABLE' if status['safe_to_deploy'] else 'āŒ NOT SAFE'}") print(f" Safe to deploy: {'āœ… YES' if status['safe_to_deploy'] else 'āŒ NO'}") print("\n" + "="*70) def create_model_metadata(self, original_repo: str, config: Dict[str, Any]) -> Dict[str, Any]: """Create metadata for the converted model""" return { "converted_from": original_repo, "conversion_date": self.get_current_date(), "framework": "mlx", "model_type": "campp", "architecture": "d-tdnn", "license": "apache-2.0", "tags": ["speaker-recognition", "audio", "mlx", "apple-silicon"], "task": "speaker-verification", "library_name": "mlx", "datasets": ["voxceleb", "cnceleb"], "metrics": { "voxceleb1_eer": "0.65%", "parameters": "7.2M", "inference_speed": "optimized_for_apple_silicon" }, **config } def get_current_date(self) -> str: """Get current date in ISO format""" return datetime.now().isoformat() def estimate_model_performance(self, weights: Dict[str, mx.array]) -> Dict[str, Any]: """Estimate model performance characteristics""" total_params = sum(w.size for w in weights.values()) # Estimate memory usage (rough approximation) total_bytes = total_params * 4 # Assuming fp32 memory_mb = total_bytes / (1024 * 1024) # Estimate model complexity conv_layers = sum(1 for name in weights.keys() if 'conv' in name.lower()) linear_layers = sum(1 for name in weights.keys() if any(x in name.lower() for x in ['linear', 'fc'])) return { "total_parameters": total_params, "estimated_memory_mb": memory_mb, "conv_layers": conv_layers, "linear_layers": linear_layers, "model_complexity": "efficient" if total_params < 10e6 else "standard" } def optimize_for_inference(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: """Apply MLX-specific optimizations for inference""" optimized_weights = {} for name, weight in weights.items(): # Ensure weights are in optimal format for MLX optimized_weight = mx.array(weight) # MLX-specific optimizations could go here # For now, just ensure proper data type if optimized_weight.dtype != mx.float32: optimized_weight = optimized_weight.astype(mx.float32) optimized_weights[name] = optimized_weight return optimized_weights def test_conversion(): """Test the conversion utilities with comprehensive status checking""" utils = ConversionUtils() # Create dummy PyTorch xvector weights (proper source format) dummy_weights = { # Input layer 'xvector.tdnn.linear.weight': torch.randn(64, 80, 3), 'xvector.tdnn.nonlinear.batchnorm.weight': torch.randn(64), 'xvector.tdnn.nonlinear.batchnorm.bias': torch.randn(64), 'xvector.tdnn.nonlinear.batchnorm.running_mean': torch.randn(64), 'xvector.tdnn.nonlinear.batchnorm.running_var': torch.randn(64), # Dense block 0 'xvector.block1.tdnnd1.linear1.weight': torch.randn(32, 64, 3), 'xvector.block1.tdnnd1.nonlinear1.batchnorm.weight': torch.randn(32), 'xvector.block1.tdnnd1.nonlinear1.batchnorm.bias': torch.randn(32), # Transition layer 'xvector.transit1.linear.weight': torch.randn(256, 96, 1), 'xvector.transit1.nonlinear.batchnorm.weight': torch.randn(256), 'xvector.transit1.nonlinear.batchnorm.bias': torch.randn(256), # Final layer 'xvector.out_nonlinear.batchnorm.weight': torch.randn(512), 'xvector.out_nonlinear.batchnorm.bias': torch.randn(512), } # Convert mlx_weights, config = utils.convert_weights_to_mlx(dummy_weights) # Verify conversion verification = {} print("Conversion test results:") # Get the mapping for each source weight for name, tensor in dummy_weights.items(): mlx_name = utils._xvector_to_mlx_name(name) if mlx_name and mlx_name in mlx_weights: pytorch_shape = tensor.shape mlx_shape = mlx_weights[mlx_name].shape matches = pytorch_shape == mlx_shape verification[name] = matches status = "āœ…" if matches else "āŒ" print(f" {status} {name} -> {mlx_name} | Shape: {pytorch_shape} -> {mlx_shape}") else: verification[name] = False status = "āŒ" print(f" {status} {name} (no mapping)") print(f"\nTotal weights converted: {len(mlx_weights)}") print(f"Inferred config: {config}") # Check conversion status status_report = utils.check_conversion_status(dummy_weights, mlx_weights, verification) utils.print_status_report(status_report) # Only return success if status is perfect and tests pass tests_passed = all(verification.values()) return tests_passed and status_report['is_perfect'] if __name__ == "__main__": test_passed = test_conversion() print(f"\n{'āœ…' if test_passed else 'āŒ'} Conversion utilities test {'passed' if test_passed else 'failed'}")