Quark_utils / hrm_misc.py
abanm's picture
Upload 4 files
cfef4e2 verified
# -*- coding: utf-8 -*-
from __future__ import annotations
import os, re, math, json, time, shutil, random
from dataclasses import dataclass, asdict
from typing import Optional, Callable, Dict, Any, Iterable, Tuple, List
from contextlib import nullcontext
# Torch & friends
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
# Transformers / Datasets
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
from datasets import load_dataset, DatasetDict
# Optional: Weights & Biases
try:
import wandb # noqa
except Exception:
wandb = None
# =========================================================
# Utils
# =========================================================
def set_seed(seed: int = 1337):
import numpy as np
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
def auto_device():
if torch.cuda.is_available(): return torch.device("cuda")
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): return torch.device("mps")
return torch.device("cpu")
def format_num(x):
try: return f"{x:.6g}"
except: return str(x)
def save_safetensors_safe(model: nn.Module, path: str, metadata: Optional[Dict[str, str]] = None):
"""
Save weights as .safetensors, handling tied weights (lm_head <- tok_emb) when needed.
"""
try:
from safetensors.torch import save_model # preserves shared storage & avoids duplication
save_model(model, path, metadata=metadata or {})
except Exception:
# Fallback that copies state_dict and de-duplicates lm_head if needed
try:
from safetensors.torch import save_file
state = model.state_dict()
if "lm_head.weight" in state and "tok_emb.weight" in state:
state["lm_head.weight"] = state["tok_emb.weight"].clone()
save_file(state, path, metadata=metadata or {})
except Exception as e:
print("[warn] safetensors not saved:", e)
# =========================================================
# Tokenizer helper
# =========================================================
def _gpt2_tokenizer_with_specials(
additional: Optional[List[str]] = None,
checkpoint_or_dir: Optional[str] = None,
) -> AutoTokenizer:
"""
If `checkpoint_or_dir` is provided, load tokenizer from there; else use 'gpt2'.
Ensures PAD exists (PAD→EOS), optionally adds extra specials, sets a huge model_max_length.
"""
tok = None
if checkpoint_or_dir is not None:
try:
tok = AutoTokenizer.from_pretrained(checkpoint_or_dir, use_fast=True)
except Exception as e:
print(f"[warn] Failed to load tokenizer from '{checkpoint_or_dir}': {e}")
print("[warn] Falling back to 'gpt2' tokenizer.")
if tok is None:
tok = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
if tok.eos_token is None:
tok.add_special_tokens({"eos_token": "</s>"})
if tok.pad_token is None:
tok.pad_token = tok.eos_token
if additional:
new_tokens = [t for t in additional if t not in tok.get_vocab()]
if new_tokens:
tok.add_special_tokens({"additional_special_tokens": new_tokens})
print(f"[info] Added {len(new_tokens)} special tokens to tokenizer")
tok.model_max_length = 10_000_000
tok.init_kwargs["model_max_length"] = tok.model_max_length
return tok
# =========================================================
# Fixed-block causal dataset
# =========================================================
class CausalChunked(Dataset):
"""Flatten tokens then slice into non-overlapping blocks; x == labels."""
def __init__(self, token_ids: Iterable[int], block_size: int):
ids = list(token_ids)
n_full = (len(ids) // block_size) * block_size
n_discarded = len(ids) - n_full
if n_discarded > 0 and len(ids) > 0:
pct = n_discarded / len(ids) * 100
print(f"[info] Discarded {n_discarded} tokens ({pct:.2f}%) that didn't fit into complete blocks")
ids = ids[:n_full]
self.blocks = [ids[i:i + block_size] for i in range(0, n_full, block_size)]
def __len__(self) -> int:
return len(self.blocks)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
x = torch.tensor(self.blocks[idx], dtype=torch.long)
return {"input_ids": x, "labels": x.clone()}
# =========================================================
# PAD-mask helper (for variable-length batches with padding)
# =========================================================
def mask_pad_labels(
input_ids: torch.Tensor,
labels: torch.Tensor,
pad_id: Optional[int] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Clone `labels` and set pad positions to -100 (ignored by CrossEntropyLoss).
Prefers `attention_mask` if provided, otherwise uses pad_id to detect padding.
Note: CausalChunked produces fixed-length blocks without padding, so this is
only needed if you supply your own dataloader with padding.
"""
lab = labels.clone()
if attention_mask is not None:
lab[attention_mask == 0] = -100
elif pad_id is not None:
lab[input_ids == pad_id] = -100
return lab
# =========================================================
# Dataset loader (HF datasets or local txt files)
# =========================================================
def load_dataset_fn(
source: str = "hf:lemonilia/wikified_english_dictionary",
split: str = "train",
*,
text_field: Optional[str] = "text",
word_field: str = "word",
article_field: str = "article",
block_size: int = 128,
batch_size: int = 8,
num_workers: int = 0,
shuffle: bool = True,
checkpoint_or_dir: Optional[str] = None,
additional_specials: Optional[List[str]] = None,
) -> Tuple[AutoTokenizer, DataLoader, Dict[str, int]]:
"""
Load and tokenize a dataset for causal LM training. Returns (tokenizer, DataLoader, meta).
source:
- 'hf:<name_or_path>' to read a HuggingFace dataset
- 'txt:/path1;/path2;...' to read local text files (semicolon-separated)
Behavior:
• If `text_field` is present, uses it.
• Else if both `word_field` and `article_field` exist, merges them as:
"<word>\\n<article>\\n\\n"
while stripping any <|begin_of_thought|>...<|end_of_thought|> spans.
• Else, falls back to a 'text' column if available.
• Appends EOS between docs/files to avoid cross-boundary contamination.
"""
tokenizer = _gpt2_tokenizer_with_specials(
additional=additional_specials,
checkpoint_or_dir=checkpoint_or_dir,
)
eos_id = tokenizer.eos_token_id
token_stream: List[int] = []
if source.startswith("hf:"):
ds_name = source[3:]
raw = load_dataset(ds_name)
if split not in raw:
raise ValueError(f"[error] Split '{split}' not found. Available: {list(raw.keys())}")
cols = raw[split].column_names
# A) explicit text field
if text_field is not None and text_field in cols:
field_to_use = text_field
def tok_map(batch):
return tokenizer(batch[field_to_use], add_special_tokens=False)
toks = raw.map(tok_map, batched=True, remove_columns=cols)
# B) merge word+article when requested/needed
elif (text_field is None or text_field not in cols) and word_field in cols and article_field in cols:
BEGIN_THOUGHT = re.compile(r"<\|begin_of_thought\|>.*?<\|end_of_thought\|>", re.DOTALL)
def fmt(batch):
out = []
for w, a in zip(batch[word_field], batch[article_field]):
w = (w or "").strip()
a = re.sub(BEGIN_THOUGHT, "", (a or "")).strip()
out.append(w + "\n" + a + "\n\n")
return {"text": out}
raw = raw.map(fmt, batched=True)
raw = DatasetDict({
sp: d.remove_columns([c for c in d.column_names if c != "text"])
for sp, d in raw.items()
})
def tok_map(batch):
return tokenizer(batch["text"], add_special_tokens=False)
toks = raw.map(tok_map, batched=True, remove_columns=["text"])
# C) fallback 'text'
elif "text" in cols:
def tok_map(batch):
return tokenizer(batch["text"], add_special_tokens=False)
toks = raw.map(tok_map, batched=True, remove_columns=cols)
else:
raise ValueError(
f"[error] Could not find a text source.\n"
f" - Requested text_field={text_field!r}\n"
f" - Available columns: {cols}\n"
f" - Set text_field accordingly, or set text_field=None if your dataset has "
f" both '{word_field}' and '{article_field}' to auto-merge."
)
n_empty = 0
for doc in toks[split]["input_ids"]:
if not doc:
n_empty += 1
continue
token_stream.extend(doc)
if eos_id is not None:
token_stream.append(eos_id)
if n_empty > 0:
print(f"[info] Skipped {n_empty} empty documents")
elif source.startswith("txt:"):
paths = [p for p in source[4:].split(";") if p]
if not paths:
raise ValueError("[error] No file paths provided after 'txt:'")
for p in paths:
if not os.path.exists(p):
raise FileNotFoundError(f"[error] File not found: {p}")
with open(p, "r", encoding="utf-8") as f:
text = f.read()
if text.strip():
ids = tokenizer(text, add_special_tokens=False)["input_ids"]
token_stream.extend(ids)
if eos_id is not None:
token_stream.append(eos_id)
else:
raise ValueError("[error] source must start with 'hf:' or 'txt:'")
if not token_stream:
raise ValueError("[error] No tokens extracted from the source. Check your data.")
ds = CausalChunked(token_stream, block_size)
if len(ds) == 0:
raise ValueError(
f"[error] Tokenized corpus ({len(token_stream)} tokens) is too small "
f"for block_size={block_size}. No complete blocks produced."
)
loader = DataLoader(
ds,
batch_size=batch_size,
shuffle=shuffle,
drop_last=True,
pin_memory=torch.cuda.is_available(),
num_workers=num_workers,
)
meta = {
"vocab_size": len(tokenizer),
"eos_id": eos_id,
"n_blocks": len(ds),
"n_tokens": len(token_stream),
"tokens_per_block": block_size,
}
print(f"[info] Dataset ready: {meta['n_blocks']} blocks, {meta['n_tokens']} tokens total")
return tokenizer, loader, meta
# =========================================================
# Trainer
# =========================================================
@dataclass
class TrainConfig:
output_dir: str = "outputs/hrm_run"
num_epochs: int = 1
max_steps: Optional[int] = None # if set, overrides num_epochs
per_device_train_batch_size: int = 8
gradient_accumulation_steps: int = 1
learning_rate: float = 1e-4
betas: tuple = (0.9, 0.95)
eps: float = 1e-8
weight_decay: float = 0.01
warmup_ratio: float = 0.06
max_grad_norm: float = 0.5
log_every: int = 100
save_every: int = 2000
eval_every: int = 2000
save_total_limit: int = 3
fp16: bool = False
bf16: bool = True # prefer bf16 if supported
seed: int = 1337
resume_from: Optional[str] = None # path to checkpoint dir
early_stopping_patience: Optional[int] = None # steps without eval improvement
best_metric: str = "eval/loss"
greater_is_better: bool = False
torch_compile: bool = False
# Optional W&B
wandb_enable: bool = False
wandb_entity: Optional[str] = None
wandb_project: Optional[str] = None
wandb_run_name: Optional[str] = None
def _out_get(out: Any, key: str, default=None):
if isinstance(out, dict):
return out.get(key, default)
return getattr(out, key, default)
class MiniTRLTrainer:
"""
TRL-like supervised trainer:
Model forward must accept (input_ids, labels) and return something with:
- loss (required)
- logits (optional but recommended; used for sanity checks)
- lm_loss (optional; logged if present)
- ponder_loss (optional; logged if present)
DataLoader must yield dicts with keys:
- "input_ids" and (optionally) "labels". If "labels" missing, labels=input_ids.
- If you pad to fixed length externally, also pass "attention_mask" so we can mask pad tokens.
"""
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
tokenizer: Optional[AutoTokenizer] = None,
eval_loader: Optional[DataLoader] = None,
config: TrainConfig = TrainConfig(),
compute_metrics: Optional[Callable[[Dict[str, float]], Dict[str, float]]] = None,
custom_loss_fn: Optional[Callable[[Any], torch.Tensor]] = None, # receives model outputs
device: Optional[torch.device] = None,
):
self.model = model
self.train_loader = train_loader
self.eval_loader = eval_loader
self.tok = tokenizer
self.cfg = config
self.compute_metrics = compute_metrics
self.custom_loss_fn = custom_loss_fn
self.device = device or auto_device()
set_seed(self.cfg.seed)
self.model.to(self.device)
if self.cfg.torch_compile:
try:
self.model = torch.compile(self.model)
except Exception as e:
print("[warn] torch.compile failed:", e)
# AMP dtype
if self.device.type == "cuda":
self.amp_dtype = torch.bfloat16 if (self.cfg.bf16 and torch.cuda.is_bf16_supported()) else (torch.float16 if self.cfg.fp16 else None)
else:
self.amp_dtype = None
# Param groups with/without weight decay
decay, no_decay = [], []
for n, p in self.model.named_parameters():
if not p.requires_grad:
continue
nl = n.lower()
if p.ndim == 1 or "norm" in nl or "bias" in nl or ("tok_emb.weight" in n):
no_decay.append(p)
else:
decay.append(p)
self.optimizer = torch.optim.AdamW(
[{"params": decay, "weight_decay": self.cfg.weight_decay},
{"params": no_decay, "weight_decay": 0.0}],
lr=self.cfg.learning_rate, betas=self.cfg.betas, eps=self.cfg.eps
)
# Scheduler
steps_per_epoch = math.ceil(len(self.train_loader) / max(1, self.cfg.gradient_accumulation_steps))
total_updates = self.cfg.max_steps if self.cfg.max_steps is not None else self.cfg.num_epochs * max(1, steps_per_epoch)
total_updates = max(1, total_updates) # guard
warmup_steps = int(self.cfg.warmup_ratio * total_updates)
self.scheduler = get_linear_schedule_with_warmup(self.optimizer, warmup_steps, total_updates)
# GradScaler only for fp16
self.scaler = torch.cuda.amp.GradScaler(enabled=(self.amp_dtype == torch.float16 and self.device.type == "cuda"))
# State
self.global_step = 0
self.best_metric_val = float("-inf") if self.cfg.greater_is_better else float("inf")
self.no_improve_steps = 0
os.makedirs(self.cfg.output_dir, exist_ok=True)
self._maybe_resume()
# W&B
self._wandb_run = None
if self.cfg.wandb_enable:
if wandb is None:
print("[warn] wandb_enable=True but wandb is not installed; proceeding without W&B.")
else:
self._wandb_run = wandb.init(
entity=self.cfg.wandb_entity,
project=self.cfg.wandb_project or "hrm",
name=self.cfg.wandb_run_name,
config=asdict(self.cfg),
)
# -------------------------- public API --------------------------
def train(self):
self.model.train()
log_acc_loss = 0.0
log_acc_tokens = 0
t0 = time.time()
max_updates = self.cfg.max_steps
if max_updates is None:
steps_per_epoch = math.ceil(len(self.train_loader) / max(1, self.cfg.gradient_accumulation_steps))
max_updates = self.cfg.num_epochs * max(1, steps_per_epoch)
pbar = tqdm(total=max_updates, initial=self.global_step, desc="Training", dynamic_ncols=True)
while self.global_step < max_updates:
for batch in self.train_loader:
if self.global_step >= max_updates:
break
input_ids = batch["input_ids"].to(self.device)
labels = batch.get("labels", input_ids).to(self.device)
# Mask pads only if attention_mask/pad present
pad_id = getattr(self.tok, "pad_token_id", None) if self.tok is not None else (
getattr(getattr(self.model, "config", None), "pad_token_id", None)
)
attn = batch.get("attention_mask", None)
attn = attn.to(self.device) if attn is not None else None
labels = mask_pad_labels(input_ids, labels, pad_id=pad_id, attention_mask=attn)
ctx = (torch.autocast(device_type=self.device.type, dtype=self.amp_dtype)
if (self.amp_dtype is not None and self.device.type in ("cuda", "mps"))
else nullcontext())
with ctx:
out = self.model(input_ids=input_ids, labels=labels)
loss = _out_get(out, "loss")
if self.custom_loss_fn is not None:
loss = self.custom_loss_fn(out)
loss = loss / max(1, self.cfg.gradient_accumulation_steps)
logits = _out_get(out, "logits", None)
if logits is not None:
if not torch.isfinite(logits).all():
mx = logits.detach().float().abs().max().item()
raise FloatingPointError(f"logits non-finite (max|logit|={mx})")
if not torch.isfinite(loss):
lmax = (logits.detach().float().abs().max().item() if logits is not None else float("nan"))
print(f"[dbg] non-finite loss; max|logit|={lmax}, lm={_out_get(out,'lm_loss')}, ponder={_out_get(out,'ponder_loss')}")
raise FloatingPointError("Loss became non-finite.")
if self.scaler.is_enabled():
self.scaler.scale(loss).backward()
else:
loss.backward()
do_step = ((self.global_step + 1) % self.cfg.gradient_accumulation_steps == 0)
if do_step:
if self.scaler.is_enabled():
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.max_grad_norm)
if self.scaler.is_enabled():
prev_scale = self.scaler.get_scale()
self.scaler.step(self.optimizer)
self.scaler.update()
if self.scaler.get_scale() >= prev_scale:
self.scheduler.step()
else:
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad(set_to_none=True)
self.global_step += 1
pbar.update(1)
# Logging accumulators (token-weighted). Count only non-pad tokens.
tokens = int((labels != -100).sum().item())
lm_for_log = _out_get(out, "lm_loss", loss.detach())
log_acc_loss += float(lm_for_log) * max(1, tokens)
log_acc_tokens += max(1, tokens)
if self.global_step % max(1, self.cfg.log_every) == 0:
avg_loss = log_acc_loss / max(1, log_acc_tokens)
msg = {
"step": self.global_step,
"lr": self.scheduler.get_last_lr()[0],
"avg_lm_loss": avg_loss,
"ppl~": math.exp(min(20.0, avg_loss)),
"ponder": (_out_get(out, "ponder_loss", None)),
"elapsed_s": int(time.time() - t0),
}
tqdm.write("[log] " + ", ".join(f"{k}={format_num(v)}" for k, v in msg.items() if v is not None))
if self._wandb_run is not None:
self._wandb_run.log({k: v for k, v in msg.items() if isinstance(v, (int, float))})
log_acc_loss = 0.0
log_acc_tokens = 0
# Eval / early stop
if self.eval_loader and self.global_step % max(1, self.cfg.eval_every) == 0:
metrics = self.evaluate()
improved = self._check_improve(metrics[self.cfg.best_metric])
if self._wandb_run is not None:
self._wandb_run.log(metrics)
if self.cfg.early_stopping_patience is not None:
if improved:
self.no_improve_steps = 0
else:
self.no_improve_steps += self.cfg.eval_every
if self.no_improve_steps >= self.cfg.early_stopping_patience:
tqdm.write("[early-stop] patience exhausted.")
self._save_checkpoint(tag="final")
pbar.close()
return
if self.global_step % max(1, self.cfg.save_every) == 0:
self._save_checkpoint()
pbar.close()
self._save_checkpoint(tag="final")
@torch.no_grad()
def evaluate(self) -> Dict[str, float]:
self.model.eval()
total_loss = 0.0
total_tokens = 0
total_ponder = 0.0
n_batches = 0
for batch in tqdm(self.eval_loader, desc="Eval", leave=False):
input_ids = batch["input_ids"].to(self.device)
labels = batch.get("labels", input_ids).to(self.device)
pad_id = getattr(self.tok, "pad_token_id", None) if self.tok is not None else (
getattr(getattr(self.model, "config", None), "pad_token_id", None)
)
attn = batch.get("attention_mask", None)
attn = attn.to(self.device) if attn is not None else None
labels = mask_pad_labels(input_ids, labels, pad_id=pad_id, attention_mask=attn)
out = self.model(input_ids=input_ids, labels=labels)
lm = float(_out_get(out, "lm_loss", _out_get(out, "loss")))
tokens = int((labels != -100).sum().item())
total_loss += lm * max(1, tokens)
total_tokens += max(1, tokens)
pl = _out_get(out, "ponder_loss", None)
if pl is not None:
total_ponder += float(pl)
n_batches += 1
avg_loss = total_loss / max(1, total_tokens)
ppl = math.exp(min(20.0, avg_loss))
avg_ponder = (total_ponder / max(1, n_batches)) if n_batches > 0 else float("nan")
metrics = {"eval/loss": avg_loss, "eval/ppl": ppl, "eval/ponder": avg_ponder, "step": self.global_step}
tqdm.write("[eval] " + ", ".join(f"{k}={format_num(v)}" for k, v in metrics.items()))
self.model.train()
return metrics
# -------------------------- checkpoints --------------------------
def _save_checkpoint(self, tag: Optional[str] = None):
tag = tag or f"step{self.global_step}"
ckpt_dir = os.path.join(self.cfg.output_dir, f"ckpt-{tag}")
os.makedirs(ckpt_dir, exist_ok=True)
# trainer state (resumable)
torch.save({
"model": self.model.state_dict(),
"opt": self.optimizer.state_dict(),
"sched": self.scheduler.state_dict(),
"scaler": (self.scaler.state_dict() if self.scaler.is_enabled() else None),
"global_step": self.global_step,
"config": asdict(self.cfg),
}, os.path.join(ckpt_dir, "trainer_state.pt"))
# weights-only safetensors + minimal config
save_safetensors_safe(self.model, os.path.join(ckpt_dir, "model.safetensors"),
metadata={"note": "MiniTRLTrainer save", "global_step": str(self.global_step)})
with open(os.path.join(ckpt_dir, "config.json"), "w") as f:
json.dump({"global_step": self.global_step, **asdict(self.cfg)}, f, indent=2)
self._prune_checkpoints()
def _prune_checkpoints(self):
if self.cfg.save_total_limit is None:
return
subs = [d for d in os.listdir(self.cfg.output_dir) if d.startswith("ckpt-")]
if len(subs) <= self.cfg.save_total_limit:
return
subs = sorted(subs, key=lambda s: os.path.getmtime(os.path.join(self.cfg.output_dir, s)))
for d in subs[:-self.cfg.save_total_limit]:
shutil.rmtree(os.path.join(self.cfg.output_dir, d), ignore_errors=True)
def _maybe_resume(self):
if not self.cfg.resume_from:
return
state_path = os.path.join(self.cfg.resume_from, "trainer_state.pt")
if not os.path.exists(state_path):
print(f"[resume] not found: {state_path}")
return
ckpt = torch.load(state_path, map_location="cpu")
self.model.load_state_dict(ckpt["model"], strict=False)
self.optimizer.load_state_dict(ckpt["opt"])
self.scheduler.load_state_dict(ckpt["sched"])
if ckpt.get("scaler") and self.scaler.is_enabled():
self.scaler.load_state_dict(ckpt["scaler"])
self.global_step = int(ckpt.get("global_step", 0))
print(f"[resume] loaded from {self.cfg.resume_from} @ step {self.global_step}")
def _check_improve(self, val: float) -> bool:
improved = (val > self.best_metric_val) if self.cfg.greater_is_better else (val < self.best_metric_val)
if improved:
self.best_metric_val = val
return improved
# =========================================================
# Checkpoint helpers (complete save)
# =========================================================
def _state_dict_for_safetensors(model):
"""
Build a CPU state_dict suitable for safetensors.
If lm_head.weight is tied to tok_emb.weight, omit lm_head.weight to avoid duplicate storage.
"""
tied = hasattr(model, "lm_head") and hasattr(model, "tok_emb") and (
getattr(model.lm_head, "weight", None) is getattr(model.tok_emb, "weight", None)
)
sd_cpu = {k: v.detach().cpu() for k, v in model.state_dict().items()}
if tied and "lm_head.weight" in sd_cpu:
sd_cpu.pop("lm_head.weight")
return sd_cpu, tied
def retie_output_embedding(model):
"""
Re-tie output and input embeddings after loading weights, if model provides get_*_embeddings().
"""
if hasattr(model, "get_input_embeddings") and hasattr(model, "get_output_embeddings"):
inp = model.get_input_embeddings()
out = model.get_output_embeddings()
if inp is not None and out is not None and out.weight.data_ptr() != inp.weight.data_ptr():
out.weight = inp.weight
def _chain_get(obj, attrs, default=None):
"""
Safe chained getattr: _chain_get(model, ["L_mod", "attn", "num_heads"], default=None)
"""
cur = obj
for a in attrs:
if not hasattr(cur, a):
return default
cur = getattr(cur, a)
return cur
def save_model_complete(model, save_dir, tokenizer=None, training_args=None, metadata=None):
"""
Save model with all details: weights (.pt + .safetensors), config, architecture,
parameter summaries, tokenizer (optional), and a README.
Returns: save_dir
"""
os.makedirs(save_dir, exist_ok=True)
from datetime import datetime
from collections import OrderedDict
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
print(f"Saving model to: {save_dir}")
# 1) Weights (.pt)
print("1. Saving model weights (.pt)...")
checkpoint = {
"model_state_dict": model.state_dict(),
"timestamp": timestamp,
}
if training_args and "optimizer_state" in training_args:
checkpoint["optimizer_state_dict"] = training_args["optimizer_state"]
if training_args and "scheduler_state" in training_args:
checkpoint["scheduler_state_dict"] = training_args["scheduler_state"]
if training_args and "epoch" in training_args:
checkpoint["epoch"] = training_args["epoch"]
if training_args and "global_step" in training_args:
checkpoint["global_step"] = training_args["global_step"]
torch.save(checkpoint, os.path.join(save_dir, "model.pt"))
print(" ✓ Saved: model.pt")
# 1b) Weights (.safetensors)
print("1b. Saving model weights (.safetensors)...")
try:
from safetensors.torch import save_file
sd_cpu, tied = _state_dict_for_safetensors(model)
save_file(sd_cpu, os.path.join(save_dir, "model.safetensors"))
if tied:
print(" ℹ Weight tying detected: excluded lm_head.weight (re-tie on load).")
print(" ✓ Saved: model.safetensors")
except ImportError:
print(" ⚠ safetensors not installed, skipping .safetensors format")
except Exception as e:
print(f" ⚠ Could not save safetensors: {e}")
# 2) Save a minimal config (best-effort introspection)
print("2. Saving model config...")
vocab_size = getattr(model, "vocab_size", None)
d_model = getattr(model, "d_model", None)
n_heads = _chain_get(model, ["L_mod", "attn", "num_heads"], default=None)
d_ff = _chain_get(model, ["L_mod", "mlp", "w1", "out_features"], default=None)
dropout = _chain_get(model, ["L_mod", "drop", "p"], default=None)
k_l_steps = getattr(model, "k_l_steps", None)
max_cycles = getattr(model, "max_cycles", None)
ponder_w = getattr(model, "ponder_w", None)
has_out_norm = hasattr(model, "out_norm")
tied_flag = hasattr(model, "lm_head") and hasattr(model, "tok_emb") and (
getattr(model.lm_head, "weight", None) is getattr(model.tok_emb, "weight", None)
)
config = {
"model_type": type(model).__name__,
"vocab_size": vocab_size,
"d_model": d_model,
"n_heads": n_heads,
"d_ff": d_ff,
"dropout": dropout,
"k_l_steps": k_l_steps,
"max_cycles": max_cycles,
"ponder_loss_weight": ponder_w,
"has_out_norm": has_out_norm,
"weight_tying": tied_flag,
"tie_word_embeddings": tied_flag,
}
with open(os.path.join(save_dir, "config.json"), "w") as f:
json.dump(config, f, indent=2)
print(" ✓ Saved: config.json")
# 3) Architecture string
print("3. Saving model architecture...")
with open(os.path.join(save_dir, "architecture.txt"), "w") as f:
f.write(str(model))
print(" ✓ Saved: architecture.txt")
# 4) Parameter details
print("4. Saving parameter details...")
param_info = []
total_params = 0
trainable_params = 0
for name, p in model.named_parameters():
n = p.numel()
total_params += n
if p.requires_grad:
trainable_params += n
param_info.append({
"name": name,
"shape": list(p.shape),
"dtype": str(p.dtype),
"requires_grad": p.requires_grad,
"num_params": n,
"device": str(p.device),
})
param_summary = {
"total_parameters": total_params,
"trainable_parameters": trainable_params,
"non_trainable_parameters": total_params - trainable_params,
"size_mb": total_params * 4 / (1024 ** 2), # float32 estimate
"parameters": param_info,
}
with open(os.path.join(save_dir, "parameters.json"), "w") as f:
json.dump(param_summary, f, indent=2)
print(" ✓ Saved: parameters.json")
print(f" Total parameters: {total_params:,}")
print(f" Trainable: {trainable_params:,}")
print(f" Model size: {total_params * 4 / (1024**2):.2f} MB")
# 5) Layer-wise breakdown (top-level children)
print("5. Saving layer-wise breakdown...")
from collections import OrderedDict
layer_params = OrderedDict()
for name, module in model.named_children():
num_params = sum(p.numel() for p in module.parameters())
layer_params[name] = {
"num_params": num_params,
"percentage": 100 * num_params / total_params if total_params > 0 else 0,
}
with open(os.path.join(save_dir, "layer_params.json"), "w") as f:
json.dump(layer_params, f, indent=2)
print(" ✓ Saved: layer_params.json")
# 6) Training args (if provided)
if training_args:
print("6. Saving training arguments...")
serializable_args = {}
for k, v in training_args.items():
if isinstance(v, (int, float, str, bool, list, dict, type(None))):
serializable_args[k] = v
else:
serializable_args[k] = str(v)
with open(os.path.join(save_dir, "training_args.json"), "w") as f:
json.dump(serializable_args, f, indent=2)
print(" ✓ Saved: training_args.json")
# 7) Metadata
print("7. Saving metadata...")
metadata_full = {
"timestamp": timestamp,
"pytorch_version": torch.__version__,
"cuda_available": torch.cuda.is_available(),
"cuda_version": torch.version.cuda if torch.cuda.is_available() else None,
"device": str(next(model.parameters()).device),
"dtype": str(next(model.parameters()).dtype),
}
if metadata:
metadata_full.update(metadata)
with open(os.path.join(save_dir, "metadata.json"), "w") as f:
json.dump(metadata_full, f, indent=2)
print(" ✓ Saved: metadata.json")
# 8) Tokenizer (optional)
if tokenizer is not None:
print("8. Saving tokenizer...")
try:
tokenizer.save_pretrained(save_dir)
print(" ✓ Saved tokenizer files")
except Exception as e:
print(f" ⚠ Could not save tokenizer: {e}")
# 9) README
print("9. Creating README...")
readme_content = f"""# HRM/LM Model Checkpoint
## Model Information
- **Model Type**: {config['model_type']}
- **Timestamp**: {timestamp}
- **Total Parameters**: {total_params:,}
- **Trainable Parameters**: {trainable_params:,}
- **Model Size**: {total_params * 4 / (1024**2):.2f} MB
## Architecture (best-effort introspection)
- **Vocabulary Size**: {config['vocab_size']}
- **Hidden Dimension**: {config['d_model']}
- **Attention Heads**: {config['n_heads']}
- **FFN Dimension**: {config['d_ff']}
- **Dropout**: {config['dropout']}
- **L-mod Steps**: {config['k_l_steps']}
- **Max Cycles**: {config['max_cycles']}
- **Has Output Norm**: {config['has_out_norm']}
- **Weight Tying**: {config['weight_tying']} (tok_emb ↔ lm_head)
## Files
- `model.pt` — Full checkpoint (PyTorch)
- `model.safetensors` — Safetensors (excludes lm_head if tied)
- `config.json` — Model configuration summary
- `architecture.txt` — Stringified architecture
- `parameters.json` — Parameter metadata
- `layer_params.json` — Layer-wise parameter counts
- `training_args.json` — Training hyperparameters (if provided)
- `metadata.json` — Environment/device metadata
- Tokenizer files (if provided)
"""
with open(os.path.join(save_dir, 'README.md'), 'w') as f:
f.write(readme_content)
print(f" ✓ Saved: README.md")
print("\n" + "="*60)
print("SAVE COMPLETE!")
print("="*60)
print(f"Location: {save_dir}")
print(f"Files saved: {len(os.listdir(save_dir))}")
print("\nSummary:")
print(f" Total parameters: {total_params:,}")
print(f" Model size: {total_params * 4 / (1024**2):.2f} MB")
print(f" Config saved: ✓")
print(f" Weights saved: ✓")
print(f" Tokenizer saved: {'✓' if tokenizer else '✗'}")
print("="*60)
return save_dir
# =========================================================
# Minimal CLI (dynamic model loading via --load-fn module:function)
# =========================================================
def _load_via_callable(load_fn: str, **kwargs):
"""
load_fn: 'module.submodule:function_name' (e.g., 'hrm_utils:load_hrm')
kwargs: forwarded to the function
"""
if ":" not in load_fn:
raise ValueError("load_fn must look like 'module.submodule:function_name'")
mod_name, fn_name = load_fn.split(":", 1)
import importlib
mod = importlib.import_module(mod_name)
fn = getattr(mod, fn_name)
return fn(**kwargs)
def main():
import argparse
p = argparse.ArgumentParser(description="All-in-one HRM/LM data + trainer + checkpointing")
sub = p.add_subparsers(dest="cmd", required=True)
# prepare data
sp = sub.add_parser("prepare", help="Tokenize and build a quick dataloader")
sp.add_argument("--source", default="hf:lemonilia/wikified_english_dictionary")
sp.add_argument("--split", default="train")
sp.add_argument("--text-field", default="text")
sp.add_argument("--block-size", type=int, default=128)
sp.add_argument("--batch-size", type=int, default=8)
sp.add_argument("--tokenizer-dir", default=None)
# train
st = sub.add_parser("train", help="Train a model via dynamic loader")
st.add_argument("--load-fn", required=True, help="module:function (e.g. hrm_utils:load_hrm)")
st.add_argument("--load-args", default="{}", help="JSON dict of kwargs to pass to load-fn (e.g. '{\"name\":\"hrm_v0.04\",\"device\":\"cuda\",\"with_tokenizer\":true}')")
st.add_argument("--source", default="hf:lemonilia/wikified_english_dictionary")
st.add_argument("--split", default="train")
st.add_argument("--text-field", default="text")
st.add_argument("--block-size", type=int, default=128)
st.add_argument("--batch-size", type=int, default=8)
st.add_argument("--epochs", type=int, default=1)
st.add_argument("--lr", type=float, default=1e-4)
st.add_argument("--output-dir", default="outputs/hrm_run")
st.add_argument("--wandb", action="store_true")
st.add_argument("--wandb-entity", default=None)
st.add_argument("--wandb-project", default=None)
st.add_argument("--wandb-run-name", default=None)
# save checkpoint (complete)
ss = sub.add_parser("save", help="Save a fully-documented checkpoint for an already-loaded model")
ss.add_argument("--load-fn", required=True)
ss.add_argument("--load-args", default="{}")
ss.add_argument("--save-dir", default="saved_models/hrm_export")
ss.add_argument("--with-tokenizer", action="store_true")
args = p.parse_args()
if args.cmd == "prepare":
tok, loader, meta = load_dataset_fn(
source=args.source,
split=args.split,
text_field=args.text_field,
block_size=args.block_size,
batch_size=args.batch_size,
checkpoint_or_dir=args.tokenizer_dir,
)
print("[ok] Prepared one pass through dataloader:")
for i, b in enumerate(loader):
print(" batch", i, {k: v.shape for k, v in b.items()})
if i > 2: break
elif args.cmd == "train":
load_kwargs = json.loads(args.load_args or "{}")
obj = _load_via_callable(args.load_fn, **load_kwargs)
if isinstance(obj, tuple) and len(obj) >= 2:
model, tokenizer = obj[0], obj[1]
else:
# assume loader returns just model; tokenizer is optional/None
model, tokenizer = obj, None
tok, train_loader, _ = load_dataset_fn(
source=args.source,
split=args.split,
text_field=args.text_field,
block_size=args.block_size,
batch_size=args.batch_size,
checkpoint_or_dir=(tokenizer.name_or_path if tokenizer is not None else None),
)
tokenizer = tokenizer or tok
cfg = TrainConfig(
output_dir=args.output_dir,
num_epochs=args.epochs,
learning_rate=args.lr,
wandb_enable=bool(args.wandb),
wandb_entity=args.wandb_entity,
wandb_project=args.wandb_project,
wandb_run_name=args.wandb_run_name,
)
trainer = MiniTRLTrainer(
model=model,
train_loader=train_loader,
tokenizer=tokenizer,
eval_loader=None, # plug one in if you want
config=cfg,
)
trainer.train()
elif args.cmd == "save":
load_kwargs = json.loads(args.load_args or "{}")
obj = _load_via_callable(args.load_fn, **load_kwargs)
if isinstance(obj, tuple) and len(obj) >= 2:
model, tokenizer = obj[0], obj[1]
else:
model, tokenizer = obj, None
save_model_complete(model, args.save_dir, tokenizer=(tokenizer if args.with_tokenizer else None))
if __name__ == "__main__":
main()