File size: 4,053 Bytes
39057fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python3
"""Convert the dots.tts-soar AudioVAE decoder weights to an mlx-swift layout.

Reads vocoder.safetensors (full AudioVAE), selects only the decoder-side
tensors (post_proj, dec_mi_layer, decoder.*), applies the PyTorch -> MLX
conv weight permutations, and writes vocoder_decoder_mlx.safetensors.

Run on CPU only. torch + numpy + safetensors, no MLX, no MPS.

Decoder-side weights in the checkpoint are ALREADY weight-norm folded
(plain `weight` tensors, no weight_g / weight_v on the decoder side), so no
fusion is needed here. The 80 weight_v / weight_g pairs in the file all live
on the encoder side (audio_encoder.*) which we drop.

Layout conversion (the only transform applied):
  PyTorch Conv1d weight        (out, in/groups, k)   -> MLX (out, k, in/groups)   permute(0,2,1)
  PyTorch ConvTranspose1d wt   (in,  out/groups, k)  -> MLX (out, k, in/groups)   permute(1,2,0)
Linear and LSTM weights, biases, snake alpha/beta, and anti-alias filter
buffers are copied unchanged (the Swift side handles LSTM gate ordering and
filter layout directly).
"""

from __future__ import annotations

import sys

import numpy as np
import torch
from safetensors.numpy import save_file
from safetensors.torch import load_file

SRC = (
    "/Users/samm/.cache/huggingface/hub/models--rednote-hilab--dots.tts-soar/"
    "snapshots/1fd9452e55c2c9f38fe1a8ee09eaf7448c222d35/vocoder.safetensors"
)
DST = "/Users/samm/git/dots-mlx-spike/vocoder_decoder_mlx.safetensors"

# Prefixes that make up the decode path (latent -> waveform).
DECODER_PREFIXES = ("post_proj.", "dec_mi_layer.", "decoder.")


def is_conv1d(key: str) -> bool:
    """Conv1d weight: post_proj, decoder.conv_pre, decoder.conv_post, and every
    resblock conv (convs1/convs2). These are nn.Conv1d -> (out, in, k)."""
    if not key.endswith(".weight"):
        return False
    if key == "post_proj.weight":
        return True
    if key in ("decoder.conv_pre.weight", "decoder.conv_post.weight"):
        return True
    if key.startswith("decoder.resblocks.") and (".convs1." in key or ".convs2." in key):
        return True
    return False


def is_convtranspose1d(key: str) -> bool:
    """ConvTranspose1d weight: the upsamplers decoder.ups.<i>.0.weight.
    nn.ConvTranspose1d -> (in, out, k)."""
    return (
        key.startswith("decoder.ups.")
        and key.endswith(".0.weight")
        and key.count(".") == 4
    )


def main() -> int:
    state = load_file(SRC)  # dict[str, torch.Tensor] on CPU

    out: dict[str, np.ndarray] = {}
    n_conv = n_convt = n_copy = 0
    dropped = 0

    for key, tensor in state.items():
        if not key.startswith(DECODER_PREFIXES):
            dropped += 1
            continue

        t = tensor.detach().to(torch.float32).cpu()

        if is_conv1d(key):
            # (out, in/groups, k) -> (out, k, in/groups)
            assert t.ndim == 3, f"{key} expected 3D conv, got {t.shape}"
            t = t.permute(0, 2, 1).contiguous()
            n_conv += 1
        elif is_convtranspose1d(key):
            # (in, out/groups, k) -> (out, k, in/groups)
            assert t.ndim == 3, f"{key} expected 3D convT, got {t.shape}"
            t = t.permute(1, 2, 0).contiguous()
            n_convt += 1
        else:
            n_copy += 1

        out[key] = t.numpy()

    save_file(out, DST, metadata={"format": "pt", "source": "dots.tts-soar vocoder decoder"})

    total = sum(int(np.prod(v.shape)) for v in out.values())
    print(f"wrote {DST}")
    print(f"  tensors kept       : {len(out)}")
    print(f"  conv1d permuted    : {n_conv}")
    print(f"  convtranspose perm : {n_convt}")
    print(f"  copied unchanged   : {n_copy}")
    print(f"  encoder-side dropped: {dropped}")
    print(f"  total decoder params: {total} ({total/1e6:.2f}M)")
    # spot-check shapes
    for k in ("post_proj.weight", "decoder.conv_pre.weight",
              "decoder.ups.0.0.weight", "decoder.conv_post.weight"):
        print(f"  {k}: {tuple(out[k].shape)}")
    return 0


if __name__ == "__main__":
    sys.exit(main())