File size: 21,972 Bytes
cfef4e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
# hrm_utils.py — Minimal, robust HRM loader + tokenizer support
# --------------------------------------------------------------
# - Handles .pt/.bin/.safetensors (single file or HF sharded index)
# - Adapts q/k/v names to torch.nn.MultiheadAttention format
# - Infers config if config.json is missing
# - Prefers checkpoint vocab_size over config to avoid shape mismatches
# - Optional tokenizer load (local files) + embedding resize + weight tying
# - Returns (model, tokenizer) when with_tokenizer=True (else just model)

import os, json, glob, math, inspect
from typing import Optional, Dict, Any

import torch
import torch.nn as nn
import torch.nn.functional as F

# ---------------- Blocks ----------------
class RMSNorm(nn.Module):
    def __init__(self, d, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d))
    def forward(self, x):
        return self.weight * (x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps))

class SinusoidalPositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len=8192):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe, persistent=False)
    def forward(self, L: int):
        return self.pe[:L].unsqueeze(0)

class SwiGLU(nn.Module):
    def __init__(self, d_model, d_ff, pdrop=0.1):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_model, d_ff, bias=False)
        self.w3 = nn.Linear(d_ff, d_model, bias=False)
        self.drop = nn.Dropout(pdrop)
    def forward(self, x):
        return self.drop(self.w3(F.silu(self.w1(x)) * self.w2(x)))

class AttnBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, pdrop=0.1):
        super().__init__()
        self.norm1 = RMSNorm(d_model)
        self.attn  = nn.MultiheadAttention(d_model, n_heads, dropout=pdrop, batch_first=True)
        self.drop  = nn.Dropout(pdrop)
        self.norm2 = RMSNorm(d_model)
        self.mlp   = SwiGLU(d_model, d_ff, pdrop)
    def forward(self, x, attn_mask=None, key_padding_mask=None):
        if attn_mask is not None:
            assert attn_mask.dtype == torch.bool and attn_mask.dim() == 2
        if key_padding_mask is not None:
            assert key_padding_mask.dtype == torch.bool and key_padding_mask.dim() == 2
        h = self.norm1(x)
        a, _ = self.attn(h, h, h, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
        x = x + self.drop(a)
        x = x + self.drop(self.mlp(self.norm2(x)))
        return x

# ---------------- Model ----------------
class HRMForCausalLM(nn.Module):
    def __init__(self, vocab_size: int, d_model=512, n_heads=8, d_ff=2048, dropout=0.1,
                 k_l_steps=4, max_cycles=8, ponder_loss_weight=1e-2):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        self.vocab_size = vocab_size
        self.d_model    = d_model
        self.k_l_steps  = k_l_steps
        self.max_cycles = max_cycles
        self.ponder_w   = ponder_loss_weight

        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = SinusoidalPositionalEmbedding(d_model, max_len=8192)
        self.in_net  = nn.Sequential(nn.Linear(d_model, d_model), nn.GELU(), RMSNorm(d_model))

        self.L_mod = AttnBlock(d_model, n_heads, d_ff, dropout)
        self.H_mod = AttnBlock(d_model, n_heads, d_ff, dropout)

        self.halt_head = nn.Linear(d_model, 1)
        nn.init.constant_(self.halt_head.bias, -1.5)

        self.out_norm = RMSNorm(d_model)

        self.lm_head   = nn.Linear(d_model, vocab_size, bias=False)
        self.lm_head.weight = self.tok_emb.weight  # tie

        self._cached_causal_bool = {}
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Embedding)):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
        if isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.zeros_(m.bias)

    def _causal_bool_mask(self, L: int, device):
        k = (L, device)
        if k not in self._cached_causal_bool:
            self._cached_causal_bool[k] = torch.triu(torch.ones(L, L, dtype=torch.bool, device=device), 1)
        return self._cached_causal_bool[k]

    def forward(self, input_ids, attention_mask=None, labels=None):
        B, L = input_ids.shape
        device = input_ids.device
        x_tok = self.tok_emb(input_ids)
        pos   = self.pos_emb(L).to(device=device, dtype=x_tok.dtype)  # keep dtype aligned
        x     = self.in_net(x_tok + pos)

        causal_bool = self._causal_bool_mask(L, device)
        key_padding_mask = (attention_mask == 0) if attention_mask is not None else None

        z_L = x.clone()
        z_H = torch.zeros_like(x)

        eps = 1e-6
        rema = torch.ones((B, L), device=device, dtype=x_tok.dtype)
        collected_H = torch.zeros_like(z_H)
        ponder_terms = []

        for c in range(self.max_cycles):
            for _ in range(self.k_l_steps):
                z_L = self.L_mod(z_L + z_H + x, attn_mask=causal_bool, key_padding_mask=key_padding_mask)
            z_H = self.H_mod(z_H + z_L, attn_mask=causal_bool, key_padding_mask=key_padding_mask)

            p_halt = torch.sigmoid(self.halt_head(z_H)).squeeze(-1).clamp(eps, 1 - eps)
            last = torch.full_like(p_halt, fill_value=(c == self.max_cycles - 1), dtype=torch.bool)
            halt_p = torch.where(last, torch.ones_like(p_halt), p_halt)

            contrib = (rema * halt_p).unsqueeze(-1)
            collected_H = collected_H + contrib * z_H

            ponder_terms.append(rema * halt_p)
            rema = rema * (1.0 - halt_p)
            if torch.all(rema < 1e-4):
                break

        collected_H = self.out_norm(collected_H)
        logits = self.lm_head(collected_H)

        loss = lm_loss = ponder = None
        if labels is not None:
            sl = logits[:, :-1, :].contiguous()
            y  = labels[:, 1:].contiguous()
            B_, Lm1, V = sl.shape
            lm_loss = F.cross_entropy(sl.float().view(B_ * Lm1, V), y.view(B_ * Lm1))
            ponder  = torch.stack(ponder_terms, dim=-1).sum(dim=-1).mean()
            loss    = lm_loss + self.ponder_w * ponder

        return {"loss": loss, "logits": logits, "lm_loss": lm_loss, "ponder_loss": ponder}

    # ---- HF-style hooks ----
    def get_input_embeddings(self):
        return self.tok_emb
    def set_input_embeddings(self, new_emb):
        self.tok_emb = new_emb
        if hasattr(self, "lm_head"):
            self.lm_head.weight = self.tok_emb.weight
    def tie_weights(self):
        if hasattr(self, "lm_head") and hasattr(self, "tok_emb"):
            self.lm_head.weight = self.tok_emb.weight

# -------------- Loader helpers --------------
def _resolve_device(device: Optional[str]) -> torch.device:
    if device is None or device == "auto":
        if torch.cuda.is_available(): return torch.device("cuda")
        if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): return torch.device("mps")
        return torch.device("cpu")
    return torch.device(device)

def _resolve_dtype(dtype: str) -> torch.dtype:
    d = str(dtype).lower()
    if d in ("fp32","float32","f32"): return torch.float32
    if d in ("bf16","bfloat16"):      return torch.bfloat16
    if d in ("fp16","float16","half"):return torch.float16
    if d == "auto":
        if torch.cuda.is_available() and getattr(torch.cuda, "is_bf16_supported", lambda: False)(): return torch.bfloat16
        return torch.float32
    raise ValueError(f"Unknown dtype {dtype}")

def _find_checkpoint(path_or_dir: str) -> str:
    if os.path.isfile(path_or_dir): return path_or_dir
    if not os.path.isdir(path_or_dir): raise FileNotFoundError(f"Not a file or directory: {path_or_dir}")
    st = glob.glob(os.path.join(path_or_dir, "*.safetensors"))
    if len(st) == 1: return st[0]
    if len(st) > 1:
        for cand in ("model.safetensors","pytorch_model.safetensors"):
            p = os.path.join(path_or_dir, cand)
            if os.path.exists(p): return p
        return sorted(st)[0]
    for idx in ("model.safetensors.index.json","pytorch_model.bin.index.json"):
        p = os.path.join(path_or_dir, idx)
        if os.path.exists(p): return p
    for cand in ("pytorch_model.bin","model.bin","model.pt"):
        p = os.path.join(path_or_dir, cand)
        if os.path.exists(p): return p
    pt = glob.glob(os.path.join(path_or_dir, "*.pt")) + glob.glob(os.path.join(path_or_dir, "*.bin"))
    if pt: return sorted(pt)[0]
    raise FileNotFoundError(f"No checkpoint found in {path_or_dir}")

def _torch_load(path: str):
    try:
        return torch.load(path, map_location="cpu", weights_only=True)
    except TypeError:
        return torch.load(path, map_location="cpu")

def _normalize_keys(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    def strip(k: str) -> str:
        for pref in ("module.","model.","transformer."):
            if k.startswith(pref): return k[len(pref):]
        return k
    return {strip(k): v for k, v in sd.items()}

def _adapt_attention_keys(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    sd = dict(sd)
    def handle(prefix: str):
        qkv_w = sd.pop(f"{prefix}.qkv.weight", None)
        if qkv_w is not None:
            sd[f"{prefix}.in_proj_weight"] = qkv_w
        qkv_b = sd.pop(f"{prefix}.qkv.bias", None)
        if qkv_b is not None:
            sd[f"{prefix}.in_proj_bias"] = qkv_b

        q_w = sd.pop(f"{prefix}.q_proj.weight", None)
        k_w = sd.pop(f"{prefix}.k_proj.weight", None)
        v_w = sd.pop(f"{prefix}.v_proj.weight", None)
        if q_w is not None and k_w is not None and v_w is not None:
            sd[f"{prefix}.in_proj_weight"] = torch.cat([q_w, k_w, v_w], dim=0)

        q_b = sd.pop(f"{prefix}.q_proj.bias", None)
        k_b = sd.pop(f"{prefix}.k_proj.bias", None)
        v_b = sd.pop(f"{prefix}.v_proj.bias", None)
        if q_b is not None and k_b is not None and v_b is not None:
            sd[f"{prefix}.in_proj_bias"] = torch.cat([q_b, k_b, v_b], dim=0)

        o_w = sd.pop(f"{prefix}.o.weight", None)
        if o_w is not None:
            sd[f"{prefix}.out_proj.weight"] = o_w
        o_b = sd.pop(f"{prefix}.o.bias", None)
        if o_b is not None:
            sd[f"{prefix}.out_proj.bias"] = o_b

        if f"{prefix}.in_proj_weight" in sd and f"{prefix}.in_proj_bias" not in sd:
            E = sd[f"{prefix}.in_proj_weight"].shape[1]
            sd[f"{prefix}.in_proj_bias"] = torch.zeros(3 * E, dtype=sd[f"{prefix}.in_proj_weight"].dtype)
    for blk in ("L_mod.attn", "H_mod.attn"):
        handle(blk)
    return sd

def _load_state_dict(ckpt_path: str) -> Dict[str, torch.Tensor]:
    if ckpt_path.endswith(".safetensors"):
        from safetensors.torch import load_file as safe_load
        return _normalize_keys(safe_load(ckpt_path, device="cpu"))
    if ckpt_path.endswith("model.safetensors.index.json"):
        base = os.path.dirname(ckpt_path)
        with open(ckpt_path, "r", encoding="utf-8") as f:
            idx = json.load(f)
        from safetensors import safe_open
        state = {}
        for shard in sorted(set(idx.get("weight_map", {}).values())):
            with safe_open(os.path.join(base, shard), framework="pt", device="cpu") as sf:
                for k in sf.keys():
                    state[k] = sf.get_tensor(k)
        return _normalize_keys(state)
    if ckpt_path.endswith("pytorch_model.bin.index.json"):
        base = os.path.dirname(ckpt_path)
        with open(ckpt_path, "r", encoding="utf-8") as f:
            idx = json.load(f)
        state = {}
        for shard in sorted(set(idx.get("weight_map", {}).values())):
            part = _torch_load(os.path.join(base, shard))
            if isinstance(part, dict) and "state_dict" in part:
                part = part["state_dict"]
            state.update(part)
        return _normalize_keys(state)
    if ckpt_path.endswith((".pt",".bin")):
        obj = _torch_load(ckpt_path)
        if isinstance(obj, dict) and "state_dict" in obj:
            obj = obj["state_dict"]
        return _normalize_keys(obj)
    if ckpt_path.endswith(".json"):
        raise ValueError("Pass the directory, not the index/config JSON.")
    raise ValueError(f"Unsupported checkpoint type: {ckpt_path}")

def _load_config_if_any(path_or_dir: str) -> Optional[Dict[str, Any]]:
    p = path_or_dir if path_or_dir.endswith(".json") else os.path.join(path_or_dir, "config.json")
    if os.path.exists(p):
        with open(p, "r", encoding="utf-8") as f:
            return json.load(f)
    return None

def _infer_config_from_state(sd: Dict[str, torch.Tensor]) -> Dict[str, Any]:
    te = sd.get("tok_emb.weight", None)
    if te is None:
        te = sd.get("lm_head.weight", None)
    if te is None:
        raise ValueError("Cannot infer config: missing tok_emb.weight (or lm_head.weight).")
    vocab_size, d_model = te.shape
    w1 = sd.get("L_mod.mlp.w1.weight", None)
    if w1 is None:
        w1 = sd.get("H_mod.mlp.w1.weight", None)
    d_ff = int(w1.shape[0]) if w1 is not None else int(4 * d_model)
    return dict(vocab_size=int(vocab_size), d_model=int(d_model), n_heads=8, d_ff=int(d_ff),
                dropout=0.1, k_l_steps=4, max_cycles=8, ponder_loss_weight=1e-2)

_ALLOWED_KW = {"vocab_size","d_model","n_heads","d_ff","dropout","k_l_steps","max_cycles","ponder_loss_weight"}
_DROP_KEYS = {"weight_tying","tie_word_embeddings","torch_dtype","architectures","model_type",
              "initializer_range","layer_norm_eps","max_position_embeddings","use_cache"}

def _sanitize_and_map_config(raw_cfg: Dict[str, Any], ModelCls):
    cfg = dict(raw_cfg) if raw_cfg else {}
    for src, dst in {"hidden_size":"d_model","num_attention_heads":"n_heads","intermediate_size":"d_ff"}.items():
        if src in cfg and dst not in cfg:
            cfg[dst] = cfg[src]
    if "vocab_size" not in cfg and raw_cfg and "vocab_size" in raw_cfg:
        cfg["vocab_size"] = raw_cfg["vocab_size"]
    for k in list(cfg.keys()):
        if k in _DROP_KEYS:
            cfg.pop(k, None)
    cfg = {k: v for k, v in cfg.items() if k in _ALLOWED_KW}
    allowed = set(inspect.signature(ModelCls.__init__).parameters.keys()) - {"self"}
    cfg = {k: v for k, v in cfg.items() if k in allowed}
    return cfg

def _complete_and_filter_for_model(sd: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]:
    sd2 = dict(sd)
    msd = model.state_dict()
    for blk in ("L_mod.attn", "H_mod.attn"):
        ipw = f"{blk}.in_proj_weight"
        ipb = f"{blk}.in_proj_bias"
        if ipw in sd2 and ipb not in sd2 and ipb in msd:
            E = sd2[ipw].shape[1]
            sd2[ipb] = torch.zeros(3 * E, dtype=sd2[ipw].dtype)
        opw = f"{blk}.out_proj.weight"
        opb = f"{blk}.out_proj.bias"
        if opw in sd2 and opb not in sd2 and opb in msd:
            out_dim = msd[opb].shape[0]
            sd2[opb] = torch.zeros(out_dim, dtype=sd2[opw].dtype)
    # Drop unknown or mismatched-shape keys
    sd2 = {k: v for k, v in sd2.items() if (k in msd) and (tuple(v.shape) == tuple(msd[k].shape))}
    return sd2

# -------------- Tokenizer helpers --------------
def _load_local_tokenizer(tok_dir: str):
    tok = None
    try:
        from transformers import AutoTokenizer, PreTrainedTokenizerFast, GPT2TokenizerFast
        try:
            tok = AutoTokenizer.from_pretrained(tok_dir, local_files_only=True, use_fast=True, trust_remote_code=True)
            return tok
        except Exception as e:
            print(f"[hrm_loader] AutoTokenizer fallback: {e}")
        tj = os.path.join(tok_dir, "tokenizer.json")
        if tok is None and os.path.exists(tj):
            try:
                from tokenizers import Tokenizer
                core = Tokenizer.from_file(tj)
                spec_path = os.path.join(tok_dir, "special_tokens_map.json")
                spec = {}
                if os.path.exists(spec_path):
                    with open(spec_path, "r", encoding="utf-8") as f:
                        spec = json.load(f)
                tok = PreTrainedTokenizerFast(tokenizer_object=core, **{k:v for k,v in spec.items() if isinstance(v,str)})
                return tok
            except Exception as e:
                print(f"[hrm_loader] tokenizer.json fallback failed: {e}")
        vv = os.path.join(tok_dir, "vocab.json")
        mm = os.path.join(tok_dir, "merges.txt")
        if tok is None and os.path.exists(vv) and os.path.exists(mm):
            try:
                tok = GPT2TokenizerFast(vocab_file=vv, merges_file=mm)
                spec_path = os.path.join(tok_dir, "special_tokens_map.json")
                if os.path.exists(spec_path):
                    with open(spec_path, "r", encoding="utf-8") as f:
                        spec = json.load(f)
                    st = {k: spec[k] for k in ["bos_token","eos_token","unk_token","pad_token","sep_token","cls_token","mask_token"] if k in spec}
                    if st:
                        tok.add_special_tokens(st)
                return tok
            except Exception as e:
                print(f"[hrm_loader] GPT2TokenizerFast fallback failed: {e}")
    except Exception as e:
        print(f"[hrm_loader] transformers/tokenizers unavailable or failed: {e}")
    return tok

def _maybe_resize_embeddings_(model: nn.Module, vocab_size_new: int):
    vocab_size_old = model.tok_emb.num_embeddings
    if vocab_size_new == vocab_size_old:
        return
    device = next(model.parameters()).device
    dtype  = next(model.parameters()).dtype
    d_model = model.d_model
    old_w = model.tok_emb.weight.data.detach().to(device=device, dtype=dtype)
    new_emb = nn.Embedding(vocab_size_new, d_model, device=device, dtype=dtype)
    nn.init.normal_(new_emb.weight, mean=0.0, std=0.02)
    keep = min(vocab_size_old, vocab_size_new)
    new_emb.weight.data[:keep] = old_w[:keep]
    model.tok_emb = new_emb
    new_head = nn.Linear(d_model, vocab_size_new, bias=False, device=device, dtype=dtype)
    model.lm_head = new_head
    model.lm_head.weight = model.tok_emb.weight
    print(f"[hrm_loader] resized embeddings: {vocab_size_old} -> {vocab_size_new}")

def _vocab_from_sd(sd: Dict[str, torch.Tensor]) -> Optional[int]:
    te = sd.get("tok_emb.weight", None)
    if te is None:
        te = sd.get("lm_head.weight", None)
    return int(te.shape[0]) if te is not None else None

# -------------- Public loader --------------
def load_hrm(
    checkpoint_or_dir: str,
    device: Optional[str] = "auto",
    dtype: str = "auto",
    strict: bool = True,
    override_config: Optional[Dict[str, Any]] = None,
    ModelCls=None,
    with_tokenizer: bool = False,
    tokenizer_path: Optional[str] = None,
):
    if ModelCls is None:
        ModelCls = HRMForCausalLM

    ckpt = _find_checkpoint(checkpoint_or_dir)
    sd = _load_state_dict(ckpt)
    sd = _adapt_attention_keys(sd)

    # NEW: If lm_head.weight is absent but tok_emb.weight exists (tied-weights checkpoint),
    # mirror it to avoid "missing lm_head.weight" in load_state_dict.
    if "lm_head.weight" not in sd and "tok_emb.weight" in sd:
        sd["lm_head.weight"] = sd["tok_emb.weight"]

    cfg_dir = checkpoint_or_dir if os.path.isdir(checkpoint_or_dir) else os.path.dirname(ckpt)
    raw_cfg = _load_config_if_any(cfg_dir) or _infer_config_from_state(sd)
    if override_config:
        raw_cfg.update(override_config)
    cfg = _sanitize_and_map_config(raw_cfg, ModelCls)

    # Prefer checkpoint vocab_size to avoid size mismatches
    sd_vocab = _vocab_from_sd(sd)
    if sd_vocab is not None and (cfg.get("vocab_size") is None or cfg["vocab_size"] != sd_vocab):
        print(f"[hrm_loader] adjusting vocab_size config {cfg.get('vocab_size')} -> {sd_vocab} from checkpoint")
        cfg["vocab_size"] = sd_vocab

    dev = _resolve_device(device)
    dt  = _resolve_dtype(dtype)

    model = ModelCls(**cfg)
    sd = _complete_and_filter_for_model(sd, model)

    # Load weights (safe: shapes now match)
    ik = model.load_state_dict(sd, strict=False)
    missing = list(getattr(ik, "missing_keys", []))
    unexpected = list(getattr(ik, "unexpected_keys", []))
    if missing or unexpected:
        print(f"[hrm_loader] load_state_dict: missing={len(missing)} unexpected={len(unexpected)}")
        if missing:   print("  missing (sample):", missing[:8])
        if unexpected:print("  unexpected (sample):", unexpected[:8])
        if strict:
            raise RuntimeError(
                "Strict load requested but state_dict mismatch remains.\n"
                f"Missing (n={len(missing)}): {missing[:12]}\n"
                f"Unexpected (n={len(unexpected)}): {unexpected[:12]}"
            )

    model.to(dev)
    if dt != torch.float32:
        model.to(dtype=dt)  # parameters + buffers

    try:
        if hasattr(model, "lm_head") and hasattr(model, "tok_emb") and model.lm_head.weight is not model.tok_emb.weight:
            model.lm_head.weight = model.tok_emb.weight
    except Exception:
        pass

    model.eval()

    tokenizer = None
    if with_tokenizer:
        tdir = tokenizer_path or cfg_dir
        tokenizer = _load_local_tokenizer(tdir)
        if tokenizer is None:
            print(f"[hrm_loader] WARNING: could not load tokenizer from {tdir}")
        else:
            try:
                _maybe_resize_embeddings_(model, len(tokenizer))
            except Exception as e:
                print(f"[hrm_loader] embedding resize check failed: {e}")

    return (model, tokenizer) if with_tokenizer else model

__all__ = ["HRMForCausalLM", "load_hrm"]