File size: 5,424 Bytes
39057fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
#!/usr/bin/env python3
"""Convert the dots.tts-soar reference-audio conditioning path weights to MLX layout.

Extracts:
  - patch_encoder.*           from model.safetensors   -> patchencoder_mlx.safetensors
  - audio_encoder.* / enc_mi_layer.* / pre_proj.*  from vocoder.safetensors -> audiovae_encoder_mlx.safetensors
  - latent_stats (mean/var)   from latent_stats.pt     -> latent_stats.json

Conventions applied:
  - Conv1d weights transposed (out,in,k) -> (out,k,in) for mlx.nn.Conv1d.
  - weight_norm params (weight_g,weight_v) folded to a single weight, then transposed.
  - Linear / RMSNorm / LayerNorm / LSTM / bias tensors copied unchanged.

CPU + torch + numpy + safetensors only. No MLX, no MPS.
"""
from __future__ import annotations

import json
import re
from pathlib import Path

import torch
from safetensors import safe_open
from safetensors.torch import save_file

SNAP = Path(
    "/Users/samm/.cache/huggingface/hub/models--rednote-hilab--dots.tts-soar/"
    "snapshots/1fd9452e55c2c9f38fe1a8ee09eaf7448c222d35"
)
OUT = Path("/Users/samm/git/dots-mlx-spike")
MODEL_ST = SNAP / "model.safetensors"
VOCODER_ST = SNAP / "vocoder.safetensors"
LATENT_STATS = SNAP / "latent_stats.pt"


def load_prefixed(path: Path, prefixes: tuple[str, ...]) -> dict[str, torch.Tensor]:
    out: dict[str, torch.Tensor] = {}
    with safe_open(str(path), framework="pt", device="cpu") as f:
        for k in f.keys():
            if k.startswith(prefixes):
                out[k] = f.get_tensor(k).float()
    return out


def conv_transpose(w: torch.Tensor) -> torch.Tensor:
    """(out, in, k) -> (out, k, in) for mlx.nn.Conv1d."""
    assert w.ndim == 3, w.shape
    return w.permute(0, 2, 1).contiguous()


def fold_weight_norm(g: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    """w = g * v / ||v||, norm over dims (in, k) per output channel (PyTorch dim=0)."""
    norm = v.pow(2).sum(dim=(1, 2), keepdim=True).sqrt()
    return g * v / norm


# ---- PatchEncoder (model.safetensors) ----
def convert_patch_encoder() -> dict[str, torch.Tensor]:
    raw = load_prefixed(MODEL_ST, ("patch_encoder.",))
    out: dict[str, torch.Tensor] = {}
    for k, t in raw.items():
        name = k[len("patch_encoder.") :]
        if name == "ds_proj.weight":  # only conv weight in the patch encoder
            out[name] = conv_transpose(t)
        else:
            out[name] = t.contiguous()
    # q/k RMSNorm are affine-free (no weight keys present) -> nothing to emit.
    return out


# ---- AudioVAE encoder (vocoder.safetensors) ----
def convert_audiovae_encoder() -> dict[str, torch.Tensor]:
    raw = load_prefixed(
        VOCODER_ST, ("audio_encoder.", "enc_mi_layer.", "pre_proj.")
    )
    # Group weight_norm pairs by their base key (strip _g/_v suffix).
    wn_pairs: dict[str, dict[str, torch.Tensor]] = {}
    plain: dict[str, torch.Tensor] = {}
    for k, t in raw.items():
        if k.endswith(".weight_g"):
            wn_pairs.setdefault(k[: -len("_g")], {})["g"] = t
        elif k.endswith(".weight_v"):
            wn_pairs.setdefault(k[: -len("_v")], {})["v"] = t
        else:
            plain[k] = t

    out: dict[str, torch.Tensor] = {}
    # Fold + transpose every weight_normed conv.
    for base, gv in wn_pairs.items():
        assert "g" in gv and "v" in gv, f"incomplete weight_norm pair: {base}"
        folded = fold_weight_norm(gv["g"], gv["v"])  # (out,in,k)
        out[base] = conv_transpose(folded)  # -> (out,k,in)

    # Plain tensors: conv weights (pre_proj) transposed; everything else copied.
    for k, t in plain.items():
        if k.endswith(".weight") and t.ndim == 3:  # pre_proj.weight (plain conv)
            out[k] = conv_transpose(t)
        else:
            out[k] = t.contiguous()  # biases, linear, LSTM packed weights, layernorm
    return out


def convert_latent_stats() -> dict:
    d = torch.load(LATENT_STATS, weights_only=False)
    mean = torch.as_tensor(d["mean"]).float()
    var = torch.as_tensor(d["var"]).float()
    return {"mean": mean.tolist(), "var": var.tolist(),
            "shape": list(mean.shape)}


def main() -> None:
    OUT.mkdir(parents=True, exist_ok=True)

    pe = convert_patch_encoder()
    save_file(pe, str(OUT / "patchencoder_mlx.safetensors"))
    print(f"patchencoder_mlx.safetensors: {len(pe)} tensors")
    for k in ("ds_proj.weight", "in_proj.weight", "out_proj.weight"):
        print(f"  {k}: {tuple(pe[k].shape)}")

    ae = convert_audiovae_encoder()
    save_file(ae, str(OUT / "audiovae_encoder_mlx.safetensors"))
    print(f"audiovae_encoder_mlx.safetensors: {len(ae)} tensors")
    # spot-check a folded+transposed down conv and the pre_proj
    for k in (
        "audio_encoder.generator.0.layer.weight",
        "audio_encoder.generator.17.layer.weight",
        "pre_proj.weight",
        "enc_mi_layer.1.lstm.weight_ih_l0",
    ):
        if k in ae:
            print(f"  {k}: {tuple(ae[k].shape)}")

    stats = convert_latent_stats()
    (OUT / "latent_stats.json").write_text(json.dumps(stats), encoding="utf-8")
    print(f"latent_stats.json: mean/var shape {stats['shape']}")

    # sanity: count conv tensors transposed
    n_pe_conv = sum(1 for k in pe if k == "ds_proj.weight")
    n_ae_conv = sum(
        1 for k in ae if k.endswith(".weight") and ae[k].ndim == 3
    )
    print(f"conv weights transposed: patch_encoder={n_pe_conv} audiovae={n_ae_conv}")


if __name__ == "__main__":
    main()