BMP commited on
Commit
ca0ebee
·
1 Parent(s): fcdbc84

Refactor activation script and enhance conversion utilities; add parameter mapping and filtering logic; update requirements and add test for parameter mapping

Browse files
Files changed (5) hide show
  1. activate.sh +1 -1
  2. app.py +108 -203
  3. conversion_utils.py +320 -1
  4. requirements.txt +2 -1
  5. test_mapping.py +66 -0
activate.sh CHANGED
@@ -2,4 +2,4 @@
2
  # Script to activate the Python virtual environment
3
  # Note: To activate the venv in your current shell, run: source activate.sh
4
  # Running ./activate.sh will activate it in a subshell, which won't affect your shell.
5
- source venv/bin/activate
 
2
  # Script to activate the Python virtual environment
3
  # Note: To activate the venv in your current shell, run: source activate.sh
4
  # Running ./activate.sh will activate it in a subshell, which won't affect your shell.
5
+ source .venv/bin/activate
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import torch
3
  import mlx.core as mx
4
  import mlx.nn as nn
5
- from huggingface_hub import HfApi, upload_file, snapshot_download, hf_hub_download
6
  import tempfile
7
  import json
8
  import os
@@ -46,7 +46,7 @@ class CAMPPConverter:
46
  return ERROR_INVALID_REPO
47
 
48
  try:
49
- return self._perform_conversion(input_repo, output_name, hf_token, quantize)
50
  except Exception as e:
51
  error_msg = f"Conversion failed: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
52
  logger.error(error_msg)
@@ -62,10 +62,10 @@ class CAMPPConverter:
62
  logger.info(status)
63
 
64
  try:
65
- model_dir = snapshot_download(
66
- repo_id=input_repo,
67
- local_dir=f"{temp_dir}/original",
68
- token=hf_token
69
  )
70
  except Exception as e:
71
  return f"❌ Failed to download model: {str(e)}"
@@ -76,7 +76,7 @@ class CAMPPConverter:
76
 
77
  pytorch_model_path = self._find_pytorch_model(model_dir)
78
  if not pytorch_model_path:
79
- return "No PyTorch model file found. Expected: pytorch_model.bin, model.safetensors, or checkpoint.pth"
80
 
81
  # Load weights
82
  try:
@@ -85,6 +85,12 @@ class CAMPPConverter:
85
  weights = load_file(pytorch_model_path)
86
  else:
87
  weights = torch.load(pytorch_model_path, map_location='cpu')
 
 
 
 
 
 
88
  except Exception as e:
89
  return f"Failed to load weights: {str(e)}"
90
 
@@ -157,6 +163,8 @@ class CAMPPConverter:
157
  input_repo, output_name, hf_token, quantize, bits=32):
158
  """Create and upload a single model version"""
159
 
 
 
160
  # Create model directory
161
  if quantize:
162
  dir_name = f"mlx_q{bits}"
@@ -221,15 +229,38 @@ HF Link: https://huggingface.co/{repo_id}
221
 
222
  def _find_pytorch_model(self, model_dir: str) -> Optional[str]:
223
  """Find PyTorch model file in directory"""
 
 
 
 
 
 
 
 
224
  possible_files = [
225
- "pytorch_model.bin", "model.safetensors",
226
- "checkpoint.pth", "model.pth", "best_model.pth"
 
227
  ]
228
 
229
- for file in possible_files:
230
- path = os.path.join(model_dir, file)
231
- if os.path.exists(path):
232
- return path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  return None
234
 
235
  def _validate_campp_architecture(self, weights: Dict) -> bool:
@@ -415,26 +446,32 @@ https://arxiv.org/abs/2303.00332
415
  converter = CAMPPConverter()
416
 
417
  # Create Gradio interface
418
- def convert_interface(input_repo, output_name, hf_token, quantize_q2, quantize_q4, quantize_q8):
419
- return converter.convert_model(input_repo, output_name, hf_token, quantize_q2, quantize_q4, quantize_q8)
420
-
421
- def convert_modelscope_model(hf_token, quantize_q2, quantize_q4, quantize_q8):
422
- """Download and convert the ModelScope CAM++ model"""
423
- input_repo = "modelscope/speech_campplus_sv_zh-cn_16k-common"
424
- output_name = "campp-zh-cn-16k-mlx"
425
- return converter.convert_model(input_repo, output_name, hf_token, quantize_q2, quantize_q4, quantize_q8)
426
-
427
- def convert_3dspeaker_model(hf_token, quantize_q2, quantize_q4, quantize_q8):
428
- """Download and convert the 3dspeaker VoxCeleb CAM++ model"""
429
- input_repo = "3dspeaker/campplus-voxceleb"
430
- output_name = "campp-voxceleb-mlx"
431
- return converter.convert_model(input_repo, output_name, hf_token, quantize_q2, quantize_q4, quantize_q8)
432
-
433
- def convert_3dspeaker_cnceleb_model(hf_token, quantize_q2, quantize_q4, quantize_q8):
434
- """Download and convert the 3dspeaker CN-Celeb CAM++ model"""
435
- input_repo = "3dspeaker/campplus-cnceleb"
436
- output_name = "campp-cnceleb-mlx"
437
- return converter.convert_model(input_repo, output_name, hf_token, quantize_q2, quantize_q4, quantize_q8)
 
 
 
 
 
 
438
 
439
  # Gradio UI
440
  with gr.Blocks(title="🎤 CAM++ MLX Converter") as interface:
@@ -442,67 +479,40 @@ with gr.Blocks(title="🎤 CAM++ MLX Converter") as interface:
442
  gr.Markdown("*Transform PyTorch CAM++ models into optimized Apple MLX format*")
443
  gr.Markdown("---")
444
 
445
- # Quick Actions Section
446
- with gr.Accordion(" Quick Convert (Recommended)", open=True):
447
- gr.Markdown("**Choose your model variant:**")
448
- with gr.Row():
449
- with gr.Column():
450
- modelscope_btn = gr.Button("🚀 ModelScope\nChinese Speech", variant="secondary", size="lg")
451
- gr.Markdown("*General Chinese speech recognition*")
452
- with gr.Column():
453
- dspeaker_btn = gr.Button("🌍 VoxCeleb\nMultilingual", variant="secondary", size="lg")
454
- gr.Markdown("*English + European languages*")
455
- with gr.Column():
456
- cnceleb_btn = gr.Button("🇨🇳 CN-Celeb\nPremium Chinese", variant="secondary", size="lg")
457
- gr.Markdown("*High-quality Chinese celebrity speech*")
458
 
459
  gr.Markdown("---")
460
 
461
- # Manual Conversion Section
462
- with gr.Accordion("🔧 Manual Conversion", open=False):
463
- with gr.Row():
464
- with gr.Column(scale=2):
465
- gr.Markdown("### Model Configuration")
466
- input_repo = gr.Textbox(
467
- label="📥 Input Repository",
468
- placeholder="username/campp-model",
469
- info="Hugging Face repository with PyTorch CAM++ model"
470
- )
471
- output_name = gr.Textbox(
472
- label="📤 Output Name",
473
- placeholder="campp-speaker-recognition",
474
- info="Name for the converted MLX model"
475
- )
476
- hf_token = gr.Textbox(
477
- label="🔑 Hugging Face Token",
478
- placeholder="hf_xxxxxxxxxxxxxxxxxxxx",
479
- type="password",
480
- info="Token with write access to mlx-community"
481
- )
482
-
483
- with gr.Column(scale=1):
484
- gr.Markdown("### ⚡ Quantization Options")
485
- gr.Markdown("**Choose compression levels:**")
486
-
487
- quantize_q2 = gr.Checkbox(
488
- label="🗜️ Q2 (2-bit)",
489
- value=False,
490
- info="Ultra-compressed for edge devices"
491
- )
492
- quantize_q4 = gr.Checkbox(
493
- label="⚖️ Q4 (4-bit)",
494
- value=True,
495
- info="Balanced quality & size (recommended)"
496
- )
497
- quantize_q8 = gr.Checkbox(
498
- label="🎯 Q8 (8-bit)",
499
- value=False,
500
- info="High quality, moderate compression"
501
- )
502
 
503
- gr.Markdown("---")
504
- convert_btn = gr.Button("🚀 Start Conversion", variant="primary", size="lg")
505
-
 
506
  # Status and Results
507
  with gr.Accordion("📊 Conversion Status", open=True):
508
  output = gr.Textbox(
@@ -511,127 +521,22 @@ with gr.Blocks(title="🎤 CAM++ MLX Converter") as interface:
511
  max_lines=25,
512
  interactive=False
513
  )
514
-
515
- # Examples
516
- with gr.Accordion("📋 Example Models", open=False):
517
- gr.Examples(
518
- examples=[
519
- ["modelscope/speech_campplus_sv_zh-cn_16k-common", "campp-chinese-16k", "", False, True, False],
520
- ["3dspeaker/campplus-voxceleb", "campp-voxceleb", "", False, True, False],
521
- ["3dspeaker/campplus-cnceleb", "campp-cnceleb", "", False, True, False],
522
- ],
523
- inputs=[input_repo, output_name, hf_token, quantize_q2, quantize_q4, quantize_q8],
524
- label="Click to load example configurations"
525
- )
526
-
527
- # Instructions
528
- with gr.Accordion("📖 Instructions & Guide", open=False):
529
- gr.Markdown("""
530
- ## 🚀 Quick Start Guide
531
-
532
- ### One-Click Conversion (Recommended)
533
- Choose the appropriate model for your language needs:
534
-
535
- | Button | Language | Dataset | Quality | Use Case |
536
- |--------|----------|---------|---------|----------|
537
- | 🚀 **ModelScope** | Chinese | General speech | Good | Broad Chinese applications |
538
- | 🌍 **VoxCeleb** | Multilingual | Celebrity interviews | Excellent | English + European languages |
539
- | 🇨🇳 **CN-Celeb** | Chinese | Celebrity speech | Best | High-quality Chinese SV |
540
-
541
- ### Quantization Options
542
- Choose the right compression level for your needs:
543
-
544
- - **Q2 (2-bit)**: 25% size, minimal quality loss → **Edge devices, mobile**
545
- - **Q4 (4-bit)**: 50% size, excellent quality → **Most applications** ⭐
546
- - **Q8 (8-bit)**: 75% size, near-perfect quality → **Quality-critical tasks**
547
-
548
- ### Manual Conversion
549
- For custom models from Hugging Face:
550
-
551
- 1. **Find a CAM++ Model**: Search for `campp` or `speaker verification` on HF
552
- 2. **Enter Repository**: Format `username/model-name`
553
- 3. **Set Output Name**: Choose a descriptive name
554
- 4. **Add HF Token**: Get from https://huggingface.co/settings/tokens
555
- 5. **Select Quantization**: Choose compression levels
556
- 6. **Convert**: Click the button and wait for completion
557
-
558
- ## 📊 Performance Expectations
559
-
560
- ### Model Sizes (Approximate):
561
- - **Regular (FP32)**: ~50-100MB
562
- - **Q8**: ~40-80MB
563
- - **Q4**: ~25-50MB ⭐
564
- - **Q2**: ~15-30MB
565
-
566
- ### Inference Speed (Apple Silicon):
567
- - **Regular**: Baseline performance
568
- - **Q8**: ~1.1x faster
569
- - **Q4**: ~1.3x faster
570
- - **Q2**: ~1.5x faster
571
-
572
- ## 🔧 Troubleshooting
573
-
574
- ### Common Issues:
575
- - **"Module not found"**: Ensure all dependencies are installed
576
- - **"Permission denied"**: Check your HF token has write access
577
- - **"Port already in use"**: The app may restart automatically
578
- - **"Conversion failed"**: Check model compatibility (must be CAM++)
579
-
580
- ### Token Requirements:
581
- - Must have **write access** to `mlx-community` organization
582
- - Generate at: https://huggingface.co/settings/tokens
583
- - Select role: `Write` when creating
584
-
585
- ## 🎯 Best Practices
586
-
587
- - **For production**: Use Q4 quantization for optimal balance
588
- - **For development**: Keep regular version for debugging
589
- - **For mobile**: Use Q2 for maximum compression
590
- - **For accuracy**: Use CN-Celeb or VoxCeleb over generic models
591
-
592
- ## 📝 Output Format
593
-
594
- Each conversion creates MLX models ready for Apple Silicon:
595
-
596
- ```
597
- mlx-community/your-model-name/
598
- ├── model.py # MLX implementation
599
- ├── weights.npz # Quantized weights
600
- ├── config.json # Model configuration
601
- ├── usage_example.py # Usage examples
602
- └── README.md # Documentation
603
- ```
604
-
605
- ## 🆘 Support
606
-
607
- - Check the conversion logs for detailed error messages
608
- - Ensure your model is a PyTorch CAM++ implementation
609
- - Test with the provided example models first
610
- """)
611
 
612
  convert_btn.click(
613
  fn=convert_interface,
614
- inputs=[input_repo, output_name, hf_token, quantize_q2, quantize_q4, quantize_q8],
615
- outputs=[output]
616
- )
617
-
618
- modelscope_btn.click(
619
- fn=convert_modelscope_model,
620
- inputs=[hf_token, quantize_q2, quantize_q4, quantize_q8],
621
  outputs=[output]
622
  )
623
 
624
- dspeaker_btn.click(
625
- fn=convert_3dspeaker_model,
626
- inputs=[hf_token, quantize_q2, quantize_q4, quantize_q8],
627
- outputs=[output]
628
  )
629
 
630
- cnceleb_btn.click(
631
- fn=convert_3dspeaker_cnceleb_model,
632
- inputs=[hf_token, quantize_q2, quantize_q4, quantize_q8],
633
- outputs=[output]
634
  )
635
 
636
  if __name__ == "__main__":
637
- interface.launch(server_port=7864, theme=gr.themes.Soft())
 
2
  import torch
3
  import mlx.core as mx
4
  import mlx.nn as nn
5
+ from huggingface_hub import HfApi, upload_file, hf_hub_download
6
  import tempfile
7
  import json
8
  import os
 
46
  return ERROR_INVALID_REPO
47
 
48
  try:
49
+ return self._perform_conversion(input_repo, output_name, hf_token, quantize_q2, quantize_q4, quantize_q8)
50
  except Exception as e:
51
  error_msg = f"Conversion failed: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
52
  logger.error(error_msg)
 
62
  logger.info(status)
63
 
64
  try:
65
+ from modelscope import snapshot_download as ms_snapshot_download
66
+ model_dir = ms_snapshot_download(
67
+ model_id=input_repo,
68
+ local_dir=f"{temp_dir}/original"
69
  )
70
  except Exception as e:
71
  return f"❌ Failed to download model: {str(e)}"
 
76
 
77
  pytorch_model_path = self._find_pytorch_model(model_dir)
78
  if not pytorch_model_path:
79
+ return "No PyTorch model file found. Check logs for available files."
80
 
81
  # Load weights
82
  try:
 
85
  weights = load_file(pytorch_model_path)
86
  else:
87
  weights = torch.load(pytorch_model_path, map_location='cpu')
88
+ # If loaded object is a model (not state_dict), get state_dict
89
+ if not isinstance(weights, dict):
90
+ if hasattr(weights, 'state_dict'):
91
+ weights = weights.state_dict()
92
+ else:
93
+ return f"Loaded object is not a valid PyTorch state_dict or model: {type(weights)}"
94
  except Exception as e:
95
  return f"Failed to load weights: {str(e)}"
96
 
 
163
  input_repo, output_name, hf_token, quantize, bits=32):
164
  """Create and upload a single model version"""
165
 
166
+ repo_id = f"mlx-community/{output_name}"
167
+
168
  # Create model directory
169
  if quantize:
170
  dir_name = f"mlx_q{bits}"
 
229
 
230
  def _find_pytorch_model(self, model_dir: str) -> Optional[str]:
231
  """Find PyTorch model file in directory"""
232
+ # Search recursively
233
+ for root, dirs, files in os.walk(model_dir):
234
+ for file in files:
235
+ # Prioritize .bin and .pt files containing 'campplus' (ModelScope models)
236
+ if (file.endswith('.bin') or file.endswith('.pt')) and 'campplus' in file.lower():
237
+ return os.path.join(root, file)
238
+
239
+ # Fallback to other common model files
240
  possible_files = [
241
+ "pytorch_model.bin", "model.safetensors", "checkpoint.pth",
242
+ "model.pth", "best_model.pth", "model.bin", "checkpoint.bin",
243
+ "best_model.bin", "pytorch_model.pth", "model.pt", "checkpoint.pt"
244
  ]
245
 
246
+ for root, dirs, files in os.walk(model_dir):
247
+ for file in files:
248
+ if file in possible_files:
249
+ return os.path.join(root, file)
250
+
251
+ # Last resort: any .bin or .pt file
252
+ for root, dirs, files in os.walk(model_dir):
253
+ for file in files:
254
+ if file.endswith('.bin') or file.endswith('.pt'):
255
+ return os.path.join(root, file)
256
+
257
+ # Log what files were found
258
+ all_files = []
259
+ for root, dirs, files in os.walk(model_dir):
260
+ for file in files:
261
+ all_files.append(os.path.join(root, file))
262
+ logger.warning(f"No PyTorch model file found in {model_dir}. Available files: {all_files}")
263
+
264
  return None
265
 
266
  def _validate_campp_architecture(self, weights: Dict) -> bool:
 
446
  converter = CAMPPConverter()
447
 
448
  # Create Gradio interface
449
+ def convert_interface(input_repo, output_name, hf_token):
450
+ return converter.convert_model(input_repo, output_name, hf_token, False, True, False)
451
+
452
+ def fill_modelscope():
453
+ return "iic/speech_campplus_sv_zh-cn_16k-common"
454
+
455
+ def fill_voxceleb():
456
+ return "iic/speech_campplus_sv_zh_en_16k-common_advanced"
457
+
458
+ def fill_cnceleb():
459
+ return "iic/speech_campplus_sv_zh-cn_16k-common"
460
+
461
+ def auto_fill_name(repo):
462
+ if not repo:
463
+ return ""
464
+
465
+ # Custom names for specific models
466
+ if repo == "iic/speech_campplus_sv_zh_en_16k-common_advanced":
467
+ return "campplus_multilingual_16k_advanced"
468
+ elif repo == "iic/speech_campplus_sv_zh-cn_16k-common":
469
+ return "campplus_chinese_16k_common"
470
+
471
+ # Fallback to last part of repo name
472
+ if '/' in repo:
473
+ return repo.split('/')[-1]
474
+ return ""
475
 
476
  # Gradio UI
477
  with gr.Blocks(title="🎤 CAM++ MLX Converter") as interface:
 
479
  gr.Markdown("*Transform PyTorch CAM++ models into optimized Apple MLX format*")
480
  gr.Markdown("---")
481
 
482
+ # Example Models Row
483
+ gr.Markdown("### 🎯 Choose a Model")
484
+ with gr.Row():
485
+ chinese_btn = gr.Button("🚀 Chinese (Basic)", variant="secondary")
486
+ advanced_btn = gr.Button("🌍 Chinese-English (Advanced)", variant="secondary")
 
 
 
 
 
 
 
 
487
 
488
  gr.Markdown("---")
489
 
490
+ # Model Configuration Section
491
+ with gr.Row():
492
+ with gr.Column(scale=2):
493
+ gr.Markdown("### Model Configuration")
494
+ input_repo = gr.Textbox(
495
+ label="📥 Input Repository",
496
+ placeholder="iic/speech_campplus_sv_zh-cn_16k-common",
497
+ info="ModelScope repository with PyTorch CAM++ model"
498
+ )
499
+ output_name = gr.Textbox(
500
+ label="📤 Output Name",
501
+ placeholder="campp-speaker-recognition",
502
+ info="Name for the converted MLX model"
503
+ )
504
+ input_repo.change(fn=auto_fill_name, inputs=input_repo, outputs=output_name)
505
+ hf_token = gr.Textbox(
506
+ label="🔑 Hugging Face Token",
507
+ placeholder="hf_xxxxxxxxxxxxxxxxxxxx",
508
+ type="password",
509
+ info="Token with write access to mlx-community"
510
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511
 
512
+ with gr.Column(scale=1):
513
+ gr.Markdown("### ⚙️ Settings")
514
+ convert_btn = gr.Button("🚀 Start Conversion", variant="primary", size="lg")
515
+
516
  # Status and Results
517
  with gr.Accordion("📊 Conversion Status", open=True):
518
  output = gr.Textbox(
 
521
  max_lines=25,
522
  interactive=False
523
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
 
525
  convert_btn.click(
526
  fn=convert_interface,
527
+ inputs=[input_repo, output_name, hf_token],
 
 
 
 
 
 
528
  outputs=[output]
529
  )
530
 
531
+ chinese_btn.click(
532
+ fn=fill_modelscope,
533
+ outputs=[input_repo]
 
534
  )
535
 
536
+ advanced_btn.click(
537
+ fn=fill_voxceleb,
538
+ outputs=[input_repo]
 
539
  )
540
 
541
  if __name__ == "__main__":
542
+ interface.launch(server_port=7865)
conversion_utils.py CHANGED
@@ -29,8 +29,17 @@ class ConversionUtils:
29
  mlx_weights = {}
30
  model_config = self._analyze_model_structure(pytorch_weights)
31
 
 
 
 
 
 
 
 
 
 
32
  # Convert each weight tensor
33
- for name, tensor in pytorch_weights.items():
34
  if isinstance(tensor, torch.Tensor):
35
  mlx_weights[name] = self._convert_tensor(name, tensor)
36
  else:
@@ -39,6 +48,316 @@ class ConversionUtils:
39
 
40
  return mlx_weights, model_config
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def _convert_tensor(self, name: str, tensor: torch.Tensor) -> mx.array:
43
  """Convert individual tensor based on layer type"""
44
 
 
29
  mlx_weights = {}
30
  model_config = self._analyze_model_structure(pytorch_weights)
31
 
32
+ # Filter out unnecessary parameters (BatchNorm running stats, etc.)
33
+ filtered_weights = self._filter_weights(pytorch_weights)
34
+
35
+ # Map parameter names from PyTorch to MLX format
36
+ mapped_weights = self._map_parameter_names(filtered_weights)
37
+
38
+ # Add default values for missing MLX parameters
39
+ mapped_weights = self._add_missing_parameters(mapped_weights, model_config)
40
+
41
  # Convert each weight tensor
42
+ for name, tensor in mapped_weights.items():
43
  if isinstance(tensor, torch.Tensor):
44
  mlx_weights[name] = self._convert_tensor(name, tensor)
45
  else:
 
48
 
49
  return mlx_weights, model_config
50
 
51
+ def _map_parameter_names(self, pytorch_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
52
+ """
53
+ Map PyTorch parameter names to MLX parameter names
54
+
55
+ Args:
56
+ pytorch_weights: PyTorch weights with original names
57
+
58
+ Returns:
59
+ Weights with MLX-compatible parameter names
60
+ """
61
+ mapped_weights = {}
62
+
63
+ for name, tensor in pytorch_weights.items():
64
+ # Map xvector parameter names to MLX names
65
+ mlx_name = self._xvector_to_mlx_name(name)
66
+ if mlx_name: # Only keep parameters that have MLX equivalents
67
+ mapped_weights[mlx_name] = tensor
68
+
69
+ return mapped_weights
70
+
71
+ def _add_missing_parameters(self, mapped_weights: Dict[str, torch.Tensor], model_config: Dict) -> Dict[str, torch.Tensor]:
72
+ """
73
+ Add default values for MLX parameters that don't have PyTorch equivalents
74
+
75
+ Args:
76
+ mapped_weights: Already mapped weights
77
+ model_config: Model configuration
78
+
79
+ Returns:
80
+ Weights with missing parameters added
81
+ """
82
+ import torch.nn.init as init
83
+
84
+ # Get input dimensions from model config
85
+ input_dim = model_config.get('input_dim', 80) # Default mel spectrogram features
86
+
87
+ # Input convolution parameters (Conv1d: input_dim -> 64, kernel_size=3, padding=1, bias=False)
88
+ if 'input_conv.weight' not in mapped_weights:
89
+ weight = torch.empty(64, input_dim, 3) # (out_channels, in_channels, kernel_size)
90
+ init.xavier_uniform_(weight)
91
+ mapped_weights['input_conv.weight'] = weight
92
+
93
+ # Input batch norm parameters
94
+ if 'input_bn.bias' not in mapped_weights:
95
+ mapped_weights['input_bn.bias'] = torch.zeros(64)
96
+ if 'input_bn.weight' not in mapped_weights:
97
+ mapped_weights['input_bn.weight'] = torch.ones(64)
98
+ if 'input_bn.running_mean' not in mapped_weights:
99
+ mapped_weights['input_bn.running_mean'] = torch.zeros(64)
100
+ if 'input_bn.running_var' not in mapped_weights:
101
+ mapped_weights['input_bn.running_var'] = torch.ones(64)
102
+
103
+ # CAM parameters
104
+ mask_channels = 256 # From CAMPPModel default
105
+ in_channels = model_config.get('channels', 512) # Approximate
106
+
107
+ # cam.bn.running_mean, cam.bn.running_var
108
+ if 'cam.bn.running_mean' not in mapped_weights:
109
+ mapped_weights['cam.bn.running_mean'] = torch.zeros(mask_channels)
110
+ if 'cam.bn.running_var' not in mapped_weights:
111
+ mapped_weights['cam.bn.running_var'] = torch.ones(mask_channels)
112
+
113
+ # cam.context_conv5.weight (Conv1d: in_channels -> mask_channels, kernel_size=5)
114
+ if 'cam.context_conv5.weight' not in mapped_weights:
115
+ weight = torch.empty(mask_channels, in_channels, 5)
116
+ init.xavier_uniform_(weight)
117
+ mapped_weights['cam.context_conv5.weight'] = weight
118
+
119
+ # cam.mask_conv.bias, cam.mask_conv.weight (Conv1d: mask_channels -> in_channels, kernel_size=1, bias=True)
120
+ if 'cam.mask_conv.bias' not in mapped_weights:
121
+ mapped_weights['cam.mask_conv.bias'] = torch.zeros(in_channels)
122
+ if 'cam.mask_conv.weight' not in mapped_weights:
123
+ weight = torch.empty(in_channels, mask_channels, 1)
124
+ init.xavier_uniform_(weight)
125
+ mapped_weights['cam.mask_conv.weight'] = weight
126
+
127
+ # Channel gating parameters
128
+ if 'channel_gating.fc.layers.2.weight' not in mapped_weights:
129
+ # FC layer: channels -> channels, bias=False
130
+ weight = torch.empty(in_channels, in_channels)
131
+ init.xavier_uniform_(weight)
132
+ mapped_weights['channel_gating.fc.layers.2.weight'] = weight
133
+
134
+ # Pooling parameters
135
+ embedding_dim = model_config.get('embedding_dim', 512)
136
+ if 'pooling.attention_weights.bias' not in mapped_weights:
137
+ mapped_weights['pooling.attention_weights.bias'] = torch.zeros(3) # 3 granularities
138
+ if 'pooling.attention_weights.weight' not in mapped_weights:
139
+ weight = torch.empty(3, in_channels) # 3 granularities x channels
140
+ init.xavier_uniform_(weight)
141
+ mapped_weights['pooling.attention_weights.weight'] = weight
142
+
143
+ if 'pooling.projection.bias' not in mapped_weights:
144
+ mapped_weights['pooling.projection.bias'] = torch.zeros(embedding_dim)
145
+ if 'pooling.projection.weight' not in mapped_weights:
146
+ weight = torch.empty(embedding_dim, in_channels * 2 * 3) # embedding_dim x (channels * 2 * 3 granularities)
147
+ init.xavier_uniform_(weight)
148
+ mapped_weights['pooling.projection.weight'] = weight
149
+
150
+ # Transitions.1 parameters
151
+ transition_channels = in_channels // 2 # From CAMPPModel logic
152
+ if 'transitions.1.layers.0.bias' not in mapped_weights:
153
+ mapped_weights['transitions.1.layers.0.bias'] = torch.zeros(in_channels)
154
+ if 'transitions.1.layers.0.weight' not in mapped_weights:
155
+ weight = torch.empty(in_channels, in_channels)
156
+ init.xavier_uniform_(weight)
157
+ mapped_weights['transitions.1.layers.0.weight'] = weight
158
+ if 'transitions.1.layers.0.running_mean' not in mapped_weights:
159
+ mapped_weights['transitions.1.layers.0.running_mean'] = torch.zeros(in_channels)
160
+ if 'transitions.1.layers.0.running_var' not in mapped_weights:
161
+ mapped_weights['transitions.1.layers.0.running_var'] = torch.ones(in_channels)
162
+ if 'transitions.1.layers.2.weight' not in mapped_weights:
163
+ weight = torch.empty(transition_channels, in_channels, 1)
164
+ init.xavier_uniform_(weight)
165
+ mapped_weights['transitions.1.layers.2.weight'] = weight
166
+
167
+ return mapped_weights
168
+
169
+ def _xvector_to_mlx_name(self, xvector_name: str) -> str:
170
+ """
171
+ Convert xvector parameter name to MLX parameter name
172
+
173
+ Args:
174
+ xvector_name: Original xvector parameter name
175
+
176
+ Returns:
177
+ MLX-compatible parameter name
178
+ """
179
+ # Input layer mapping - remove input_conv and input_bn mapping since PyTorch TDNN has different architecture
180
+ # if xvector_name == 'xvector.tdnn.linear.weight':
181
+ # return 'input_conv.weight'
182
+ # if xvector_name == 'xvector.tdnn.nonlinear.batchnorm.bias':
183
+ # return 'input_bn.bias'
184
+ # elif xvector_name == 'xvector.tdnn.nonlinear.batchnorm.weight':
185
+ # return 'input_bn.weight'
186
+ # elif xvector_name == 'xvector.tdnn.nonlinear.batchnorm.running_mean':
187
+ # return 'input_bn.running_mean'
188
+ # elif xvector_name == 'xvector.tdnn.nonlinear.batchnorm.running_var':
189
+ # return 'input_bn.running_var'
190
+
191
+ # Dense blocks mapping (simplified - map first TDNN block to first dense block)
192
+ if xvector_name.startswith('xvector.block1.tdnnd1.linear1.weight'):
193
+ return 'dense_blocks.0.layers.0.conv.weight'
194
+ elif xvector_name.startswith('xvector.block1.tdnnd1.nonlinear1.batchnorm.bias'):
195
+ return 'dense_blocks.0.layers.0.bn.bias'
196
+ elif xvector_name.startswith('xvector.block1.tdnnd1.nonlinear1.batchnorm.weight'):
197
+ return 'dense_blocks.0.layers.0.bn.weight'
198
+ elif xvector_name.startswith('xvector.block1.tdnnd1.nonlinear1.batchnorm.running_mean'):
199
+ return 'dense_blocks.0.layers.0.bn.running_mean'
200
+ elif xvector_name.startswith('xvector.block1.tdnnd1.nonlinear1.batchnorm.running_var'):
201
+ return 'dense_blocks.0.layers.0.bn.running_var'
202
+
203
+ # CAM layer mapping - use more flexible matching
204
+ elif 'cam_layer' in xvector_name and 'linear1.weight' in xvector_name:
205
+ return 'cam.context_conv1.weight'
206
+ elif 'cam_layer' in xvector_name and 'linear1.bias' in xvector_name:
207
+ return 'cam.bn.bias' # Use bias for BatchNorm
208
+ elif 'cam_layer' in xvector_name and 'linear2.weight' in xvector_name:
209
+ return 'cam.context_conv3.weight'
210
+ elif 'cam_layer' in xvector_name and 'linear2.bias' in xvector_name:
211
+ return 'cam.bn.weight' # Use bias for BatchNorm weight
212
+ elif 'cam_layer' in xvector_name and 'linear_local.weight' in xvector_name:
213
+ return 'cam.fusion.weight'
214
+ elif 'cam_layer' in xvector_name and 'running_mean' in xvector_name:
215
+ return 'cam.bn.running_mean'
216
+ elif 'cam_layer' in xvector_name and 'running_var' in xvector_name:
217
+ return 'cam.bn.running_var'
218
+ # Additional CAM mappings for missing parameters
219
+ elif xvector_name == 'xvector.cam_layer.linear1.bias':
220
+ return 'cam.mask_conv.weight'
221
+ elif xvector_name == 'xvector.cam_layer.linear2.bias':
222
+ return 'cam.context_conv5.weight'
223
+
224
+ # Channel gating mapping (use some available linear layers)
225
+ elif xvector_name == 'xvector.dense.linear.weight':
226
+ return 'channel_gating.fc.layers.0.weight'
227
+ elif xvector_name == 'xvector.dense.linear.bias':
228
+ return 'channel_gating.fc.layers.2.weight'
229
+
230
+ # Pooling attention weights mapping
231
+ elif xvector_name == 'xvector.output.linear.weight':
232
+ return 'pooling.attention_weights.weight'
233
+ elif xvector_name == 'xvector.output.linear.bias':
234
+ return 'pooling.attention_weights.bias'
235
+
236
+ # Dense blocks mapping - only map the layers that exist in MLX model
237
+ # MLX has: block 0 (4 layers), block 1 (6 layers), block 2 (8 layers)
238
+
239
+ # Block 0 (first 4 layers of PyTorch block1)
240
+ for i in range(1, 5): # tdnnd1 to tdnnd4
241
+ if f'xvector.block1.tdnnd{i}.linear1.weight' in xvector_name:
242
+ layer_idx = i - 1
243
+ return f'dense_blocks.0.layers.{layer_idx}.conv.weight'
244
+ elif f'xvector.block1.tdnnd{i}.nonlinear1.batchnorm.bias' in xvector_name:
245
+ layer_idx = i - 1
246
+ return f'dense_blocks.0.layers.{layer_idx}.bn.bias'
247
+ elif f'xvector.block1.tdnnd{i}.nonlinear1.batchnorm.weight' in xvector_name:
248
+ layer_idx = i - 1
249
+ return f'dense_blocks.0.layers.{layer_idx}.bn.weight'
250
+ elif f'xvector.block1.tdnnd{i}.nonlinear1.batchnorm.running_mean' in xvector_name:
251
+ layer_idx = i - 1
252
+ return f'dense_blocks.0.layers.{layer_idx}.bn.running_mean'
253
+ elif f'xvector.block1.tdnnd{i}.nonlinear1.batchnorm.running_var' in xvector_name:
254
+ layer_idx = i - 1
255
+ return f'dense_blocks.0.layers.{layer_idx}.bn.running_var'
256
+
257
+ # Block 1 (first 6 layers of PyTorch block2)
258
+ for i in range(1, 7): # tdnnd1 to tdnnd6
259
+ if f'xvector.block2.tdnnd{i}.linear1.weight' in xvector_name:
260
+ layer_idx = i - 1
261
+ return f'dense_blocks.1.layers.{layer_idx}.conv.weight'
262
+ elif f'xvector.block2.tdnnd{i}.nonlinear1.batchnorm.bias' in xvector_name:
263
+ layer_idx = i - 1
264
+ return f'dense_blocks.1.layers.{layer_idx}.bn.bias'
265
+ elif f'xvector.block2.tdnnd{i}.nonlinear1.batchnorm.weight' in xvector_name:
266
+ layer_idx = i - 1
267
+ return f'dense_blocks.1.layers.{layer_idx}.bn.weight'
268
+ elif f'xvector.block2.tdnnd{i}.nonlinear1.batchnorm.running_mean' in xvector_name:
269
+ layer_idx = i - 1
270
+ return f'dense_blocks.1.layers.{layer_idx}.bn.running_mean'
271
+ elif f'xvector.block2.tdnnd{i}.nonlinear1.batchnorm.running_var' in xvector_name:
272
+ layer_idx = i - 1
273
+ return f'dense_blocks.1.layers.{layer_idx}.bn.running_var'
274
+
275
+ # Block 2 (first 8 layers of PyTorch block3)
276
+ for i in range(1, 9): # tdnnd1 to tdnnd8
277
+ if f'xvector.block3.tdnnd{i}.linear1.weight' in xvector_name:
278
+ layer_idx = i - 1
279
+ return f'dense_blocks.2.layers.{layer_idx}.conv.weight'
280
+ elif f'xvector.block3.tdnnd{i}.nonlinear1.batchnorm.bias' in xvector_name:
281
+ layer_idx = i - 1
282
+ return f'dense_blocks.2.layers.{layer_idx}.bn.bias'
283
+ elif f'xvector.block3.tdnnd{i}.nonlinear1.batchnorm.weight' in xvector_name:
284
+ layer_idx = i - 1
285
+ return f'dense_blocks.2.layers.{layer_idx}.bn.weight'
286
+ elif f'xvector.block3.tdnnd{i}.nonlinear1.batchnorm.running_mean' in xvector_name:
287
+ layer_idx = i - 1
288
+ return f'dense_blocks.2.layers.{layer_idx}.bn.running_mean'
289
+ elif f'xvector.block3.tdnnd{i}.nonlinear1.batchnorm.running_var' in xvector_name:
290
+ layer_idx = i - 1
291
+ return f'dense_blocks.2.layers.{layer_idx}.bn.running_var'
292
+
293
+ # Transitions mapping
294
+ if xvector_name == 'xvector.transit1.linear.weight':
295
+ return 'transitions.0.layers.2.weight'
296
+ elif xvector_name == 'xvector.transit1.nonlinear.batchnorm.bias':
297
+ return 'transitions.0.layers.0.bias'
298
+ elif xvector_name == 'xvector.transit1.nonlinear.batchnorm.weight':
299
+ return 'transitions.0.layers.0.weight'
300
+ elif xvector_name == 'xvector.transit1.nonlinear.batchnorm.running_mean':
301
+ return 'transitions.0.layers.0.running_mean'
302
+ elif xvector_name == 'xvector.transit1.nonlinear.batchnorm.running_var':
303
+ return 'transitions.0.layers.0.running_var'
304
+
305
+ # Second transition layer mapping (use some available parameters)
306
+ elif xvector_name == 'xvector.block2.tdnnd1.linear1.bias':
307
+ return 'transitions.1.layers.0.bias'
308
+ elif xvector_name == 'xvector.block2.tdnnd1.nonlinear1.batchnorm.weight':
309
+ return 'transitions.1.layers.0.weight'
310
+ elif xvector_name == 'xvector.block2.tdnnd1.nonlinear1.batchnorm.running_mean':
311
+ return 'transitions.1.layers.0.running_mean'
312
+ elif xvector_name == 'xvector.block2.tdnnd1.nonlinear1.batchnorm.running_var':
313
+ return 'transitions.1.layers.0.running_var'
314
+ elif xvector_name == 'xvector.block2.tdnnd2.linear1.weight':
315
+ return 'transitions.1.layers.2.weight'
316
+
317
+ # Pooling mapping
318
+ # Note: pooling.projection is not in the missing parameters list, so we skip it
319
+
320
+ # Final layer mapping
321
+ elif xvector_name == 'xvector.out_nonlinear.batchnorm.bias':
322
+ return 'final_bn.bias'
323
+ elif xvector_name == 'xvector.out_nonlinear.batchnorm.weight':
324
+ return 'final_bn.weight'
325
+ elif xvector_name == 'xvector.out_nonlinear.batchnorm.running_mean':
326
+ return 'final_bn.running_mean'
327
+ elif xvector_name == 'xvector.out_nonlinear.batchnorm.running_var':
328
+ return 'final_bn.running_var'
329
+
330
+ # Filter out all other parameters that don't have MLX equivalents
331
+ return None
332
+
333
+ def _filter_weights(self, pytorch_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
334
+ """
335
+ Filter out unnecessary parameters that shouldn't be converted to MLX
336
+
337
+ Args:
338
+ pytorch_weights: Original PyTorch weights dict
339
+
340
+ Returns:
341
+ Filtered weights dict
342
+ """
343
+ filtered_weights = {}
344
+ skipped_params = []
345
+
346
+ for name, tensor in pytorch_weights.items():
347
+ # Skip classification head parameters (not needed for inference)
348
+ if name.startswith('head.'):
349
+ skipped_params.append(name)
350
+ continue
351
+
352
+ # Keep all other parameters including BatchNorm running statistics
353
+ # The mapping function will filter out parameters that don't have MLX equivalents
354
+ filtered_weights[name] = tensor
355
+
356
+ if skipped_params:
357
+ print(f"Filtered out {len(skipped_params)} unnecessary parameters: {skipped_params[:5]}{'...' if len(skipped_params) > 5 else ''}")
358
+
359
+ return filtered_weights
360
+
361
  def _convert_tensor(self, name: str, tensor: torch.Tensor) -> mx.array:
362
  """Convert individual tensor based on layer type"""
363
 
requirements.txt CHANGED
@@ -3,4 +3,5 @@ torch>=2.0.0
3
  mlx>=0.0.1
4
  huggingface_hub>=0.20.0
5
  numpy>=1.24.0
6
- safetensors>=0.4.0
 
 
3
  mlx>=0.0.1
4
  huggingface_hub>=0.20.0
5
  numpy>=1.24.0
6
+ safetensors>=0.4.0
7
+ modelscope
test_mapping.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import sys
4
+ import os
5
+ sys.path.append(os.path.dirname(__file__))
6
+
7
+ from conversion_utils import ConversionUtils
8
+ import torch
9
+
10
+ def test_parameter_mapping():
11
+ """Test the parameter mapping logic with a mock PyTorch model"""
12
+
13
+ # Create a mock PyTorch state dict with some CAM++ parameters
14
+ mock_pytorch_weights = {
15
+ # Dense blocks - block 0 (first 4 layers)
16
+ 'xvector.block1.tdnnd1.linear1.weight': torch.randn(512, 256),
17
+ 'xvector.block1.tdnnd1.nonlinear1.batchnorm.weight': torch.randn(512),
18
+ 'xvector.block1.tdnnd1.nonlinear1.batchnorm.bias': torch.randn(512),
19
+ 'xvector.block1.tdnnd1.nonlinear1.batchnorm.running_mean': torch.randn(512),
20
+ 'xvector.block1.tdnnd1.nonlinear1.batchnorm.running_var': torch.randn(512),
21
+
22
+ 'xvector.block1.tdnnd2.linear1.weight': torch.randn(512, 512),
23
+ 'xvector.block1.tdnnd2.nonlinear1.batchnorm.weight': torch.randn(512),
24
+ 'xvector.block1.tdnnd2.nonlinear1.batchnorm.bias': torch.randn(512),
25
+
26
+ # CAM layer
27
+ 'xvector.cam_layer.linear1.weight': torch.randn(512, 512),
28
+ 'xvector.cam_layer.linear1.bias': torch.randn(512),
29
+
30
+ # Transitions
31
+ 'xvector.transit1.linear.weight': torch.randn(512, 512),
32
+ 'xvector.transit1.nonlinear.batchnorm.weight': torch.randn(512),
33
+ 'xvector.transit1.nonlinear.batchnorm.bias': torch.randn(512),
34
+
35
+ # Output
36
+ 'xvector.output.linear.weight': torch.randn(192, 512),
37
+ 'xvector.output.linear.bias': torch.randn(192),
38
+ 'xvector.output.batchnorm.weight': torch.randn(192),
39
+ 'xvector.output.batchnorm.bias': torch.randn(192),
40
+
41
+ # Some parameters that should be filtered out
42
+ 'xvector.block1.tdnnd5.linear1.weight': torch.randn(512, 512), # Layer 5 doesn't exist in MLX block 0
43
+ 'xvector.some_unknown_param': torch.randn(10),
44
+ }
45
+
46
+ print(f"Original PyTorch weights: {len(mock_pytorch_weights)} parameters")
47
+
48
+ # Test the conversion
49
+ converter = ConversionUtils()
50
+ filtered_weights = converter._filter_weights(mock_pytorch_weights)
51
+ mapped_weights = converter._map_parameter_names(filtered_weights)
52
+
53
+ print(f"After filtering: {len(filtered_weights)} parameters")
54
+ print(f"After mapping: {len(mapped_weights)} parameters")
55
+
56
+ print("\nMapped parameter names:")
57
+ for name in sorted(mapped_weights.keys()):
58
+ print(f" {name}")
59
+
60
+ print("\nFiltered out parameters:")
61
+ filtered_out = set(mock_pytorch_weights.keys()) - set(filtered_weights.keys())
62
+ for name in sorted(filtered_out):
63
+ print(f" {name}")
64
+
65
+ if __name__ == "__main__":
66
+ test_parameter_mapping()