dots.tts-soar-mlx / scripts /convert_speaker.py
smcleod's picture
Upload folder using huggingface_hub
39057fb verified
#!/usr/bin/env python3
"""Convert dots.tts-soar CAM++ speaker encoder weights to an MLX-friendly safetensors.
Source : speaker_encoder.safetensors (PyTorch state dict, key prefix `model.`)
Output : speaker_encoder_mlx.safetensors
Transforms applied:
- Strip the leading `model.` prefix from every CAM++ key (the torch wrapper
nests the CAMPPlus under `self.model`). The torchaudio resample buffer
`resample.kernel` is dropped: it is NOT used at inference for the soar model
because the encoder is built with sample_rate == vocoder.sample_rate and the
real path resamples 48k -> 16k via this buffer. We document the resample in
the spec and the Swift port performs resampling itself (see spec); the kernel
is a torchaudio sinc design we reproduce independently, so it is not exported.
- Conv1d weight (out, in/groups, k) -> MLX Conv1d (out, k, in/groups): permute (0,2,1).
- Conv2d weight (out, in, kH, kW) -> MLX Conv2d (out, kH, kW, in): permute (0,2,3,1).
- BatchNorm/Linear/LayerNorm 1D and 2D tensors are copied unchanged.
- `num_batches_tracked` int64 scalars are dropped (not needed for inference).
Run on CPU only. No MPS / MLX import here.
"""
from __future__ import annotations
import collections
from pathlib import Path
import torch
from safetensors import safe_open
from safetensors.torch import save_file
SNAPSHOT = Path(
"/Users/samm/.cache/huggingface/hub/models--rednote-hilab--dots.tts-soar/"
"snapshots/1fd9452e55c2c9f38fe1a8ee09eaf7448c222d35"
)
SRC = SNAPSHOT / "speaker_encoder.safetensors"
DST = Path("/Users/samm/git/dots-mlx-spike/speaker_encoder_mlx.safetensors")
def is_conv1d_weight(key: str, tensor: torch.Tensor) -> bool:
return key.endswith(".weight") and tensor.ndim == 3
def is_conv2d_weight(key: str, tensor: torch.Tensor) -> bool:
return key.endswith(".weight") and tensor.ndim == 4
def main() -> None:
f = safe_open(str(SRC), "pt")
out: dict[str, torch.Tensor] = {}
stats = collections.Counter()
for key in f.keys():
t = f.get_tensor(key).contiguous()
# Drop torchaudio resample kernel (reproduced independently in Swift).
if key.startswith("resample."):
stats["dropped_resample"] += 1
continue
# Drop BN step counters (inference uses running stats only).
if key.endswith("num_batches_tracked"):
stats["dropped_num_batches_tracked"] += 1
continue
# Strip the `model.` wrapper prefix.
clean = key[len("model."):] if key.startswith("model.") else key
if is_conv2d_weight(clean, t):
# (out, in, kH, kW) -> (out, kH, kW, in)
t = t.permute(0, 2, 3, 1).contiguous()
stats["conv2d_transposed"] += 1
elif is_conv1d_weight(clean, t):
# (out, in/groups, k) -> (out, k, in/groups)
t = t.permute(0, 2, 1).contiguous()
stats["conv1d_transposed"] += 1
else:
stats["copied"] += 1
out[clean] = t.to(torch.float32)
DST.parent.mkdir(parents=True, exist_ok=True)
save_file(out, str(DST))
total_params = sum(v.numel() for v in out.values())
print(f"wrote {DST}")
print(f"output tensors: {len(out)}")
print(f"output params : {total_params}")
print("stats:", dict(stats))
# spot-check a couple of transposed shapes
for k in ("head.conv1.weight", "xvector.tdnn.linear.weight",
"xvector.dense.linear.weight"):
if k in out:
print(f" {k}: {tuple(out[k].shape)}")
if __name__ == "__main__":
main()