dots.tts-soar-mlx / scripts /extract_backbone.py
smcleod's picture
Upload folder using huggingface_hub
39057fb verified
#!/usr/bin/env python
"""Extract dots.tts Qwen2 backbone (llm.* tensors) into a standard HF Qwen2 checkpoint.
Strips the `llm.` prefix so `llm.model.layers.N...` -> `model.layers.N...`, the
verbatim HF Qwen2 naming. Writes config.json + model.safetensors and copies the
tokenizer files so mlx_lm.load() / mlx_lm.convert() can read the dir directly.
"""
import json
import shutil
from pathlib import Path
from safetensors import safe_open
from safetensors.torch import save_file
SNAP = Path(
"/Users/samm/.cache/huggingface/hub/models--rednote-hilab--dots.tts-soar/"
"snapshots/1fd9452e55c2c9f38fe1a8ee09eaf7448c222d35"
)
OUT = Path("/Users/samm/git/dots-mlx-spike/qwen2-backbone-hf")
OUT.mkdir(parents=True, exist_ok=True)
src = SNAP / "model.safetensors"
prefix = "llm."
backbone = {}
other_prefixes = {}
with safe_open(str(src), framework="pt") as f:
for key in f.keys():
if key.startswith(prefix):
new_key = key[len(prefix):]
backbone[new_key] = f.get_tensor(key)
else:
top = key.split(".")[0]
other_prefixes[top] = other_prefixes.get(top, 0) + 1
print(f"extracted {len(backbone)} backbone tensors")
print("skipped non-backbone top-level prefixes:")
for k, v in sorted(other_prefixes.items()):
print(f" {k}: {v} tensors")
# Sanity: confirm the standard Qwen2 names are present.
must_have = [
"model.embed_tokens.weight",
"model.norm.weight",
"model.layers.0.self_attn.q_proj.weight",
"model.layers.0.self_attn.q_proj.bias",
"model.layers.27.mlp.down_proj.weight",
]
for name in must_have:
assert name in backbone, f"MISSING expected tensor {name}"
print(f" ok {name} {tuple(backbone[name].shape)} {backbone[name].dtype}")
has_lm_head = any(k.startswith("lm_head") for k in backbone)
print(f"lm_head present in checkpoint: {has_lm_head} (tie_word_embeddings=true -> expect False)")
# Cast to the model's native dtype (bf16) for an honest baseline; the source
# blob is fp32. mlx will down/upcast on convert anyway.
import torch
backbone = {k: v.to(torch.bfloat16) if v.is_floating_point() else v for k, v in backbone.items()}
save_file(backbone, str(OUT / "model.safetensors"), metadata={"format": "pt"})
print(f"wrote {OUT / 'model.safetensors'}")
# Build a clean Qwen2 config.json from llm_config.json (drop generation cruft).
cfg = {
"architectures": ["Qwen2ForCausalLM"],
"model_type": "qwen2",
"vocab_size": 151672,
"hidden_size": 1536,
"intermediate_size": 8960,
"num_hidden_layers": 28,
"num_attention_heads": 12,
"num_key_value_heads": 2,
"hidden_act": "silu",
"max_position_embeddings": 131072,
"initializer_range": 0.02,
"rms_norm_eps": 1e-06,
"use_cache": True,
"rope_theta": 1000000.0,
"rope_scaling": None,
"use_sliding_window": False,
"sliding_window": None,
"max_window_layers": 28,
"attention_dropout": 0.0,
"tie_word_embeddings": True,
"bos_token_id": 151643,
"eos_token_id": 151643,
"torch_dtype": "bfloat16",
"transformers_version": "4.57.0",
}
(OUT / "config.json").write_text(json.dumps(cfg, indent=2))
print(f"wrote {OUT / 'config.json'}")
# Copy tokenizer files.
for name in [
"added_tokens.json",
"merges.txt",
"special_tokens_map.json",
"tokenizer_config.json",
"tokenizer.json",
"vocab.json",
]:
s = SNAP / name
if s.exists():
shutil.copy(str(s.resolve()), str(OUT / name))
print(f"copied {name}")
else:
print(f"MISSING tokenizer file {name}")
print("done")