File size: 3,565 Bytes
39057fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python
"""Extract dots.tts Qwen2 backbone (llm.* tensors) into a standard HF Qwen2 checkpoint.

Strips the `llm.` prefix so `llm.model.layers.N...` -> `model.layers.N...`, the
verbatim HF Qwen2 naming. Writes config.json + model.safetensors and copies the
tokenizer files so mlx_lm.load() / mlx_lm.convert() can read the dir directly.
"""
import json
import shutil
from pathlib import Path

from safetensors import safe_open
from safetensors.torch import save_file

SNAP = Path(
    "/Users/samm/.cache/huggingface/hub/models--rednote-hilab--dots.tts-soar/"
    "snapshots/1fd9452e55c2c9f38fe1a8ee09eaf7448c222d35"
)
OUT = Path("/Users/samm/git/dots-mlx-spike/qwen2-backbone-hf")
OUT.mkdir(parents=True, exist_ok=True)

src = SNAP / "model.safetensors"
prefix = "llm."

backbone = {}
other_prefixes = {}
with safe_open(str(src), framework="pt") as f:
    for key in f.keys():
        if key.startswith(prefix):
            new_key = key[len(prefix):]
            backbone[new_key] = f.get_tensor(key)
        else:
            top = key.split(".")[0]
            other_prefixes[top] = other_prefixes.get(top, 0) + 1

print(f"extracted {len(backbone)} backbone tensors")
print("skipped non-backbone top-level prefixes:")
for k, v in sorted(other_prefixes.items()):
    print(f"  {k}: {v} tensors")

# Sanity: confirm the standard Qwen2 names are present.
must_have = [
    "model.embed_tokens.weight",
    "model.norm.weight",
    "model.layers.0.self_attn.q_proj.weight",
    "model.layers.0.self_attn.q_proj.bias",
    "model.layers.27.mlp.down_proj.weight",
]
for name in must_have:
    assert name in backbone, f"MISSING expected tensor {name}"
    print(f"  ok {name} {tuple(backbone[name].shape)} {backbone[name].dtype}")

has_lm_head = any(k.startswith("lm_head") for k in backbone)
print(f"lm_head present in checkpoint: {has_lm_head} (tie_word_embeddings=true -> expect False)")

# Cast to the model's native dtype (bf16) for an honest baseline; the source
# blob is fp32. mlx will down/upcast on convert anyway.
import torch
backbone = {k: v.to(torch.bfloat16) if v.is_floating_point() else v for k, v in backbone.items()}

save_file(backbone, str(OUT / "model.safetensors"), metadata={"format": "pt"})
print(f"wrote {OUT / 'model.safetensors'}")

# Build a clean Qwen2 config.json from llm_config.json (drop generation cruft).
cfg = {
    "architectures": ["Qwen2ForCausalLM"],
    "model_type": "qwen2",
    "vocab_size": 151672,
    "hidden_size": 1536,
    "intermediate_size": 8960,
    "num_hidden_layers": 28,
    "num_attention_heads": 12,
    "num_key_value_heads": 2,
    "hidden_act": "silu",
    "max_position_embeddings": 131072,
    "initializer_range": 0.02,
    "rms_norm_eps": 1e-06,
    "use_cache": True,
    "rope_theta": 1000000.0,
    "rope_scaling": None,
    "use_sliding_window": False,
    "sliding_window": None,
    "max_window_layers": 28,
    "attention_dropout": 0.0,
    "tie_word_embeddings": True,
    "bos_token_id": 151643,
    "eos_token_id": 151643,
    "torch_dtype": "bfloat16",
    "transformers_version": "4.57.0",
}
(OUT / "config.json").write_text(json.dumps(cfg, indent=2))
print(f"wrote {OUT / 'config.json'}")

# Copy tokenizer files.
for name in [
    "added_tokens.json",
    "merges.txt",
    "special_tokens_map.json",
    "tokenizer_config.json",
    "tokenizer.json",
    "vocab.json",
]:
    s = SNAP / name
    if s.exists():
        shutil.copy(str(s.resolve()), str(OUT / name))
        print(f"copied {name}")
    else:
        print(f"MISSING tokenizer file {name}")

print("done")