# hrm_utils.py — Minimal, robust HRM loader + tokenizer support # -------------------------------------------------------------- # - Handles .pt/.bin/.safetensors (single file or HF sharded index) # - Adapts q/k/v names to torch.nn.MultiheadAttention format # - Infers config if config.json is missing # - Prefers checkpoint vocab_size over config to avoid shape mismatches # - Optional tokenizer load (local files) + embedding resize + weight tying # - Returns (model, tokenizer) when with_tokenizer=True (else just model) import os, json, glob, math, inspect from typing import Optional, Dict, Any import torch import torch.nn as nn import torch.nn.functional as F # ---------------- Blocks ---------------- class RMSNorm(nn.Module): def __init__(self, d, eps=1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(d)) def forward(self, x): return self.weight * (x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)) class SinusoidalPositionalEmbedding(nn.Module): def __init__(self, d_model, max_len=8192): super().__init__() pe = torch.zeros(max_len, d_model) pos = torch.arange(0, max_len).unsqueeze(1) div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(pos * div) pe[:, 1::2] = torch.cos(pos * div) self.register_buffer("pe", pe, persistent=False) def forward(self, L: int): return self.pe[:L].unsqueeze(0) class SwiGLU(nn.Module): def __init__(self, d_model, d_ff, pdrop=0.1): super().__init__() self.w1 = nn.Linear(d_model, d_ff, bias=False) self.w2 = nn.Linear(d_model, d_ff, bias=False) self.w3 = nn.Linear(d_ff, d_model, bias=False) self.drop = nn.Dropout(pdrop) def forward(self, x): return self.drop(self.w3(F.silu(self.w1(x)) * self.w2(x))) class AttnBlock(nn.Module): def __init__(self, d_model, n_heads, d_ff, pdrop=0.1): super().__init__() self.norm1 = RMSNorm(d_model) self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=pdrop, batch_first=True) self.drop = nn.Dropout(pdrop) self.norm2 = RMSNorm(d_model) self.mlp = SwiGLU(d_model, d_ff, pdrop) def forward(self, x, attn_mask=None, key_padding_mask=None): if attn_mask is not None: assert attn_mask.dtype == torch.bool and attn_mask.dim() == 2 if key_padding_mask is not None: assert key_padding_mask.dtype == torch.bool and key_padding_mask.dim() == 2 h = self.norm1(x) a, _ = self.attn(h, h, h, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False) x = x + self.drop(a) x = x + self.drop(self.mlp(self.norm2(x))) return x # ---------------- Model ---------------- class HRMForCausalLM(nn.Module): def __init__(self, vocab_size: int, d_model=512, n_heads=8, d_ff=2048, dropout=0.1, k_l_steps=4, max_cycles=8, ponder_loss_weight=1e-2): super().__init__() assert d_model % n_heads == 0, "d_model must be divisible by n_heads" self.vocab_size = vocab_size self.d_model = d_model self.k_l_steps = k_l_steps self.max_cycles = max_cycles self.ponder_w = ponder_loss_weight self.tok_emb = nn.Embedding(vocab_size, d_model) self.pos_emb = SinusoidalPositionalEmbedding(d_model, max_len=8192) self.in_net = nn.Sequential(nn.Linear(d_model, d_model), nn.GELU(), RMSNorm(d_model)) self.L_mod = AttnBlock(d_model, n_heads, d_ff, dropout) self.H_mod = AttnBlock(d_model, n_heads, d_ff, dropout) self.halt_head = nn.Linear(d_model, 1) nn.init.constant_(self.halt_head.bias, -1.5) self.out_norm = RMSNorm(d_model) self.lm_head = nn.Linear(d_model, vocab_size, bias=False) self.lm_head.weight = self.tok_emb.weight # tie self._cached_causal_bool = {} self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Linear, nn.Embedding)): nn.init.normal_(m.weight, mean=0.0, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.zeros_(m.bias) def _causal_bool_mask(self, L: int, device): k = (L, device) if k not in self._cached_causal_bool: self._cached_causal_bool[k] = torch.triu(torch.ones(L, L, dtype=torch.bool, device=device), 1) return self._cached_causal_bool[k] def forward(self, input_ids, attention_mask=None, labels=None): B, L = input_ids.shape device = input_ids.device x_tok = self.tok_emb(input_ids) pos = self.pos_emb(L).to(device=device, dtype=x_tok.dtype) # keep dtype aligned x = self.in_net(x_tok + pos) causal_bool = self._causal_bool_mask(L, device) key_padding_mask = (attention_mask == 0) if attention_mask is not None else None z_L = x.clone() z_H = torch.zeros_like(x) eps = 1e-6 rema = torch.ones((B, L), device=device, dtype=x_tok.dtype) collected_H = torch.zeros_like(z_H) ponder_terms = [] for c in range(self.max_cycles): for _ in range(self.k_l_steps): z_L = self.L_mod(z_L + z_H + x, attn_mask=causal_bool, key_padding_mask=key_padding_mask) z_H = self.H_mod(z_H + z_L, attn_mask=causal_bool, key_padding_mask=key_padding_mask) p_halt = torch.sigmoid(self.halt_head(z_H)).squeeze(-1).clamp(eps, 1 - eps) last = torch.full_like(p_halt, fill_value=(c == self.max_cycles - 1), dtype=torch.bool) halt_p = torch.where(last, torch.ones_like(p_halt), p_halt) contrib = (rema * halt_p).unsqueeze(-1) collected_H = collected_H + contrib * z_H ponder_terms.append(rema * halt_p) rema = rema * (1.0 - halt_p) if torch.all(rema < 1e-4): break collected_H = self.out_norm(collected_H) logits = self.lm_head(collected_H) loss = lm_loss = ponder = None if labels is not None: sl = logits[:, :-1, :].contiguous() y = labels[:, 1:].contiguous() B_, Lm1, V = sl.shape lm_loss = F.cross_entropy(sl.float().view(B_ * Lm1, V), y.view(B_ * Lm1)) ponder = torch.stack(ponder_terms, dim=-1).sum(dim=-1).mean() loss = lm_loss + self.ponder_w * ponder return {"loss": loss, "logits": logits, "lm_loss": lm_loss, "ponder_loss": ponder} # ---- HF-style hooks ---- def get_input_embeddings(self): return self.tok_emb def set_input_embeddings(self, new_emb): self.tok_emb = new_emb if hasattr(self, "lm_head"): self.lm_head.weight = self.tok_emb.weight def tie_weights(self): if hasattr(self, "lm_head") and hasattr(self, "tok_emb"): self.lm_head.weight = self.tok_emb.weight # -------------- Loader helpers -------------- def _resolve_device(device: Optional[str]) -> torch.device: if device is None or device == "auto": if torch.cuda.is_available(): return torch.device("cuda") if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") return torch.device(device) def _resolve_dtype(dtype: str) -> torch.dtype: d = str(dtype).lower() if d in ("fp32","float32","f32"): return torch.float32 if d in ("bf16","bfloat16"): return torch.bfloat16 if d in ("fp16","float16","half"):return torch.float16 if d == "auto": if torch.cuda.is_available() and getattr(torch.cuda, "is_bf16_supported", lambda: False)(): return torch.bfloat16 return torch.float32 raise ValueError(f"Unknown dtype {dtype}") def _find_checkpoint(path_or_dir: str) -> str: if os.path.isfile(path_or_dir): return path_or_dir if not os.path.isdir(path_or_dir): raise FileNotFoundError(f"Not a file or directory: {path_or_dir}") st = glob.glob(os.path.join(path_or_dir, "*.safetensors")) if len(st) == 1: return st[0] if len(st) > 1: for cand in ("model.safetensors","pytorch_model.safetensors"): p = os.path.join(path_or_dir, cand) if os.path.exists(p): return p return sorted(st)[0] for idx in ("model.safetensors.index.json","pytorch_model.bin.index.json"): p = os.path.join(path_or_dir, idx) if os.path.exists(p): return p for cand in ("pytorch_model.bin","model.bin","model.pt"): p = os.path.join(path_or_dir, cand) if os.path.exists(p): return p pt = glob.glob(os.path.join(path_or_dir, "*.pt")) + glob.glob(os.path.join(path_or_dir, "*.bin")) if pt: return sorted(pt)[0] raise FileNotFoundError(f"No checkpoint found in {path_or_dir}") def _torch_load(path: str): try: return torch.load(path, map_location="cpu", weights_only=True) except TypeError: return torch.load(path, map_location="cpu") def _normalize_keys(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: def strip(k: str) -> str: for pref in ("module.","model.","transformer."): if k.startswith(pref): return k[len(pref):] return k return {strip(k): v for k, v in sd.items()} def _adapt_attention_keys(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: sd = dict(sd) def handle(prefix: str): qkv_w = sd.pop(f"{prefix}.qkv.weight", None) if qkv_w is not None: sd[f"{prefix}.in_proj_weight"] = qkv_w qkv_b = sd.pop(f"{prefix}.qkv.bias", None) if qkv_b is not None: sd[f"{prefix}.in_proj_bias"] = qkv_b q_w = sd.pop(f"{prefix}.q_proj.weight", None) k_w = sd.pop(f"{prefix}.k_proj.weight", None) v_w = sd.pop(f"{prefix}.v_proj.weight", None) if q_w is not None and k_w is not None and v_w is not None: sd[f"{prefix}.in_proj_weight"] = torch.cat([q_w, k_w, v_w], dim=0) q_b = sd.pop(f"{prefix}.q_proj.bias", None) k_b = sd.pop(f"{prefix}.k_proj.bias", None) v_b = sd.pop(f"{prefix}.v_proj.bias", None) if q_b is not None and k_b is not None and v_b is not None: sd[f"{prefix}.in_proj_bias"] = torch.cat([q_b, k_b, v_b], dim=0) o_w = sd.pop(f"{prefix}.o.weight", None) if o_w is not None: sd[f"{prefix}.out_proj.weight"] = o_w o_b = sd.pop(f"{prefix}.o.bias", None) if o_b is not None: sd[f"{prefix}.out_proj.bias"] = o_b if f"{prefix}.in_proj_weight" in sd and f"{prefix}.in_proj_bias" not in sd: E = sd[f"{prefix}.in_proj_weight"].shape[1] sd[f"{prefix}.in_proj_bias"] = torch.zeros(3 * E, dtype=sd[f"{prefix}.in_proj_weight"].dtype) for blk in ("L_mod.attn", "H_mod.attn"): handle(blk) return sd def _load_state_dict(ckpt_path: str) -> Dict[str, torch.Tensor]: if ckpt_path.endswith(".safetensors"): from safetensors.torch import load_file as safe_load return _normalize_keys(safe_load(ckpt_path, device="cpu")) if ckpt_path.endswith("model.safetensors.index.json"): base = os.path.dirname(ckpt_path) with open(ckpt_path, "r", encoding="utf-8") as f: idx = json.load(f) from safetensors import safe_open state = {} for shard in sorted(set(idx.get("weight_map", {}).values())): with safe_open(os.path.join(base, shard), framework="pt", device="cpu") as sf: for k in sf.keys(): state[k] = sf.get_tensor(k) return _normalize_keys(state) if ckpt_path.endswith("pytorch_model.bin.index.json"): base = os.path.dirname(ckpt_path) with open(ckpt_path, "r", encoding="utf-8") as f: idx = json.load(f) state = {} for shard in sorted(set(idx.get("weight_map", {}).values())): part = _torch_load(os.path.join(base, shard)) if isinstance(part, dict) and "state_dict" in part: part = part["state_dict"] state.update(part) return _normalize_keys(state) if ckpt_path.endswith((".pt",".bin")): obj = _torch_load(ckpt_path) if isinstance(obj, dict) and "state_dict" in obj: obj = obj["state_dict"] return _normalize_keys(obj) if ckpt_path.endswith(".json"): raise ValueError("Pass the directory, not the index/config JSON.") raise ValueError(f"Unsupported checkpoint type: {ckpt_path}") def _load_config_if_any(path_or_dir: str) -> Optional[Dict[str, Any]]: p = path_or_dir if path_or_dir.endswith(".json") else os.path.join(path_or_dir, "config.json") if os.path.exists(p): with open(p, "r", encoding="utf-8") as f: return json.load(f) return None def _infer_config_from_state(sd: Dict[str, torch.Tensor]) -> Dict[str, Any]: te = sd.get("tok_emb.weight", None) if te is None: te = sd.get("lm_head.weight", None) if te is None: raise ValueError("Cannot infer config: missing tok_emb.weight (or lm_head.weight).") vocab_size, d_model = te.shape w1 = sd.get("L_mod.mlp.w1.weight", None) if w1 is None: w1 = sd.get("H_mod.mlp.w1.weight", None) d_ff = int(w1.shape[0]) if w1 is not None else int(4 * d_model) return dict(vocab_size=int(vocab_size), d_model=int(d_model), n_heads=8, d_ff=int(d_ff), dropout=0.1, k_l_steps=4, max_cycles=8, ponder_loss_weight=1e-2) _ALLOWED_KW = {"vocab_size","d_model","n_heads","d_ff","dropout","k_l_steps","max_cycles","ponder_loss_weight"} _DROP_KEYS = {"weight_tying","tie_word_embeddings","torch_dtype","architectures","model_type", "initializer_range","layer_norm_eps","max_position_embeddings","use_cache"} def _sanitize_and_map_config(raw_cfg: Dict[str, Any], ModelCls): cfg = dict(raw_cfg) if raw_cfg else {} for src, dst in {"hidden_size":"d_model","num_attention_heads":"n_heads","intermediate_size":"d_ff"}.items(): if src in cfg and dst not in cfg: cfg[dst] = cfg[src] if "vocab_size" not in cfg and raw_cfg and "vocab_size" in raw_cfg: cfg["vocab_size"] = raw_cfg["vocab_size"] for k in list(cfg.keys()): if k in _DROP_KEYS: cfg.pop(k, None) cfg = {k: v for k, v in cfg.items() if k in _ALLOWED_KW} allowed = set(inspect.signature(ModelCls.__init__).parameters.keys()) - {"self"} cfg = {k: v for k, v in cfg.items() if k in allowed} return cfg def _complete_and_filter_for_model(sd: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]: sd2 = dict(sd) msd = model.state_dict() for blk in ("L_mod.attn", "H_mod.attn"): ipw = f"{blk}.in_proj_weight" ipb = f"{blk}.in_proj_bias" if ipw in sd2 and ipb not in sd2 and ipb in msd: E = sd2[ipw].shape[1] sd2[ipb] = torch.zeros(3 * E, dtype=sd2[ipw].dtype) opw = f"{blk}.out_proj.weight" opb = f"{blk}.out_proj.bias" if opw in sd2 and opb not in sd2 and opb in msd: out_dim = msd[opb].shape[0] sd2[opb] = torch.zeros(out_dim, dtype=sd2[opw].dtype) # Drop unknown or mismatched-shape keys sd2 = {k: v for k, v in sd2.items() if (k in msd) and (tuple(v.shape) == tuple(msd[k].shape))} return sd2 # -------------- Tokenizer helpers -------------- def _load_local_tokenizer(tok_dir: str): tok = None try: from transformers import AutoTokenizer, PreTrainedTokenizerFast, GPT2TokenizerFast try: tok = AutoTokenizer.from_pretrained(tok_dir, local_files_only=True, use_fast=True, trust_remote_code=True) return tok except Exception as e: print(f"[hrm_loader] AutoTokenizer fallback: {e}") tj = os.path.join(tok_dir, "tokenizer.json") if tok is None and os.path.exists(tj): try: from tokenizers import Tokenizer core = Tokenizer.from_file(tj) spec_path = os.path.join(tok_dir, "special_tokens_map.json") spec = {} if os.path.exists(spec_path): with open(spec_path, "r", encoding="utf-8") as f: spec = json.load(f) tok = PreTrainedTokenizerFast(tokenizer_object=core, **{k:v for k,v in spec.items() if isinstance(v,str)}) return tok except Exception as e: print(f"[hrm_loader] tokenizer.json fallback failed: {e}") vv = os.path.join(tok_dir, "vocab.json") mm = os.path.join(tok_dir, "merges.txt") if tok is None and os.path.exists(vv) and os.path.exists(mm): try: tok = GPT2TokenizerFast(vocab_file=vv, merges_file=mm) spec_path = os.path.join(tok_dir, "special_tokens_map.json") if os.path.exists(spec_path): with open(spec_path, "r", encoding="utf-8") as f: spec = json.load(f) st = {k: spec[k] for k in ["bos_token","eos_token","unk_token","pad_token","sep_token","cls_token","mask_token"] if k in spec} if st: tok.add_special_tokens(st) return tok except Exception as e: print(f"[hrm_loader] GPT2TokenizerFast fallback failed: {e}") except Exception as e: print(f"[hrm_loader] transformers/tokenizers unavailable or failed: {e}") return tok def _maybe_resize_embeddings_(model: nn.Module, vocab_size_new: int): vocab_size_old = model.tok_emb.num_embeddings if vocab_size_new == vocab_size_old: return device = next(model.parameters()).device dtype = next(model.parameters()).dtype d_model = model.d_model old_w = model.tok_emb.weight.data.detach().to(device=device, dtype=dtype) new_emb = nn.Embedding(vocab_size_new, d_model, device=device, dtype=dtype) nn.init.normal_(new_emb.weight, mean=0.0, std=0.02) keep = min(vocab_size_old, vocab_size_new) new_emb.weight.data[:keep] = old_w[:keep] model.tok_emb = new_emb new_head = nn.Linear(d_model, vocab_size_new, bias=False, device=device, dtype=dtype) model.lm_head = new_head model.lm_head.weight = model.tok_emb.weight print(f"[hrm_loader] resized embeddings: {vocab_size_old} -> {vocab_size_new}") def _vocab_from_sd(sd: Dict[str, torch.Tensor]) -> Optional[int]: te = sd.get("tok_emb.weight", None) if te is None: te = sd.get("lm_head.weight", None) return int(te.shape[0]) if te is not None else None # -------------- Public loader -------------- def load_hrm( checkpoint_or_dir: str, device: Optional[str] = "auto", dtype: str = "auto", strict: bool = True, override_config: Optional[Dict[str, Any]] = None, ModelCls=None, with_tokenizer: bool = False, tokenizer_path: Optional[str] = None, ): if ModelCls is None: ModelCls = HRMForCausalLM ckpt = _find_checkpoint(checkpoint_or_dir) sd = _load_state_dict(ckpt) sd = _adapt_attention_keys(sd) # NEW: If lm_head.weight is absent but tok_emb.weight exists (tied-weights checkpoint), # mirror it to avoid "missing lm_head.weight" in load_state_dict. if "lm_head.weight" not in sd and "tok_emb.weight" in sd: sd["lm_head.weight"] = sd["tok_emb.weight"] cfg_dir = checkpoint_or_dir if os.path.isdir(checkpoint_or_dir) else os.path.dirname(ckpt) raw_cfg = _load_config_if_any(cfg_dir) or _infer_config_from_state(sd) if override_config: raw_cfg.update(override_config) cfg = _sanitize_and_map_config(raw_cfg, ModelCls) # Prefer checkpoint vocab_size to avoid size mismatches sd_vocab = _vocab_from_sd(sd) if sd_vocab is not None and (cfg.get("vocab_size") is None or cfg["vocab_size"] != sd_vocab): print(f"[hrm_loader] adjusting vocab_size config {cfg.get('vocab_size')} -> {sd_vocab} from checkpoint") cfg["vocab_size"] = sd_vocab dev = _resolve_device(device) dt = _resolve_dtype(dtype) model = ModelCls(**cfg) sd = _complete_and_filter_for_model(sd, model) # Load weights (safe: shapes now match) ik = model.load_state_dict(sd, strict=False) missing = list(getattr(ik, "missing_keys", [])) unexpected = list(getattr(ik, "unexpected_keys", [])) if missing or unexpected: print(f"[hrm_loader] load_state_dict: missing={len(missing)} unexpected={len(unexpected)}") if missing: print(" missing (sample):", missing[:8]) if unexpected:print(" unexpected (sample):", unexpected[:8]) if strict: raise RuntimeError( "Strict load requested but state_dict mismatch remains.\n" f"Missing (n={len(missing)}): {missing[:12]}\n" f"Unexpected (n={len(unexpected)}): {unexpected[:12]}" ) model.to(dev) if dt != torch.float32: model.to(dtype=dt) # parameters + buffers try: if hasattr(model, "lm_head") and hasattr(model, "tok_emb") and model.lm_head.weight is not model.tok_emb.weight: model.lm_head.weight = model.tok_emb.weight except Exception: pass model.eval() tokenizer = None if with_tokenizer: tdir = tokenizer_path or cfg_dir tokenizer = _load_local_tokenizer(tdir) if tokenizer is None: print(f"[hrm_loader] WARNING: could not load tokenizer from {tdir}") else: try: _maybe_resize_embeddings_(model, len(tokenizer)) except Exception as e: print(f"[hrm_loader] embedding resize check failed: {e}") return (model, tokenizer) if with_tokenizer else model __all__ = ["HRMForCausalLM", "load_hrm"]