File size: 2,281 Bytes
39057fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#!/usr/bin/env python
"""Convert the validated backbone + DiT into the dots.tts-soar-mlx model-repo layout.

Pins the on-disk convention the Swift loader + unified pipeline consume:
  <repo>/backbone/      mlx_lm.convert output (config.json + model.safetensors, int4 g64) + tokenizer
  <repo>/dit/           model.safetensors (fp32) + config.json (DiT TransformerConfig subset)
  <repo>/latent_stats.pt, config.json (upstream copies)
Vocoder / speaker / patch_encoder / audiovae_encoder dirs are produced by the
parallel component agents and folded in later.
"""
import json
import shutil
from pathlib import Path

import mlx.core as mx
from mlx_lm import convert

SNAP = Path(
    "/Users/samm/.cache/huggingface/hub/models--rednote-hilab--dots.tts-soar/"
    "snapshots/1fd9452e55c2c9f38fe1a8ee09eaf7448c222d35"
)
HF_BACKBONE = "/Users/samm/git/dots-mlx-spike/qwen2-backbone-hf"
REPO = Path("/Users/samm/git/sammcj/dots.tts-soar-mlx")

# --- backbone: int4 group-64 via mlx_lm.convert ---
backbone_out = REPO / "backbone"
if backbone_out.exists():
    shutil.rmtree(backbone_out)
convert(HF_BACKBONE, str(backbone_out), quantize=True, q_bits=4, q_group_size=64)
print(f"backbone int4 -> {backbone_out}")

# --- DiT: strip velocity_field_predictor. prefix, save fp32 ---
dit_out = REPO / "dit"
dit_out.mkdir(parents=True, exist_ok=True)
allw = mx.load(str(SNAP / "model.safetensors"))
PREFIX = "velocity_field_predictor."
dit_w = {k[len(PREFIX):]: v.astype(mx.float32) for k, v in allw.items() if k.startswith(PREFIX)}
mx.save_safetensors(str(dit_out / "model.safetensors"), dit_w)
dit_cfg = {
    "num_layers": 18, "num_heads": 16, "hidden_size": 1024, "ffn_hidden_size": 4096,
    "head_dim": 64, "in_dim": 1024, "out_dim": 128, "patch_size": 4,
    "qk_norm": True, "norm_layer": "RMSNorm", "rms_eps": 1.1920929e-7,
    "ln_eps": 1e-5, "modulation": True, "rotary_theta": 10000.0, "mode": "flow_matching",
}
(dit_out / "config.json").write_text(json.dumps(dit_cfg, indent=2))
print(f"dit fp32 ({len(dit_w)} tensors) -> {dit_out}")

# --- top-level upstream copies ---
for name in ["config.json", "llm_config.json", "latent_stats.pt"]:
    s = SNAP / name
    if s.exists():
        shutil.copy(str(s.resolve()), str(REPO / name))
        print(f"copied {name}")

print("done")