Text-to-Speech
MLX
Safetensors
English
dots_tts
tts
quantized
4-bit precision
8-bit precision
apple-silicon
dots.tts
Instructions to use smcleod/dots.tts-soar-mlx with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use smcleod/dots.tts-soar-mlx with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir dots.tts-soar-mlx smcleod/dots.tts-soar-mlx
- Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- LM Studio
File size: 4,053 Bytes
39057fb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | #!/usr/bin/env python3
"""Convert the dots.tts-soar AudioVAE decoder weights to an mlx-swift layout.
Reads vocoder.safetensors (full AudioVAE), selects only the decoder-side
tensors (post_proj, dec_mi_layer, decoder.*), applies the PyTorch -> MLX
conv weight permutations, and writes vocoder_decoder_mlx.safetensors.
Run on CPU only. torch + numpy + safetensors, no MLX, no MPS.
Decoder-side weights in the checkpoint are ALREADY weight-norm folded
(plain `weight` tensors, no weight_g / weight_v on the decoder side), so no
fusion is needed here. The 80 weight_v / weight_g pairs in the file all live
on the encoder side (audio_encoder.*) which we drop.
Layout conversion (the only transform applied):
PyTorch Conv1d weight (out, in/groups, k) -> MLX (out, k, in/groups) permute(0,2,1)
PyTorch ConvTranspose1d wt (in, out/groups, k) -> MLX (out, k, in/groups) permute(1,2,0)
Linear and LSTM weights, biases, snake alpha/beta, and anti-alias filter
buffers are copied unchanged (the Swift side handles LSTM gate ordering and
filter layout directly).
"""
from __future__ import annotations
import sys
import numpy as np
import torch
from safetensors.numpy import save_file
from safetensors.torch import load_file
SRC = (
"/Users/samm/.cache/huggingface/hub/models--rednote-hilab--dots.tts-soar/"
"snapshots/1fd9452e55c2c9f38fe1a8ee09eaf7448c222d35/vocoder.safetensors"
)
DST = "/Users/samm/git/dots-mlx-spike/vocoder_decoder_mlx.safetensors"
# Prefixes that make up the decode path (latent -> waveform).
DECODER_PREFIXES = ("post_proj.", "dec_mi_layer.", "decoder.")
def is_conv1d(key: str) -> bool:
"""Conv1d weight: post_proj, decoder.conv_pre, decoder.conv_post, and every
resblock conv (convs1/convs2). These are nn.Conv1d -> (out, in, k)."""
if not key.endswith(".weight"):
return False
if key == "post_proj.weight":
return True
if key in ("decoder.conv_pre.weight", "decoder.conv_post.weight"):
return True
if key.startswith("decoder.resblocks.") and (".convs1." in key or ".convs2." in key):
return True
return False
def is_convtranspose1d(key: str) -> bool:
"""ConvTranspose1d weight: the upsamplers decoder.ups.<i>.0.weight.
nn.ConvTranspose1d -> (in, out, k)."""
return (
key.startswith("decoder.ups.")
and key.endswith(".0.weight")
and key.count(".") == 4
)
def main() -> int:
state = load_file(SRC) # dict[str, torch.Tensor] on CPU
out: dict[str, np.ndarray] = {}
n_conv = n_convt = n_copy = 0
dropped = 0
for key, tensor in state.items():
if not key.startswith(DECODER_PREFIXES):
dropped += 1
continue
t = tensor.detach().to(torch.float32).cpu()
if is_conv1d(key):
# (out, in/groups, k) -> (out, k, in/groups)
assert t.ndim == 3, f"{key} expected 3D conv, got {t.shape}"
t = t.permute(0, 2, 1).contiguous()
n_conv += 1
elif is_convtranspose1d(key):
# (in, out/groups, k) -> (out, k, in/groups)
assert t.ndim == 3, f"{key} expected 3D convT, got {t.shape}"
t = t.permute(1, 2, 0).contiguous()
n_convt += 1
else:
n_copy += 1
out[key] = t.numpy()
save_file(out, DST, metadata={"format": "pt", "source": "dots.tts-soar vocoder decoder"})
total = sum(int(np.prod(v.shape)) for v in out.values())
print(f"wrote {DST}")
print(f" tensors kept : {len(out)}")
print(f" conv1d permuted : {n_conv}")
print(f" convtranspose perm : {n_convt}")
print(f" copied unchanged : {n_copy}")
print(f" encoder-side dropped: {dropped}")
print(f" total decoder params: {total} ({total/1e6:.2f}M)")
# spot-check shapes
for k in ("post_proj.weight", "decoder.conv_pre.weight",
"decoder.ups.0.0.weight", "decoder.conv_post.weight"):
print(f" {k}: {tuple(out[k].shape)}")
return 0
if __name__ == "__main__":
sys.exit(main())
|