Spaces:
Sleeping
Sleeping
File size: 53,428 Bytes
368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f ca0ebee 368806f ca0ebee 368806f 656e7f6 368806f 656e7f6 368806f 77877c8 656e7f6 77877c8 656e7f6 77877c8 656e7f6 77877c8 656e7f6 77877c8 656e7f6 77877c8 656e7f6 77877c8 656e7f6 77877c8 656e7f6 77877c8 656e7f6 77877c8 656e7f6 77877c8 656e7f6 77877c8 ca0ebee 656e7f6 ca0ebee 656e7f6 ca0ebee 656e7f6 ca0ebee 656e7f6 ca0ebee 656e7f6 ca0ebee 77877c8 ca0ebee 656e7f6 ca0ebee 656e7f6 ca0ebee 656e7f6 ca0ebee 656e7f6 ca0ebee 656e7f6 77877c8 656e7f6 77877c8 656e7f6 ca0ebee 656e7f6 ca0ebee 656e7f6 ca0ebee 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 656e7f6 368806f 77877c8 368806f 77877c8 368806f 17bd838 656e7f6 17bd838 656e7f6 17bd838 656e7f6 17bd838 368806f 17bd838 368806f 77877c8 368806f 77877c8 368806f 17bd838 77877c8 368806f 77877c8 368806f 17bd838 368806f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 | import mlx.core as mx
import numpy as np
import torch
from typing import Dict, Any, Tuple, Optional, List, Set
from datetime import datetime
import re
import logging
# Set up logging
logger = logging.getLogger(__name__)
# Constants for conversion thresholds
MIN_QUANTIZATION_SIZE = 1000 # Don't quantize tensors smaller than this
MIN_VERIFICATION_RATE = 95.0 # Minimum acceptable verification rate (%)
MAX_VERIFICATION_FAILURES = 2 # Maximum allowed verification failures
BATCHNORM_EPS = 1e-5
BATCHNORM_MOMENTUM = 0.1
class ConversionUtils:
"""Utilities for converting PyTorch CAM++ models to MLX format"""
def __init__(self, use_modelscope_architecture: bool = True):
"""
Initialize conversion utilities
Args:
use_modelscope_architecture: If True, use ModelScope architecture with embedded CAM
If False, use original architecture with shared CAM
"""
self.use_modelscope_architecture = use_modelscope_architecture
self.layer_mapping = {
'conv1d': self._convert_conv1d,
'linear': self._convert_linear,
'batchnorm': self._convert_batchnorm,
'embedding': self._convert_embedding
}
def convert_weights_to_mlx(self, pytorch_weights: Dict[str, torch.Tensor]) -> Tuple[Dict[str, mx.array], Dict[str, Any]]:
"""
Convert PyTorch weights to MLX format
Args:
pytorch_weights: Dictionary of PyTorch tensors
Returns:
Tuple of (mlx_weights, model_config)
"""
mlx_weights = {}
model_config = self._analyze_model_structure(pytorch_weights)
# Filter out unnecessary parameters (BatchNorm running stats, etc.)
filtered_weights = self._filter_weights(pytorch_weights)
# Map parameter names from PyTorch to MLX format
mapped_weights = self._map_parameter_names(filtered_weights)
# Add default values for missing MLX parameters
mapped_weights = self._add_missing_parameters(mapped_weights, model_config)
# Convert each weight tensor
for name, tensor in mapped_weights.items():
if isinstance(tensor, torch.Tensor):
converted = self._convert_tensor(name, tensor)
# Skip None values (e.g., num_batches_tracked)
if converted is not None:
mlx_weights[name] = converted
else:
# Handle non-tensor values (e.g., integers, strings)
continue
return mlx_weights, model_config
def _analyze_model_structure(self, pytorch_weights: Dict[str, torch.Tensor]) -> Dict[str, Any]:
"""
Analyze the PyTorch model structure to infer configuration
Args:
pytorch_weights: PyTorch weights dictionary
Returns:
Model configuration dictionary
"""
config = {
'input_dim': 80, # Default mel spectrogram features
'embedding_dim': 192, # Default embedding dimension for ModelScope
'channels': 512, # Default number of channels
'cam_channels': 128, # Default CAM channels
}
# Detect block structure for ModelScope architecture
if self.use_modelscope_architecture:
blocks = {1: set(), 2: set(), 3: set()}
for name in pytorch_weights.keys():
if 'xvector.block1.tdnnd' in name:
layer_num = name.split('tdnnd')[1].split('.')[0]
blocks[1].add(int(layer_num))
elif 'xvector.block2.tdnnd' in name:
layer_num = name.split('tdnnd')[1].split('.')[0]
blocks[2].add(int(layer_num))
elif 'xvector.block3.tdnnd' in name:
layer_num = name.split('tdnnd')[1].split('.')[0]
blocks[3].add(int(layer_num))
# Set block_layers configuration
if any(blocks.values()):
config['block_layers'] = [
len(blocks[1]) if blocks[1] else 4, # Default to 4 if not found
len(blocks[2]) if blocks[2] else 9, # Default to 9 if not found
len(blocks[3]) if blocks[3] else 16 # Default to 16 if not found
]
logger.info(f"Detected block structure: {config['block_layers']}")
# Try to infer input dimension and kernel size from first conv layer
for name, tensor in pytorch_weights.items():
if 'xvector.tdnn.linear.weight' in name:
if tensor.ndim == 3: # Conv1d weight: (out_channels, in_channels, kernel_size)
config['input_dim'] = tensor.shape[1] # in_channels
config['channels'] = tensor.shape[0] # out_channels
config['input_kernel_size'] = tensor.shape[2] # kernel_size
logger.info(f"Detected input layer: dim={config['input_dim']}, channels={config['channels']}, kernel_size={config['input_kernel_size']}")
break
# Try to infer embedding dimension from dense layer
for name, tensor in pytorch_weights.items():
if 'xvector.dense.linear.weight' in name:
if tensor.ndim == 3: # Conv1d with kernel_size=1
config['embedding_dim'] = tensor.shape[0] # out_channels
break
# Count total parameters for estimation
total_params = sum(tensor.numel() for tensor in pytorch_weights.values())
config['total_params'] = total_params
return config
def _map_parameter_names(self, pytorch_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Map PyTorch parameter names to MLX parameter names
Args:
pytorch_weights: PyTorch weights with original names
Returns:
Weights with MLX-compatible parameter names
"""
mapped_weights = {}
for name, tensor in pytorch_weights.items():
# Choose mapping function based on architecture
if self.use_modelscope_architecture:
mlx_name = self._xvector_to_mlx_modelscope_name(name)
else:
mlx_name = self._xvector_to_mlx_name(name)
if mlx_name: # Only keep parameters that have MLX equivalents
mapped_weights[mlx_name] = tensor
return mapped_weights
def _add_missing_parameters(self, mapped_weights: Dict[str, torch.Tensor], model_config: Dict) -> Dict[str, torch.Tensor]:
"""
Add default values for MLX parameters that don't have PyTorch equivalents
Args:
mapped_weights: Already mapped weights
model_config: Model configuration
Returns:
Weights with missing parameters added
Note: This method intentionally does NOT add fake/random parameters.
Adding untrained random weights will degrade model accuracy significantly.
The conversion should only include weights that are actually mapped from
the source model. Better to fail explicitly when a layer is missing than
to add random weights that produce nonsensical outputs.
"""
# Return mapped weights as-is without adding arbitrary fake parameters
return mapped_weights
def get_missing_mlx_parameters(self, pytorch_weights: Dict[str, torch.Tensor], mlx_weights: Dict[str, mx.array]) -> Dict[str, str]:
"""
Get list of MLX parameters that don't have source PyTorch equivalents
Args:
pytorch_weights: Original PyTorch weights
mlx_weights: Converted MLX weights
Returns:
Dictionary mapping MLX parameter names to their source parameter names (or "NOT FOUND")
"""
missing_params = {}
# Define expected MLX model parameters
expected_mlx_params = {
# Input layer
'input_conv.weight', 'input_bn.weight', 'input_bn.bias',
'input_bn.running_mean', 'input_bn.running_var',
# Dense blocks (0-2)
'dense_blocks.0.layers.0.conv.weight', 'dense_blocks.0.layers.0.bn.weight', 'dense_blocks.0.layers.0.bn.bias',
'dense_blocks.0.layers.0.bn.running_mean', 'dense_blocks.0.layers.0.bn.running_var',
'dense_blocks.0.layers.1.conv.weight', 'dense_blocks.0.layers.1.bn.weight', 'dense_blocks.0.layers.1.bn.bias',
'dense_blocks.0.layers.2.conv.weight', 'dense_blocks.0.layers.2.bn.weight', 'dense_blocks.0.layers.2.bn.bias',
'dense_blocks.0.layers.3.conv.weight', 'dense_blocks.0.layers.3.bn.weight', 'dense_blocks.0.layers.3.bn.bias',
# Transitions
'transitions.0.layers.0.weight', 'transitions.0.layers.0.bias',
'transitions.0.layers.0.running_mean', 'transitions.0.layers.0.running_var',
'transitions.0.layers.2.weight',
'transitions.1.layers.0.weight', 'transitions.1.layers.0.bias',
'transitions.1.layers.0.running_mean', 'transitions.1.layers.0.running_var',
'transitions.1.layers.2.weight',
# CAM layer
'cam.context_conv1.weight', 'cam.context_conv1.bias',
'cam.context_conv3.weight', 'cam.context_conv3.bias',
'cam.context_conv5.weight', 'cam.context_conv5.bias',
'cam.mask_conv.weight', 'cam.mask_conv.bias',
'cam.bn.weight', 'cam.bn.bias', 'cam.bn.running_mean', 'cam.bn.running_var',
# Channel gating
'channel_gating.fc.layers.0.weight', 'channel_gating.fc.layers.0.bias',
'channel_gating.fc.layers.1.weight', 'channel_gating.fc.layers.1.bias',
'channel_gating.fc.layers.2.weight', 'channel_gating.fc.layers.2.bias',
# Pooling
'pooling.attention_weights.weight', 'pooling.attention_weights.bias',
'pooling.projection.weight', 'pooling.projection.bias',
# Final layer
'final_bn.weight', 'final_bn.bias', 'final_bn.running_mean', 'final_bn.running_var',
}
# Check which expected parameters are missing from converted weights
for param in expected_mlx_params:
if param not in mlx_weights:
missing_params[param] = "NOT FOUND"
return missing_params
def _xvector_to_mlx_modelscope_name(self, xvector_name: str) -> Optional[str]:
"""
Convert xvector parameter name to MLX ModelScope architecture parameter name
This mapping is for ModelScope CAM++ models where CAM is embedded in each TDNN layer.
Architecture:
- Input layer (TDNN)
- Block 1: 4 TDNN layers with embedded CAM
- Transit 1
- Block 2: 9 TDNN layers with embedded CAM
- Transit 2
- Block 3: 16 TDNN layers with embedded CAM
- Dense layer (Conv1d kernel_size=1)
Args:
xvector_name: Original xvector parameter name from PyTorch model
Returns:
MLX-compatible parameter name, or None if parameter should be skipped
"""
# ========== INPUT LAYER ==========
if xvector_name == 'xvector.tdnn.linear.weight':
return 'input_conv.weight'
elif 'xvector.tdnn.nonlinear.batchnorm' in xvector_name:
param_type = xvector_name.split('.')[-1] # bias, weight, running_mean, running_var
# Skip num_batches_tracked (PyTorch tracking statistic, not needed)
if param_type == 'num_batches_tracked':
return None
return f'input_bn.{param_type}'
# ========== DENSE BLOCKS WITH EMBEDDED CAM ==========
# Extract block number and layer number
import re
block_match = re.match(r'xvector\.block(\d+)\.tdnnd(\d+)\.(.*)', xvector_name)
if block_match:
block_num = int(block_match.group(1)) # 1, 2, or 3
layer_num = int(block_match.group(2)) # 1-indexed
param_path = block_match.group(3)
# Map to MLX block index (0, 1, 2)
mlx_block_idx = block_num - 1
# Map to MLX layer index (0-indexed)
mlx_layer_idx = layer_num - 1
# Main TDNN layer parameters
if param_path.startswith('linear1.'):
param_type = param_path.split('.')[-1]
return f'block{mlx_block_idx}_{mlx_layer_idx}.conv.{param_type}'
# PyTorch has TWO batch norms per layer:
# - nonlinear1.batchnorm: sized for INPUT channels (applied before conv)
# - nonlinear2.batchnorm: sized for OUTPUT channels (applied after conv)
# MLX model only has one BN (after conv), so map nonlinear2 to bn
elif param_path.startswith('nonlinear1.batchnorm.'):
# Skip nonlinear1 batch norm - it's sized for input channels
return None
elif param_path.startswith('nonlinear2.batchnorm.'):
param_type = param_path.split('.')[-1]
# Skip num_batches_tracked
if param_type == 'num_batches_tracked':
return None
return f'block{mlx_block_idx}_{mlx_layer_idx}.bn.{param_type}'
# Embedded CAM layer parameters
elif param_path.startswith('cam_layer.linear1.'):
param_type = param_path.split('.')[-1]
return f'block{mlx_block_idx}_{mlx_layer_idx}.cam.linear1.{param_type}'
elif param_path.startswith('cam_layer.linear2.'):
param_type = param_path.split('.')[-1]
return f'block{mlx_block_idx}_{mlx_layer_idx}.cam.linear2.{param_type}'
elif param_path.startswith('cam_layer.linear_local.'):
param_type = param_path.split('.')[-1]
return f'block{mlx_block_idx}_{mlx_layer_idx}.cam.linear_local.{param_type}'
# ========== TRANSITION LAYERS ==========
if 'xvector.transit1.' in xvector_name:
if '.linear.weight' in xvector_name:
return 'transit1.conv.weight'
elif 'nonlinear.batchnorm' in xvector_name:
param_type = xvector_name.split('.')[-1]
# Skip num_batches_tracked
if param_type == 'num_batches_tracked':
return None
return f'transit1.bn.{param_type}'
if 'xvector.transit2.' in xvector_name:
if '.linear.weight' in xvector_name:
return 'transit2.conv.weight'
elif 'nonlinear.batchnorm' in xvector_name:
param_type = xvector_name.split('.')[-1]
# Skip num_batches_tracked
if param_type == 'num_batches_tracked':
return None
return f'transit2.bn.{param_type}'
# ========== DENSE LAYER ==========
if 'xvector.dense.linear.' in xvector_name:
param_type = xvector_name.split('.')[-1]
return f'dense.{param_type}'
# ========== SKIP UNMAPPED PARAMETERS ==========
# These don't exist in ModelScope architecture
if any(x in xvector_name for x in ['head.', 'output.', 'pool', 'final_bn']):
logger.debug(f"Skipping parameter not in ModelScope architecture: {xvector_name}")
return None
# Log unexpected parameters
if xvector_name.startswith('xvector.'):
logger.debug(f"Skipping unmapped parameter: {xvector_name}")
return None
def _xvector_to_mlx_name(self, xvector_name: str) -> Optional[str]:
"""
Convert xvector parameter name to MLX parameter name with comprehensive mapping
This method maps PyTorch CAM++ xvector parameters to MLX CAMPPModel parameters.
It handles:
- Input layer (TDNN)
- Dense blocks (3 blocks with 4, 6, 8 layers respectively)
- Transition layers between blocks
- Context-Aware Masking (CAM) layer
- Channel gating mechanism
- Multi-granularity pooling
- Final batch normalization
Args:
xvector_name: Original xvector parameter name from PyTorch model
Returns:
MLX-compatible parameter name, or None if parameter should be skipped
"""
# ========== INPUT LAYER MAPPING ==========
if xvector_name == 'xvector.tdnn.linear.weight':
return 'input_conv.weight'
elif xvector_name == 'xvector.tdnn.nonlinear.batchnorm.bias':
return 'input_bn.bias'
elif xvector_name == 'xvector.tdnn.nonlinear.batchnorm.weight':
return 'input_bn.weight'
elif xvector_name == 'xvector.tdnn.nonlinear.batchnorm.running_mean':
return 'input_bn.running_mean'
elif xvector_name == 'xvector.tdnn.nonlinear.batchnorm.running_var':
return 'input_bn.running_var'
# ========== DENSE BLOCKS MAPPING ==========
# MLX architecture: block 0 (4 layers), block 1 (6 layers), block 2 (8 layers)
# Map PyTorch block1/block2/block3 to MLX dense_blocks.0/1/2
# Block 0: Map first 4 layers of PyTorch block1
for i in range(1, 13): # Handle up to 12 layers (generous for real models)
# Block 0 - first 4 layers
if i <= 4 and f'xvector.block1.tdnnd{i}.' in xvector_name:
layer_idx = i - 1
if '.linear1.weight' in xvector_name:
return f'dense_blocks.0.layers.{layer_idx}.conv.weight'
elif '.nonlinear1.batchnorm.bias' in xvector_name:
return f'dense_blocks.0.layers.{layer_idx}.bn.bias'
elif '.nonlinear1.batchnorm.weight' in xvector_name:
return f'dense_blocks.0.layers.{layer_idx}.bn.weight'
elif '.nonlinear1.batchnorm.running_mean' in xvector_name:
return f'dense_blocks.0.layers.{layer_idx}.bn.running_mean'
elif '.nonlinear1.batchnorm.running_var' in xvector_name:
return f'dense_blocks.0.layers.{layer_idx}.bn.running_var'
# Block 1 - first 6 layers of PyTorch block2
# Skip block2.tdnnd1 and block2.tdnnd2 as they may be used for transition
if i >= 3 and i <= 8 and f'xvector.block2.tdnnd{i}.' in xvector_name:
layer_idx = i - 3 # Map block2.tdnnd3 -> layer 0, etc.
if layer_idx < 6: # Only map first 6 layers
if '.linear1.weight' in xvector_name:
return f'dense_blocks.1.layers.{layer_idx}.conv.weight'
elif '.nonlinear1.batchnorm.bias' in xvector_name:
return f'dense_blocks.1.layers.{layer_idx}.bn.bias'
elif '.nonlinear1.batchnorm.weight' in xvector_name:
return f'dense_blocks.1.layers.{layer_idx}.bn.weight'
elif '.nonlinear1.batchnorm.running_mean' in xvector_name:
return f'dense_blocks.1.layers.{layer_idx}.bn.running_mean'
elif '.nonlinear1.batchnorm.running_var' in xvector_name:
return f'dense_blocks.1.layers.{layer_idx}.bn.running_var'
# Block 2 - first 8 layers of PyTorch block3
if i <= 8 and f'xvector.block3.tdnnd{i}.' in xvector_name:
layer_idx = i - 1
if '.linear1.weight' in xvector_name:
return f'dense_blocks.2.layers.{layer_idx}.conv.weight'
elif '.nonlinear1.batchnorm.bias' in xvector_name:
return f'dense_blocks.2.layers.{layer_idx}.bn.bias'
elif '.nonlinear1.batchnorm.weight' in xvector_name:
return f'dense_blocks.2.layers.{layer_idx}.bn.weight'
elif '.nonlinear1.batchnorm.running_mean' in xvector_name:
return f'dense_blocks.2.layers.{layer_idx}.bn.running_mean'
elif '.nonlinear1.batchnorm.running_var' in xvector_name:
return f'dense_blocks.2.layers.{layer_idx}.bn.running_var'
# ========== TRANSITION LAYERS MAPPING ==========
# Transition 0: After block 0
if 'xvector.transit1.' in xvector_name:
if '.linear.weight' in xvector_name:
return 'transitions.0.layers.2.weight'
elif '.nonlinear.batchnorm.bias' in xvector_name:
return 'transitions.0.layers.0.bias'
elif '.nonlinear.batchnorm.weight' in xvector_name:
return 'transitions.0.layers.0.weight'
elif '.nonlinear.batchnorm.running_mean' in xvector_name:
return 'transitions.0.layers.0.running_mean'
elif '.nonlinear.batchnorm.running_var' in xvector_name:
return 'transitions.0.layers.0.running_var'
# Transition 1: Use block2.tdnnd1 and tdnnd2 (before dense block 1)
if 'xvector.transit2.' in xvector_name or 'xvector.block2.tdnnd1.' in xvector_name:
# Map transit2 or beginning of block2 to transition 1
if '.linear.weight' in xvector_name or 'xvector.block2.tdnnd2.linear1.weight' in xvector_name:
return 'transitions.1.layers.2.weight'
elif '.nonlinear.batchnorm.bias' in xvector_name or 'xvector.block2.tdnnd1.nonlinear1.batchnorm.bias' in xvector_name:
return 'transitions.1.layers.0.bias'
elif '.nonlinear.batchnorm.weight' in xvector_name or 'xvector.block2.tdnnd1.nonlinear1.batchnorm.weight' in xvector_name:
return 'transitions.1.layers.0.weight'
elif '.nonlinear.batchnorm.running_mean' in xvector_name or 'xvector.block2.tdnnd1.nonlinear1.batchnorm.running_mean' in xvector_name:
return 'transitions.1.layers.0.running_mean'
elif '.nonlinear.batchnorm.running_var' in xvector_name or 'xvector.block2.tdnnd1.nonlinear1.batchnorm.running_var' in xvector_name:
return 'transitions.1.layers.0.running_var'
# ========== CAM LAYER MAPPING ==========
# Context-aware masking with multi-scale convolutions
# NOTE: Real ModelScope models have CAM embedded in EACH TDNN layer,
# but MLX model has ONE shared CAM layer. We map only the first occurrence
# from block1.tdnnd1.cam_layer and skip all others.
if 'cam_layer' in xvector_name or 'cam.' in xvector_name:
# Only map CAM from the first block's first layer
# Skip CAM from all other layers to avoid conflicts
is_first_cam = 'block1.tdnnd1.cam_layer' in xvector_name
if not is_first_cam:
logger.debug(f"Skipping embedded CAM layer (only using first occurrence): {xvector_name}")
return None
# Map first CAM layer to MLX shared CAM
# ModelScope structure: linear1 (1x1 conv), linear2 (1x1 conv), linear_local (3x3 conv)
# MLX structure: context_conv1 (1x1), context_conv3 (3x3), context_conv5 (5x5)
if 'cam_layer.linear1.weight' in xvector_name:
return 'cam.context_conv1.weight'
elif 'cam_layer.linear1.bias' in xvector_name:
logger.debug(f"Skipping CAM context_conv1 bias (MLX uses bias=False): {xvector_name}")
return None # MLX context_conv1 has bias=False
elif 'cam_layer.linear2.weight' in xvector_name:
# Map linear2 (1x1) to context_conv3 - note: this is a compromise
# Real model has 1x1 conv here, MLX expects 3x3
logger.warning(f"Mapping 1x1 conv to context_conv3 (shape mismatch possible): {xvector_name}")
return 'cam.context_conv3.weight'
elif 'cam_layer.linear2.bias' in xvector_name:
logger.debug(f"Skipping CAM context_conv3 bias (MLX uses bias=False): {xvector_name}")
return None # MLX context_conv3 has bias=False
elif 'cam_layer.linear_local.weight' in xvector_name:
# Map linear_local (3x3) to mask_conv
return 'cam.mask_conv.weight'
elif 'cam_layer.linear_local.bias' in xvector_name:
# linear_local typically has no bias in ModelScope models
logger.debug(f"Skipping CAM mask_conv bias: {xvector_name}")
return None
# Handle standalone cam. parameters (if model has separate CAM layer)
elif 'context1.weight' in xvector_name or (not 'cam_layer' in xvector_name and 'linear1.weight' in xvector_name):
return 'cam.context_conv1.weight'
elif 'context3.weight' in xvector_name or (not 'cam_layer' in xvector_name and 'linear2.weight' in xvector_name):
return 'cam.context_conv3.weight'
elif 'context5.weight' in xvector_name or (not 'cam_layer' in xvector_name and 'linear3.weight' in xvector_name):
return 'cam.context_conv5.weight'
elif 'mask_conv.weight' in xvector_name:
return 'cam.mask_conv.weight'
elif 'fusion.weight' in xvector_name:
return 'cam.fusion.weight'
# Batch normalization
elif 'batchnorm.weight' in xvector_name:
return 'cam.bn.weight'
elif 'batchnorm.bias' in xvector_name:
return 'cam.bn.bias'
elif 'running_mean' in xvector_name:
return 'cam.bn.running_mean'
elif 'running_var' in xvector_name:
return 'cam.bn.running_var'
# ========== CHANNEL GATING MAPPING ==========
# Channel-wise context gating (squeeze-excitation style)
# NOTE: Real ModelScope models only have xvector.dense.linear (single layer)
# MLX model expects 3-layer FC, but real model has only 1 layer
if 'xvector.dense.' in xvector_name:
if '.linear.weight' in xvector_name or 'xvector.dense.linear.weight' == xvector_name:
# Map to first layer - this is the only dense layer in real model
return 'channel_gating.fc.layers.0.weight'
elif '.linear.bias' in xvector_name or 'xvector.dense.linear.bias' == xvector_name:
# Check if bias exists (some models use Conv1d without bias)
logger.debug(f"Mapping dense bias (may not exist in Conv1d): {xvector_name}")
return 'channel_gating.fc.layers.0.bias'
# The following layers don't exist in real ModelScope models
elif 'linear_mid.weight' in xvector_name:
logger.warning(f"Found linear_mid layer (unexpected in ModelScope model): {xvector_name}")
return 'channel_gating.fc.layers.1.weight'
elif 'linear_mid.bias' in xvector_name:
return 'channel_gating.fc.layers.1.bias'
elif 'linear_out.weight' in xvector_name:
logger.warning(f"Found linear_out layer (unexpected in ModelScope model): {xvector_name}")
return 'channel_gating.fc.layers.2.weight'
elif 'linear_out.bias' in xvector_name:
return 'channel_gating.fc.layers.2.bias'
# ========== POOLING LAYER MAPPING ==========
# Multi-granularity statistical pooling
# NOTE: Real ModelScope models typically DON'T have xvector.output or pooling layers
# These models are feature extractors that end at xvector.dense
if 'xvector.output.' in xvector_name or 'xvector.pool' in xvector_name:
logger.warning(f"Found pooling/output layer (rare in ModelScope models): {xvector_name}")
if 'xvector.output.linear.weight' == xvector_name:
return 'pooling.attention_weights.weight'
elif 'xvector.output.linear.bias' == xvector_name:
return 'pooling.attention_weights.bias'
elif 'pool_output.linear.weight' in xvector_name or 'pooling.linear.weight' in xvector_name:
return 'pooling.projection.weight'
elif 'pool_output.linear.bias' in xvector_name or 'pooling.linear.bias' in xvector_name:
return 'pooling.projection.bias'
# ========== FINAL BATCH NORMALIZATION ==========
if 'xvector.out_nonlinear.batchnorm.' in xvector_name or 'xvector.final_bn.' in xvector_name:
if '.bias' in xvector_name:
return 'final_bn.bias'
elif '.weight' in xvector_name:
return 'final_bn.weight'
elif 'running_mean' in xvector_name:
return 'final_bn.running_mean'
elif 'running_var' in xvector_name:
return 'final_bn.running_var'
# ========== SKIP UNMAPPED PARAMETERS ==========
# Log skipped parameters for debugging
if xvector_name.startswith('xvector.'):
logger.debug(f"Skipping unmapped parameter: {xvector_name}")
return None
def _filter_weights(self, pytorch_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Filter out unnecessary parameters that shouldn't be converted to MLX
Args:
pytorch_weights: Original PyTorch weights dict
Returns:
Filtered weights dict
"""
filtered_weights = {}
skipped_params = []
for name, tensor in pytorch_weights.items():
# Skip classification head parameters (not needed for inference)
if name.startswith('head.'):
skipped_params.append(name)
continue
# Keep all other parameters including BatchNorm running statistics
# The mapping function will filter out parameters that don't have MLX equivalents
filtered_weights[name] = tensor
if skipped_params:
print(f"Filtered out {len(skipped_params)} unnecessary parameters: {skipped_params[:5]}{'...' if len(skipped_params) > 5 else ''}")
return filtered_weights
def _convert_tensor(self, name: str, tensor: torch.Tensor) -> Optional[mx.array]:
"""
Convert individual tensor based on layer type and shape
Args:
name: Parameter name
tensor: PyTorch tensor to convert
Returns:
MLX array, or None if parameter should be skipped (e.g., num_batches_tracked)
"""
# Convert to numpy first
numpy_tensor = tensor.detach().cpu().numpy()
# Biases don't need any conversion, just pass through
if name.endswith('.bias'):
return mx.array(numpy_tensor)
# Determine layer type from name AND shape
layer_type = self._identify_layer_type(name)
# Override layer type based on actual tensor shape
# This handles cases where Conv1d(kernel_size=1) is used but named like Linear
if numpy_tensor.ndim == 3:
# 3D tensor must be Conv1d, regardless of name
layer_type = 'conv1d'
elif numpy_tensor.ndim == 2 and layer_type == 'conv1d':
# 2D tensor can't be Conv1d, must be Linear
layer_type = 'linear'
elif numpy_tensor.ndim == 1:
# 1D tensor is likely BatchNorm or bias
if 'bn' in name.lower() or 'batchnorm' in name.lower() or 'running' in name.lower():
layer_type = 'batchnorm'
# Apply layer-specific transformations
if layer_type in self.layer_mapping:
numpy_tensor = self.layer_mapping[layer_type](name, numpy_tensor)
# Handle None returns (e.g., num_batches_tracked)
if numpy_tensor is None:
return None
# Convert to MLX array
return mx.array(numpy_tensor)
def _identify_layer_type(self, name: str) -> str:
"""Identify layer type from parameter name"""
name_lower = name.lower()
# BatchNorm check first (more specific)
if 'bn' in name_lower or 'batchnorm' in name_lower or 'batch_norm' in name_lower:
return 'batchnorm'
# Conv1d check (including 'conv' in name)
elif 'conv1d' in name_lower or 'conv' in name_lower:
return 'conv1d'
# Linear/FC check
elif 'linear' in name_lower or 'fc' in name_lower or 'dense' in name_lower:
return 'linear'
# Embedding check
elif 'embed' in name_lower:
return 'embedding'
else:
return 'default'
def _convert_conv1d(self, name: str, weight: np.ndarray) -> np.ndarray:
"""
Convert Conv1d weights from PyTorch to MLX format
PyTorch Conv1d: (out_channels, in_channels, kernel_size)
MLX Conv1d: (out_channels, kernel_size, in_channels) - DIFFERENT format!
Special case: Conv1d with kernel_size=1 can be used as Linear layer
Args:
name: Parameter name (for error reporting)
weight: Weight tensor as numpy array
Returns:
Converted weight tensor
Raises:
ValueError: If weight shape is invalid for Conv1d
"""
# Validate Conv1d weight shape
if weight.ndim != 3:
raise ValueError(f"Conv1d weight {name} must be 3D, got shape {weight.shape}")
out_channels, in_channels, kernel_size = weight.shape
# Validate kernel size is reasonable (1, 3, 5 are common)
if kernel_size > 11:
logger.warning(f"Unusual kernel size {kernel_size} for Conv1d {name}")
# MLX Conv1d uses (out_channels, kernel_size, in_channels) format
# Transpose from PyTorch's (out_channels, in_channels, kernel_size)
# This applies to ALL kernel sizes, including kernel_size=1
mlx_weight = weight.transpose(0, 2, 1)
logger.debug(f"Transposed Conv1d weight {name}: {weight.shape} -> {mlx_weight.shape}")
return mlx_weight
def _convert_linear(self, name: str, weight: np.ndarray) -> np.ndarray:
"""
Convert Linear layer weights
PyTorch Linear: (out_features, in_features)
MLX Linear: (out_features, in_features) - same format
Args:
name: Parameter name (for error reporting)
weight: Weight tensor as numpy array
Returns:
Converted weight tensor
Raises:
ValueError: If weight shape is invalid for Linear
"""
if weight.ndim != 2:
raise ValueError(f"Linear weight {name} must be 2D, got shape {weight.shape}")
return weight # No change needed for linear layers
def _convert_batchnorm(self, name: str, weight: np.ndarray) -> Optional[np.ndarray]:
"""
Convert BatchNorm parameters
Args:
name: Parameter name (for error reporting)
weight: Weight/bias/running_mean/running_var tensor
Returns:
Converted tensor, or None if parameter should be skipped (e.g., num_batches_tracked)
Raises:
ValueError: If tensor shape is invalid for BatchNorm
"""
# Skip num_batches_tracked (it's a scalar tracking statistic, not needed in MLX)
if 'num_batches_tracked' in name:
logger.debug(f"Skipping num_batches_tracked (not needed in MLX): {name}")
return None # Will be filtered out
# BatchNorm parameters should be 1D vectors
if weight.ndim != 1:
raise ValueError(f"BatchNorm parameter {name} must be 1D, got shape {weight.shape}")
# Check for NaN/Inf in running statistics
if 'running_mean' in name or 'running_var' in name:
if np.isnan(weight).any():
logger.warning(f"BatchNorm {name} contains NaN values - may indicate untrained model")
if np.isinf(weight).any():
logger.warning(f"BatchNorm {name} contains Inf values - may indicate numerical instability")
return weight
def _convert_embedding(self, name: str, weight: np.ndarray) -> np.ndarray:
"""
Convert Embedding layer weights
Args:
name: Parameter name (unused but kept for API consistency)
weight: Embedding weight tensor
Returns:
Converted weight tensor
"""
return weight # No change needed for embeddings
def quantize_weights(self, weights: Dict[str, mx.array],
bits: int = 4, group_size: int = 64) -> Dict[str, mx.array]:
"""
Quantize weights to reduce model size using MLX's built-in quantization
Note: This creates a copy of the weights dictionary, so callers don't need to copy before calling.
Args:
weights: MLX weights dictionary
bits: Number of bits for quantization (2, 4, or 8)
group_size: Group size for quantization (32, 64, or 128)
Returns:
Quantized weights dictionary (new copy)
"""
# Create a new dictionary to avoid modifying the original
quantized_weights = {}
skipped_count = 0
quantized_count = 0
logger.info(f"Starting {bits}-bit quantization with group_size={group_size}...")
for name, weight in weights.items():
if self._should_quantize(name, weight):
try:
# MLX quantization requires:
# 1. At least 2D tensors
# 2. Last dimension divisible by group_size
if len(weight.shape) < 2:
logger.debug(f"Skipping {name}: 1D tensor")
quantized_weights[name] = weight
skipped_count += 1
continue
if weight.shape[-1] % group_size != 0:
logger.debug(f"Skipping {name}: last dim {weight.shape[-1]} not divisible by {group_size}")
quantized_weights[name] = weight
skipped_count += 1
continue
# Quantize using MLX's affine quantization
w_q, scales, biases = mx.quantize(weight, group_size=group_size, bits=bits)
# Store quantized weights with special naming for scales and biases
# Format: name:qSCALES_GS64_B4 (scales for group_size=64, bits=4)
# This reduces the number of keys compared to separate metadata arrays
quantized_weights[name] = w_q
quantized_weights[f"{name}:qSCALES_GS{group_size}_B{bits}"] = scales
quantized_weights[f"{name}:qBIASES_GS{group_size}_B{bits}"] = biases
quantized_count += 1
# Log size reduction
original_size = weight.size * 4 # float32 = 4 bytes
# Quantized size = packed weights + scales + biases
quantized_size = w_q.nbytes + scales.nbytes + biases.nbytes
reduction = (1 - quantized_size / original_size) * 100
logger.debug(f"Quantized {name}: {reduction:.1f}% size reduction ({original_size//1024}KB β {quantized_size//1024}KB)")
except Exception as e:
# If quantization fails for this weight, keep original
logger.warning(f"Failed to quantize {name}: {e}, keeping original")
quantized_weights[name] = weight
skipped_count += 1
else:
# Keep small weights in full precision
quantized_weights[name] = weight
skipped_count += 1
logger.info(f"Quantization complete: {quantized_count} weights quantized, {skipped_count} kept in full precision")
return quantized_weights
def _quantize_to_int8(self, weight: mx.array) -> mx.array:
"""
Quantize a weight tensor to 8-bit precision
Args:
weight: Weight tensor to quantize
Returns:
Quantized weight tensor
"""
# Simple symmetric quantization to int8 range
# Find scale factor
abs_max = mx.max(mx.abs(weight))
scale = abs_max / 127.0
if scale == 0:
return weight
# Quantize and dequantize
quantized = mx.round(weight / scale)
quantized = mx.clip(quantized, -127, 127)
dequantized = quantized * scale
return dequantized.astype(mx.float32)
def _should_quantize(self, name: str, weight: mx.array) -> bool:
"""Determine if a weight should be quantized"""
# Don't quantize very small tensors or bias terms
if weight.size < MIN_QUANTIZATION_SIZE:
return False
# Don't quantize bias terms
if 'bias' in name.lower():
return False
# Don't quantize batchnorm parameters (weight, bias, running_mean, running_var)
if any(bn_key in name.lower() for bn_key in ['bn', 'batchnorm', 'batch_norm', 'running_mean', 'running_var']):
return False
# Quantize large weight matrices (Conv, Linear)
return True
def verify_conversion(self, pytorch_weights: Dict[str, torch.Tensor],
mlx_weights: Dict[str, mx.array]) -> Dict[str, bool]:
"""
Verify that conversion was successful by comparing shapes and values
Args:
pytorch_weights: Original PyTorch weights
mlx_weights: Converted MLX weights
Returns:
Dictionary of verification results
"""
results = {}
for name in pytorch_weights.keys():
if name in mlx_weights:
pytorch_tensor = pytorch_weights[name]
mlx_array = mlx_weights[name]
# Compare basic properties
pytorch_shape = pytorch_tensor.shape
mlx_shape = mlx_array.shape
# All layers should have matching shapes (no transpose)
results[name] = pytorch_shape == mlx_shape
# Additional verification: check if values are reasonable
if results[name]:
pytorch_values = pytorch_tensor.detach().cpu().numpy()
mlx_values = np.array(mlx_array)
# Check if the values are approximately equal
value_check = np.allclose(pytorch_values, mlx_values, rtol=1e-5, atol=1e-6)
results[name] = results[name] and value_check
else:
results[name] = False
return results
def check_conversion_status(self, pytorch_weights: Dict[str, torch.Tensor],
mlx_weights: Dict[str, mx.array],
verification_results: Dict[str, bool]) -> Dict[str, Any]:
"""
Check comprehensive status of conversion to ensure it's safe to deploy
Args:
pytorch_weights: Original PyTorch weights
mlx_weights: Converted MLX weights
verification_results: Results from verify_conversion
Returns:
Status dictionary with detailed report
"""
status = {
'is_perfect': False,
'total_source_weights': len(pytorch_weights),
'total_converted_weights': len(mlx_weights),
'verification_passed': sum(1 for v in verification_results.values() if v),
'verification_failed': sum(1 for v in verification_results.values() if not v),
'verification_rate': 0.0,
'errors': [],
'warnings': [],
'safe_to_deploy': False,
}
# Calculate verification rate
total_verified = len(verification_results)
if total_verified > 0:
status['verification_rate'] = (status['verification_passed'] / total_verified) * 100
# Check for critical issues
if len(mlx_weights) == 0:
status['errors'].append("No weights were converted - conversion failed completely")
if status['verification_failed'] > 0:
failed_weights = [name for name, result in verification_results.items() if not result]
status['errors'].append(
f"{status['verification_failed']} weight(s) failed verification: {failed_weights[:3]}"
f"{'...' if len(failed_weights) > 3 else ''}"
)
if len(mlx_weights) < len(pytorch_weights) * 0.5:
status['warnings'].append(
f"Only {len(mlx_weights)}/{len(pytorch_weights)} weights were converted "
f"({(len(mlx_weights)/len(pytorch_weights)*100):.1f}%) - possible mapping issues"
)
# Check data type consistency
dtype_set = set()
for weight in mlx_weights.values():
dtype_set.add(str(weight.dtype))
if len(dtype_set) > 1:
status['warnings'].append(f"Mixed data types detected in converted weights: {dtype_set}")
# Check for NaN or Inf values
nan_inf_weights = []
for name, weight in mlx_weights.items():
weight_np = np.array(weight)
if np.isnan(weight_np).any():
nan_inf_weights.append(f"{name} (NaN)")
elif np.isinf(weight_np).any():
nan_inf_weights.append(f"{name} (Inf)")
if nan_inf_weights:
status['errors'].append(f"Weights contain NaN/Inf: {nan_inf_weights[:3]}")
# Determine if safe to deploy
status['is_perfect'] = (
len(status['errors']) == 0 and
status['verification_rate'] == 100.0 and
len(mlx_weights) > 0
)
# Conservative approach: only deploy if perfect
status['safe_to_deploy'] = status['is_perfect']
if not status['safe_to_deploy'] and len(status['errors']) == 0:
status['safe_to_deploy'] = (
status['verification_rate'] >= MIN_VERIFICATION_RATE and
status['verification_failed'] <= MAX_VERIFICATION_FAILURES and
len(nan_inf_weights) == 0
)
return status
def print_status_report(self, status: Dict[str, Any]) -> None:
"""Print a formatted status report"""
print("\n" + "="*70)
print("CONVERSION STATUS REPORT")
print("="*70)
print(f"\nπ Conversion Statistics:")
print(f" Total source weights: {status['total_source_weights']}")
print(f" Total converted weights: {status['total_converted_weights']}")
print(f" Verification passed: {status['verification_passed']}/{status['verification_passed'] + status['verification_failed']}")
print(f" Verification rate: {status['verification_rate']:.1f}%")
if status['errors']:
print(f"\nβ Errors ({len(status['errors'])}):")
for error in status['errors']:
print(f" β’ {error}")
if status['warnings']:
print(f"\nβ οΈ Warnings ({len(status['warnings'])}):")
for warning in status['warnings']:
print(f" β’ {warning}")
print(f"\nπ Deployment Decision:")
if status['is_perfect']:
print(f" Status: β
PERFECT - All checks passed")
else:
print(f" Status: {'β
ACCEPTABLE' if status['safe_to_deploy'] else 'β NOT SAFE'}")
print(f" Safe to deploy: {'β
YES' if status['safe_to_deploy'] else 'β NO'}")
print("\n" + "="*70)
def create_model_metadata(self, original_repo: str, config: Dict[str, Any]) -> Dict[str, Any]:
"""Create metadata for the converted model"""
return {
"converted_from": original_repo,
"conversion_date": self.get_current_date(),
"framework": "mlx",
"model_type": "campp",
"architecture": "d-tdnn",
"license": "apache-2.0",
"tags": ["speaker-recognition", "audio", "mlx", "apple-silicon"],
"task": "speaker-verification",
"library_name": "mlx",
"datasets": ["voxceleb", "cnceleb"],
"metrics": {
"voxceleb1_eer": "0.65%",
"parameters": "7.2M",
"inference_speed": "optimized_for_apple_silicon"
},
**config
}
def get_current_date(self) -> str:
"""Get current date in ISO format"""
return datetime.now().isoformat()
def estimate_model_performance(self, weights: Dict[str, mx.array]) -> Dict[str, Any]:
"""Estimate model performance characteristics"""
total_params = sum(w.size for w in weights.values())
# Estimate memory usage (rough approximation)
total_bytes = total_params * 4 # Assuming fp32
memory_mb = total_bytes / (1024 * 1024)
# Estimate model complexity
conv_layers = sum(1 for name in weights.keys() if 'conv' in name.lower())
linear_layers = sum(1 for name in weights.keys() if any(x in name.lower() for x in ['linear', 'fc']))
return {
"total_parameters": total_params,
"estimated_memory_mb": memory_mb,
"conv_layers": conv_layers,
"linear_layers": linear_layers,
"model_complexity": "efficient" if total_params < 10e6 else "standard"
}
def optimize_for_inference(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Apply MLX-specific optimizations for inference"""
optimized_weights = {}
for name, weight in weights.items():
# Ensure weights are in optimal format for MLX
optimized_weight = mx.array(weight)
# MLX-specific optimizations could go here
# For now, just ensure proper data type
if optimized_weight.dtype != mx.float32:
optimized_weight = optimized_weight.astype(mx.float32)
optimized_weights[name] = optimized_weight
return optimized_weights
def test_conversion():
"""Test the conversion utilities with comprehensive status checking"""
utils = ConversionUtils()
# Create dummy PyTorch xvector weights (proper source format)
dummy_weights = {
# Input layer
'xvector.tdnn.linear.weight': torch.randn(64, 80, 3),
'xvector.tdnn.nonlinear.batchnorm.weight': torch.randn(64),
'xvector.tdnn.nonlinear.batchnorm.bias': torch.randn(64),
'xvector.tdnn.nonlinear.batchnorm.running_mean': torch.randn(64),
'xvector.tdnn.nonlinear.batchnorm.running_var': torch.randn(64),
# Dense block 0
'xvector.block1.tdnnd1.linear1.weight': torch.randn(32, 64, 3),
'xvector.block1.tdnnd1.nonlinear1.batchnorm.weight': torch.randn(32),
'xvector.block1.tdnnd1.nonlinear1.batchnorm.bias': torch.randn(32),
# Transition layer
'xvector.transit1.linear.weight': torch.randn(256, 96, 1),
'xvector.transit1.nonlinear.batchnorm.weight': torch.randn(256),
'xvector.transit1.nonlinear.batchnorm.bias': torch.randn(256),
# Final layer
'xvector.out_nonlinear.batchnorm.weight': torch.randn(512),
'xvector.out_nonlinear.batchnorm.bias': torch.randn(512),
}
# Convert
mlx_weights, config = utils.convert_weights_to_mlx(dummy_weights)
# Verify conversion
verification = {}
print("Conversion test results:")
# Get the mapping for each source weight
for name, tensor in dummy_weights.items():
mlx_name = utils._xvector_to_mlx_name(name)
if mlx_name and mlx_name in mlx_weights:
pytorch_shape = tensor.shape
mlx_shape = mlx_weights[mlx_name].shape
matches = pytorch_shape == mlx_shape
verification[name] = matches
status = "β
" if matches else "β"
print(f" {status} {name} -> {mlx_name} | Shape: {pytorch_shape} -> {mlx_shape}")
else:
verification[name] = False
status = "β"
print(f" {status} {name} (no mapping)")
print(f"\nTotal weights converted: {len(mlx_weights)}")
print(f"Inferred config: {config}")
# Check conversion status
status_report = utils.check_conversion_status(dummy_weights, mlx_weights, verification)
utils.print_status_report(status_report)
# Only return success if status is perfect and tests pass
tests_passed = all(verification.values())
return tests_passed and status_report['is_perfect']
if __name__ == "__main__":
test_passed = test_conversion()
print(f"\n{'β
' if test_passed else 'β'} Conversion utilities test {'passed' if test_passed else 'failed'}")
|