Quark_utils / hrm_utils.py
abanm's picture
Upload 4 files
cfef4e2 verified
# hrm_utils.py — Minimal, robust HRM loader + tokenizer support
# --------------------------------------------------------------
# - Handles .pt/.bin/.safetensors (single file or HF sharded index)
# - Adapts q/k/v names to torch.nn.MultiheadAttention format
# - Infers config if config.json is missing
# - Prefers checkpoint vocab_size over config to avoid shape mismatches
# - Optional tokenizer load (local files) + embedding resize + weight tying
# - Returns (model, tokenizer) when with_tokenizer=True (else just model)
import os, json, glob, math, inspect
from typing import Optional, Dict, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
# ---------------- Blocks ----------------
class RMSNorm(nn.Module):
def __init__(self, d, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d))
def forward(self, x):
return self.weight * (x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps))
class SinusoidalPositionalEmbedding(nn.Module):
def __init__(self, d_model, max_len=8192):
super().__init__()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len).unsqueeze(1)
div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer("pe", pe, persistent=False)
def forward(self, L: int):
return self.pe[:L].unsqueeze(0)
class SwiGLU(nn.Module):
def __init__(self, d_model, d_ff, pdrop=0.1):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w2 = nn.Linear(d_model, d_ff, bias=False)
self.w3 = nn.Linear(d_ff, d_model, bias=False)
self.drop = nn.Dropout(pdrop)
def forward(self, x):
return self.drop(self.w3(F.silu(self.w1(x)) * self.w2(x)))
class AttnBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, pdrop=0.1):
super().__init__()
self.norm1 = RMSNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=pdrop, batch_first=True)
self.drop = nn.Dropout(pdrop)
self.norm2 = RMSNorm(d_model)
self.mlp = SwiGLU(d_model, d_ff, pdrop)
def forward(self, x, attn_mask=None, key_padding_mask=None):
if attn_mask is not None:
assert attn_mask.dtype == torch.bool and attn_mask.dim() == 2
if key_padding_mask is not None:
assert key_padding_mask.dtype == torch.bool and key_padding_mask.dim() == 2
h = self.norm1(x)
a, _ = self.attn(h, h, h, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)
x = x + self.drop(a)
x = x + self.drop(self.mlp(self.norm2(x)))
return x
# ---------------- Model ----------------
class HRMForCausalLM(nn.Module):
def __init__(self, vocab_size: int, d_model=512, n_heads=8, d_ff=2048, dropout=0.1,
k_l_steps=4, max_cycles=8, ponder_loss_weight=1e-2):
super().__init__()
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
self.vocab_size = vocab_size
self.d_model = d_model
self.k_l_steps = k_l_steps
self.max_cycles = max_cycles
self.ponder_w = ponder_loss_weight
self.tok_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = SinusoidalPositionalEmbedding(d_model, max_len=8192)
self.in_net = nn.Sequential(nn.Linear(d_model, d_model), nn.GELU(), RMSNorm(d_model))
self.L_mod = AttnBlock(d_model, n_heads, d_ff, dropout)
self.H_mod = AttnBlock(d_model, n_heads, d_ff, dropout)
self.halt_head = nn.Linear(d_model, 1)
nn.init.constant_(self.halt_head.bias, -1.5)
self.out_norm = RMSNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
self.lm_head.weight = self.tok_emb.weight # tie
self._cached_causal_bool = {}
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Linear, nn.Embedding)):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.zeros_(m.bias)
def _causal_bool_mask(self, L: int, device):
k = (L, device)
if k not in self._cached_causal_bool:
self._cached_causal_bool[k] = torch.triu(torch.ones(L, L, dtype=torch.bool, device=device), 1)
return self._cached_causal_bool[k]
def forward(self, input_ids, attention_mask=None, labels=None):
B, L = input_ids.shape
device = input_ids.device
x_tok = self.tok_emb(input_ids)
pos = self.pos_emb(L).to(device=device, dtype=x_tok.dtype) # keep dtype aligned
x = self.in_net(x_tok + pos)
causal_bool = self._causal_bool_mask(L, device)
key_padding_mask = (attention_mask == 0) if attention_mask is not None else None
z_L = x.clone()
z_H = torch.zeros_like(x)
eps = 1e-6
rema = torch.ones((B, L), device=device, dtype=x_tok.dtype)
collected_H = torch.zeros_like(z_H)
ponder_terms = []
for c in range(self.max_cycles):
for _ in range(self.k_l_steps):
z_L = self.L_mod(z_L + z_H + x, attn_mask=causal_bool, key_padding_mask=key_padding_mask)
z_H = self.H_mod(z_H + z_L, attn_mask=causal_bool, key_padding_mask=key_padding_mask)
p_halt = torch.sigmoid(self.halt_head(z_H)).squeeze(-1).clamp(eps, 1 - eps)
last = torch.full_like(p_halt, fill_value=(c == self.max_cycles - 1), dtype=torch.bool)
halt_p = torch.where(last, torch.ones_like(p_halt), p_halt)
contrib = (rema * halt_p).unsqueeze(-1)
collected_H = collected_H + contrib * z_H
ponder_terms.append(rema * halt_p)
rema = rema * (1.0 - halt_p)
if torch.all(rema < 1e-4):
break
collected_H = self.out_norm(collected_H)
logits = self.lm_head(collected_H)
loss = lm_loss = ponder = None
if labels is not None:
sl = logits[:, :-1, :].contiguous()
y = labels[:, 1:].contiguous()
B_, Lm1, V = sl.shape
lm_loss = F.cross_entropy(sl.float().view(B_ * Lm1, V), y.view(B_ * Lm1))
ponder = torch.stack(ponder_terms, dim=-1).sum(dim=-1).mean()
loss = lm_loss + self.ponder_w * ponder
return {"loss": loss, "logits": logits, "lm_loss": lm_loss, "ponder_loss": ponder}
# ---- HF-style hooks ----
def get_input_embeddings(self):
return self.tok_emb
def set_input_embeddings(self, new_emb):
self.tok_emb = new_emb
if hasattr(self, "lm_head"):
self.lm_head.weight = self.tok_emb.weight
def tie_weights(self):
if hasattr(self, "lm_head") and hasattr(self, "tok_emb"):
self.lm_head.weight = self.tok_emb.weight
# -------------- Loader helpers --------------
def _resolve_device(device: Optional[str]) -> torch.device:
if device is None or device == "auto":
if torch.cuda.is_available(): return torch.device("cuda")
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): return torch.device("mps")
return torch.device("cpu")
return torch.device(device)
def _resolve_dtype(dtype: str) -> torch.dtype:
d = str(dtype).lower()
if d in ("fp32","float32","f32"): return torch.float32
if d in ("bf16","bfloat16"): return torch.bfloat16
if d in ("fp16","float16","half"):return torch.float16
if d == "auto":
if torch.cuda.is_available() and getattr(torch.cuda, "is_bf16_supported", lambda: False)(): return torch.bfloat16
return torch.float32
raise ValueError(f"Unknown dtype {dtype}")
def _find_checkpoint(path_or_dir: str) -> str:
if os.path.isfile(path_or_dir): return path_or_dir
if not os.path.isdir(path_or_dir): raise FileNotFoundError(f"Not a file or directory: {path_or_dir}")
st = glob.glob(os.path.join(path_or_dir, "*.safetensors"))
if len(st) == 1: return st[0]
if len(st) > 1:
for cand in ("model.safetensors","pytorch_model.safetensors"):
p = os.path.join(path_or_dir, cand)
if os.path.exists(p): return p
return sorted(st)[0]
for idx in ("model.safetensors.index.json","pytorch_model.bin.index.json"):
p = os.path.join(path_or_dir, idx)
if os.path.exists(p): return p
for cand in ("pytorch_model.bin","model.bin","model.pt"):
p = os.path.join(path_or_dir, cand)
if os.path.exists(p): return p
pt = glob.glob(os.path.join(path_or_dir, "*.pt")) + glob.glob(os.path.join(path_or_dir, "*.bin"))
if pt: return sorted(pt)[0]
raise FileNotFoundError(f"No checkpoint found in {path_or_dir}")
def _torch_load(path: str):
try:
return torch.load(path, map_location="cpu", weights_only=True)
except TypeError:
return torch.load(path, map_location="cpu")
def _normalize_keys(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
def strip(k: str) -> str:
for pref in ("module.","model.","transformer."):
if k.startswith(pref): return k[len(pref):]
return k
return {strip(k): v for k, v in sd.items()}
def _adapt_attention_keys(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
sd = dict(sd)
def handle(prefix: str):
qkv_w = sd.pop(f"{prefix}.qkv.weight", None)
if qkv_w is not None:
sd[f"{prefix}.in_proj_weight"] = qkv_w
qkv_b = sd.pop(f"{prefix}.qkv.bias", None)
if qkv_b is not None:
sd[f"{prefix}.in_proj_bias"] = qkv_b
q_w = sd.pop(f"{prefix}.q_proj.weight", None)
k_w = sd.pop(f"{prefix}.k_proj.weight", None)
v_w = sd.pop(f"{prefix}.v_proj.weight", None)
if q_w is not None and k_w is not None and v_w is not None:
sd[f"{prefix}.in_proj_weight"] = torch.cat([q_w, k_w, v_w], dim=0)
q_b = sd.pop(f"{prefix}.q_proj.bias", None)
k_b = sd.pop(f"{prefix}.k_proj.bias", None)
v_b = sd.pop(f"{prefix}.v_proj.bias", None)
if q_b is not None and k_b is not None and v_b is not None:
sd[f"{prefix}.in_proj_bias"] = torch.cat([q_b, k_b, v_b], dim=0)
o_w = sd.pop(f"{prefix}.o.weight", None)
if o_w is not None:
sd[f"{prefix}.out_proj.weight"] = o_w
o_b = sd.pop(f"{prefix}.o.bias", None)
if o_b is not None:
sd[f"{prefix}.out_proj.bias"] = o_b
if f"{prefix}.in_proj_weight" in sd and f"{prefix}.in_proj_bias" not in sd:
E = sd[f"{prefix}.in_proj_weight"].shape[1]
sd[f"{prefix}.in_proj_bias"] = torch.zeros(3 * E, dtype=sd[f"{prefix}.in_proj_weight"].dtype)
for blk in ("L_mod.attn", "H_mod.attn"):
handle(blk)
return sd
def _load_state_dict(ckpt_path: str) -> Dict[str, torch.Tensor]:
if ckpt_path.endswith(".safetensors"):
from safetensors.torch import load_file as safe_load
return _normalize_keys(safe_load(ckpt_path, device="cpu"))
if ckpt_path.endswith("model.safetensors.index.json"):
base = os.path.dirname(ckpt_path)
with open(ckpt_path, "r", encoding="utf-8") as f:
idx = json.load(f)
from safetensors import safe_open
state = {}
for shard in sorted(set(idx.get("weight_map", {}).values())):
with safe_open(os.path.join(base, shard), framework="pt", device="cpu") as sf:
for k in sf.keys():
state[k] = sf.get_tensor(k)
return _normalize_keys(state)
if ckpt_path.endswith("pytorch_model.bin.index.json"):
base = os.path.dirname(ckpt_path)
with open(ckpt_path, "r", encoding="utf-8") as f:
idx = json.load(f)
state = {}
for shard in sorted(set(idx.get("weight_map", {}).values())):
part = _torch_load(os.path.join(base, shard))
if isinstance(part, dict) and "state_dict" in part:
part = part["state_dict"]
state.update(part)
return _normalize_keys(state)
if ckpt_path.endswith((".pt",".bin")):
obj = _torch_load(ckpt_path)
if isinstance(obj, dict) and "state_dict" in obj:
obj = obj["state_dict"]
return _normalize_keys(obj)
if ckpt_path.endswith(".json"):
raise ValueError("Pass the directory, not the index/config JSON.")
raise ValueError(f"Unsupported checkpoint type: {ckpt_path}")
def _load_config_if_any(path_or_dir: str) -> Optional[Dict[str, Any]]:
p = path_or_dir if path_or_dir.endswith(".json") else os.path.join(path_or_dir, "config.json")
if os.path.exists(p):
with open(p, "r", encoding="utf-8") as f:
return json.load(f)
return None
def _infer_config_from_state(sd: Dict[str, torch.Tensor]) -> Dict[str, Any]:
te = sd.get("tok_emb.weight", None)
if te is None:
te = sd.get("lm_head.weight", None)
if te is None:
raise ValueError("Cannot infer config: missing tok_emb.weight (or lm_head.weight).")
vocab_size, d_model = te.shape
w1 = sd.get("L_mod.mlp.w1.weight", None)
if w1 is None:
w1 = sd.get("H_mod.mlp.w1.weight", None)
d_ff = int(w1.shape[0]) if w1 is not None else int(4 * d_model)
return dict(vocab_size=int(vocab_size), d_model=int(d_model), n_heads=8, d_ff=int(d_ff),
dropout=0.1, k_l_steps=4, max_cycles=8, ponder_loss_weight=1e-2)
_ALLOWED_KW = {"vocab_size","d_model","n_heads","d_ff","dropout","k_l_steps","max_cycles","ponder_loss_weight"}
_DROP_KEYS = {"weight_tying","tie_word_embeddings","torch_dtype","architectures","model_type",
"initializer_range","layer_norm_eps","max_position_embeddings","use_cache"}
def _sanitize_and_map_config(raw_cfg: Dict[str, Any], ModelCls):
cfg = dict(raw_cfg) if raw_cfg else {}
for src, dst in {"hidden_size":"d_model","num_attention_heads":"n_heads","intermediate_size":"d_ff"}.items():
if src in cfg and dst not in cfg:
cfg[dst] = cfg[src]
if "vocab_size" not in cfg and raw_cfg and "vocab_size" in raw_cfg:
cfg["vocab_size"] = raw_cfg["vocab_size"]
for k in list(cfg.keys()):
if k in _DROP_KEYS:
cfg.pop(k, None)
cfg = {k: v for k, v in cfg.items() if k in _ALLOWED_KW}
allowed = set(inspect.signature(ModelCls.__init__).parameters.keys()) - {"self"}
cfg = {k: v for k, v in cfg.items() if k in allowed}
return cfg
def _complete_and_filter_for_model(sd: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]:
sd2 = dict(sd)
msd = model.state_dict()
for blk in ("L_mod.attn", "H_mod.attn"):
ipw = f"{blk}.in_proj_weight"
ipb = f"{blk}.in_proj_bias"
if ipw in sd2 and ipb not in sd2 and ipb in msd:
E = sd2[ipw].shape[1]
sd2[ipb] = torch.zeros(3 * E, dtype=sd2[ipw].dtype)
opw = f"{blk}.out_proj.weight"
opb = f"{blk}.out_proj.bias"
if opw in sd2 and opb not in sd2 and opb in msd:
out_dim = msd[opb].shape[0]
sd2[opb] = torch.zeros(out_dim, dtype=sd2[opw].dtype)
# Drop unknown or mismatched-shape keys
sd2 = {k: v for k, v in sd2.items() if (k in msd) and (tuple(v.shape) == tuple(msd[k].shape))}
return sd2
# -------------- Tokenizer helpers --------------
def _load_local_tokenizer(tok_dir: str):
tok = None
try:
from transformers import AutoTokenizer, PreTrainedTokenizerFast, GPT2TokenizerFast
try:
tok = AutoTokenizer.from_pretrained(tok_dir, local_files_only=True, use_fast=True, trust_remote_code=True)
return tok
except Exception as e:
print(f"[hrm_loader] AutoTokenizer fallback: {e}")
tj = os.path.join(tok_dir, "tokenizer.json")
if tok is None and os.path.exists(tj):
try:
from tokenizers import Tokenizer
core = Tokenizer.from_file(tj)
spec_path = os.path.join(tok_dir, "special_tokens_map.json")
spec = {}
if os.path.exists(spec_path):
with open(spec_path, "r", encoding="utf-8") as f:
spec = json.load(f)
tok = PreTrainedTokenizerFast(tokenizer_object=core, **{k:v for k,v in spec.items() if isinstance(v,str)})
return tok
except Exception as e:
print(f"[hrm_loader] tokenizer.json fallback failed: {e}")
vv = os.path.join(tok_dir, "vocab.json")
mm = os.path.join(tok_dir, "merges.txt")
if tok is None and os.path.exists(vv) and os.path.exists(mm):
try:
tok = GPT2TokenizerFast(vocab_file=vv, merges_file=mm)
spec_path = os.path.join(tok_dir, "special_tokens_map.json")
if os.path.exists(spec_path):
with open(spec_path, "r", encoding="utf-8") as f:
spec = json.load(f)
st = {k: spec[k] for k in ["bos_token","eos_token","unk_token","pad_token","sep_token","cls_token","mask_token"] if k in spec}
if st:
tok.add_special_tokens(st)
return tok
except Exception as e:
print(f"[hrm_loader] GPT2TokenizerFast fallback failed: {e}")
except Exception as e:
print(f"[hrm_loader] transformers/tokenizers unavailable or failed: {e}")
return tok
def _maybe_resize_embeddings_(model: nn.Module, vocab_size_new: int):
vocab_size_old = model.tok_emb.num_embeddings
if vocab_size_new == vocab_size_old:
return
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
d_model = model.d_model
old_w = model.tok_emb.weight.data.detach().to(device=device, dtype=dtype)
new_emb = nn.Embedding(vocab_size_new, d_model, device=device, dtype=dtype)
nn.init.normal_(new_emb.weight, mean=0.0, std=0.02)
keep = min(vocab_size_old, vocab_size_new)
new_emb.weight.data[:keep] = old_w[:keep]
model.tok_emb = new_emb
new_head = nn.Linear(d_model, vocab_size_new, bias=False, device=device, dtype=dtype)
model.lm_head = new_head
model.lm_head.weight = model.tok_emb.weight
print(f"[hrm_loader] resized embeddings: {vocab_size_old} -> {vocab_size_new}")
def _vocab_from_sd(sd: Dict[str, torch.Tensor]) -> Optional[int]:
te = sd.get("tok_emb.weight", None)
if te is None:
te = sd.get("lm_head.weight", None)
return int(te.shape[0]) if te is not None else None
# -------------- Public loader --------------
def load_hrm(
checkpoint_or_dir: str,
device: Optional[str] = "auto",
dtype: str = "auto",
strict: bool = True,
override_config: Optional[Dict[str, Any]] = None,
ModelCls=None,
with_tokenizer: bool = False,
tokenizer_path: Optional[str] = None,
):
if ModelCls is None:
ModelCls = HRMForCausalLM
ckpt = _find_checkpoint(checkpoint_or_dir)
sd = _load_state_dict(ckpt)
sd = _adapt_attention_keys(sd)
# NEW: If lm_head.weight is absent but tok_emb.weight exists (tied-weights checkpoint),
# mirror it to avoid "missing lm_head.weight" in load_state_dict.
if "lm_head.weight" not in sd and "tok_emb.weight" in sd:
sd["lm_head.weight"] = sd["tok_emb.weight"]
cfg_dir = checkpoint_or_dir if os.path.isdir(checkpoint_or_dir) else os.path.dirname(ckpt)
raw_cfg = _load_config_if_any(cfg_dir) or _infer_config_from_state(sd)
if override_config:
raw_cfg.update(override_config)
cfg = _sanitize_and_map_config(raw_cfg, ModelCls)
# Prefer checkpoint vocab_size to avoid size mismatches
sd_vocab = _vocab_from_sd(sd)
if sd_vocab is not None and (cfg.get("vocab_size") is None or cfg["vocab_size"] != sd_vocab):
print(f"[hrm_loader] adjusting vocab_size config {cfg.get('vocab_size')} -> {sd_vocab} from checkpoint")
cfg["vocab_size"] = sd_vocab
dev = _resolve_device(device)
dt = _resolve_dtype(dtype)
model = ModelCls(**cfg)
sd = _complete_and_filter_for_model(sd, model)
# Load weights (safe: shapes now match)
ik = model.load_state_dict(sd, strict=False)
missing = list(getattr(ik, "missing_keys", []))
unexpected = list(getattr(ik, "unexpected_keys", []))
if missing or unexpected:
print(f"[hrm_loader] load_state_dict: missing={len(missing)} unexpected={len(unexpected)}")
if missing: print(" missing (sample):", missing[:8])
if unexpected:print(" unexpected (sample):", unexpected[:8])
if strict:
raise RuntimeError(
"Strict load requested but state_dict mismatch remains.\n"
f"Missing (n={len(missing)}): {missing[:12]}\n"
f"Unexpected (n={len(unexpected)}): {unexpected[:12]}"
)
model.to(dev)
if dt != torch.float32:
model.to(dtype=dt) # parameters + buffers
try:
if hasattr(model, "lm_head") and hasattr(model, "tok_emb") and model.lm_head.weight is not model.tok_emb.weight:
model.lm_head.weight = model.tok_emb.weight
except Exception:
pass
model.eval()
tokenizer = None
if with_tokenizer:
tdir = tokenizer_path or cfg_dir
tokenizer = _load_local_tokenizer(tdir)
if tokenizer is None:
print(f"[hrm_loader] WARNING: could not load tokenizer from {tdir}")
else:
try:
_maybe_resize_embeddings_(model, len(tokenizer))
except Exception as e:
print(f"[hrm_loader] embedding resize check failed: {e}")
return (model, tokenizer) if with_tokenizer else model
__all__ = ["HRMForCausalLM", "load_hrm"]