#!/usr/bin/env python3 """Convert the dots.tts-soar reference-audio conditioning path weights to MLX layout. Extracts: - patch_encoder.* from model.safetensors -> patchencoder_mlx.safetensors - audio_encoder.* / enc_mi_layer.* / pre_proj.* from vocoder.safetensors -> audiovae_encoder_mlx.safetensors - latent_stats (mean/var) from latent_stats.pt -> latent_stats.json Conventions applied: - Conv1d weights transposed (out,in,k) -> (out,k,in) for mlx.nn.Conv1d. - weight_norm params (weight_g,weight_v) folded to a single weight, then transposed. - Linear / RMSNorm / LayerNorm / LSTM / bias tensors copied unchanged. CPU + torch + numpy + safetensors only. No MLX, no MPS. """ from __future__ import annotations import json import re from pathlib import Path import torch from safetensors import safe_open from safetensors.torch import save_file SNAP = Path( "/Users/samm/.cache/huggingface/hub/models--rednote-hilab--dots.tts-soar/" "snapshots/1fd9452e55c2c9f38fe1a8ee09eaf7448c222d35" ) OUT = Path("/Users/samm/git/dots-mlx-spike") MODEL_ST = SNAP / "model.safetensors" VOCODER_ST = SNAP / "vocoder.safetensors" LATENT_STATS = SNAP / "latent_stats.pt" def load_prefixed(path: Path, prefixes: tuple[str, ...]) -> dict[str, torch.Tensor]: out: dict[str, torch.Tensor] = {} with safe_open(str(path), framework="pt", device="cpu") as f: for k in f.keys(): if k.startswith(prefixes): out[k] = f.get_tensor(k).float() return out def conv_transpose(w: torch.Tensor) -> torch.Tensor: """(out, in, k) -> (out, k, in) for mlx.nn.Conv1d.""" assert w.ndim == 3, w.shape return w.permute(0, 2, 1).contiguous() def fold_weight_norm(g: torch.Tensor, v: torch.Tensor) -> torch.Tensor: """w = g * v / ||v||, norm over dims (in, k) per output channel (PyTorch dim=0).""" norm = v.pow(2).sum(dim=(1, 2), keepdim=True).sqrt() return g * v / norm # ---- PatchEncoder (model.safetensors) ---- def convert_patch_encoder() -> dict[str, torch.Tensor]: raw = load_prefixed(MODEL_ST, ("patch_encoder.",)) out: dict[str, torch.Tensor] = {} for k, t in raw.items(): name = k[len("patch_encoder.") :] if name == "ds_proj.weight": # only conv weight in the patch encoder out[name] = conv_transpose(t) else: out[name] = t.contiguous() # q/k RMSNorm are affine-free (no weight keys present) -> nothing to emit. return out # ---- AudioVAE encoder (vocoder.safetensors) ---- def convert_audiovae_encoder() -> dict[str, torch.Tensor]: raw = load_prefixed( VOCODER_ST, ("audio_encoder.", "enc_mi_layer.", "pre_proj.") ) # Group weight_norm pairs by their base key (strip _g/_v suffix). wn_pairs: dict[str, dict[str, torch.Tensor]] = {} plain: dict[str, torch.Tensor] = {} for k, t in raw.items(): if k.endswith(".weight_g"): wn_pairs.setdefault(k[: -len("_g")], {})["g"] = t elif k.endswith(".weight_v"): wn_pairs.setdefault(k[: -len("_v")], {})["v"] = t else: plain[k] = t out: dict[str, torch.Tensor] = {} # Fold + transpose every weight_normed conv. for base, gv in wn_pairs.items(): assert "g" in gv and "v" in gv, f"incomplete weight_norm pair: {base}" folded = fold_weight_norm(gv["g"], gv["v"]) # (out,in,k) out[base] = conv_transpose(folded) # -> (out,k,in) # Plain tensors: conv weights (pre_proj) transposed; everything else copied. for k, t in plain.items(): if k.endswith(".weight") and t.ndim == 3: # pre_proj.weight (plain conv) out[k] = conv_transpose(t) else: out[k] = t.contiguous() # biases, linear, LSTM packed weights, layernorm return out def convert_latent_stats() -> dict: d = torch.load(LATENT_STATS, weights_only=False) mean = torch.as_tensor(d["mean"]).float() var = torch.as_tensor(d["var"]).float() return {"mean": mean.tolist(), "var": var.tolist(), "shape": list(mean.shape)} def main() -> None: OUT.mkdir(parents=True, exist_ok=True) pe = convert_patch_encoder() save_file(pe, str(OUT / "patchencoder_mlx.safetensors")) print(f"patchencoder_mlx.safetensors: {len(pe)} tensors") for k in ("ds_proj.weight", "in_proj.weight", "out_proj.weight"): print(f" {k}: {tuple(pe[k].shape)}") ae = convert_audiovae_encoder() save_file(ae, str(OUT / "audiovae_encoder_mlx.safetensors")) print(f"audiovae_encoder_mlx.safetensors: {len(ae)} tensors") # spot-check a folded+transposed down conv and the pre_proj for k in ( "audio_encoder.generator.0.layer.weight", "audio_encoder.generator.17.layer.weight", "pre_proj.weight", "enc_mi_layer.1.lstm.weight_ih_l0", ): if k in ae: print(f" {k}: {tuple(ae[k].shape)}") stats = convert_latent_stats() (OUT / "latent_stats.json").write_text(json.dumps(stats), encoding="utf-8") print(f"latent_stats.json: mean/var shape {stats['shape']}") # sanity: count conv tensors transposed n_pe_conv = sum(1 for k in pe if k == "ds_proj.weight") n_ae_conv = sum( 1 for k in ae if k.endswith(".weight") and ae[k].ndim == 3 ) print(f"conv weights transposed: patch_encoder={n_pe_conv} audiovae={n_ae_conv}") if __name__ == "__main__": main()