|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os, json, glob, math, inspect |
|
|
from typing import Optional, Dict, Any |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
def __init__(self, d, eps=1e-6): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.weight = nn.Parameter(torch.ones(d)) |
|
|
def forward(self, x): |
|
|
return self.weight * (x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)) |
|
|
|
|
|
class SinusoidalPositionalEmbedding(nn.Module): |
|
|
def __init__(self, d_model, max_len=8192): |
|
|
super().__init__() |
|
|
pe = torch.zeros(max_len, d_model) |
|
|
pos = torch.arange(0, max_len).unsqueeze(1) |
|
|
div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) |
|
|
pe[:, 0::2] = torch.sin(pos * div) |
|
|
pe[:, 1::2] = torch.cos(pos * div) |
|
|
self.register_buffer("pe", pe, persistent=False) |
|
|
def forward(self, L: int): |
|
|
return self.pe[:L].unsqueeze(0) |
|
|
|
|
|
class SwiGLU(nn.Module): |
|
|
def __init__(self, d_model, d_ff, pdrop=0.1): |
|
|
super().__init__() |
|
|
self.w1 = nn.Linear(d_model, d_ff, bias=False) |
|
|
self.w2 = nn.Linear(d_model, d_ff, bias=False) |
|
|
self.w3 = nn.Linear(d_ff, d_model, bias=False) |
|
|
self.drop = nn.Dropout(pdrop) |
|
|
def forward(self, x): |
|
|
return self.drop(self.w3(F.silu(self.w1(x)) * self.w2(x))) |
|
|
|
|
|
class AttnBlock(nn.Module): |
|
|
def __init__(self, d_model, n_heads, d_ff, pdrop=0.1): |
|
|
super().__init__() |
|
|
self.norm1 = RMSNorm(d_model) |
|
|
self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=pdrop, batch_first=True) |
|
|
self.drop = nn.Dropout(pdrop) |
|
|
self.norm2 = RMSNorm(d_model) |
|
|
self.mlp = SwiGLU(d_model, d_ff, pdrop) |
|
|
def forward(self, x, attn_mask=None, key_padding_mask=None): |
|
|
if attn_mask is not None: |
|
|
assert attn_mask.dtype == torch.bool and attn_mask.dim() == 2 |
|
|
if key_padding_mask is not None: |
|
|
assert key_padding_mask.dtype == torch.bool and key_padding_mask.dim() == 2 |
|
|
h = self.norm1(x) |
|
|
a, _ = self.attn(h, h, h, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False) |
|
|
x = x + self.drop(a) |
|
|
x = x + self.drop(self.mlp(self.norm2(x))) |
|
|
return x |
|
|
|
|
|
|
|
|
class HRMForCausalLM(nn.Module): |
|
|
def __init__(self, vocab_size: int, d_model=512, n_heads=8, d_ff=2048, dropout=0.1, |
|
|
k_l_steps=4, max_cycles=8, ponder_loss_weight=1e-2): |
|
|
super().__init__() |
|
|
assert d_model % n_heads == 0, "d_model must be divisible by n_heads" |
|
|
self.vocab_size = vocab_size |
|
|
self.d_model = d_model |
|
|
self.k_l_steps = k_l_steps |
|
|
self.max_cycles = max_cycles |
|
|
self.ponder_w = ponder_loss_weight |
|
|
|
|
|
self.tok_emb = nn.Embedding(vocab_size, d_model) |
|
|
self.pos_emb = SinusoidalPositionalEmbedding(d_model, max_len=8192) |
|
|
self.in_net = nn.Sequential(nn.Linear(d_model, d_model), nn.GELU(), RMSNorm(d_model)) |
|
|
|
|
|
self.L_mod = AttnBlock(d_model, n_heads, d_ff, dropout) |
|
|
self.H_mod = AttnBlock(d_model, n_heads, d_ff, dropout) |
|
|
|
|
|
self.halt_head = nn.Linear(d_model, 1) |
|
|
nn.init.constant_(self.halt_head.bias, -1.5) |
|
|
|
|
|
self.out_norm = RMSNorm(d_model) |
|
|
|
|
|
self.lm_head = nn.Linear(d_model, vocab_size, bias=False) |
|
|
self.lm_head.weight = self.tok_emb.weight |
|
|
|
|
|
self._cached_causal_bool = {} |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, (nn.Linear, nn.Embedding)): |
|
|
nn.init.normal_(m.weight, mean=0.0, std=0.02) |
|
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
|
|
|
def _causal_bool_mask(self, L: int, device): |
|
|
k = (L, device) |
|
|
if k not in self._cached_causal_bool: |
|
|
self._cached_causal_bool[k] = torch.triu(torch.ones(L, L, dtype=torch.bool, device=device), 1) |
|
|
return self._cached_causal_bool[k] |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None, labels=None): |
|
|
B, L = input_ids.shape |
|
|
device = input_ids.device |
|
|
x_tok = self.tok_emb(input_ids) |
|
|
pos = self.pos_emb(L).to(device=device, dtype=x_tok.dtype) |
|
|
x = self.in_net(x_tok + pos) |
|
|
|
|
|
causal_bool = self._causal_bool_mask(L, device) |
|
|
key_padding_mask = (attention_mask == 0) if attention_mask is not None else None |
|
|
|
|
|
z_L = x.clone() |
|
|
z_H = torch.zeros_like(x) |
|
|
|
|
|
eps = 1e-6 |
|
|
rema = torch.ones((B, L), device=device, dtype=x_tok.dtype) |
|
|
collected_H = torch.zeros_like(z_H) |
|
|
ponder_terms = [] |
|
|
|
|
|
for c in range(self.max_cycles): |
|
|
for _ in range(self.k_l_steps): |
|
|
z_L = self.L_mod(z_L + z_H + x, attn_mask=causal_bool, key_padding_mask=key_padding_mask) |
|
|
z_H = self.H_mod(z_H + z_L, attn_mask=causal_bool, key_padding_mask=key_padding_mask) |
|
|
|
|
|
p_halt = torch.sigmoid(self.halt_head(z_H)).squeeze(-1).clamp(eps, 1 - eps) |
|
|
last = torch.full_like(p_halt, fill_value=(c == self.max_cycles - 1), dtype=torch.bool) |
|
|
halt_p = torch.where(last, torch.ones_like(p_halt), p_halt) |
|
|
|
|
|
contrib = (rema * halt_p).unsqueeze(-1) |
|
|
collected_H = collected_H + contrib * z_H |
|
|
|
|
|
ponder_terms.append(rema * halt_p) |
|
|
rema = rema * (1.0 - halt_p) |
|
|
if torch.all(rema < 1e-4): |
|
|
break |
|
|
|
|
|
collected_H = self.out_norm(collected_H) |
|
|
logits = self.lm_head(collected_H) |
|
|
|
|
|
loss = lm_loss = ponder = None |
|
|
if labels is not None: |
|
|
sl = logits[:, :-1, :].contiguous() |
|
|
y = labels[:, 1:].contiguous() |
|
|
B_, Lm1, V = sl.shape |
|
|
lm_loss = F.cross_entropy(sl.float().view(B_ * Lm1, V), y.view(B_ * Lm1)) |
|
|
ponder = torch.stack(ponder_terms, dim=-1).sum(dim=-1).mean() |
|
|
loss = lm_loss + self.ponder_w * ponder |
|
|
|
|
|
return {"loss": loss, "logits": logits, "lm_loss": lm_loss, "ponder_loss": ponder} |
|
|
|
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.tok_emb |
|
|
def set_input_embeddings(self, new_emb): |
|
|
self.tok_emb = new_emb |
|
|
if hasattr(self, "lm_head"): |
|
|
self.lm_head.weight = self.tok_emb.weight |
|
|
def tie_weights(self): |
|
|
if hasattr(self, "lm_head") and hasattr(self, "tok_emb"): |
|
|
self.lm_head.weight = self.tok_emb.weight |
|
|
|
|
|
|
|
|
def _resolve_device(device: Optional[str]) -> torch.device: |
|
|
if device is None or device == "auto": |
|
|
if torch.cuda.is_available(): return torch.device("cuda") |
|
|
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): return torch.device("mps") |
|
|
return torch.device("cpu") |
|
|
return torch.device(device) |
|
|
|
|
|
def _resolve_dtype(dtype: str) -> torch.dtype: |
|
|
d = str(dtype).lower() |
|
|
if d in ("fp32","float32","f32"): return torch.float32 |
|
|
if d in ("bf16","bfloat16"): return torch.bfloat16 |
|
|
if d in ("fp16","float16","half"):return torch.float16 |
|
|
if d == "auto": |
|
|
if torch.cuda.is_available() and getattr(torch.cuda, "is_bf16_supported", lambda: False)(): return torch.bfloat16 |
|
|
return torch.float32 |
|
|
raise ValueError(f"Unknown dtype {dtype}") |
|
|
|
|
|
def _find_checkpoint(path_or_dir: str) -> str: |
|
|
if os.path.isfile(path_or_dir): return path_or_dir |
|
|
if not os.path.isdir(path_or_dir): raise FileNotFoundError(f"Not a file or directory: {path_or_dir}") |
|
|
st = glob.glob(os.path.join(path_or_dir, "*.safetensors")) |
|
|
if len(st) == 1: return st[0] |
|
|
if len(st) > 1: |
|
|
for cand in ("model.safetensors","pytorch_model.safetensors"): |
|
|
p = os.path.join(path_or_dir, cand) |
|
|
if os.path.exists(p): return p |
|
|
return sorted(st)[0] |
|
|
for idx in ("model.safetensors.index.json","pytorch_model.bin.index.json"): |
|
|
p = os.path.join(path_or_dir, idx) |
|
|
if os.path.exists(p): return p |
|
|
for cand in ("pytorch_model.bin","model.bin","model.pt"): |
|
|
p = os.path.join(path_or_dir, cand) |
|
|
if os.path.exists(p): return p |
|
|
pt = glob.glob(os.path.join(path_or_dir, "*.pt")) + glob.glob(os.path.join(path_or_dir, "*.bin")) |
|
|
if pt: return sorted(pt)[0] |
|
|
raise FileNotFoundError(f"No checkpoint found in {path_or_dir}") |
|
|
|
|
|
def _torch_load(path: str): |
|
|
try: |
|
|
return torch.load(path, map_location="cpu", weights_only=True) |
|
|
except TypeError: |
|
|
return torch.load(path, map_location="cpu") |
|
|
|
|
|
def _normalize_keys(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
|
def strip(k: str) -> str: |
|
|
for pref in ("module.","model.","transformer."): |
|
|
if k.startswith(pref): return k[len(pref):] |
|
|
return k |
|
|
return {strip(k): v for k, v in sd.items()} |
|
|
|
|
|
def _adapt_attention_keys(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
|
sd = dict(sd) |
|
|
def handle(prefix: str): |
|
|
qkv_w = sd.pop(f"{prefix}.qkv.weight", None) |
|
|
if qkv_w is not None: |
|
|
sd[f"{prefix}.in_proj_weight"] = qkv_w |
|
|
qkv_b = sd.pop(f"{prefix}.qkv.bias", None) |
|
|
if qkv_b is not None: |
|
|
sd[f"{prefix}.in_proj_bias"] = qkv_b |
|
|
|
|
|
q_w = sd.pop(f"{prefix}.q_proj.weight", None) |
|
|
k_w = sd.pop(f"{prefix}.k_proj.weight", None) |
|
|
v_w = sd.pop(f"{prefix}.v_proj.weight", None) |
|
|
if q_w is not None and k_w is not None and v_w is not None: |
|
|
sd[f"{prefix}.in_proj_weight"] = torch.cat([q_w, k_w, v_w], dim=0) |
|
|
|
|
|
q_b = sd.pop(f"{prefix}.q_proj.bias", None) |
|
|
k_b = sd.pop(f"{prefix}.k_proj.bias", None) |
|
|
v_b = sd.pop(f"{prefix}.v_proj.bias", None) |
|
|
if q_b is not None and k_b is not None and v_b is not None: |
|
|
sd[f"{prefix}.in_proj_bias"] = torch.cat([q_b, k_b, v_b], dim=0) |
|
|
|
|
|
o_w = sd.pop(f"{prefix}.o.weight", None) |
|
|
if o_w is not None: |
|
|
sd[f"{prefix}.out_proj.weight"] = o_w |
|
|
o_b = sd.pop(f"{prefix}.o.bias", None) |
|
|
if o_b is not None: |
|
|
sd[f"{prefix}.out_proj.bias"] = o_b |
|
|
|
|
|
if f"{prefix}.in_proj_weight" in sd and f"{prefix}.in_proj_bias" not in sd: |
|
|
E = sd[f"{prefix}.in_proj_weight"].shape[1] |
|
|
sd[f"{prefix}.in_proj_bias"] = torch.zeros(3 * E, dtype=sd[f"{prefix}.in_proj_weight"].dtype) |
|
|
for blk in ("L_mod.attn", "H_mod.attn"): |
|
|
handle(blk) |
|
|
return sd |
|
|
|
|
|
def _load_state_dict(ckpt_path: str) -> Dict[str, torch.Tensor]: |
|
|
if ckpt_path.endswith(".safetensors"): |
|
|
from safetensors.torch import load_file as safe_load |
|
|
return _normalize_keys(safe_load(ckpt_path, device="cpu")) |
|
|
if ckpt_path.endswith("model.safetensors.index.json"): |
|
|
base = os.path.dirname(ckpt_path) |
|
|
with open(ckpt_path, "r", encoding="utf-8") as f: |
|
|
idx = json.load(f) |
|
|
from safetensors import safe_open |
|
|
state = {} |
|
|
for shard in sorted(set(idx.get("weight_map", {}).values())): |
|
|
with safe_open(os.path.join(base, shard), framework="pt", device="cpu") as sf: |
|
|
for k in sf.keys(): |
|
|
state[k] = sf.get_tensor(k) |
|
|
return _normalize_keys(state) |
|
|
if ckpt_path.endswith("pytorch_model.bin.index.json"): |
|
|
base = os.path.dirname(ckpt_path) |
|
|
with open(ckpt_path, "r", encoding="utf-8") as f: |
|
|
idx = json.load(f) |
|
|
state = {} |
|
|
for shard in sorted(set(idx.get("weight_map", {}).values())): |
|
|
part = _torch_load(os.path.join(base, shard)) |
|
|
if isinstance(part, dict) and "state_dict" in part: |
|
|
part = part["state_dict"] |
|
|
state.update(part) |
|
|
return _normalize_keys(state) |
|
|
if ckpt_path.endswith((".pt",".bin")): |
|
|
obj = _torch_load(ckpt_path) |
|
|
if isinstance(obj, dict) and "state_dict" in obj: |
|
|
obj = obj["state_dict"] |
|
|
return _normalize_keys(obj) |
|
|
if ckpt_path.endswith(".json"): |
|
|
raise ValueError("Pass the directory, not the index/config JSON.") |
|
|
raise ValueError(f"Unsupported checkpoint type: {ckpt_path}") |
|
|
|
|
|
def _load_config_if_any(path_or_dir: str) -> Optional[Dict[str, Any]]: |
|
|
p = path_or_dir if path_or_dir.endswith(".json") else os.path.join(path_or_dir, "config.json") |
|
|
if os.path.exists(p): |
|
|
with open(p, "r", encoding="utf-8") as f: |
|
|
return json.load(f) |
|
|
return None |
|
|
|
|
|
def _infer_config_from_state(sd: Dict[str, torch.Tensor]) -> Dict[str, Any]: |
|
|
te = sd.get("tok_emb.weight", None) |
|
|
if te is None: |
|
|
te = sd.get("lm_head.weight", None) |
|
|
if te is None: |
|
|
raise ValueError("Cannot infer config: missing tok_emb.weight (or lm_head.weight).") |
|
|
vocab_size, d_model = te.shape |
|
|
w1 = sd.get("L_mod.mlp.w1.weight", None) |
|
|
if w1 is None: |
|
|
w1 = sd.get("H_mod.mlp.w1.weight", None) |
|
|
d_ff = int(w1.shape[0]) if w1 is not None else int(4 * d_model) |
|
|
return dict(vocab_size=int(vocab_size), d_model=int(d_model), n_heads=8, d_ff=int(d_ff), |
|
|
dropout=0.1, k_l_steps=4, max_cycles=8, ponder_loss_weight=1e-2) |
|
|
|
|
|
_ALLOWED_KW = {"vocab_size","d_model","n_heads","d_ff","dropout","k_l_steps","max_cycles","ponder_loss_weight"} |
|
|
_DROP_KEYS = {"weight_tying","tie_word_embeddings","torch_dtype","architectures","model_type", |
|
|
"initializer_range","layer_norm_eps","max_position_embeddings","use_cache"} |
|
|
|
|
|
def _sanitize_and_map_config(raw_cfg: Dict[str, Any], ModelCls): |
|
|
cfg = dict(raw_cfg) if raw_cfg else {} |
|
|
for src, dst in {"hidden_size":"d_model","num_attention_heads":"n_heads","intermediate_size":"d_ff"}.items(): |
|
|
if src in cfg and dst not in cfg: |
|
|
cfg[dst] = cfg[src] |
|
|
if "vocab_size" not in cfg and raw_cfg and "vocab_size" in raw_cfg: |
|
|
cfg["vocab_size"] = raw_cfg["vocab_size"] |
|
|
for k in list(cfg.keys()): |
|
|
if k in _DROP_KEYS: |
|
|
cfg.pop(k, None) |
|
|
cfg = {k: v for k, v in cfg.items() if k in _ALLOWED_KW} |
|
|
allowed = set(inspect.signature(ModelCls.__init__).parameters.keys()) - {"self"} |
|
|
cfg = {k: v for k, v in cfg.items() if k in allowed} |
|
|
return cfg |
|
|
|
|
|
def _complete_and_filter_for_model(sd: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]: |
|
|
sd2 = dict(sd) |
|
|
msd = model.state_dict() |
|
|
for blk in ("L_mod.attn", "H_mod.attn"): |
|
|
ipw = f"{blk}.in_proj_weight" |
|
|
ipb = f"{blk}.in_proj_bias" |
|
|
if ipw in sd2 and ipb not in sd2 and ipb in msd: |
|
|
E = sd2[ipw].shape[1] |
|
|
sd2[ipb] = torch.zeros(3 * E, dtype=sd2[ipw].dtype) |
|
|
opw = f"{blk}.out_proj.weight" |
|
|
opb = f"{blk}.out_proj.bias" |
|
|
if opw in sd2 and opb not in sd2 and opb in msd: |
|
|
out_dim = msd[opb].shape[0] |
|
|
sd2[opb] = torch.zeros(out_dim, dtype=sd2[opw].dtype) |
|
|
|
|
|
sd2 = {k: v for k, v in sd2.items() if (k in msd) and (tuple(v.shape) == tuple(msd[k].shape))} |
|
|
return sd2 |
|
|
|
|
|
|
|
|
def _load_local_tokenizer(tok_dir: str): |
|
|
tok = None |
|
|
try: |
|
|
from transformers import AutoTokenizer, PreTrainedTokenizerFast, GPT2TokenizerFast |
|
|
try: |
|
|
tok = AutoTokenizer.from_pretrained(tok_dir, local_files_only=True, use_fast=True, trust_remote_code=True) |
|
|
return tok |
|
|
except Exception as e: |
|
|
print(f"[hrm_loader] AutoTokenizer fallback: {e}") |
|
|
tj = os.path.join(tok_dir, "tokenizer.json") |
|
|
if tok is None and os.path.exists(tj): |
|
|
try: |
|
|
from tokenizers import Tokenizer |
|
|
core = Tokenizer.from_file(tj) |
|
|
spec_path = os.path.join(tok_dir, "special_tokens_map.json") |
|
|
spec = {} |
|
|
if os.path.exists(spec_path): |
|
|
with open(spec_path, "r", encoding="utf-8") as f: |
|
|
spec = json.load(f) |
|
|
tok = PreTrainedTokenizerFast(tokenizer_object=core, **{k:v for k,v in spec.items() if isinstance(v,str)}) |
|
|
return tok |
|
|
except Exception as e: |
|
|
print(f"[hrm_loader] tokenizer.json fallback failed: {e}") |
|
|
vv = os.path.join(tok_dir, "vocab.json") |
|
|
mm = os.path.join(tok_dir, "merges.txt") |
|
|
if tok is None and os.path.exists(vv) and os.path.exists(mm): |
|
|
try: |
|
|
tok = GPT2TokenizerFast(vocab_file=vv, merges_file=mm) |
|
|
spec_path = os.path.join(tok_dir, "special_tokens_map.json") |
|
|
if os.path.exists(spec_path): |
|
|
with open(spec_path, "r", encoding="utf-8") as f: |
|
|
spec = json.load(f) |
|
|
st = {k: spec[k] for k in ["bos_token","eos_token","unk_token","pad_token","sep_token","cls_token","mask_token"] if k in spec} |
|
|
if st: |
|
|
tok.add_special_tokens(st) |
|
|
return tok |
|
|
except Exception as e: |
|
|
print(f"[hrm_loader] GPT2TokenizerFast fallback failed: {e}") |
|
|
except Exception as e: |
|
|
print(f"[hrm_loader] transformers/tokenizers unavailable or failed: {e}") |
|
|
return tok |
|
|
|
|
|
def _maybe_resize_embeddings_(model: nn.Module, vocab_size_new: int): |
|
|
vocab_size_old = model.tok_emb.num_embeddings |
|
|
if vocab_size_new == vocab_size_old: |
|
|
return |
|
|
device = next(model.parameters()).device |
|
|
dtype = next(model.parameters()).dtype |
|
|
d_model = model.d_model |
|
|
old_w = model.tok_emb.weight.data.detach().to(device=device, dtype=dtype) |
|
|
new_emb = nn.Embedding(vocab_size_new, d_model, device=device, dtype=dtype) |
|
|
nn.init.normal_(new_emb.weight, mean=0.0, std=0.02) |
|
|
keep = min(vocab_size_old, vocab_size_new) |
|
|
new_emb.weight.data[:keep] = old_w[:keep] |
|
|
model.tok_emb = new_emb |
|
|
new_head = nn.Linear(d_model, vocab_size_new, bias=False, device=device, dtype=dtype) |
|
|
model.lm_head = new_head |
|
|
model.lm_head.weight = model.tok_emb.weight |
|
|
print(f"[hrm_loader] resized embeddings: {vocab_size_old} -> {vocab_size_new}") |
|
|
|
|
|
def _vocab_from_sd(sd: Dict[str, torch.Tensor]) -> Optional[int]: |
|
|
te = sd.get("tok_emb.weight", None) |
|
|
if te is None: |
|
|
te = sd.get("lm_head.weight", None) |
|
|
return int(te.shape[0]) if te is not None else None |
|
|
|
|
|
|
|
|
def load_hrm( |
|
|
checkpoint_or_dir: str, |
|
|
device: Optional[str] = "auto", |
|
|
dtype: str = "auto", |
|
|
strict: bool = True, |
|
|
override_config: Optional[Dict[str, Any]] = None, |
|
|
ModelCls=None, |
|
|
with_tokenizer: bool = False, |
|
|
tokenizer_path: Optional[str] = None, |
|
|
): |
|
|
if ModelCls is None: |
|
|
ModelCls = HRMForCausalLM |
|
|
|
|
|
ckpt = _find_checkpoint(checkpoint_or_dir) |
|
|
sd = _load_state_dict(ckpt) |
|
|
sd = _adapt_attention_keys(sd) |
|
|
|
|
|
|
|
|
|
|
|
if "lm_head.weight" not in sd and "tok_emb.weight" in sd: |
|
|
sd["lm_head.weight"] = sd["tok_emb.weight"] |
|
|
|
|
|
cfg_dir = checkpoint_or_dir if os.path.isdir(checkpoint_or_dir) else os.path.dirname(ckpt) |
|
|
raw_cfg = _load_config_if_any(cfg_dir) or _infer_config_from_state(sd) |
|
|
if override_config: |
|
|
raw_cfg.update(override_config) |
|
|
cfg = _sanitize_and_map_config(raw_cfg, ModelCls) |
|
|
|
|
|
|
|
|
sd_vocab = _vocab_from_sd(sd) |
|
|
if sd_vocab is not None and (cfg.get("vocab_size") is None or cfg["vocab_size"] != sd_vocab): |
|
|
print(f"[hrm_loader] adjusting vocab_size config {cfg.get('vocab_size')} -> {sd_vocab} from checkpoint") |
|
|
cfg["vocab_size"] = sd_vocab |
|
|
|
|
|
dev = _resolve_device(device) |
|
|
dt = _resolve_dtype(dtype) |
|
|
|
|
|
model = ModelCls(**cfg) |
|
|
sd = _complete_and_filter_for_model(sd, model) |
|
|
|
|
|
|
|
|
ik = model.load_state_dict(sd, strict=False) |
|
|
missing = list(getattr(ik, "missing_keys", [])) |
|
|
unexpected = list(getattr(ik, "unexpected_keys", [])) |
|
|
if missing or unexpected: |
|
|
print(f"[hrm_loader] load_state_dict: missing={len(missing)} unexpected={len(unexpected)}") |
|
|
if missing: print(" missing (sample):", missing[:8]) |
|
|
if unexpected:print(" unexpected (sample):", unexpected[:8]) |
|
|
if strict: |
|
|
raise RuntimeError( |
|
|
"Strict load requested but state_dict mismatch remains.\n" |
|
|
f"Missing (n={len(missing)}): {missing[:12]}\n" |
|
|
f"Unexpected (n={len(unexpected)}): {unexpected[:12]}" |
|
|
) |
|
|
|
|
|
model.to(dev) |
|
|
if dt != torch.float32: |
|
|
model.to(dtype=dt) |
|
|
|
|
|
try: |
|
|
if hasattr(model, "lm_head") and hasattr(model, "tok_emb") and model.lm_head.weight is not model.tok_emb.weight: |
|
|
model.lm_head.weight = model.tok_emb.weight |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
model.eval() |
|
|
|
|
|
tokenizer = None |
|
|
if with_tokenizer: |
|
|
tdir = tokenizer_path or cfg_dir |
|
|
tokenizer = _load_local_tokenizer(tdir) |
|
|
if tokenizer is None: |
|
|
print(f"[hrm_loader] WARNING: could not load tokenizer from {tdir}") |
|
|
else: |
|
|
try: |
|
|
_maybe_resize_embeddings_(model, len(tokenizer)) |
|
|
except Exception as e: |
|
|
print(f"[hrm_loader] embedding resize check failed: {e}") |
|
|
|
|
|
return (model, tokenizer) if with_tokenizer else model |
|
|
|
|
|
__all__ = ["HRMForCausalLM", "load_hrm"] |
|
|
|