#!/usr/bin/env python """Extract dots.tts Qwen2 backbone (llm.* tensors) into a standard HF Qwen2 checkpoint. Strips the `llm.` prefix so `llm.model.layers.N...` -> `model.layers.N...`, the verbatim HF Qwen2 naming. Writes config.json + model.safetensors and copies the tokenizer files so mlx_lm.load() / mlx_lm.convert() can read the dir directly. """ import json import shutil from pathlib import Path from safetensors import safe_open from safetensors.torch import save_file SNAP = Path( "/Users/samm/.cache/huggingface/hub/models--rednote-hilab--dots.tts-soar/" "snapshots/1fd9452e55c2c9f38fe1a8ee09eaf7448c222d35" ) OUT = Path("/Users/samm/git/dots-mlx-spike/qwen2-backbone-hf") OUT.mkdir(parents=True, exist_ok=True) src = SNAP / "model.safetensors" prefix = "llm." backbone = {} other_prefixes = {} with safe_open(str(src), framework="pt") as f: for key in f.keys(): if key.startswith(prefix): new_key = key[len(prefix):] backbone[new_key] = f.get_tensor(key) else: top = key.split(".")[0] other_prefixes[top] = other_prefixes.get(top, 0) + 1 print(f"extracted {len(backbone)} backbone tensors") print("skipped non-backbone top-level prefixes:") for k, v in sorted(other_prefixes.items()): print(f" {k}: {v} tensors") # Sanity: confirm the standard Qwen2 names are present. must_have = [ "model.embed_tokens.weight", "model.norm.weight", "model.layers.0.self_attn.q_proj.weight", "model.layers.0.self_attn.q_proj.bias", "model.layers.27.mlp.down_proj.weight", ] for name in must_have: assert name in backbone, f"MISSING expected tensor {name}" print(f" ok {name} {tuple(backbone[name].shape)} {backbone[name].dtype}") has_lm_head = any(k.startswith("lm_head") for k in backbone) print(f"lm_head present in checkpoint: {has_lm_head} (tie_word_embeddings=true -> expect False)") # Cast to the model's native dtype (bf16) for an honest baseline; the source # blob is fp32. mlx will down/upcast on convert anyway. import torch backbone = {k: v.to(torch.bfloat16) if v.is_floating_point() else v for k, v in backbone.items()} save_file(backbone, str(OUT / "model.safetensors"), metadata={"format": "pt"}) print(f"wrote {OUT / 'model.safetensors'}") # Build a clean Qwen2 config.json from llm_config.json (drop generation cruft). cfg = { "architectures": ["Qwen2ForCausalLM"], "model_type": "qwen2", "vocab_size": 151672, "hidden_size": 1536, "intermediate_size": 8960, "num_hidden_layers": 28, "num_attention_heads": 12, "num_key_value_heads": 2, "hidden_act": "silu", "max_position_embeddings": 131072, "initializer_range": 0.02, "rms_norm_eps": 1e-06, "use_cache": True, "rope_theta": 1000000.0, "rope_scaling": None, "use_sliding_window": False, "sliding_window": None, "max_window_layers": 28, "attention_dropout": 0.0, "tie_word_embeddings": True, "bos_token_id": 151643, "eos_token_id": 151643, "torch_dtype": "bfloat16", "transformers_version": "4.57.0", } (OUT / "config.json").write_text(json.dumps(cfg, indent=2)) print(f"wrote {OUT / 'config.json'}") # Copy tokenizer files. for name in [ "added_tokens.json", "merges.txt", "special_tokens_map.json", "tokenizer_config.json", "tokenizer.json", "vocab.json", ]: s = SNAP / name if s.exists(): shutil.copy(str(s.resolve()), str(OUT / name)) print(f"copied {name}") else: print(f"MISSING tokenizer file {name}") print("done")