Spaces:
Sleeping
Sleeping
| import mlx.core as mx | |
| import numpy as np | |
| import torch | |
| from typing import Dict, Any, Tuple, Optional, List, Set | |
| from datetime import datetime | |
| import re | |
| import logging | |
| # Set up logging | |
| logger = logging.getLogger(__name__) | |
| # Constants for conversion thresholds | |
| MIN_QUANTIZATION_SIZE = 1000 # Don't quantize tensors smaller than this | |
| MIN_VERIFICATION_RATE = 95.0 # Minimum acceptable verification rate (%) | |
| MAX_VERIFICATION_FAILURES = 2 # Maximum allowed verification failures | |
| BATCHNORM_EPS = 1e-5 | |
| BATCHNORM_MOMENTUM = 0.1 | |
| class ConversionUtils: | |
| """Utilities for converting PyTorch CAM++ models to MLX format""" | |
| def __init__(self, use_modelscope_architecture: bool = True): | |
| """ | |
| Initialize conversion utilities | |
| Args: | |
| use_modelscope_architecture: If True, use ModelScope architecture with embedded CAM | |
| If False, use original architecture with shared CAM | |
| """ | |
| self.use_modelscope_architecture = use_modelscope_architecture | |
| self.layer_mapping = { | |
| 'conv1d': self._convert_conv1d, | |
| 'linear': self._convert_linear, | |
| 'batchnorm': self._convert_batchnorm, | |
| 'embedding': self._convert_embedding | |
| } | |
| def convert_weights_to_mlx(self, pytorch_weights: Dict[str, torch.Tensor]) -> Tuple[Dict[str, mx.array], Dict[str, Any]]: | |
| """ | |
| Convert PyTorch weights to MLX format | |
| Args: | |
| pytorch_weights: Dictionary of PyTorch tensors | |
| Returns: | |
| Tuple of (mlx_weights, model_config) | |
| """ | |
| mlx_weights = {} | |
| model_config = self._analyze_model_structure(pytorch_weights) | |
| # Filter out unnecessary parameters (BatchNorm running stats, etc.) | |
| filtered_weights = self._filter_weights(pytorch_weights) | |
| # Map parameter names from PyTorch to MLX format | |
| mapped_weights = self._map_parameter_names(filtered_weights) | |
| # Add default values for missing MLX parameters | |
| mapped_weights = self._add_missing_parameters(mapped_weights, model_config) | |
| # Convert each weight tensor | |
| for name, tensor in mapped_weights.items(): | |
| if isinstance(tensor, torch.Tensor): | |
| converted = self._convert_tensor(name, tensor) | |
| # Skip None values (e.g., num_batches_tracked) | |
| if converted is not None: | |
| mlx_weights[name] = converted | |
| else: | |
| # Handle non-tensor values (e.g., integers, strings) | |
| continue | |
| return mlx_weights, model_config | |
| def _analyze_model_structure(self, pytorch_weights: Dict[str, torch.Tensor]) -> Dict[str, Any]: | |
| """ | |
| Analyze the PyTorch model structure to infer configuration | |
| Args: | |
| pytorch_weights: PyTorch weights dictionary | |
| Returns: | |
| Model configuration dictionary | |
| """ | |
| config = { | |
| 'input_dim': 80, # Default mel spectrogram features | |
| 'embedding_dim': 192, # Default embedding dimension for ModelScope | |
| 'channels': 512, # Default number of channels | |
| 'cam_channels': 128, # Default CAM channels | |
| } | |
| # Detect block structure for ModelScope architecture | |
| if self.use_modelscope_architecture: | |
| blocks = {1: set(), 2: set(), 3: set()} | |
| for name in pytorch_weights.keys(): | |
| if 'xvector.block1.tdnnd' in name: | |
| layer_num = name.split('tdnnd')[1].split('.')[0] | |
| blocks[1].add(int(layer_num)) | |
| elif 'xvector.block2.tdnnd' in name: | |
| layer_num = name.split('tdnnd')[1].split('.')[0] | |
| blocks[2].add(int(layer_num)) | |
| elif 'xvector.block3.tdnnd' in name: | |
| layer_num = name.split('tdnnd')[1].split('.')[0] | |
| blocks[3].add(int(layer_num)) | |
| # Set block_layers configuration | |
| if any(blocks.values()): | |
| config['block_layers'] = [ | |
| len(blocks[1]) if blocks[1] else 4, # Default to 4 if not found | |
| len(blocks[2]) if blocks[2] else 9, # Default to 9 if not found | |
| len(blocks[3]) if blocks[3] else 16 # Default to 16 if not found | |
| ] | |
| logger.info(f"Detected block structure: {config['block_layers']}") | |
| # Try to infer input dimension and kernel size from first conv layer | |
| for name, tensor in pytorch_weights.items(): | |
| if 'xvector.tdnn.linear.weight' in name: | |
| if tensor.ndim == 3: # Conv1d weight: (out_channels, in_channels, kernel_size) | |
| config['input_dim'] = tensor.shape[1] # in_channels | |
| config['channels'] = tensor.shape[0] # out_channels | |
| config['input_kernel_size'] = tensor.shape[2] # kernel_size | |
| logger.info(f"Detected input layer: dim={config['input_dim']}, channels={config['channels']}, kernel_size={config['input_kernel_size']}") | |
| break | |
| # Try to infer embedding dimension from dense layer | |
| for name, tensor in pytorch_weights.items(): | |
| if 'xvector.dense.linear.weight' in name: | |
| if tensor.ndim == 3: # Conv1d with kernel_size=1 | |
| config['embedding_dim'] = tensor.shape[0] # out_channels | |
| break | |
| # Count total parameters for estimation | |
| total_params = sum(tensor.numel() for tensor in pytorch_weights.values()) | |
| config['total_params'] = total_params | |
| return config | |
| def _map_parameter_names(self, pytorch_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |
| """ | |
| Map PyTorch parameter names to MLX parameter names | |
| Args: | |
| pytorch_weights: PyTorch weights with original names | |
| Returns: | |
| Weights with MLX-compatible parameter names | |
| """ | |
| mapped_weights = {} | |
| for name, tensor in pytorch_weights.items(): | |
| # Choose mapping function based on architecture | |
| if self.use_modelscope_architecture: | |
| mlx_name = self._xvector_to_mlx_modelscope_name(name) | |
| else: | |
| mlx_name = self._xvector_to_mlx_name(name) | |
| if mlx_name: # Only keep parameters that have MLX equivalents | |
| mapped_weights[mlx_name] = tensor | |
| return mapped_weights | |
| def _add_missing_parameters(self, mapped_weights: Dict[str, torch.Tensor], model_config: Dict) -> Dict[str, torch.Tensor]: | |
| """ | |
| Add default values for MLX parameters that don't have PyTorch equivalents | |
| Args: | |
| mapped_weights: Already mapped weights | |
| model_config: Model configuration | |
| Returns: | |
| Weights with missing parameters added | |
| Note: This method intentionally does NOT add fake/random parameters. | |
| Adding untrained random weights will degrade model accuracy significantly. | |
| The conversion should only include weights that are actually mapped from | |
| the source model. Better to fail explicitly when a layer is missing than | |
| to add random weights that produce nonsensical outputs. | |
| """ | |
| # Return mapped weights as-is without adding arbitrary fake parameters | |
| return mapped_weights | |
| def get_missing_mlx_parameters(self, pytorch_weights: Dict[str, torch.Tensor], mlx_weights: Dict[str, mx.array]) -> Dict[str, str]: | |
| """ | |
| Get list of MLX parameters that don't have source PyTorch equivalents | |
| Args: | |
| pytorch_weights: Original PyTorch weights | |
| mlx_weights: Converted MLX weights | |
| Returns: | |
| Dictionary mapping MLX parameter names to their source parameter names (or "NOT FOUND") | |
| """ | |
| missing_params = {} | |
| # Define expected MLX model parameters | |
| expected_mlx_params = { | |
| # Input layer | |
| 'input_conv.weight', 'input_bn.weight', 'input_bn.bias', | |
| 'input_bn.running_mean', 'input_bn.running_var', | |
| # Dense blocks (0-2) | |
| 'dense_blocks.0.layers.0.conv.weight', 'dense_blocks.0.layers.0.bn.weight', 'dense_blocks.0.layers.0.bn.bias', | |
| 'dense_blocks.0.layers.0.bn.running_mean', 'dense_blocks.0.layers.0.bn.running_var', | |
| 'dense_blocks.0.layers.1.conv.weight', 'dense_blocks.0.layers.1.bn.weight', 'dense_blocks.0.layers.1.bn.bias', | |
| 'dense_blocks.0.layers.2.conv.weight', 'dense_blocks.0.layers.2.bn.weight', 'dense_blocks.0.layers.2.bn.bias', | |
| 'dense_blocks.0.layers.3.conv.weight', 'dense_blocks.0.layers.3.bn.weight', 'dense_blocks.0.layers.3.bn.bias', | |
| # Transitions | |
| 'transitions.0.layers.0.weight', 'transitions.0.layers.0.bias', | |
| 'transitions.0.layers.0.running_mean', 'transitions.0.layers.0.running_var', | |
| 'transitions.0.layers.2.weight', | |
| 'transitions.1.layers.0.weight', 'transitions.1.layers.0.bias', | |
| 'transitions.1.layers.0.running_mean', 'transitions.1.layers.0.running_var', | |
| 'transitions.1.layers.2.weight', | |
| # CAM layer | |
| 'cam.context_conv1.weight', 'cam.context_conv1.bias', | |
| 'cam.context_conv3.weight', 'cam.context_conv3.bias', | |
| 'cam.context_conv5.weight', 'cam.context_conv5.bias', | |
| 'cam.mask_conv.weight', 'cam.mask_conv.bias', | |
| 'cam.bn.weight', 'cam.bn.bias', 'cam.bn.running_mean', 'cam.bn.running_var', | |
| # Channel gating | |
| 'channel_gating.fc.layers.0.weight', 'channel_gating.fc.layers.0.bias', | |
| 'channel_gating.fc.layers.1.weight', 'channel_gating.fc.layers.1.bias', | |
| 'channel_gating.fc.layers.2.weight', 'channel_gating.fc.layers.2.bias', | |
| # Pooling | |
| 'pooling.attention_weights.weight', 'pooling.attention_weights.bias', | |
| 'pooling.projection.weight', 'pooling.projection.bias', | |
| # Final layer | |
| 'final_bn.weight', 'final_bn.bias', 'final_bn.running_mean', 'final_bn.running_var', | |
| } | |
| # Check which expected parameters are missing from converted weights | |
| for param in expected_mlx_params: | |
| if param not in mlx_weights: | |
| missing_params[param] = "NOT FOUND" | |
| return missing_params | |
| def _xvector_to_mlx_modelscope_name(self, xvector_name: str) -> Optional[str]: | |
| """ | |
| Convert xvector parameter name to MLX ModelScope architecture parameter name | |
| This mapping is for ModelScope CAM++ models where CAM is embedded in each TDNN layer. | |
| Architecture: | |
| - Input layer (TDNN) | |
| - Block 1: 4 TDNN layers with embedded CAM | |
| - Transit 1 | |
| - Block 2: 9 TDNN layers with embedded CAM | |
| - Transit 2 | |
| - Block 3: 16 TDNN layers with embedded CAM | |
| - Dense layer (Conv1d kernel_size=1) | |
| Args: | |
| xvector_name: Original xvector parameter name from PyTorch model | |
| Returns: | |
| MLX-compatible parameter name, or None if parameter should be skipped | |
| """ | |
| # ========== INPUT LAYER ========== | |
| if xvector_name == 'xvector.tdnn.linear.weight': | |
| return 'input_conv.weight' | |
| elif 'xvector.tdnn.nonlinear.batchnorm' in xvector_name: | |
| param_type = xvector_name.split('.')[-1] # bias, weight, running_mean, running_var | |
| # Skip num_batches_tracked (PyTorch tracking statistic, not needed) | |
| if param_type == 'num_batches_tracked': | |
| return None | |
| return f'input_bn.{param_type}' | |
| # ========== DENSE BLOCKS WITH EMBEDDED CAM ========== | |
| # Extract block number and layer number | |
| import re | |
| block_match = re.match(r'xvector\.block(\d+)\.tdnnd(\d+)\.(.*)', xvector_name) | |
| if block_match: | |
| block_num = int(block_match.group(1)) # 1, 2, or 3 | |
| layer_num = int(block_match.group(2)) # 1-indexed | |
| param_path = block_match.group(3) | |
| # Map to MLX block index (0, 1, 2) | |
| mlx_block_idx = block_num - 1 | |
| # Map to MLX layer index (0-indexed) | |
| mlx_layer_idx = layer_num - 1 | |
| # Main TDNN layer parameters | |
| if param_path.startswith('linear1.'): | |
| param_type = param_path.split('.')[-1] | |
| return f'block{mlx_block_idx}_{mlx_layer_idx}.conv.{param_type}' | |
| # PyTorch has TWO batch norms per layer: | |
| # - nonlinear1.batchnorm: sized for INPUT channels (applied before conv) | |
| # - nonlinear2.batchnorm: sized for OUTPUT channels (applied after conv) | |
| # MLX model only has one BN (after conv), so map nonlinear2 to bn | |
| elif param_path.startswith('nonlinear1.batchnorm.'): | |
| # Skip nonlinear1 batch norm - it's sized for input channels | |
| return None | |
| elif param_path.startswith('nonlinear2.batchnorm.'): | |
| param_type = param_path.split('.')[-1] | |
| # Skip num_batches_tracked | |
| if param_type == 'num_batches_tracked': | |
| return None | |
| return f'block{mlx_block_idx}_{mlx_layer_idx}.bn.{param_type}' | |
| # Embedded CAM layer parameters | |
| elif param_path.startswith('cam_layer.linear1.'): | |
| param_type = param_path.split('.')[-1] | |
| return f'block{mlx_block_idx}_{mlx_layer_idx}.cam.linear1.{param_type}' | |
| elif param_path.startswith('cam_layer.linear2.'): | |
| param_type = param_path.split('.')[-1] | |
| return f'block{mlx_block_idx}_{mlx_layer_idx}.cam.linear2.{param_type}' | |
| elif param_path.startswith('cam_layer.linear_local.'): | |
| param_type = param_path.split('.')[-1] | |
| return f'block{mlx_block_idx}_{mlx_layer_idx}.cam.linear_local.{param_type}' | |
| # ========== TRANSITION LAYERS ========== | |
| if 'xvector.transit1.' in xvector_name: | |
| if '.linear.weight' in xvector_name: | |
| return 'transit1.conv.weight' | |
| elif 'nonlinear.batchnorm' in xvector_name: | |
| param_type = xvector_name.split('.')[-1] | |
| # Skip num_batches_tracked | |
| if param_type == 'num_batches_tracked': | |
| return None | |
| return f'transit1.bn.{param_type}' | |
| if 'xvector.transit2.' in xvector_name: | |
| if '.linear.weight' in xvector_name: | |
| return 'transit2.conv.weight' | |
| elif 'nonlinear.batchnorm' in xvector_name: | |
| param_type = xvector_name.split('.')[-1] | |
| # Skip num_batches_tracked | |
| if param_type == 'num_batches_tracked': | |
| return None | |
| return f'transit2.bn.{param_type}' | |
| # ========== DENSE LAYER ========== | |
| if 'xvector.dense.linear.' in xvector_name: | |
| param_type = xvector_name.split('.')[-1] | |
| return f'dense.{param_type}' | |
| # ========== SKIP UNMAPPED PARAMETERS ========== | |
| # These don't exist in ModelScope architecture | |
| if any(x in xvector_name for x in ['head.', 'output.', 'pool', 'final_bn']): | |
| logger.debug(f"Skipping parameter not in ModelScope architecture: {xvector_name}") | |
| return None | |
| # Log unexpected parameters | |
| if xvector_name.startswith('xvector.'): | |
| logger.debug(f"Skipping unmapped parameter: {xvector_name}") | |
| return None | |
| def _xvector_to_mlx_name(self, xvector_name: str) -> Optional[str]: | |
| """ | |
| Convert xvector parameter name to MLX parameter name with comprehensive mapping | |
| This method maps PyTorch CAM++ xvector parameters to MLX CAMPPModel parameters. | |
| It handles: | |
| - Input layer (TDNN) | |
| - Dense blocks (3 blocks with 4, 6, 8 layers respectively) | |
| - Transition layers between blocks | |
| - Context-Aware Masking (CAM) layer | |
| - Channel gating mechanism | |
| - Multi-granularity pooling | |
| - Final batch normalization | |
| Args: | |
| xvector_name: Original xvector parameter name from PyTorch model | |
| Returns: | |
| MLX-compatible parameter name, or None if parameter should be skipped | |
| """ | |
| # ========== INPUT LAYER MAPPING ========== | |
| if xvector_name == 'xvector.tdnn.linear.weight': | |
| return 'input_conv.weight' | |
| elif xvector_name == 'xvector.tdnn.nonlinear.batchnorm.bias': | |
| return 'input_bn.bias' | |
| elif xvector_name == 'xvector.tdnn.nonlinear.batchnorm.weight': | |
| return 'input_bn.weight' | |
| elif xvector_name == 'xvector.tdnn.nonlinear.batchnorm.running_mean': | |
| return 'input_bn.running_mean' | |
| elif xvector_name == 'xvector.tdnn.nonlinear.batchnorm.running_var': | |
| return 'input_bn.running_var' | |
| # ========== DENSE BLOCKS MAPPING ========== | |
| # MLX architecture: block 0 (4 layers), block 1 (6 layers), block 2 (8 layers) | |
| # Map PyTorch block1/block2/block3 to MLX dense_blocks.0/1/2 | |
| # Block 0: Map first 4 layers of PyTorch block1 | |
| for i in range(1, 13): # Handle up to 12 layers (generous for real models) | |
| # Block 0 - first 4 layers | |
| if i <= 4 and f'xvector.block1.tdnnd{i}.' in xvector_name: | |
| layer_idx = i - 1 | |
| if '.linear1.weight' in xvector_name: | |
| return f'dense_blocks.0.layers.{layer_idx}.conv.weight' | |
| elif '.nonlinear1.batchnorm.bias' in xvector_name: | |
| return f'dense_blocks.0.layers.{layer_idx}.bn.bias' | |
| elif '.nonlinear1.batchnorm.weight' in xvector_name: | |
| return f'dense_blocks.0.layers.{layer_idx}.bn.weight' | |
| elif '.nonlinear1.batchnorm.running_mean' in xvector_name: | |
| return f'dense_blocks.0.layers.{layer_idx}.bn.running_mean' | |
| elif '.nonlinear1.batchnorm.running_var' in xvector_name: | |
| return f'dense_blocks.0.layers.{layer_idx}.bn.running_var' | |
| # Block 1 - first 6 layers of PyTorch block2 | |
| # Skip block2.tdnnd1 and block2.tdnnd2 as they may be used for transition | |
| if i >= 3 and i <= 8 and f'xvector.block2.tdnnd{i}.' in xvector_name: | |
| layer_idx = i - 3 # Map block2.tdnnd3 -> layer 0, etc. | |
| if layer_idx < 6: # Only map first 6 layers | |
| if '.linear1.weight' in xvector_name: | |
| return f'dense_blocks.1.layers.{layer_idx}.conv.weight' | |
| elif '.nonlinear1.batchnorm.bias' in xvector_name: | |
| return f'dense_blocks.1.layers.{layer_idx}.bn.bias' | |
| elif '.nonlinear1.batchnorm.weight' in xvector_name: | |
| return f'dense_blocks.1.layers.{layer_idx}.bn.weight' | |
| elif '.nonlinear1.batchnorm.running_mean' in xvector_name: | |
| return f'dense_blocks.1.layers.{layer_idx}.bn.running_mean' | |
| elif '.nonlinear1.batchnorm.running_var' in xvector_name: | |
| return f'dense_blocks.1.layers.{layer_idx}.bn.running_var' | |
| # Block 2 - first 8 layers of PyTorch block3 | |
| if i <= 8 and f'xvector.block3.tdnnd{i}.' in xvector_name: | |
| layer_idx = i - 1 | |
| if '.linear1.weight' in xvector_name: | |
| return f'dense_blocks.2.layers.{layer_idx}.conv.weight' | |
| elif '.nonlinear1.batchnorm.bias' in xvector_name: | |
| return f'dense_blocks.2.layers.{layer_idx}.bn.bias' | |
| elif '.nonlinear1.batchnorm.weight' in xvector_name: | |
| return f'dense_blocks.2.layers.{layer_idx}.bn.weight' | |
| elif '.nonlinear1.batchnorm.running_mean' in xvector_name: | |
| return f'dense_blocks.2.layers.{layer_idx}.bn.running_mean' | |
| elif '.nonlinear1.batchnorm.running_var' in xvector_name: | |
| return f'dense_blocks.2.layers.{layer_idx}.bn.running_var' | |
| # ========== TRANSITION LAYERS MAPPING ========== | |
| # Transition 0: After block 0 | |
| if 'xvector.transit1.' in xvector_name: | |
| if '.linear.weight' in xvector_name: | |
| return 'transitions.0.layers.2.weight' | |
| elif '.nonlinear.batchnorm.bias' in xvector_name: | |
| return 'transitions.0.layers.0.bias' | |
| elif '.nonlinear.batchnorm.weight' in xvector_name: | |
| return 'transitions.0.layers.0.weight' | |
| elif '.nonlinear.batchnorm.running_mean' in xvector_name: | |
| return 'transitions.0.layers.0.running_mean' | |
| elif '.nonlinear.batchnorm.running_var' in xvector_name: | |
| return 'transitions.0.layers.0.running_var' | |
| # Transition 1: Use block2.tdnnd1 and tdnnd2 (before dense block 1) | |
| if 'xvector.transit2.' in xvector_name or 'xvector.block2.tdnnd1.' in xvector_name: | |
| # Map transit2 or beginning of block2 to transition 1 | |
| if '.linear.weight' in xvector_name or 'xvector.block2.tdnnd2.linear1.weight' in xvector_name: | |
| return 'transitions.1.layers.2.weight' | |
| elif '.nonlinear.batchnorm.bias' in xvector_name or 'xvector.block2.tdnnd1.nonlinear1.batchnorm.bias' in xvector_name: | |
| return 'transitions.1.layers.0.bias' | |
| elif '.nonlinear.batchnorm.weight' in xvector_name or 'xvector.block2.tdnnd1.nonlinear1.batchnorm.weight' in xvector_name: | |
| return 'transitions.1.layers.0.weight' | |
| elif '.nonlinear.batchnorm.running_mean' in xvector_name or 'xvector.block2.tdnnd1.nonlinear1.batchnorm.running_mean' in xvector_name: | |
| return 'transitions.1.layers.0.running_mean' | |
| elif '.nonlinear.batchnorm.running_var' in xvector_name or 'xvector.block2.tdnnd1.nonlinear1.batchnorm.running_var' in xvector_name: | |
| return 'transitions.1.layers.0.running_var' | |
| # ========== CAM LAYER MAPPING ========== | |
| # Context-aware masking with multi-scale convolutions | |
| # NOTE: Real ModelScope models have CAM embedded in EACH TDNN layer, | |
| # but MLX model has ONE shared CAM layer. We map only the first occurrence | |
| # from block1.tdnnd1.cam_layer and skip all others. | |
| if 'cam_layer' in xvector_name or 'cam.' in xvector_name: | |
| # Only map CAM from the first block's first layer | |
| # Skip CAM from all other layers to avoid conflicts | |
| is_first_cam = 'block1.tdnnd1.cam_layer' in xvector_name | |
| if not is_first_cam: | |
| logger.debug(f"Skipping embedded CAM layer (only using first occurrence): {xvector_name}") | |
| return None | |
| # Map first CAM layer to MLX shared CAM | |
| # ModelScope structure: linear1 (1x1 conv), linear2 (1x1 conv), linear_local (3x3 conv) | |
| # MLX structure: context_conv1 (1x1), context_conv3 (3x3), context_conv5 (5x5) | |
| if 'cam_layer.linear1.weight' in xvector_name: | |
| return 'cam.context_conv1.weight' | |
| elif 'cam_layer.linear1.bias' in xvector_name: | |
| logger.debug(f"Skipping CAM context_conv1 bias (MLX uses bias=False): {xvector_name}") | |
| return None # MLX context_conv1 has bias=False | |
| elif 'cam_layer.linear2.weight' in xvector_name: | |
| # Map linear2 (1x1) to context_conv3 - note: this is a compromise | |
| # Real model has 1x1 conv here, MLX expects 3x3 | |
| logger.warning(f"Mapping 1x1 conv to context_conv3 (shape mismatch possible): {xvector_name}") | |
| return 'cam.context_conv3.weight' | |
| elif 'cam_layer.linear2.bias' in xvector_name: | |
| logger.debug(f"Skipping CAM context_conv3 bias (MLX uses bias=False): {xvector_name}") | |
| return None # MLX context_conv3 has bias=False | |
| elif 'cam_layer.linear_local.weight' in xvector_name: | |
| # Map linear_local (3x3) to mask_conv | |
| return 'cam.mask_conv.weight' | |
| elif 'cam_layer.linear_local.bias' in xvector_name: | |
| # linear_local typically has no bias in ModelScope models | |
| logger.debug(f"Skipping CAM mask_conv bias: {xvector_name}") | |
| return None | |
| # Handle standalone cam. parameters (if model has separate CAM layer) | |
| elif 'context1.weight' in xvector_name or (not 'cam_layer' in xvector_name and 'linear1.weight' in xvector_name): | |
| return 'cam.context_conv1.weight' | |
| elif 'context3.weight' in xvector_name or (not 'cam_layer' in xvector_name and 'linear2.weight' in xvector_name): | |
| return 'cam.context_conv3.weight' | |
| elif 'context5.weight' in xvector_name or (not 'cam_layer' in xvector_name and 'linear3.weight' in xvector_name): | |
| return 'cam.context_conv5.weight' | |
| elif 'mask_conv.weight' in xvector_name: | |
| return 'cam.mask_conv.weight' | |
| elif 'fusion.weight' in xvector_name: | |
| return 'cam.fusion.weight' | |
| # Batch normalization | |
| elif 'batchnorm.weight' in xvector_name: | |
| return 'cam.bn.weight' | |
| elif 'batchnorm.bias' in xvector_name: | |
| return 'cam.bn.bias' | |
| elif 'running_mean' in xvector_name: | |
| return 'cam.bn.running_mean' | |
| elif 'running_var' in xvector_name: | |
| return 'cam.bn.running_var' | |
| # ========== CHANNEL GATING MAPPING ========== | |
| # Channel-wise context gating (squeeze-excitation style) | |
| # NOTE: Real ModelScope models only have xvector.dense.linear (single layer) | |
| # MLX model expects 3-layer FC, but real model has only 1 layer | |
| if 'xvector.dense.' in xvector_name: | |
| if '.linear.weight' in xvector_name or 'xvector.dense.linear.weight' == xvector_name: | |
| # Map to first layer - this is the only dense layer in real model | |
| return 'channel_gating.fc.layers.0.weight' | |
| elif '.linear.bias' in xvector_name or 'xvector.dense.linear.bias' == xvector_name: | |
| # Check if bias exists (some models use Conv1d without bias) | |
| logger.debug(f"Mapping dense bias (may not exist in Conv1d): {xvector_name}") | |
| return 'channel_gating.fc.layers.0.bias' | |
| # The following layers don't exist in real ModelScope models | |
| elif 'linear_mid.weight' in xvector_name: | |
| logger.warning(f"Found linear_mid layer (unexpected in ModelScope model): {xvector_name}") | |
| return 'channel_gating.fc.layers.1.weight' | |
| elif 'linear_mid.bias' in xvector_name: | |
| return 'channel_gating.fc.layers.1.bias' | |
| elif 'linear_out.weight' in xvector_name: | |
| logger.warning(f"Found linear_out layer (unexpected in ModelScope model): {xvector_name}") | |
| return 'channel_gating.fc.layers.2.weight' | |
| elif 'linear_out.bias' in xvector_name: | |
| return 'channel_gating.fc.layers.2.bias' | |
| # ========== POOLING LAYER MAPPING ========== | |
| # Multi-granularity statistical pooling | |
| # NOTE: Real ModelScope models typically DON'T have xvector.output or pooling layers | |
| # These models are feature extractors that end at xvector.dense | |
| if 'xvector.output.' in xvector_name or 'xvector.pool' in xvector_name: | |
| logger.warning(f"Found pooling/output layer (rare in ModelScope models): {xvector_name}") | |
| if 'xvector.output.linear.weight' == xvector_name: | |
| return 'pooling.attention_weights.weight' | |
| elif 'xvector.output.linear.bias' == xvector_name: | |
| return 'pooling.attention_weights.bias' | |
| elif 'pool_output.linear.weight' in xvector_name or 'pooling.linear.weight' in xvector_name: | |
| return 'pooling.projection.weight' | |
| elif 'pool_output.linear.bias' in xvector_name or 'pooling.linear.bias' in xvector_name: | |
| return 'pooling.projection.bias' | |
| # ========== FINAL BATCH NORMALIZATION ========== | |
| if 'xvector.out_nonlinear.batchnorm.' in xvector_name or 'xvector.final_bn.' in xvector_name: | |
| if '.bias' in xvector_name: | |
| return 'final_bn.bias' | |
| elif '.weight' in xvector_name: | |
| return 'final_bn.weight' | |
| elif 'running_mean' in xvector_name: | |
| return 'final_bn.running_mean' | |
| elif 'running_var' in xvector_name: | |
| return 'final_bn.running_var' | |
| # ========== SKIP UNMAPPED PARAMETERS ========== | |
| # Log skipped parameters for debugging | |
| if xvector_name.startswith('xvector.'): | |
| logger.debug(f"Skipping unmapped parameter: {xvector_name}") | |
| return None | |
| def _filter_weights(self, pytorch_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |
| """ | |
| Filter out unnecessary parameters that shouldn't be converted to MLX | |
| Args: | |
| pytorch_weights: Original PyTorch weights dict | |
| Returns: | |
| Filtered weights dict | |
| """ | |
| filtered_weights = {} | |
| skipped_params = [] | |
| for name, tensor in pytorch_weights.items(): | |
| # Skip classification head parameters (not needed for inference) | |
| if name.startswith('head.'): | |
| skipped_params.append(name) | |
| continue | |
| # Keep all other parameters including BatchNorm running statistics | |
| # The mapping function will filter out parameters that don't have MLX equivalents | |
| filtered_weights[name] = tensor | |
| if skipped_params: | |
| print(f"Filtered out {len(skipped_params)} unnecessary parameters: {skipped_params[:5]}{'...' if len(skipped_params) > 5 else ''}") | |
| return filtered_weights | |
| def _convert_tensor(self, name: str, tensor: torch.Tensor) -> Optional[mx.array]: | |
| """ | |
| Convert individual tensor based on layer type and shape | |
| Args: | |
| name: Parameter name | |
| tensor: PyTorch tensor to convert | |
| Returns: | |
| MLX array, or None if parameter should be skipped (e.g., num_batches_tracked) | |
| """ | |
| # Convert to numpy first | |
| numpy_tensor = tensor.detach().cpu().numpy() | |
| # Biases don't need any conversion, just pass through | |
| if name.endswith('.bias'): | |
| return mx.array(numpy_tensor) | |
| # Determine layer type from name AND shape | |
| layer_type = self._identify_layer_type(name) | |
| # Override layer type based on actual tensor shape | |
| # This handles cases where Conv1d(kernel_size=1) is used but named like Linear | |
| if numpy_tensor.ndim == 3: | |
| # 3D tensor must be Conv1d, regardless of name | |
| layer_type = 'conv1d' | |
| elif numpy_tensor.ndim == 2 and layer_type == 'conv1d': | |
| # 2D tensor can't be Conv1d, must be Linear | |
| layer_type = 'linear' | |
| elif numpy_tensor.ndim == 1: | |
| # 1D tensor is likely BatchNorm or bias | |
| if 'bn' in name.lower() or 'batchnorm' in name.lower() or 'running' in name.lower(): | |
| layer_type = 'batchnorm' | |
| # Apply layer-specific transformations | |
| if layer_type in self.layer_mapping: | |
| numpy_tensor = self.layer_mapping[layer_type](name, numpy_tensor) | |
| # Handle None returns (e.g., num_batches_tracked) | |
| if numpy_tensor is None: | |
| return None | |
| # Convert to MLX array | |
| return mx.array(numpy_tensor) | |
| def _identify_layer_type(self, name: str) -> str: | |
| """Identify layer type from parameter name""" | |
| name_lower = name.lower() | |
| # BatchNorm check first (more specific) | |
| if 'bn' in name_lower or 'batchnorm' in name_lower or 'batch_norm' in name_lower: | |
| return 'batchnorm' | |
| # Conv1d check (including 'conv' in name) | |
| elif 'conv1d' in name_lower or 'conv' in name_lower: | |
| return 'conv1d' | |
| # Linear/FC check | |
| elif 'linear' in name_lower or 'fc' in name_lower or 'dense' in name_lower: | |
| return 'linear' | |
| # Embedding check | |
| elif 'embed' in name_lower: | |
| return 'embedding' | |
| else: | |
| return 'default' | |
| def _convert_conv1d(self, name: str, weight: np.ndarray) -> np.ndarray: | |
| """ | |
| Convert Conv1d weights from PyTorch to MLX format | |
| PyTorch Conv1d: (out_channels, in_channels, kernel_size) | |
| MLX Conv1d: (out_channels, kernel_size, in_channels) - DIFFERENT format! | |
| Special case: Conv1d with kernel_size=1 can be used as Linear layer | |
| Args: | |
| name: Parameter name (for error reporting) | |
| weight: Weight tensor as numpy array | |
| Returns: | |
| Converted weight tensor | |
| Raises: | |
| ValueError: If weight shape is invalid for Conv1d | |
| """ | |
| # Validate Conv1d weight shape | |
| if weight.ndim != 3: | |
| raise ValueError(f"Conv1d weight {name} must be 3D, got shape {weight.shape}") | |
| out_channels, in_channels, kernel_size = weight.shape | |
| # Validate kernel size is reasonable (1, 3, 5 are common) | |
| if kernel_size > 11: | |
| logger.warning(f"Unusual kernel size {kernel_size} for Conv1d {name}") | |
| # MLX Conv1d uses (out_channels, kernel_size, in_channels) format | |
| # Transpose from PyTorch's (out_channels, in_channels, kernel_size) | |
| # This applies to ALL kernel sizes, including kernel_size=1 | |
| mlx_weight = weight.transpose(0, 2, 1) | |
| logger.debug(f"Transposed Conv1d weight {name}: {weight.shape} -> {mlx_weight.shape}") | |
| return mlx_weight | |
| def _convert_linear(self, name: str, weight: np.ndarray) -> np.ndarray: | |
| """ | |
| Convert Linear layer weights | |
| PyTorch Linear: (out_features, in_features) | |
| MLX Linear: (out_features, in_features) - same format | |
| Args: | |
| name: Parameter name (for error reporting) | |
| weight: Weight tensor as numpy array | |
| Returns: | |
| Converted weight tensor | |
| Raises: | |
| ValueError: If weight shape is invalid for Linear | |
| """ | |
| if weight.ndim != 2: | |
| raise ValueError(f"Linear weight {name} must be 2D, got shape {weight.shape}") | |
| return weight # No change needed for linear layers | |
| def _convert_batchnorm(self, name: str, weight: np.ndarray) -> Optional[np.ndarray]: | |
| """ | |
| Convert BatchNorm parameters | |
| Args: | |
| name: Parameter name (for error reporting) | |
| weight: Weight/bias/running_mean/running_var tensor | |
| Returns: | |
| Converted tensor, or None if parameter should be skipped (e.g., num_batches_tracked) | |
| Raises: | |
| ValueError: If tensor shape is invalid for BatchNorm | |
| """ | |
| # Skip num_batches_tracked (it's a scalar tracking statistic, not needed in MLX) | |
| if 'num_batches_tracked' in name: | |
| logger.debug(f"Skipping num_batches_tracked (not needed in MLX): {name}") | |
| return None # Will be filtered out | |
| # BatchNorm parameters should be 1D vectors | |
| if weight.ndim != 1: | |
| raise ValueError(f"BatchNorm parameter {name} must be 1D, got shape {weight.shape}") | |
| # Check for NaN/Inf in running statistics | |
| if 'running_mean' in name or 'running_var' in name: | |
| if np.isnan(weight).any(): | |
| logger.warning(f"BatchNorm {name} contains NaN values - may indicate untrained model") | |
| if np.isinf(weight).any(): | |
| logger.warning(f"BatchNorm {name} contains Inf values - may indicate numerical instability") | |
| return weight | |
| def _convert_embedding(self, name: str, weight: np.ndarray) -> np.ndarray: | |
| """ | |
| Convert Embedding layer weights | |
| Args: | |
| name: Parameter name (unused but kept for API consistency) | |
| weight: Embedding weight tensor | |
| Returns: | |
| Converted weight tensor | |
| """ | |
| return weight # No change needed for embeddings | |
| def quantize_weights(self, weights: Dict[str, mx.array], | |
| bits: int = 4, group_size: int = 64) -> Dict[str, mx.array]: | |
| """ | |
| Quantize weights to reduce model size using MLX's built-in quantization | |
| Note: This creates a copy of the weights dictionary, so callers don't need to copy before calling. | |
| Args: | |
| weights: MLX weights dictionary | |
| bits: Number of bits for quantization (2, 4, or 8) | |
| group_size: Group size for quantization (32, 64, or 128) | |
| Returns: | |
| Quantized weights dictionary (new copy) | |
| """ | |
| # Create a new dictionary to avoid modifying the original | |
| quantized_weights = {} | |
| skipped_count = 0 | |
| quantized_count = 0 | |
| logger.info(f"Starting {bits}-bit quantization with group_size={group_size}...") | |
| for name, weight in weights.items(): | |
| if self._should_quantize(name, weight): | |
| try: | |
| # MLX quantization requires: | |
| # 1. At least 2D tensors | |
| # 2. Last dimension divisible by group_size | |
| if len(weight.shape) < 2: | |
| logger.debug(f"Skipping {name}: 1D tensor") | |
| quantized_weights[name] = weight | |
| skipped_count += 1 | |
| continue | |
| if weight.shape[-1] % group_size != 0: | |
| logger.debug(f"Skipping {name}: last dim {weight.shape[-1]} not divisible by {group_size}") | |
| quantized_weights[name] = weight | |
| skipped_count += 1 | |
| continue | |
| # Quantize using MLX's affine quantization | |
| w_q, scales, biases = mx.quantize(weight, group_size=group_size, bits=bits) | |
| # Store quantized weights with special naming for scales and biases | |
| # Format: name:qSCALES_GS64_B4 (scales for group_size=64, bits=4) | |
| # This reduces the number of keys compared to separate metadata arrays | |
| quantized_weights[name] = w_q | |
| quantized_weights[f"{name}:qSCALES_GS{group_size}_B{bits}"] = scales | |
| quantized_weights[f"{name}:qBIASES_GS{group_size}_B{bits}"] = biases | |
| quantized_count += 1 | |
| # Log size reduction | |
| original_size = weight.size * 4 # float32 = 4 bytes | |
| # Quantized size = packed weights + scales + biases | |
| quantized_size = w_q.nbytes + scales.nbytes + biases.nbytes | |
| reduction = (1 - quantized_size / original_size) * 100 | |
| logger.debug(f"Quantized {name}: {reduction:.1f}% size reduction ({original_size//1024}KB β {quantized_size//1024}KB)") | |
| except Exception as e: | |
| # If quantization fails for this weight, keep original | |
| logger.warning(f"Failed to quantize {name}: {e}, keeping original") | |
| quantized_weights[name] = weight | |
| skipped_count += 1 | |
| else: | |
| # Keep small weights in full precision | |
| quantized_weights[name] = weight | |
| skipped_count += 1 | |
| logger.info(f"Quantization complete: {quantized_count} weights quantized, {skipped_count} kept in full precision") | |
| return quantized_weights | |
| def _quantize_to_int8(self, weight: mx.array) -> mx.array: | |
| """ | |
| Quantize a weight tensor to 8-bit precision | |
| Args: | |
| weight: Weight tensor to quantize | |
| Returns: | |
| Quantized weight tensor | |
| """ | |
| # Simple symmetric quantization to int8 range | |
| # Find scale factor | |
| abs_max = mx.max(mx.abs(weight)) | |
| scale = abs_max / 127.0 | |
| if scale == 0: | |
| return weight | |
| # Quantize and dequantize | |
| quantized = mx.round(weight / scale) | |
| quantized = mx.clip(quantized, -127, 127) | |
| dequantized = quantized * scale | |
| return dequantized.astype(mx.float32) | |
| def _should_quantize(self, name: str, weight: mx.array) -> bool: | |
| """Determine if a weight should be quantized""" | |
| # Don't quantize very small tensors or bias terms | |
| if weight.size < MIN_QUANTIZATION_SIZE: | |
| return False | |
| # Don't quantize bias terms | |
| if 'bias' in name.lower(): | |
| return False | |
| # Don't quantize batchnorm parameters (weight, bias, running_mean, running_var) | |
| if any(bn_key in name.lower() for bn_key in ['bn', 'batchnorm', 'batch_norm', 'running_mean', 'running_var']): | |
| return False | |
| # Quantize large weight matrices (Conv, Linear) | |
| return True | |
| def verify_conversion(self, pytorch_weights: Dict[str, torch.Tensor], | |
| mlx_weights: Dict[str, mx.array]) -> Dict[str, bool]: | |
| """ | |
| Verify that conversion was successful by comparing shapes and values | |
| Args: | |
| pytorch_weights: Original PyTorch weights | |
| mlx_weights: Converted MLX weights | |
| Returns: | |
| Dictionary of verification results | |
| """ | |
| results = {} | |
| for name in pytorch_weights.keys(): | |
| if name in mlx_weights: | |
| pytorch_tensor = pytorch_weights[name] | |
| mlx_array = mlx_weights[name] | |
| # Compare basic properties | |
| pytorch_shape = pytorch_tensor.shape | |
| mlx_shape = mlx_array.shape | |
| # All layers should have matching shapes (no transpose) | |
| results[name] = pytorch_shape == mlx_shape | |
| # Additional verification: check if values are reasonable | |
| if results[name]: | |
| pytorch_values = pytorch_tensor.detach().cpu().numpy() | |
| mlx_values = np.array(mlx_array) | |
| # Check if the values are approximately equal | |
| value_check = np.allclose(pytorch_values, mlx_values, rtol=1e-5, atol=1e-6) | |
| results[name] = results[name] and value_check | |
| else: | |
| results[name] = False | |
| return results | |
| def check_conversion_status(self, pytorch_weights: Dict[str, torch.Tensor], | |
| mlx_weights: Dict[str, mx.array], | |
| verification_results: Dict[str, bool]) -> Dict[str, Any]: | |
| """ | |
| Check comprehensive status of conversion to ensure it's safe to deploy | |
| Args: | |
| pytorch_weights: Original PyTorch weights | |
| mlx_weights: Converted MLX weights | |
| verification_results: Results from verify_conversion | |
| Returns: | |
| Status dictionary with detailed report | |
| """ | |
| status = { | |
| 'is_perfect': False, | |
| 'total_source_weights': len(pytorch_weights), | |
| 'total_converted_weights': len(mlx_weights), | |
| 'verification_passed': sum(1 for v in verification_results.values() if v), | |
| 'verification_failed': sum(1 for v in verification_results.values() if not v), | |
| 'verification_rate': 0.0, | |
| 'errors': [], | |
| 'warnings': [], | |
| 'safe_to_deploy': False, | |
| } | |
| # Calculate verification rate | |
| total_verified = len(verification_results) | |
| if total_verified > 0: | |
| status['verification_rate'] = (status['verification_passed'] / total_verified) * 100 | |
| # Check for critical issues | |
| if len(mlx_weights) == 0: | |
| status['errors'].append("No weights were converted - conversion failed completely") | |
| if status['verification_failed'] > 0: | |
| failed_weights = [name for name, result in verification_results.items() if not result] | |
| status['errors'].append( | |
| f"{status['verification_failed']} weight(s) failed verification: {failed_weights[:3]}" | |
| f"{'...' if len(failed_weights) > 3 else ''}" | |
| ) | |
| if len(mlx_weights) < len(pytorch_weights) * 0.5: | |
| status['warnings'].append( | |
| f"Only {len(mlx_weights)}/{len(pytorch_weights)} weights were converted " | |
| f"({(len(mlx_weights)/len(pytorch_weights)*100):.1f}%) - possible mapping issues" | |
| ) | |
| # Check data type consistency | |
| dtype_set = set() | |
| for weight in mlx_weights.values(): | |
| dtype_set.add(str(weight.dtype)) | |
| if len(dtype_set) > 1: | |
| status['warnings'].append(f"Mixed data types detected in converted weights: {dtype_set}") | |
| # Check for NaN or Inf values | |
| nan_inf_weights = [] | |
| for name, weight in mlx_weights.items(): | |
| weight_np = np.array(weight) | |
| if np.isnan(weight_np).any(): | |
| nan_inf_weights.append(f"{name} (NaN)") | |
| elif np.isinf(weight_np).any(): | |
| nan_inf_weights.append(f"{name} (Inf)") | |
| if nan_inf_weights: | |
| status['errors'].append(f"Weights contain NaN/Inf: {nan_inf_weights[:3]}") | |
| # Determine if safe to deploy | |
| status['is_perfect'] = ( | |
| len(status['errors']) == 0 and | |
| status['verification_rate'] == 100.0 and | |
| len(mlx_weights) > 0 | |
| ) | |
| # Conservative approach: only deploy if perfect | |
| status['safe_to_deploy'] = status['is_perfect'] | |
| if not status['safe_to_deploy'] and len(status['errors']) == 0: | |
| status['safe_to_deploy'] = ( | |
| status['verification_rate'] >= MIN_VERIFICATION_RATE and | |
| status['verification_failed'] <= MAX_VERIFICATION_FAILURES and | |
| len(nan_inf_weights) == 0 | |
| ) | |
| return status | |
| def print_status_report(self, status: Dict[str, Any]) -> None: | |
| """Print a formatted status report""" | |
| print("\n" + "="*70) | |
| print("CONVERSION STATUS REPORT") | |
| print("="*70) | |
| print(f"\nπ Conversion Statistics:") | |
| print(f" Total source weights: {status['total_source_weights']}") | |
| print(f" Total converted weights: {status['total_converted_weights']}") | |
| print(f" Verification passed: {status['verification_passed']}/{status['verification_passed'] + status['verification_failed']}") | |
| print(f" Verification rate: {status['verification_rate']:.1f}%") | |
| if status['errors']: | |
| print(f"\nβ Errors ({len(status['errors'])}):") | |
| for error in status['errors']: | |
| print(f" β’ {error}") | |
| if status['warnings']: | |
| print(f"\nβ οΈ Warnings ({len(status['warnings'])}):") | |
| for warning in status['warnings']: | |
| print(f" β’ {warning}") | |
| print(f"\nπ Deployment Decision:") | |
| if status['is_perfect']: | |
| print(f" Status: β PERFECT - All checks passed") | |
| else: | |
| print(f" Status: {'β ACCEPTABLE' if status['safe_to_deploy'] else 'β NOT SAFE'}") | |
| print(f" Safe to deploy: {'β YES' if status['safe_to_deploy'] else 'β NO'}") | |
| print("\n" + "="*70) | |
| def create_model_metadata(self, original_repo: str, config: Dict[str, Any]) -> Dict[str, Any]: | |
| """Create metadata for the converted model""" | |
| return { | |
| "converted_from": original_repo, | |
| "conversion_date": self.get_current_date(), | |
| "framework": "mlx", | |
| "model_type": "campp", | |
| "architecture": "d-tdnn", | |
| "license": "apache-2.0", | |
| "tags": ["speaker-recognition", "audio", "mlx", "apple-silicon"], | |
| "task": "speaker-verification", | |
| "library_name": "mlx", | |
| "datasets": ["voxceleb", "cnceleb"], | |
| "metrics": { | |
| "voxceleb1_eer": "0.65%", | |
| "parameters": "7.2M", | |
| "inference_speed": "optimized_for_apple_silicon" | |
| }, | |
| **config | |
| } | |
| def get_current_date(self) -> str: | |
| """Get current date in ISO format""" | |
| return datetime.now().isoformat() | |
| def estimate_model_performance(self, weights: Dict[str, mx.array]) -> Dict[str, Any]: | |
| """Estimate model performance characteristics""" | |
| total_params = sum(w.size for w in weights.values()) | |
| # Estimate memory usage (rough approximation) | |
| total_bytes = total_params * 4 # Assuming fp32 | |
| memory_mb = total_bytes / (1024 * 1024) | |
| # Estimate model complexity | |
| conv_layers = sum(1 for name in weights.keys() if 'conv' in name.lower()) | |
| linear_layers = sum(1 for name in weights.keys() if any(x in name.lower() for x in ['linear', 'fc'])) | |
| return { | |
| "total_parameters": total_params, | |
| "estimated_memory_mb": memory_mb, | |
| "conv_layers": conv_layers, | |
| "linear_layers": linear_layers, | |
| "model_complexity": "efficient" if total_params < 10e6 else "standard" | |
| } | |
| def optimize_for_inference(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: | |
| """Apply MLX-specific optimizations for inference""" | |
| optimized_weights = {} | |
| for name, weight in weights.items(): | |
| # Ensure weights are in optimal format for MLX | |
| optimized_weight = mx.array(weight) | |
| # MLX-specific optimizations could go here | |
| # For now, just ensure proper data type | |
| if optimized_weight.dtype != mx.float32: | |
| optimized_weight = optimized_weight.astype(mx.float32) | |
| optimized_weights[name] = optimized_weight | |
| return optimized_weights | |
| def test_conversion(): | |
| """Test the conversion utilities with comprehensive status checking""" | |
| utils = ConversionUtils() | |
| # Create dummy PyTorch xvector weights (proper source format) | |
| dummy_weights = { | |
| # Input layer | |
| 'xvector.tdnn.linear.weight': torch.randn(64, 80, 3), | |
| 'xvector.tdnn.nonlinear.batchnorm.weight': torch.randn(64), | |
| 'xvector.tdnn.nonlinear.batchnorm.bias': torch.randn(64), | |
| 'xvector.tdnn.nonlinear.batchnorm.running_mean': torch.randn(64), | |
| 'xvector.tdnn.nonlinear.batchnorm.running_var': torch.randn(64), | |
| # Dense block 0 | |
| 'xvector.block1.tdnnd1.linear1.weight': torch.randn(32, 64, 3), | |
| 'xvector.block1.tdnnd1.nonlinear1.batchnorm.weight': torch.randn(32), | |
| 'xvector.block1.tdnnd1.nonlinear1.batchnorm.bias': torch.randn(32), | |
| # Transition layer | |
| 'xvector.transit1.linear.weight': torch.randn(256, 96, 1), | |
| 'xvector.transit1.nonlinear.batchnorm.weight': torch.randn(256), | |
| 'xvector.transit1.nonlinear.batchnorm.bias': torch.randn(256), | |
| # Final layer | |
| 'xvector.out_nonlinear.batchnorm.weight': torch.randn(512), | |
| 'xvector.out_nonlinear.batchnorm.bias': torch.randn(512), | |
| } | |
| # Convert | |
| mlx_weights, config = utils.convert_weights_to_mlx(dummy_weights) | |
| # Verify conversion | |
| verification = {} | |
| print("Conversion test results:") | |
| # Get the mapping for each source weight | |
| for name, tensor in dummy_weights.items(): | |
| mlx_name = utils._xvector_to_mlx_name(name) | |
| if mlx_name and mlx_name in mlx_weights: | |
| pytorch_shape = tensor.shape | |
| mlx_shape = mlx_weights[mlx_name].shape | |
| matches = pytorch_shape == mlx_shape | |
| verification[name] = matches | |
| status = "β " if matches else "β" | |
| print(f" {status} {name} -> {mlx_name} | Shape: {pytorch_shape} -> {mlx_shape}") | |
| else: | |
| verification[name] = False | |
| status = "β" | |
| print(f" {status} {name} (no mapping)") | |
| print(f"\nTotal weights converted: {len(mlx_weights)}") | |
| print(f"Inferred config: {config}") | |
| # Check conversion status | |
| status_report = utils.check_conversion_status(dummy_weights, mlx_weights, verification) | |
| utils.print_status_report(status_report) | |
| # Only return success if status is perfect and tests pass | |
| tests_passed = all(verification.values()) | |
| return tests_passed and status_report['is_perfect'] | |
| if __name__ == "__main__": | |
| test_passed = test_conversion() | |
| print(f"\n{'β ' if test_passed else 'β'} Conversion utilities test {'passed' if test_passed else 'failed'}") | |