File size: 53,428 Bytes
368806f
 
 
656e7f6
368806f
 
656e7f6
 
 
 
 
 
 
 
 
 
 
368806f
 
 
656e7f6
 
 
 
 
 
 
 
 
 
368806f
 
 
 
 
 
656e7f6
368806f
 
 
 
 
 
 
 
 
 
 
 
 
ca0ebee
 
 
 
 
 
 
 
 
368806f
ca0ebee
368806f
656e7f6
 
 
 
368806f
 
 
656e7f6
368806f
 
77877c8
 
 
656e7f6
77877c8
 
656e7f6
77877c8
 
 
 
 
656e7f6
77877c8
656e7f6
77877c8
656e7f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77877c8
656e7f6
77877c8
 
656e7f6
 
 
77877c8
656e7f6
 
77877c8
656e7f6
 
 
77877c8
656e7f6
77877c8
 
 
656e7f6
77877c8
 
ca0ebee
 
 
656e7f6
ca0ebee
 
656e7f6
ca0ebee
 
 
 
656e7f6
ca0ebee
656e7f6
 
 
 
 
 
ca0ebee
 
656e7f6
ca0ebee
 
 
 
 
 
 
 
 
 
 
 
 
77877c8
 
 
 
 
 
 
ca0ebee
 
656e7f6
ca0ebee
656e7f6
ca0ebee
 
656e7f6
 
ca0ebee
 
656e7f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca0ebee
656e7f6
 
77877c8
 
656e7f6
77877c8
 
 
 
 
 
 
656e7f6
 
 
 
 
 
 
 
 
ca0ebee
656e7f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca0ebee
656e7f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca0ebee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656e7f6
 
 
 
 
 
 
 
 
 
 
368806f
 
656e7f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368806f
 
 
656e7f6
 
 
 
 
368806f
 
 
 
 
 
656e7f6
 
 
 
 
 
368806f
656e7f6
 
368806f
656e7f6
368806f
 
 
 
 
 
 
 
656e7f6
368806f
656e7f6
 
 
 
 
 
 
 
 
 
 
 
368806f
656e7f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368806f
 
 
656e7f6
368806f
 
656e7f6
 
 
 
 
 
 
 
 
 
368806f
656e7f6
 
 
368806f
656e7f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368806f
656e7f6
 
 
 
 
 
 
 
 
 
368806f
 
656e7f6
368806f
 
656e7f6
 
 
 
368806f
 
656e7f6
 
 
368806f
656e7f6
368806f
656e7f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368806f
656e7f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368806f
 
 
656e7f6
368806f
656e7f6
368806f
656e7f6
368806f
 
 
656e7f6
 
 
368806f
656e7f6
 
368806f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77877c8
 
368806f
 
 
 
 
 
77877c8
 
368806f
 
 
 
 
 
17bd838
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656e7f6
17bd838
 
656e7f6
17bd838
 
656e7f6
 
17bd838
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368806f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17bd838
368806f
 
77877c8
368806f
77877c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368806f
 
 
 
 
17bd838
77877c8
368806f
 
77877c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368806f
17bd838
 
 
 
 
 
 
368806f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
import mlx.core as mx
import numpy as np
import torch
from typing import Dict, Any, Tuple, Optional, List, Set
from datetime import datetime
import re
import logging

# Set up logging
logger = logging.getLogger(__name__)

# Constants for conversion thresholds
MIN_QUANTIZATION_SIZE = 1000  # Don't quantize tensors smaller than this
MIN_VERIFICATION_RATE = 95.0  # Minimum acceptable verification rate (%)
MAX_VERIFICATION_FAILURES = 2  # Maximum allowed verification failures
BATCHNORM_EPS = 1e-5
BATCHNORM_MOMENTUM = 0.1

class ConversionUtils:
    """Utilities for converting PyTorch CAM++ models to MLX format"""

    def __init__(self, use_modelscope_architecture: bool = True):
        """
        Initialize conversion utilities

        Args:
            use_modelscope_architecture: If True, use ModelScope architecture with embedded CAM
                                        If False, use original architecture with shared CAM
        """
        self.use_modelscope_architecture = use_modelscope_architecture
        self.layer_mapping = {
            'conv1d': self._convert_conv1d,
            'linear': self._convert_linear,
            'batchnorm': self._convert_batchnorm,
            'embedding': self._convert_embedding
        }

    def convert_weights_to_mlx(self, pytorch_weights: Dict[str, torch.Tensor]) -> Tuple[Dict[str, mx.array], Dict[str, Any]]:
        """
        Convert PyTorch weights to MLX format
        
        Args:
            pytorch_weights: Dictionary of PyTorch tensors
            
        Returns:
            Tuple of (mlx_weights, model_config)
        """
        mlx_weights = {}
        model_config = self._analyze_model_structure(pytorch_weights)
        
        # Filter out unnecessary parameters (BatchNorm running stats, etc.)
        filtered_weights = self._filter_weights(pytorch_weights)
        
        # Map parameter names from PyTorch to MLX format
        mapped_weights = self._map_parameter_names(filtered_weights)
        
        # Add default values for missing MLX parameters
        mapped_weights = self._add_missing_parameters(mapped_weights, model_config)
        
        # Convert each weight tensor
        for name, tensor in mapped_weights.items():
            if isinstance(tensor, torch.Tensor):
                converted = self._convert_tensor(name, tensor)
                # Skip None values (e.g., num_batches_tracked)
                if converted is not None:
                    mlx_weights[name] = converted
            else:
                # Handle non-tensor values (e.g., integers, strings)
                continue

        return mlx_weights, model_config
    
    def _analyze_model_structure(self, pytorch_weights: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        """
        Analyze the PyTorch model structure to infer configuration

        Args:
            pytorch_weights: PyTorch weights dictionary

        Returns:
            Model configuration dictionary
        """
        config = {
            'input_dim': 80,  # Default mel spectrogram features
            'embedding_dim': 192,  # Default embedding dimension for ModelScope
            'channels': 512,  # Default number of channels
            'cam_channels': 128,  # Default CAM channels
        }

        # Detect block structure for ModelScope architecture
        if self.use_modelscope_architecture:
            blocks = {1: set(), 2: set(), 3: set()}

            for name in pytorch_weights.keys():
                if 'xvector.block1.tdnnd' in name:
                    layer_num = name.split('tdnnd')[1].split('.')[0]
                    blocks[1].add(int(layer_num))
                elif 'xvector.block2.tdnnd' in name:
                    layer_num = name.split('tdnnd')[1].split('.')[0]
                    blocks[2].add(int(layer_num))
                elif 'xvector.block3.tdnnd' in name:
                    layer_num = name.split('tdnnd')[1].split('.')[0]
                    blocks[3].add(int(layer_num))

            # Set block_layers configuration
            if any(blocks.values()):
                config['block_layers'] = [
                    len(blocks[1]) if blocks[1] else 4,  # Default to 4 if not found
                    len(blocks[2]) if blocks[2] else 9,  # Default to 9 if not found
                    len(blocks[3]) if blocks[3] else 16  # Default to 16 if not found
                ]
                logger.info(f"Detected block structure: {config['block_layers']}")

        # Try to infer input dimension and kernel size from first conv layer
        for name, tensor in pytorch_weights.items():
            if 'xvector.tdnn.linear.weight' in name:
                if tensor.ndim == 3:  # Conv1d weight: (out_channels, in_channels, kernel_size)
                    config['input_dim'] = tensor.shape[1]  # in_channels
                    config['channels'] = tensor.shape[0]  # out_channels
                    config['input_kernel_size'] = tensor.shape[2]  # kernel_size
                    logger.info(f"Detected input layer: dim={config['input_dim']}, channels={config['channels']}, kernel_size={config['input_kernel_size']}")
                    break

        # Try to infer embedding dimension from dense layer
        for name, tensor in pytorch_weights.items():
            if 'xvector.dense.linear.weight' in name:
                if tensor.ndim == 3:  # Conv1d with kernel_size=1
                    config['embedding_dim'] = tensor.shape[0]  # out_channels
                    break

        # Count total parameters for estimation
        total_params = sum(tensor.numel() for tensor in pytorch_weights.values())
        config['total_params'] = total_params

        return config
    
    def _map_parameter_names(self, pytorch_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Map PyTorch parameter names to MLX parameter names

        Args:
            pytorch_weights: PyTorch weights with original names

        Returns:
            Weights with MLX-compatible parameter names
        """
        mapped_weights = {}

        for name, tensor in pytorch_weights.items():
            # Choose mapping function based on architecture
            if self.use_modelscope_architecture:
                mlx_name = self._xvector_to_mlx_modelscope_name(name)
            else:
                mlx_name = self._xvector_to_mlx_name(name)

            if mlx_name:  # Only keep parameters that have MLX equivalents
                mapped_weights[mlx_name] = tensor

        return mapped_weights
    
    def _add_missing_parameters(self, mapped_weights: Dict[str, torch.Tensor], model_config: Dict) -> Dict[str, torch.Tensor]:
        """
        Add default values for MLX parameters that don't have PyTorch equivalents
        
        Args:
            mapped_weights: Already mapped weights
            model_config: Model configuration
            
        Returns:
            Weights with missing parameters added
        
        Note: This method intentionally does NOT add fake/random parameters.
        Adding untrained random weights will degrade model accuracy significantly.
        The conversion should only include weights that are actually mapped from 
        the source model. Better to fail explicitly when a layer is missing than 
        to add random weights that produce nonsensical outputs.
        """
        # Return mapped weights as-is without adding arbitrary fake parameters
        return mapped_weights
    
    def get_missing_mlx_parameters(self, pytorch_weights: Dict[str, torch.Tensor], mlx_weights: Dict[str, mx.array]) -> Dict[str, str]:
        """
        Get list of MLX parameters that don't have source PyTorch equivalents
        
        Args:
            pytorch_weights: Original PyTorch weights
            mlx_weights: Converted MLX weights
            
        Returns:
            Dictionary mapping MLX parameter names to their source parameter names (or "NOT FOUND")
        """
        missing_params = {}
        
        # Define expected MLX model parameters
        expected_mlx_params = {
            # Input layer
            'input_conv.weight', 'input_bn.weight', 'input_bn.bias', 
            'input_bn.running_mean', 'input_bn.running_var',
            # Dense blocks (0-2)
            'dense_blocks.0.layers.0.conv.weight', 'dense_blocks.0.layers.0.bn.weight', 'dense_blocks.0.layers.0.bn.bias',
            'dense_blocks.0.layers.0.bn.running_mean', 'dense_blocks.0.layers.0.bn.running_var',
            'dense_blocks.0.layers.1.conv.weight', 'dense_blocks.0.layers.1.bn.weight', 'dense_blocks.0.layers.1.bn.bias',
            'dense_blocks.0.layers.2.conv.weight', 'dense_blocks.0.layers.2.bn.weight', 'dense_blocks.0.layers.2.bn.bias',
            'dense_blocks.0.layers.3.conv.weight', 'dense_blocks.0.layers.3.bn.weight', 'dense_blocks.0.layers.3.bn.bias',
            # Transitions
            'transitions.0.layers.0.weight', 'transitions.0.layers.0.bias',
            'transitions.0.layers.0.running_mean', 'transitions.0.layers.0.running_var',
            'transitions.0.layers.2.weight',
            'transitions.1.layers.0.weight', 'transitions.1.layers.0.bias',
            'transitions.1.layers.0.running_mean', 'transitions.1.layers.0.running_var',
            'transitions.1.layers.2.weight',
            # CAM layer
            'cam.context_conv1.weight', 'cam.context_conv1.bias',
            'cam.context_conv3.weight', 'cam.context_conv3.bias',
            'cam.context_conv5.weight', 'cam.context_conv5.bias',
            'cam.mask_conv.weight', 'cam.mask_conv.bias',
            'cam.bn.weight', 'cam.bn.bias', 'cam.bn.running_mean', 'cam.bn.running_var',
            # Channel gating
            'channel_gating.fc.layers.0.weight', 'channel_gating.fc.layers.0.bias',
            'channel_gating.fc.layers.1.weight', 'channel_gating.fc.layers.1.bias',
            'channel_gating.fc.layers.2.weight', 'channel_gating.fc.layers.2.bias',
            # Pooling
            'pooling.attention_weights.weight', 'pooling.attention_weights.bias',
            'pooling.projection.weight', 'pooling.projection.bias',
            # Final layer
            'final_bn.weight', 'final_bn.bias', 'final_bn.running_mean', 'final_bn.running_var',
        }
        
        # Check which expected parameters are missing from converted weights
        for param in expected_mlx_params:
            if param not in mlx_weights:
                missing_params[param] = "NOT FOUND"
        
        return missing_params
    
    def _xvector_to_mlx_modelscope_name(self, xvector_name: str) -> Optional[str]:
        """
        Convert xvector parameter name to MLX ModelScope architecture parameter name

        This mapping is for ModelScope CAM++ models where CAM is embedded in each TDNN layer.

        Architecture:
        - Input layer (TDNN)
        - Block 1: 4 TDNN layers with embedded CAM
        - Transit 1
        - Block 2: 9 TDNN layers with embedded CAM
        - Transit 2
        - Block 3: 16 TDNN layers with embedded CAM
        - Dense layer (Conv1d kernel_size=1)

        Args:
            xvector_name: Original xvector parameter name from PyTorch model

        Returns:
            MLX-compatible parameter name, or None if parameter should be skipped
        """

        # ========== INPUT LAYER ==========
        if xvector_name == 'xvector.tdnn.linear.weight':
            return 'input_conv.weight'
        elif 'xvector.tdnn.nonlinear.batchnorm' in xvector_name:
            param_type = xvector_name.split('.')[-1]  # bias, weight, running_mean, running_var
            # Skip num_batches_tracked (PyTorch tracking statistic, not needed)
            if param_type == 'num_batches_tracked':
                return None
            return f'input_bn.{param_type}'

        # ========== DENSE BLOCKS WITH EMBEDDED CAM ==========
        # Extract block number and layer number
        import re
        block_match = re.match(r'xvector\.block(\d+)\.tdnnd(\d+)\.(.*)', xvector_name)
        if block_match:
            block_num = int(block_match.group(1))  # 1, 2, or 3
            layer_num = int(block_match.group(2))  # 1-indexed
            param_path = block_match.group(3)

            # Map to MLX block index (0, 1, 2)
            mlx_block_idx = block_num - 1
            # Map to MLX layer index (0-indexed)
            mlx_layer_idx = layer_num - 1

            # Main TDNN layer parameters
            if param_path.startswith('linear1.'):
                param_type = param_path.split('.')[-1]
                return f'block{mlx_block_idx}_{mlx_layer_idx}.conv.{param_type}'

            # PyTorch has TWO batch norms per layer:
            # - nonlinear1.batchnorm: sized for INPUT channels (applied before conv)
            # - nonlinear2.batchnorm: sized for OUTPUT channels (applied after conv)
            # MLX model only has one BN (after conv), so map nonlinear2 to bn
            elif param_path.startswith('nonlinear1.batchnorm.'):
                # Skip nonlinear1 batch norm - it's sized for input channels
                return None

            elif param_path.startswith('nonlinear2.batchnorm.'):
                param_type = param_path.split('.')[-1]
                # Skip num_batches_tracked
                if param_type == 'num_batches_tracked':
                    return None
                return f'block{mlx_block_idx}_{mlx_layer_idx}.bn.{param_type}'

            # Embedded CAM layer parameters
            elif param_path.startswith('cam_layer.linear1.'):
                param_type = param_path.split('.')[-1]
                return f'block{mlx_block_idx}_{mlx_layer_idx}.cam.linear1.{param_type}'

            elif param_path.startswith('cam_layer.linear2.'):
                param_type = param_path.split('.')[-1]
                return f'block{mlx_block_idx}_{mlx_layer_idx}.cam.linear2.{param_type}'

            elif param_path.startswith('cam_layer.linear_local.'):
                param_type = param_path.split('.')[-1]
                return f'block{mlx_block_idx}_{mlx_layer_idx}.cam.linear_local.{param_type}'

        # ========== TRANSITION LAYERS ==========
        if 'xvector.transit1.' in xvector_name:
            if '.linear.weight' in xvector_name:
                return 'transit1.conv.weight'
            elif 'nonlinear.batchnorm' in xvector_name:
                param_type = xvector_name.split('.')[-1]
                # Skip num_batches_tracked
                if param_type == 'num_batches_tracked':
                    return None
                return f'transit1.bn.{param_type}'

        if 'xvector.transit2.' in xvector_name:
            if '.linear.weight' in xvector_name:
                return 'transit2.conv.weight'
            elif 'nonlinear.batchnorm' in xvector_name:
                param_type = xvector_name.split('.')[-1]
                # Skip num_batches_tracked
                if param_type == 'num_batches_tracked':
                    return None
                return f'transit2.bn.{param_type}'

        # ========== DENSE LAYER ==========
        if 'xvector.dense.linear.' in xvector_name:
            param_type = xvector_name.split('.')[-1]
            return f'dense.{param_type}'

        # ========== SKIP UNMAPPED PARAMETERS ==========
        # These don't exist in ModelScope architecture
        if any(x in xvector_name for x in ['head.', 'output.', 'pool', 'final_bn']):
            logger.debug(f"Skipping parameter not in ModelScope architecture: {xvector_name}")
            return None

        # Log unexpected parameters
        if xvector_name.startswith('xvector.'):
            logger.debug(f"Skipping unmapped parameter: {xvector_name}")

        return None

    def _xvector_to_mlx_name(self, xvector_name: str) -> Optional[str]:
        """
        Convert xvector parameter name to MLX parameter name with comprehensive mapping

        This method maps PyTorch CAM++ xvector parameters to MLX CAMPPModel parameters.
        It handles:
        - Input layer (TDNN)
        - Dense blocks (3 blocks with 4, 6, 8 layers respectively)
        - Transition layers between blocks
        - Context-Aware Masking (CAM) layer
        - Channel gating mechanism
        - Multi-granularity pooling
        - Final batch normalization

        Args:
            xvector_name: Original xvector parameter name from PyTorch model

        Returns:
            MLX-compatible parameter name, or None if parameter should be skipped
        """

        # ========== INPUT LAYER MAPPING ==========
        if xvector_name == 'xvector.tdnn.linear.weight':
            return 'input_conv.weight'
        elif xvector_name == 'xvector.tdnn.nonlinear.batchnorm.bias':
            return 'input_bn.bias'
        elif xvector_name == 'xvector.tdnn.nonlinear.batchnorm.weight':
            return 'input_bn.weight'
        elif xvector_name == 'xvector.tdnn.nonlinear.batchnorm.running_mean':
            return 'input_bn.running_mean'
        elif xvector_name == 'xvector.tdnn.nonlinear.batchnorm.running_var':
            return 'input_bn.running_var'

        # ========== DENSE BLOCKS MAPPING ==========
        # MLX architecture: block 0 (4 layers), block 1 (6 layers), block 2 (8 layers)
        # Map PyTorch block1/block2/block3 to MLX dense_blocks.0/1/2

        # Block 0: Map first 4 layers of PyTorch block1
        for i in range(1, 13):  # Handle up to 12 layers (generous for real models)
            # Block 0 - first 4 layers
            if i <= 4 and f'xvector.block1.tdnnd{i}.' in xvector_name:
                layer_idx = i - 1
                if '.linear1.weight' in xvector_name:
                    return f'dense_blocks.0.layers.{layer_idx}.conv.weight'
                elif '.nonlinear1.batchnorm.bias' in xvector_name:
                    return f'dense_blocks.0.layers.{layer_idx}.bn.bias'
                elif '.nonlinear1.batchnorm.weight' in xvector_name:
                    return f'dense_blocks.0.layers.{layer_idx}.bn.weight'
                elif '.nonlinear1.batchnorm.running_mean' in xvector_name:
                    return f'dense_blocks.0.layers.{layer_idx}.bn.running_mean'
                elif '.nonlinear1.batchnorm.running_var' in xvector_name:
                    return f'dense_blocks.0.layers.{layer_idx}.bn.running_var'

            # Block 1 - first 6 layers of PyTorch block2
            # Skip block2.tdnnd1 and block2.tdnnd2 as they may be used for transition
            if i >= 3 and i <= 8 and f'xvector.block2.tdnnd{i}.' in xvector_name:
                layer_idx = i - 3  # Map block2.tdnnd3 -> layer 0, etc.
                if layer_idx < 6:  # Only map first 6 layers
                    if '.linear1.weight' in xvector_name:
                        return f'dense_blocks.1.layers.{layer_idx}.conv.weight'
                    elif '.nonlinear1.batchnorm.bias' in xvector_name:
                        return f'dense_blocks.1.layers.{layer_idx}.bn.bias'
                    elif '.nonlinear1.batchnorm.weight' in xvector_name:
                        return f'dense_blocks.1.layers.{layer_idx}.bn.weight'
                    elif '.nonlinear1.batchnorm.running_mean' in xvector_name:
                        return f'dense_blocks.1.layers.{layer_idx}.bn.running_mean'
                    elif '.nonlinear1.batchnorm.running_var' in xvector_name:
                        return f'dense_blocks.1.layers.{layer_idx}.bn.running_var'

            # Block 2 - first 8 layers of PyTorch block3
            if i <= 8 and f'xvector.block3.tdnnd{i}.' in xvector_name:
                layer_idx = i - 1
                if '.linear1.weight' in xvector_name:
                    return f'dense_blocks.2.layers.{layer_idx}.conv.weight'
                elif '.nonlinear1.batchnorm.bias' in xvector_name:
                    return f'dense_blocks.2.layers.{layer_idx}.bn.bias'
                elif '.nonlinear1.batchnorm.weight' in xvector_name:
                    return f'dense_blocks.2.layers.{layer_idx}.bn.weight'
                elif '.nonlinear1.batchnorm.running_mean' in xvector_name:
                    return f'dense_blocks.2.layers.{layer_idx}.bn.running_mean'
                elif '.nonlinear1.batchnorm.running_var' in xvector_name:
                    return f'dense_blocks.2.layers.{layer_idx}.bn.running_var'

        # ========== TRANSITION LAYERS MAPPING ==========
        # Transition 0: After block 0
        if 'xvector.transit1.' in xvector_name:
            if '.linear.weight' in xvector_name:
                return 'transitions.0.layers.2.weight'
            elif '.nonlinear.batchnorm.bias' in xvector_name:
                return 'transitions.0.layers.0.bias'
            elif '.nonlinear.batchnorm.weight' in xvector_name:
                return 'transitions.0.layers.0.weight'
            elif '.nonlinear.batchnorm.running_mean' in xvector_name:
                return 'transitions.0.layers.0.running_mean'
            elif '.nonlinear.batchnorm.running_var' in xvector_name:
                return 'transitions.0.layers.0.running_var'

        # Transition 1: Use block2.tdnnd1 and tdnnd2 (before dense block 1)
        if 'xvector.transit2.' in xvector_name or 'xvector.block2.tdnnd1.' in xvector_name:
            # Map transit2 or beginning of block2 to transition 1
            if '.linear.weight' in xvector_name or 'xvector.block2.tdnnd2.linear1.weight' in xvector_name:
                return 'transitions.1.layers.2.weight'
            elif '.nonlinear.batchnorm.bias' in xvector_name or 'xvector.block2.tdnnd1.nonlinear1.batchnorm.bias' in xvector_name:
                return 'transitions.1.layers.0.bias'
            elif '.nonlinear.batchnorm.weight' in xvector_name or 'xvector.block2.tdnnd1.nonlinear1.batchnorm.weight' in xvector_name:
                return 'transitions.1.layers.0.weight'
            elif '.nonlinear.batchnorm.running_mean' in xvector_name or 'xvector.block2.tdnnd1.nonlinear1.batchnorm.running_mean' in xvector_name:
                return 'transitions.1.layers.0.running_mean'
            elif '.nonlinear.batchnorm.running_var' in xvector_name or 'xvector.block2.tdnnd1.nonlinear1.batchnorm.running_var' in xvector_name:
                return 'transitions.1.layers.0.running_var'

        # ========== CAM LAYER MAPPING ==========
        # Context-aware masking with multi-scale convolutions
        # NOTE: Real ModelScope models have CAM embedded in EACH TDNN layer,
        # but MLX model has ONE shared CAM layer. We map only the first occurrence
        # from block1.tdnnd1.cam_layer and skip all others.
        if 'cam_layer' in xvector_name or 'cam.' in xvector_name:
            # Only map CAM from the first block's first layer
            # Skip CAM from all other layers to avoid conflicts
            is_first_cam = 'block1.tdnnd1.cam_layer' in xvector_name

            if not is_first_cam:
                logger.debug(f"Skipping embedded CAM layer (only using first occurrence): {xvector_name}")
                return None

            # Map first CAM layer to MLX shared CAM
            # ModelScope structure: linear1 (1x1 conv), linear2 (1x1 conv), linear_local (3x3 conv)
            # MLX structure: context_conv1 (1x1), context_conv3 (3x3), context_conv5 (5x5)
            if 'cam_layer.linear1.weight' in xvector_name:
                return 'cam.context_conv1.weight'
            elif 'cam_layer.linear1.bias' in xvector_name:
                logger.debug(f"Skipping CAM context_conv1 bias (MLX uses bias=False): {xvector_name}")
                return None  # MLX context_conv1 has bias=False
            elif 'cam_layer.linear2.weight' in xvector_name:
                # Map linear2 (1x1) to context_conv3 - note: this is a compromise
                # Real model has 1x1 conv here, MLX expects 3x3
                logger.warning(f"Mapping 1x1 conv to context_conv3 (shape mismatch possible): {xvector_name}")
                return 'cam.context_conv3.weight'
            elif 'cam_layer.linear2.bias' in xvector_name:
                logger.debug(f"Skipping CAM context_conv3 bias (MLX uses bias=False): {xvector_name}")
                return None  # MLX context_conv3 has bias=False
            elif 'cam_layer.linear_local.weight' in xvector_name:
                # Map linear_local (3x3) to mask_conv
                return 'cam.mask_conv.weight'
            elif 'cam_layer.linear_local.bias' in xvector_name:
                # linear_local typically has no bias in ModelScope models
                logger.debug(f"Skipping CAM mask_conv bias: {xvector_name}")
                return None

            # Handle standalone cam. parameters (if model has separate CAM layer)
            elif 'context1.weight' in xvector_name or (not 'cam_layer' in xvector_name and 'linear1.weight' in xvector_name):
                return 'cam.context_conv1.weight'
            elif 'context3.weight' in xvector_name or (not 'cam_layer' in xvector_name and 'linear2.weight' in xvector_name):
                return 'cam.context_conv3.weight'
            elif 'context5.weight' in xvector_name or (not 'cam_layer' in xvector_name and 'linear3.weight' in xvector_name):
                return 'cam.context_conv5.weight'
            elif 'mask_conv.weight' in xvector_name:
                return 'cam.mask_conv.weight'
            elif 'fusion.weight' in xvector_name:
                return 'cam.fusion.weight'
            # Batch normalization
            elif 'batchnorm.weight' in xvector_name:
                return 'cam.bn.weight'
            elif 'batchnorm.bias' in xvector_name:
                return 'cam.bn.bias'
            elif 'running_mean' in xvector_name:
                return 'cam.bn.running_mean'
            elif 'running_var' in xvector_name:
                return 'cam.bn.running_var'

        # ========== CHANNEL GATING MAPPING ==========
        # Channel-wise context gating (squeeze-excitation style)
        # NOTE: Real ModelScope models only have xvector.dense.linear (single layer)
        # MLX model expects 3-layer FC, but real model has only 1 layer
        if 'xvector.dense.' in xvector_name:
            if '.linear.weight' in xvector_name or 'xvector.dense.linear.weight' == xvector_name:
                # Map to first layer - this is the only dense layer in real model
                return 'channel_gating.fc.layers.0.weight'
            elif '.linear.bias' in xvector_name or 'xvector.dense.linear.bias' == xvector_name:
                # Check if bias exists (some models use Conv1d without bias)
                logger.debug(f"Mapping dense bias (may not exist in Conv1d): {xvector_name}")
                return 'channel_gating.fc.layers.0.bias'
            # The following layers don't exist in real ModelScope models
            elif 'linear_mid.weight' in xvector_name:
                logger.warning(f"Found linear_mid layer (unexpected in ModelScope model): {xvector_name}")
                return 'channel_gating.fc.layers.1.weight'
            elif 'linear_mid.bias' in xvector_name:
                return 'channel_gating.fc.layers.1.bias'
            elif 'linear_out.weight' in xvector_name:
                logger.warning(f"Found linear_out layer (unexpected in ModelScope model): {xvector_name}")
                return 'channel_gating.fc.layers.2.weight'
            elif 'linear_out.bias' in xvector_name:
                return 'channel_gating.fc.layers.2.bias'

        # ========== POOLING LAYER MAPPING ==========
        # Multi-granularity statistical pooling
        # NOTE: Real ModelScope models typically DON'T have xvector.output or pooling layers
        # These models are feature extractors that end at xvector.dense
        if 'xvector.output.' in xvector_name or 'xvector.pool' in xvector_name:
            logger.warning(f"Found pooling/output layer (rare in ModelScope models): {xvector_name}")
            if 'xvector.output.linear.weight' == xvector_name:
                return 'pooling.attention_weights.weight'
            elif 'xvector.output.linear.bias' == xvector_name:
                return 'pooling.attention_weights.bias'
            elif 'pool_output.linear.weight' in xvector_name or 'pooling.linear.weight' in xvector_name:
                return 'pooling.projection.weight'
            elif 'pool_output.linear.bias' in xvector_name or 'pooling.linear.bias' in xvector_name:
                return 'pooling.projection.bias'

        # ========== FINAL BATCH NORMALIZATION ==========
        if 'xvector.out_nonlinear.batchnorm.' in xvector_name or 'xvector.final_bn.' in xvector_name:
            if '.bias' in xvector_name:
                return 'final_bn.bias'
            elif '.weight' in xvector_name:
                return 'final_bn.weight'
            elif 'running_mean' in xvector_name:
                return 'final_bn.running_mean'
            elif 'running_var' in xvector_name:
                return 'final_bn.running_var'

        # ========== SKIP UNMAPPED PARAMETERS ==========
        # Log skipped parameters for debugging
        if xvector_name.startswith('xvector.'):
            logger.debug(f"Skipping unmapped parameter: {xvector_name}")

        return None
    
    def _filter_weights(self, pytorch_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Filter out unnecessary parameters that shouldn't be converted to MLX
        
        Args:
            pytorch_weights: Original PyTorch weights dict
            
        Returns:
            Filtered weights dict
        """
        filtered_weights = {}
        skipped_params = []
        
        for name, tensor in pytorch_weights.items():
            # Skip classification head parameters (not needed for inference)
            if name.startswith('head.'):
                skipped_params.append(name)
                continue
            
            # Keep all other parameters including BatchNorm running statistics
            # The mapping function will filter out parameters that don't have MLX equivalents
            filtered_weights[name] = tensor
        
        if skipped_params:
            print(f"Filtered out {len(skipped_params)} unnecessary parameters: {skipped_params[:5]}{'...' if len(skipped_params) > 5 else ''}")
        
        return filtered_weights
    
    def _convert_tensor(self, name: str, tensor: torch.Tensor) -> Optional[mx.array]:
        """
        Convert individual tensor based on layer type and shape

        Args:
            name: Parameter name
            tensor: PyTorch tensor to convert

        Returns:
            MLX array, or None if parameter should be skipped (e.g., num_batches_tracked)
        """
        # Convert to numpy first
        numpy_tensor = tensor.detach().cpu().numpy()

        # Biases don't need any conversion, just pass through
        if name.endswith('.bias'):
            return mx.array(numpy_tensor)

        # Determine layer type from name AND shape
        layer_type = self._identify_layer_type(name)

        # Override layer type based on actual tensor shape
        # This handles cases where Conv1d(kernel_size=1) is used but named like Linear
        if numpy_tensor.ndim == 3:
            # 3D tensor must be Conv1d, regardless of name
            layer_type = 'conv1d'
        elif numpy_tensor.ndim == 2 and layer_type == 'conv1d':
            # 2D tensor can't be Conv1d, must be Linear
            layer_type = 'linear'
        elif numpy_tensor.ndim == 1:
            # 1D tensor is likely BatchNorm or bias
            if 'bn' in name.lower() or 'batchnorm' in name.lower() or 'running' in name.lower():
                layer_type = 'batchnorm'

        # Apply layer-specific transformations
        if layer_type in self.layer_mapping:
            numpy_tensor = self.layer_mapping[layer_type](name, numpy_tensor)

            # Handle None returns (e.g., num_batches_tracked)
            if numpy_tensor is None:
                return None

        # Convert to MLX array
        return mx.array(numpy_tensor)
    
    def _identify_layer_type(self, name: str) -> str:
        """Identify layer type from parameter name"""
        name_lower = name.lower()

        # BatchNorm check first (more specific)
        if 'bn' in name_lower or 'batchnorm' in name_lower or 'batch_norm' in name_lower:
            return 'batchnorm'
        # Conv1d check (including 'conv' in name)
        elif 'conv1d' in name_lower or 'conv' in name_lower:
            return 'conv1d'
        # Linear/FC check
        elif 'linear' in name_lower or 'fc' in name_lower or 'dense' in name_lower:
            return 'linear'
        # Embedding check
        elif 'embed' in name_lower:
            return 'embedding'
        else:
            return 'default'
    
    def _convert_conv1d(self, name: str, weight: np.ndarray) -> np.ndarray:
        """
        Convert Conv1d weights from PyTorch to MLX format

        PyTorch Conv1d: (out_channels, in_channels, kernel_size)
        MLX Conv1d: (out_channels, kernel_size, in_channels) - DIFFERENT format!
        Special case: Conv1d with kernel_size=1 can be used as Linear layer

        Args:
            name: Parameter name (for error reporting)
            weight: Weight tensor as numpy array

        Returns:
            Converted weight tensor

        Raises:
            ValueError: If weight shape is invalid for Conv1d
        """
        # Validate Conv1d weight shape
        if weight.ndim != 3:
            raise ValueError(f"Conv1d weight {name} must be 3D, got shape {weight.shape}")

        out_channels, in_channels, kernel_size = weight.shape

        # Validate kernel size is reasonable (1, 3, 5 are common)
        if kernel_size > 11:
            logger.warning(f"Unusual kernel size {kernel_size} for Conv1d {name}")

        # MLX Conv1d uses (out_channels, kernel_size, in_channels) format
        # Transpose from PyTorch's (out_channels, in_channels, kernel_size)
        # This applies to ALL kernel sizes, including kernel_size=1
        mlx_weight = weight.transpose(0, 2, 1)
        logger.debug(f"Transposed Conv1d weight {name}: {weight.shape} -> {mlx_weight.shape}")
        return mlx_weight

    def _convert_linear(self, name: str, weight: np.ndarray) -> np.ndarray:
        """
        Convert Linear layer weights

        PyTorch Linear: (out_features, in_features)
        MLX Linear: (out_features, in_features) - same format

        Args:
            name: Parameter name (for error reporting)
            weight: Weight tensor as numpy array

        Returns:
            Converted weight tensor

        Raises:
            ValueError: If weight shape is invalid for Linear
        """
        if weight.ndim != 2:
            raise ValueError(f"Linear weight {name} must be 2D, got shape {weight.shape}")

        return weight  # No change needed for linear layers

    def _convert_batchnorm(self, name: str, weight: np.ndarray) -> Optional[np.ndarray]:
        """
        Convert BatchNorm parameters

        Args:
            name: Parameter name (for error reporting)
            weight: Weight/bias/running_mean/running_var tensor

        Returns:
            Converted tensor, or None if parameter should be skipped (e.g., num_batches_tracked)

        Raises:
            ValueError: If tensor shape is invalid for BatchNorm
        """
        # Skip num_batches_tracked (it's a scalar tracking statistic, not needed in MLX)
        if 'num_batches_tracked' in name:
            logger.debug(f"Skipping num_batches_tracked (not needed in MLX): {name}")
            return None  # Will be filtered out

        # BatchNorm parameters should be 1D vectors
        if weight.ndim != 1:
            raise ValueError(f"BatchNorm parameter {name} must be 1D, got shape {weight.shape}")

        # Check for NaN/Inf in running statistics
        if 'running_mean' in name or 'running_var' in name:
            if np.isnan(weight).any():
                logger.warning(f"BatchNorm {name} contains NaN values - may indicate untrained model")
            if np.isinf(weight).any():
                logger.warning(f"BatchNorm {name} contains Inf values - may indicate numerical instability")

        return weight

    def _convert_embedding(self, name: str, weight: np.ndarray) -> np.ndarray:
        """
        Convert Embedding layer weights

        Args:
            name: Parameter name (unused but kept for API consistency)
            weight: Embedding weight tensor

        Returns:
            Converted weight tensor
        """
        return weight  # No change needed for embeddings
    
    def quantize_weights(self, weights: Dict[str, mx.array],
                        bits: int = 4, group_size: int = 64) -> Dict[str, mx.array]:
        """
        Quantize weights to reduce model size using MLX's built-in quantization

        Note: This creates a copy of the weights dictionary, so callers don't need to copy before calling.

        Args:
            weights: MLX weights dictionary
            bits: Number of bits for quantization (2, 4, or 8)
            group_size: Group size for quantization (32, 64, or 128)

        Returns:
            Quantized weights dictionary (new copy)
        """
        # Create a new dictionary to avoid modifying the original
        quantized_weights = {}
        skipped_count = 0
        quantized_count = 0

        logger.info(f"Starting {bits}-bit quantization with group_size={group_size}...")

        for name, weight in weights.items():
            if self._should_quantize(name, weight):
                try:
                    # MLX quantization requires:
                    # 1. At least 2D tensors
                    # 2. Last dimension divisible by group_size
                    if len(weight.shape) < 2:
                        logger.debug(f"Skipping {name}: 1D tensor")
                        quantized_weights[name] = weight
                        skipped_count += 1
                        continue

                    if weight.shape[-1] % group_size != 0:
                        logger.debug(f"Skipping {name}: last dim {weight.shape[-1]} not divisible by {group_size}")
                        quantized_weights[name] = weight
                        skipped_count += 1
                        continue

                    # Quantize using MLX's affine quantization
                    w_q, scales, biases = mx.quantize(weight, group_size=group_size, bits=bits)

                    # Store quantized weights with special naming for scales and biases
                    # Format: name:qSCALES_GS64_B4 (scales for group_size=64, bits=4)
                    # This reduces the number of keys compared to separate metadata arrays
                    quantized_weights[name] = w_q
                    quantized_weights[f"{name}:qSCALES_GS{group_size}_B{bits}"] = scales
                    quantized_weights[f"{name}:qBIASES_GS{group_size}_B{bits}"] = biases
                    quantized_count += 1

                    # Log size reduction
                    original_size = weight.size * 4  # float32 = 4 bytes
                    # Quantized size = packed weights + scales + biases
                    quantized_size = w_q.nbytes + scales.nbytes + biases.nbytes
                    reduction = (1 - quantized_size / original_size) * 100
                    logger.debug(f"Quantized {name}: {reduction:.1f}% size reduction ({original_size//1024}KB β†’ {quantized_size//1024}KB)")

                except Exception as e:
                    # If quantization fails for this weight, keep original
                    logger.warning(f"Failed to quantize {name}: {e}, keeping original")
                    quantized_weights[name] = weight
                    skipped_count += 1
            else:
                # Keep small weights in full precision
                quantized_weights[name] = weight
                skipped_count += 1

        logger.info(f"Quantization complete: {quantized_count} weights quantized, {skipped_count} kept in full precision")
        return quantized_weights

    def _quantize_to_int8(self, weight: mx.array) -> mx.array:
        """
        Quantize a weight tensor to 8-bit precision

        Args:
            weight: Weight tensor to quantize

        Returns:
            Quantized weight tensor
        """
        # Simple symmetric quantization to int8 range
        # Find scale factor
        abs_max = mx.max(mx.abs(weight))
        scale = abs_max / 127.0

        if scale == 0:
            return weight

        # Quantize and dequantize
        quantized = mx.round(weight / scale)
        quantized = mx.clip(quantized, -127, 127)
        dequantized = quantized * scale

        return dequantized.astype(mx.float32)
    
    def _should_quantize(self, name: str, weight: mx.array) -> bool:
        """Determine if a weight should be quantized"""

        # Don't quantize very small tensors or bias terms
        if weight.size < MIN_QUANTIZATION_SIZE:
            return False

        # Don't quantize bias terms
        if 'bias' in name.lower():
            return False

        # Don't quantize batchnorm parameters (weight, bias, running_mean, running_var)
        if any(bn_key in name.lower() for bn_key in ['bn', 'batchnorm', 'batch_norm', 'running_mean', 'running_var']):
            return False

        # Quantize large weight matrices (Conv, Linear)
        return True
    
    def verify_conversion(self, pytorch_weights: Dict[str, torch.Tensor], 
                         mlx_weights: Dict[str, mx.array]) -> Dict[str, bool]:
        """
        Verify that conversion was successful by comparing shapes and values
        
        Args:
            pytorch_weights: Original PyTorch weights
            mlx_weights: Converted MLX weights
            
        Returns:
            Dictionary of verification results
        """
        results = {}
        
        for name in pytorch_weights.keys():
            if name in mlx_weights:
                pytorch_tensor = pytorch_weights[name]
                mlx_array = mlx_weights[name]
                
                # Compare basic properties
                pytorch_shape = pytorch_tensor.shape
                mlx_shape = mlx_array.shape
                
                # All layers should have matching shapes (no transpose)
                results[name] = pytorch_shape == mlx_shape
                
                # Additional verification: check if values are reasonable
                if results[name]:
                    pytorch_values = pytorch_tensor.detach().cpu().numpy()
                    mlx_values = np.array(mlx_array)
                    
                    # Check if the values are approximately equal
                    value_check = np.allclose(pytorch_values, mlx_values, rtol=1e-5, atol=1e-6)
                    results[name] = results[name] and value_check
            else:
                results[name] = False
        
        return results
    
    def check_conversion_status(self, pytorch_weights: Dict[str, torch.Tensor], 
                               mlx_weights: Dict[str, mx.array],
                               verification_results: Dict[str, bool]) -> Dict[str, Any]:
        """
        Check comprehensive status of conversion to ensure it's safe to deploy
        
        Args:
            pytorch_weights: Original PyTorch weights
            mlx_weights: Converted MLX weights
            verification_results: Results from verify_conversion
            
        Returns:
            Status dictionary with detailed report
        """
        status = {
            'is_perfect': False,
            'total_source_weights': len(pytorch_weights),
            'total_converted_weights': len(mlx_weights),
            'verification_passed': sum(1 for v in verification_results.values() if v),
            'verification_failed': sum(1 for v in verification_results.values() if not v),
            'verification_rate': 0.0,
            'errors': [],
            'warnings': [],
            'safe_to_deploy': False,
        }
        
        # Calculate verification rate
        total_verified = len(verification_results)
        if total_verified > 0:
            status['verification_rate'] = (status['verification_passed'] / total_verified) * 100
        
        # Check for critical issues
        if len(mlx_weights) == 0:
            status['errors'].append("No weights were converted - conversion failed completely")
        
        if status['verification_failed'] > 0:
            failed_weights = [name for name, result in verification_results.items() if not result]
            status['errors'].append(
                f"{status['verification_failed']} weight(s) failed verification: {failed_weights[:3]}"
                f"{'...' if len(failed_weights) > 3 else ''}"
            )
        
        if len(mlx_weights) < len(pytorch_weights) * 0.5:
            status['warnings'].append(
                f"Only {len(mlx_weights)}/{len(pytorch_weights)} weights were converted "
                f"({(len(mlx_weights)/len(pytorch_weights)*100):.1f}%) - possible mapping issues"
            )
        
        # Check data type consistency
        dtype_set = set()
        for weight in mlx_weights.values():
            dtype_set.add(str(weight.dtype))
        
        if len(dtype_set) > 1:
            status['warnings'].append(f"Mixed data types detected in converted weights: {dtype_set}")
        
        # Check for NaN or Inf values
        nan_inf_weights = []
        for name, weight in mlx_weights.items():
            weight_np = np.array(weight)
            if np.isnan(weight_np).any():
                nan_inf_weights.append(f"{name} (NaN)")
            elif np.isinf(weight_np).any():
                nan_inf_weights.append(f"{name} (Inf)")
        
        if nan_inf_weights:
            status['errors'].append(f"Weights contain NaN/Inf: {nan_inf_weights[:3]}")
        
        # Determine if safe to deploy
        status['is_perfect'] = (
            len(status['errors']) == 0 and
            status['verification_rate'] == 100.0 and
            len(mlx_weights) > 0
        )

        # Conservative approach: only deploy if perfect
        status['safe_to_deploy'] = status['is_perfect']

        if not status['safe_to_deploy'] and len(status['errors']) == 0:
            status['safe_to_deploy'] = (
                status['verification_rate'] >= MIN_VERIFICATION_RATE and
                status['verification_failed'] <= MAX_VERIFICATION_FAILURES and
                len(nan_inf_weights) == 0
            )
        
        return status
    
    def print_status_report(self, status: Dict[str, Any]) -> None:
        """Print a formatted status report"""
        print("\n" + "="*70)
        print("CONVERSION STATUS REPORT")
        print("="*70)
        
        print(f"\nπŸ“Š Conversion Statistics:")
        print(f"  Total source weights:     {status['total_source_weights']}")
        print(f"  Total converted weights:  {status['total_converted_weights']}")
        print(f"  Verification passed:      {status['verification_passed']}/{status['verification_passed'] + status['verification_failed']}")
        print(f"  Verification rate:        {status['verification_rate']:.1f}%")
        
        if status['errors']:
            print(f"\n❌ Errors ({len(status['errors'])}):")
            for error in status['errors']:
                print(f"  β€’ {error}")
        
        if status['warnings']:
            print(f"\n⚠️  Warnings ({len(status['warnings'])}):")
            for warning in status['warnings']:
                print(f"  β€’ {warning}")
        
        print(f"\nπŸ” Deployment Decision:")
        if status['is_perfect']:
            print(f"  Status: βœ… PERFECT - All checks passed")
        else:
            print(f"  Status: {'βœ… ACCEPTABLE' if status['safe_to_deploy'] else '❌ NOT SAFE'}")
        
        print(f"  Safe to deploy: {'βœ… YES' if status['safe_to_deploy'] else '❌ NO'}")
        print("\n" + "="*70)
    
    def create_model_metadata(self, original_repo: str, config: Dict[str, Any]) -> Dict[str, Any]:
        """Create metadata for the converted model"""
        
        return {
            "converted_from": original_repo,
            "conversion_date": self.get_current_date(),
            "framework": "mlx",
            "model_type": "campp",
            "architecture": "d-tdnn",
            "license": "apache-2.0",
            "tags": ["speaker-recognition", "audio", "mlx", "apple-silicon"],
            "task": "speaker-verification",
            "library_name": "mlx",
            "datasets": ["voxceleb", "cnceleb"],
            "metrics": {
                "voxceleb1_eer": "0.65%",
                "parameters": "7.2M",
                "inference_speed": "optimized_for_apple_silicon"
            },
            **config
        }
    
    def get_current_date(self) -> str:
        """Get current date in ISO format"""
        return datetime.now().isoformat()
    
    def estimate_model_performance(self, weights: Dict[str, mx.array]) -> Dict[str, Any]:
        """Estimate model performance characteristics"""
        
        total_params = sum(w.size for w in weights.values())
        
        # Estimate memory usage (rough approximation)
        total_bytes = total_params * 4  # Assuming fp32
        memory_mb = total_bytes / (1024 * 1024)
        
        # Estimate model complexity
        conv_layers = sum(1 for name in weights.keys() if 'conv' in name.lower())
        linear_layers = sum(1 for name in weights.keys() if any(x in name.lower() for x in ['linear', 'fc']))
        
        return {
            "total_parameters": total_params,
            "estimated_memory_mb": memory_mb,
            "conv_layers": conv_layers,
            "linear_layers": linear_layers,
            "model_complexity": "efficient" if total_params < 10e6 else "standard"
        }
    
    def optimize_for_inference(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
        """Apply MLX-specific optimizations for inference"""
        
        optimized_weights = {}
        
        for name, weight in weights.items():
            # Ensure weights are in optimal format for MLX
            optimized_weight = mx.array(weight)
            
            # MLX-specific optimizations could go here
            # For now, just ensure proper data type
            if optimized_weight.dtype != mx.float32:
                optimized_weight = optimized_weight.astype(mx.float32)
            
            optimized_weights[name] = optimized_weight
        
        return optimized_weights

def test_conversion():
    """Test the conversion utilities with comprehensive status checking"""
    utils = ConversionUtils()
    
    # Create dummy PyTorch xvector weights (proper source format)
    dummy_weights = {
        # Input layer
        'xvector.tdnn.linear.weight': torch.randn(64, 80, 3),
        'xvector.tdnn.nonlinear.batchnorm.weight': torch.randn(64),
        'xvector.tdnn.nonlinear.batchnorm.bias': torch.randn(64),
        'xvector.tdnn.nonlinear.batchnorm.running_mean': torch.randn(64),
        'xvector.tdnn.nonlinear.batchnorm.running_var': torch.randn(64),
        
        # Dense block 0
        'xvector.block1.tdnnd1.linear1.weight': torch.randn(32, 64, 3),
        'xvector.block1.tdnnd1.nonlinear1.batchnorm.weight': torch.randn(32),
        'xvector.block1.tdnnd1.nonlinear1.batchnorm.bias': torch.randn(32),
        
        # Transition layer
        'xvector.transit1.linear.weight': torch.randn(256, 96, 1),
        'xvector.transit1.nonlinear.batchnorm.weight': torch.randn(256),
        'xvector.transit1.nonlinear.batchnorm.bias': torch.randn(256),
        
        # Final layer
        'xvector.out_nonlinear.batchnorm.weight': torch.randn(512),
        'xvector.out_nonlinear.batchnorm.bias': torch.randn(512),
    }
    
    # Convert
    mlx_weights, config = utils.convert_weights_to_mlx(dummy_weights)
    
    # Verify conversion
    verification = {}
    print("Conversion test results:")
    
    # Get the mapping for each source weight
    for name, tensor in dummy_weights.items():
        mlx_name = utils._xvector_to_mlx_name(name)
        if mlx_name and mlx_name in mlx_weights:
            pytorch_shape = tensor.shape
            mlx_shape = mlx_weights[mlx_name].shape
            matches = pytorch_shape == mlx_shape
            verification[name] = matches
            status = "βœ…" if matches else "❌"
            print(f"  {status} {name} -> {mlx_name} | Shape: {pytorch_shape} -> {mlx_shape}")
        else:
            verification[name] = False
            status = "❌"
            print(f"  {status} {name} (no mapping)")
    
    print(f"\nTotal weights converted: {len(mlx_weights)}")
    print(f"Inferred config: {config}")
    
    # Check conversion status
    status_report = utils.check_conversion_status(dummy_weights, mlx_weights, verification)
    utils.print_status_report(status_report)
    
    # Only return success if status is perfect and tests pass
    tests_passed = all(verification.values())
    return tests_passed and status_report['is_perfect']

if __name__ == "__main__":
    test_passed = test_conversion()
    print(f"\n{'βœ…' if test_passed else '❌'} Conversion utilities test {'passed' if test_passed else 'failed'}")