Text-to-Speech
MLX
Safetensors
English
dots_tts
tts
quantized
4-bit precision
8-bit precision
apple-silicon
dots.tts
Instructions to use smcleod/dots.tts-soar-mlx with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use smcleod/dots.tts-soar-mlx with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir dots.tts-soar-mlx smcleod/dots.tts-soar-mlx
- Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- LM Studio
| #!/usr/bin/env python3 | |
| """Convert the dots.tts-soar AudioVAE decoder weights to an mlx-swift layout. | |
| Reads vocoder.safetensors (full AudioVAE), selects only the decoder-side | |
| tensors (post_proj, dec_mi_layer, decoder.*), applies the PyTorch -> MLX | |
| conv weight permutations, and writes vocoder_decoder_mlx.safetensors. | |
| Run on CPU only. torch + numpy + safetensors, no MLX, no MPS. | |
| Decoder-side weights in the checkpoint are ALREADY weight-norm folded | |
| (plain `weight` tensors, no weight_g / weight_v on the decoder side), so no | |
| fusion is needed here. The 80 weight_v / weight_g pairs in the file all live | |
| on the encoder side (audio_encoder.*) which we drop. | |
| Layout conversion (the only transform applied): | |
| PyTorch Conv1d weight (out, in/groups, k) -> MLX (out, k, in/groups) permute(0,2,1) | |
| PyTorch ConvTranspose1d wt (in, out/groups, k) -> MLX (out, k, in/groups) permute(1,2,0) | |
| Linear and LSTM weights, biases, snake alpha/beta, and anti-alias filter | |
| buffers are copied unchanged (the Swift side handles LSTM gate ordering and | |
| filter layout directly). | |
| """ | |
| from __future__ import annotations | |
| import sys | |
| import numpy as np | |
| import torch | |
| from safetensors.numpy import save_file | |
| from safetensors.torch import load_file | |
| SRC = ( | |
| "/Users/samm/.cache/huggingface/hub/models--rednote-hilab--dots.tts-soar/" | |
| "snapshots/1fd9452e55c2c9f38fe1a8ee09eaf7448c222d35/vocoder.safetensors" | |
| ) | |
| DST = "/Users/samm/git/dots-mlx-spike/vocoder_decoder_mlx.safetensors" | |
| # Prefixes that make up the decode path (latent -> waveform). | |
| DECODER_PREFIXES = ("post_proj.", "dec_mi_layer.", "decoder.") | |
| def is_conv1d(key: str) -> bool: | |
| """Conv1d weight: post_proj, decoder.conv_pre, decoder.conv_post, and every | |
| resblock conv (convs1/convs2). These are nn.Conv1d -> (out, in, k).""" | |
| if not key.endswith(".weight"): | |
| return False | |
| if key == "post_proj.weight": | |
| return True | |
| if key in ("decoder.conv_pre.weight", "decoder.conv_post.weight"): | |
| return True | |
| if key.startswith("decoder.resblocks.") and (".convs1." in key or ".convs2." in key): | |
| return True | |
| return False | |
| def is_convtranspose1d(key: str) -> bool: | |
| """ConvTranspose1d weight: the upsamplers decoder.ups.<i>.0.weight. | |
| nn.ConvTranspose1d -> (in, out, k).""" | |
| return ( | |
| key.startswith("decoder.ups.") | |
| and key.endswith(".0.weight") | |
| and key.count(".") == 4 | |
| ) | |
| def main() -> int: | |
| state = load_file(SRC) # dict[str, torch.Tensor] on CPU | |
| out: dict[str, np.ndarray] = {} | |
| n_conv = n_convt = n_copy = 0 | |
| dropped = 0 | |
| for key, tensor in state.items(): | |
| if not key.startswith(DECODER_PREFIXES): | |
| dropped += 1 | |
| continue | |
| t = tensor.detach().to(torch.float32).cpu() | |
| if is_conv1d(key): | |
| # (out, in/groups, k) -> (out, k, in/groups) | |
| assert t.ndim == 3, f"{key} expected 3D conv, got {t.shape}" | |
| t = t.permute(0, 2, 1).contiguous() | |
| n_conv += 1 | |
| elif is_convtranspose1d(key): | |
| # (in, out/groups, k) -> (out, k, in/groups) | |
| assert t.ndim == 3, f"{key} expected 3D convT, got {t.shape}" | |
| t = t.permute(1, 2, 0).contiguous() | |
| n_convt += 1 | |
| else: | |
| n_copy += 1 | |
| out[key] = t.numpy() | |
| save_file(out, DST, metadata={"format": "pt", "source": "dots.tts-soar vocoder decoder"}) | |
| total = sum(int(np.prod(v.shape)) for v in out.values()) | |
| print(f"wrote {DST}") | |
| print(f" tensors kept : {len(out)}") | |
| print(f" conv1d permuted : {n_conv}") | |
| print(f" convtranspose perm : {n_convt}") | |
| print(f" copied unchanged : {n_copy}") | |
| print(f" encoder-side dropped: {dropped}") | |
| print(f" total decoder params: {total} ({total/1e6:.2f}M)") | |
| # spot-check shapes | |
| for k in ("post_proj.weight", "decoder.conv_pre.weight", | |
| "decoder.ups.0.0.weight", "decoder.conv_post.weight"): | |
| print(f" {k}: {tuple(out[k].shape)}") | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |