#!/usr/bin/env python3 """Convert dots.tts-soar CAM++ speaker encoder weights to an MLX-friendly safetensors. Source : speaker_encoder.safetensors (PyTorch state dict, key prefix `model.`) Output : speaker_encoder_mlx.safetensors Transforms applied: - Strip the leading `model.` prefix from every CAM++ key (the torch wrapper nests the CAMPPlus under `self.model`). The torchaudio resample buffer `resample.kernel` is dropped: it is NOT used at inference for the soar model because the encoder is built with sample_rate == vocoder.sample_rate and the real path resamples 48k -> 16k via this buffer. We document the resample in the spec and the Swift port performs resampling itself (see spec); the kernel is a torchaudio sinc design we reproduce independently, so it is not exported. - Conv1d weight (out, in/groups, k) -> MLX Conv1d (out, k, in/groups): permute (0,2,1). - Conv2d weight (out, in, kH, kW) -> MLX Conv2d (out, kH, kW, in): permute (0,2,3,1). - BatchNorm/Linear/LayerNorm 1D and 2D tensors are copied unchanged. - `num_batches_tracked` int64 scalars are dropped (not needed for inference). Run on CPU only. No MPS / MLX import here. """ from __future__ import annotations import collections from pathlib import Path import torch from safetensors import safe_open from safetensors.torch import save_file SNAPSHOT = Path( "/Users/samm/.cache/huggingface/hub/models--rednote-hilab--dots.tts-soar/" "snapshots/1fd9452e55c2c9f38fe1a8ee09eaf7448c222d35" ) SRC = SNAPSHOT / "speaker_encoder.safetensors" DST = Path("/Users/samm/git/dots-mlx-spike/speaker_encoder_mlx.safetensors") def is_conv1d_weight(key: str, tensor: torch.Tensor) -> bool: return key.endswith(".weight") and tensor.ndim == 3 def is_conv2d_weight(key: str, tensor: torch.Tensor) -> bool: return key.endswith(".weight") and tensor.ndim == 4 def main() -> None: f = safe_open(str(SRC), "pt") out: dict[str, torch.Tensor] = {} stats = collections.Counter() for key in f.keys(): t = f.get_tensor(key).contiguous() # Drop torchaudio resample kernel (reproduced independently in Swift). if key.startswith("resample."): stats["dropped_resample"] += 1 continue # Drop BN step counters (inference uses running stats only). if key.endswith("num_batches_tracked"): stats["dropped_num_batches_tracked"] += 1 continue # Strip the `model.` wrapper prefix. clean = key[len("model."):] if key.startswith("model.") else key if is_conv2d_weight(clean, t): # (out, in, kH, kW) -> (out, kH, kW, in) t = t.permute(0, 2, 3, 1).contiguous() stats["conv2d_transposed"] += 1 elif is_conv1d_weight(clean, t): # (out, in/groups, k) -> (out, k, in/groups) t = t.permute(0, 2, 1).contiguous() stats["conv1d_transposed"] += 1 else: stats["copied"] += 1 out[clean] = t.to(torch.float32) DST.parent.mkdir(parents=True, exist_ok=True) save_file(out, str(DST)) total_params = sum(v.numel() for v in out.values()) print(f"wrote {DST}") print(f"output tensors: {len(out)}") print(f"output params : {total_params}") print("stats:", dict(stats)) # spot-check a couple of transposed shapes for k in ("head.conv1.weight", "xvector.tdnn.linear.weight", "xvector.dense.linear.weight"): if k in out: print(f" {k}: {tuple(out[k].shape)}") if __name__ == "__main__": main()