#!/usr/bin/env python3 """Convert the dots.tts-soar AudioVAE decoder weights to an mlx-swift layout. Reads vocoder.safetensors (full AudioVAE), selects only the decoder-side tensors (post_proj, dec_mi_layer, decoder.*), applies the PyTorch -> MLX conv weight permutations, and writes vocoder_decoder_mlx.safetensors. Run on CPU only. torch + numpy + safetensors, no MLX, no MPS. Decoder-side weights in the checkpoint are ALREADY weight-norm folded (plain `weight` tensors, no weight_g / weight_v on the decoder side), so no fusion is needed here. The 80 weight_v / weight_g pairs in the file all live on the encoder side (audio_encoder.*) which we drop. Layout conversion (the only transform applied): PyTorch Conv1d weight (out, in/groups, k) -> MLX (out, k, in/groups) permute(0,2,1) PyTorch ConvTranspose1d wt (in, out/groups, k) -> MLX (out, k, in/groups) permute(1,2,0) Linear and LSTM weights, biases, snake alpha/beta, and anti-alias filter buffers are copied unchanged (the Swift side handles LSTM gate ordering and filter layout directly). """ from __future__ import annotations import sys import numpy as np import torch from safetensors.numpy import save_file from safetensors.torch import load_file SRC = ( "/Users/samm/.cache/huggingface/hub/models--rednote-hilab--dots.tts-soar/" "snapshots/1fd9452e55c2c9f38fe1a8ee09eaf7448c222d35/vocoder.safetensors" ) DST = "/Users/samm/git/dots-mlx-spike/vocoder_decoder_mlx.safetensors" # Prefixes that make up the decode path (latent -> waveform). DECODER_PREFIXES = ("post_proj.", "dec_mi_layer.", "decoder.") def is_conv1d(key: str) -> bool: """Conv1d weight: post_proj, decoder.conv_pre, decoder.conv_post, and every resblock conv (convs1/convs2). These are nn.Conv1d -> (out, in, k).""" if not key.endswith(".weight"): return False if key == "post_proj.weight": return True if key in ("decoder.conv_pre.weight", "decoder.conv_post.weight"): return True if key.startswith("decoder.resblocks.") and (".convs1." in key or ".convs2." in key): return True return False def is_convtranspose1d(key: str) -> bool: """ConvTranspose1d weight: the upsamplers decoder.ups..0.weight. nn.ConvTranspose1d -> (in, out, k).""" return ( key.startswith("decoder.ups.") and key.endswith(".0.weight") and key.count(".") == 4 ) def main() -> int: state = load_file(SRC) # dict[str, torch.Tensor] on CPU out: dict[str, np.ndarray] = {} n_conv = n_convt = n_copy = 0 dropped = 0 for key, tensor in state.items(): if not key.startswith(DECODER_PREFIXES): dropped += 1 continue t = tensor.detach().to(torch.float32).cpu() if is_conv1d(key): # (out, in/groups, k) -> (out, k, in/groups) assert t.ndim == 3, f"{key} expected 3D conv, got {t.shape}" t = t.permute(0, 2, 1).contiguous() n_conv += 1 elif is_convtranspose1d(key): # (in, out/groups, k) -> (out, k, in/groups) assert t.ndim == 3, f"{key} expected 3D convT, got {t.shape}" t = t.permute(1, 2, 0).contiguous() n_convt += 1 else: n_copy += 1 out[key] = t.numpy() save_file(out, DST, metadata={"format": "pt", "source": "dots.tts-soar vocoder decoder"}) total = sum(int(np.prod(v.shape)) for v in out.values()) print(f"wrote {DST}") print(f" tensors kept : {len(out)}") print(f" conv1d permuted : {n_conv}") print(f" convtranspose perm : {n_convt}") print(f" copied unchanged : {n_copy}") print(f" encoder-side dropped: {dropped}") print(f" total decoder params: {total} ({total/1e6:.2f}M)") # spot-check shapes for k in ("post_proj.weight", "decoder.conv_pre.weight", "decoder.ups.0.0.weight", "decoder.conv_post.weight"): print(f" {k}: {tuple(out[k].shape)}") return 0 if __name__ == "__main__": sys.exit(main())