dots.tts-soar-mlx / scripts /convert_refpath.py
smcleod's picture
Upload folder using huggingface_hub
39057fb verified
#!/usr/bin/env python3
"""Convert the dots.tts-soar reference-audio conditioning path weights to MLX layout.
Extracts:
- patch_encoder.* from model.safetensors -> patchencoder_mlx.safetensors
- audio_encoder.* / enc_mi_layer.* / pre_proj.* from vocoder.safetensors -> audiovae_encoder_mlx.safetensors
- latent_stats (mean/var) from latent_stats.pt -> latent_stats.json
Conventions applied:
- Conv1d weights transposed (out,in,k) -> (out,k,in) for mlx.nn.Conv1d.
- weight_norm params (weight_g,weight_v) folded to a single weight, then transposed.
- Linear / RMSNorm / LayerNorm / LSTM / bias tensors copied unchanged.
CPU + torch + numpy + safetensors only. No MLX, no MPS.
"""
from __future__ import annotations
import json
import re
from pathlib import Path
import torch
from safetensors import safe_open
from safetensors.torch import save_file
SNAP = Path(
"/Users/samm/.cache/huggingface/hub/models--rednote-hilab--dots.tts-soar/"
"snapshots/1fd9452e55c2c9f38fe1a8ee09eaf7448c222d35"
)
OUT = Path("/Users/samm/git/dots-mlx-spike")
MODEL_ST = SNAP / "model.safetensors"
VOCODER_ST = SNAP / "vocoder.safetensors"
LATENT_STATS = SNAP / "latent_stats.pt"
def load_prefixed(path: Path, prefixes: tuple[str, ...]) -> dict[str, torch.Tensor]:
out: dict[str, torch.Tensor] = {}
with safe_open(str(path), framework="pt", device="cpu") as f:
for k in f.keys():
if k.startswith(prefixes):
out[k] = f.get_tensor(k).float()
return out
def conv_transpose(w: torch.Tensor) -> torch.Tensor:
"""(out, in, k) -> (out, k, in) for mlx.nn.Conv1d."""
assert w.ndim == 3, w.shape
return w.permute(0, 2, 1).contiguous()
def fold_weight_norm(g: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""w = g * v / ||v||, norm over dims (in, k) per output channel (PyTorch dim=0)."""
norm = v.pow(2).sum(dim=(1, 2), keepdim=True).sqrt()
return g * v / norm
# ---- PatchEncoder (model.safetensors) ----
def convert_patch_encoder() -> dict[str, torch.Tensor]:
raw = load_prefixed(MODEL_ST, ("patch_encoder.",))
out: dict[str, torch.Tensor] = {}
for k, t in raw.items():
name = k[len("patch_encoder.") :]
if name == "ds_proj.weight": # only conv weight in the patch encoder
out[name] = conv_transpose(t)
else:
out[name] = t.contiguous()
# q/k RMSNorm are affine-free (no weight keys present) -> nothing to emit.
return out
# ---- AudioVAE encoder (vocoder.safetensors) ----
def convert_audiovae_encoder() -> dict[str, torch.Tensor]:
raw = load_prefixed(
VOCODER_ST, ("audio_encoder.", "enc_mi_layer.", "pre_proj.")
)
# Group weight_norm pairs by their base key (strip _g/_v suffix).
wn_pairs: dict[str, dict[str, torch.Tensor]] = {}
plain: dict[str, torch.Tensor] = {}
for k, t in raw.items():
if k.endswith(".weight_g"):
wn_pairs.setdefault(k[: -len("_g")], {})["g"] = t
elif k.endswith(".weight_v"):
wn_pairs.setdefault(k[: -len("_v")], {})["v"] = t
else:
plain[k] = t
out: dict[str, torch.Tensor] = {}
# Fold + transpose every weight_normed conv.
for base, gv in wn_pairs.items():
assert "g" in gv and "v" in gv, f"incomplete weight_norm pair: {base}"
folded = fold_weight_norm(gv["g"], gv["v"]) # (out,in,k)
out[base] = conv_transpose(folded) # -> (out,k,in)
# Plain tensors: conv weights (pre_proj) transposed; everything else copied.
for k, t in plain.items():
if k.endswith(".weight") and t.ndim == 3: # pre_proj.weight (plain conv)
out[k] = conv_transpose(t)
else:
out[k] = t.contiguous() # biases, linear, LSTM packed weights, layernorm
return out
def convert_latent_stats() -> dict:
d = torch.load(LATENT_STATS, weights_only=False)
mean = torch.as_tensor(d["mean"]).float()
var = torch.as_tensor(d["var"]).float()
return {"mean": mean.tolist(), "var": var.tolist(),
"shape": list(mean.shape)}
def main() -> None:
OUT.mkdir(parents=True, exist_ok=True)
pe = convert_patch_encoder()
save_file(pe, str(OUT / "patchencoder_mlx.safetensors"))
print(f"patchencoder_mlx.safetensors: {len(pe)} tensors")
for k in ("ds_proj.weight", "in_proj.weight", "out_proj.weight"):
print(f" {k}: {tuple(pe[k].shape)}")
ae = convert_audiovae_encoder()
save_file(ae, str(OUT / "audiovae_encoder_mlx.safetensors"))
print(f"audiovae_encoder_mlx.safetensors: {len(ae)} tensors")
# spot-check a folded+transposed down conv and the pre_proj
for k in (
"audio_encoder.generator.0.layer.weight",
"audio_encoder.generator.17.layer.weight",
"pre_proj.weight",
"enc_mi_layer.1.lstm.weight_ih_l0",
):
if k in ae:
print(f" {k}: {tuple(ae[k].shape)}")
stats = convert_latent_stats()
(OUT / "latent_stats.json").write_text(json.dumps(stats), encoding="utf-8")
print(f"latent_stats.json: mean/var shape {stats['shape']}")
# sanity: count conv tensors transposed
n_pe_conv = sum(1 for k in pe if k == "ds_proj.weight")
n_ae_conv = sum(
1 for k in ae if k.endswith(".weight") and ae[k].ndim == 3
)
print(f"conv weights transposed: patch_encoder={n_pe_conv} audiovae={n_ae_conv}")
if __name__ == "__main__":
main()