#!/usr/bin/env python """Convert the validated backbone + DiT into the dots.tts-soar-mlx model-repo layout. Pins the on-disk convention the Swift loader + unified pipeline consume: /backbone/ mlx_lm.convert output (config.json + model.safetensors, int4 g64) + tokenizer /dit/ model.safetensors (fp32) + config.json (DiT TransformerConfig subset) /latent_stats.pt, config.json (upstream copies) Vocoder / speaker / patch_encoder / audiovae_encoder dirs are produced by the parallel component agents and folded in later. """ import json import shutil from pathlib import Path import mlx.core as mx from mlx_lm import convert SNAP = Path( "/Users/samm/.cache/huggingface/hub/models--rednote-hilab--dots.tts-soar/" "snapshots/1fd9452e55c2c9f38fe1a8ee09eaf7448c222d35" ) HF_BACKBONE = "/Users/samm/git/dots-mlx-spike/qwen2-backbone-hf" REPO = Path("/Users/samm/git/sammcj/dots.tts-soar-mlx") # --- backbone: int4 group-64 via mlx_lm.convert --- backbone_out = REPO / "backbone" if backbone_out.exists(): shutil.rmtree(backbone_out) convert(HF_BACKBONE, str(backbone_out), quantize=True, q_bits=4, q_group_size=64) print(f"backbone int4 -> {backbone_out}") # --- DiT: strip velocity_field_predictor. prefix, save fp32 --- dit_out = REPO / "dit" dit_out.mkdir(parents=True, exist_ok=True) allw = mx.load(str(SNAP / "model.safetensors")) PREFIX = "velocity_field_predictor." dit_w = {k[len(PREFIX):]: v.astype(mx.float32) for k, v in allw.items() if k.startswith(PREFIX)} mx.save_safetensors(str(dit_out / "model.safetensors"), dit_w) dit_cfg = { "num_layers": 18, "num_heads": 16, "hidden_size": 1024, "ffn_hidden_size": 4096, "head_dim": 64, "in_dim": 1024, "out_dim": 128, "patch_size": 4, "qk_norm": True, "norm_layer": "RMSNorm", "rms_eps": 1.1920929e-7, "ln_eps": 1e-5, "modulation": True, "rotary_theta": 10000.0, "mode": "flow_matching", } (dit_out / "config.json").write_text(json.dumps(dit_cfg, indent=2)) print(f"dit fp32 ({len(dit_w)} tensors) -> {dit_out}") # --- top-level upstream copies --- for name in ["config.json", "llm_config.json", "latent_stats.pt"]: s = SNAP / name if s.exists(): shutil.copy(str(s.resolve()), str(REPO / name)) print(f"copied {name}") print("done")