|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import os, re, math, json, time, shutil, random |
|
|
from dataclasses import dataclass, asdict |
|
|
from typing import Optional, Callable, Dict, Any, Iterable, Tuple, List |
|
|
from contextlib import nullcontext |
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch import nn |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer, get_linear_schedule_with_warmup |
|
|
from datasets import load_dataset, DatasetDict |
|
|
|
|
|
|
|
|
try: |
|
|
import wandb |
|
|
except Exception: |
|
|
wandb = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_seed(seed: int = 1337): |
|
|
import numpy as np |
|
|
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
|
|
|
def auto_device(): |
|
|
if torch.cuda.is_available(): return torch.device("cuda") |
|
|
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): return torch.device("mps") |
|
|
return torch.device("cpu") |
|
|
|
|
|
|
|
|
def format_num(x): |
|
|
try: return f"{x:.6g}" |
|
|
except: return str(x) |
|
|
|
|
|
|
|
|
def save_safetensors_safe(model: nn.Module, path: str, metadata: Optional[Dict[str, str]] = None): |
|
|
""" |
|
|
Save weights as .safetensors, handling tied weights (lm_head <- tok_emb) when needed. |
|
|
""" |
|
|
try: |
|
|
from safetensors.torch import save_model |
|
|
save_model(model, path, metadata=metadata or {}) |
|
|
except Exception: |
|
|
|
|
|
try: |
|
|
from safetensors.torch import save_file |
|
|
state = model.state_dict() |
|
|
if "lm_head.weight" in state and "tok_emb.weight" in state: |
|
|
state["lm_head.weight"] = state["tok_emb.weight"].clone() |
|
|
save_file(state, path, metadata=metadata or {}) |
|
|
except Exception as e: |
|
|
print("[warn] safetensors not saved:", e) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _gpt2_tokenizer_with_specials( |
|
|
additional: Optional[List[str]] = None, |
|
|
checkpoint_or_dir: Optional[str] = None, |
|
|
) -> AutoTokenizer: |
|
|
""" |
|
|
If `checkpoint_or_dir` is provided, load tokenizer from there; else use 'gpt2'. |
|
|
Ensures PAD exists (PAD→EOS), optionally adds extra specials, sets a huge model_max_length. |
|
|
""" |
|
|
tok = None |
|
|
if checkpoint_or_dir is not None: |
|
|
try: |
|
|
tok = AutoTokenizer.from_pretrained(checkpoint_or_dir, use_fast=True) |
|
|
except Exception as e: |
|
|
print(f"[warn] Failed to load tokenizer from '{checkpoint_or_dir}': {e}") |
|
|
print("[warn] Falling back to 'gpt2' tokenizer.") |
|
|
|
|
|
if tok is None: |
|
|
tok = AutoTokenizer.from_pretrained("gpt2", use_fast=True) |
|
|
|
|
|
if tok.eos_token is None: |
|
|
tok.add_special_tokens({"eos_token": "</s>"}) |
|
|
if tok.pad_token is None: |
|
|
tok.pad_token = tok.eos_token |
|
|
|
|
|
if additional: |
|
|
new_tokens = [t for t in additional if t not in tok.get_vocab()] |
|
|
if new_tokens: |
|
|
tok.add_special_tokens({"additional_special_tokens": new_tokens}) |
|
|
print(f"[info] Added {len(new_tokens)} special tokens to tokenizer") |
|
|
|
|
|
tok.model_max_length = 10_000_000 |
|
|
tok.init_kwargs["model_max_length"] = tok.model_max_length |
|
|
return tok |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CausalChunked(Dataset): |
|
|
"""Flatten tokens then slice into non-overlapping blocks; x == labels.""" |
|
|
|
|
|
def __init__(self, token_ids: Iterable[int], block_size: int): |
|
|
ids = list(token_ids) |
|
|
n_full = (len(ids) // block_size) * block_size |
|
|
n_discarded = len(ids) - n_full |
|
|
|
|
|
if n_discarded > 0 and len(ids) > 0: |
|
|
pct = n_discarded / len(ids) * 100 |
|
|
print(f"[info] Discarded {n_discarded} tokens ({pct:.2f}%) that didn't fit into complete blocks") |
|
|
|
|
|
ids = ids[:n_full] |
|
|
self.blocks = [ids[i:i + block_size] for i in range(0, n_full, block_size)] |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.blocks) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
|
|
x = torch.tensor(self.blocks[idx], dtype=torch.long) |
|
|
return {"input_ids": x, "labels": x.clone()} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mask_pad_labels( |
|
|
input_ids: torch.Tensor, |
|
|
labels: torch.Tensor, |
|
|
pad_id: Optional[int] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Clone `labels` and set pad positions to -100 (ignored by CrossEntropyLoss). |
|
|
Prefers `attention_mask` if provided, otherwise uses pad_id to detect padding. |
|
|
|
|
|
Note: CausalChunked produces fixed-length blocks without padding, so this is |
|
|
only needed if you supply your own dataloader with padding. |
|
|
""" |
|
|
lab = labels.clone() |
|
|
if attention_mask is not None: |
|
|
lab[attention_mask == 0] = -100 |
|
|
elif pad_id is not None: |
|
|
lab[input_ids == pad_id] = -100 |
|
|
return lab |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_dataset_fn( |
|
|
source: str = "hf:lemonilia/wikified_english_dictionary", |
|
|
split: str = "train", |
|
|
*, |
|
|
text_field: Optional[str] = "text", |
|
|
word_field: str = "word", |
|
|
article_field: str = "article", |
|
|
block_size: int = 128, |
|
|
batch_size: int = 8, |
|
|
num_workers: int = 0, |
|
|
shuffle: bool = True, |
|
|
checkpoint_or_dir: Optional[str] = None, |
|
|
additional_specials: Optional[List[str]] = None, |
|
|
) -> Tuple[AutoTokenizer, DataLoader, Dict[str, int]]: |
|
|
""" |
|
|
Load and tokenize a dataset for causal LM training. Returns (tokenizer, DataLoader, meta). |
|
|
|
|
|
source: |
|
|
- 'hf:<name_or_path>' to read a HuggingFace dataset |
|
|
- 'txt:/path1;/path2;...' to read local text files (semicolon-separated) |
|
|
|
|
|
Behavior: |
|
|
• If `text_field` is present, uses it. |
|
|
• Else if both `word_field` and `article_field` exist, merges them as: |
|
|
"<word>\\n<article>\\n\\n" |
|
|
while stripping any <|begin_of_thought|>...<|end_of_thought|> spans. |
|
|
• Else, falls back to a 'text' column if available. |
|
|
• Appends EOS between docs/files to avoid cross-boundary contamination. |
|
|
""" |
|
|
tokenizer = _gpt2_tokenizer_with_specials( |
|
|
additional=additional_specials, |
|
|
checkpoint_or_dir=checkpoint_or_dir, |
|
|
) |
|
|
eos_id = tokenizer.eos_token_id |
|
|
token_stream: List[int] = [] |
|
|
|
|
|
if source.startswith("hf:"): |
|
|
ds_name = source[3:] |
|
|
raw = load_dataset(ds_name) |
|
|
if split not in raw: |
|
|
raise ValueError(f"[error] Split '{split}' not found. Available: {list(raw.keys())}") |
|
|
|
|
|
cols = raw[split].column_names |
|
|
|
|
|
|
|
|
if text_field is not None and text_field in cols: |
|
|
field_to_use = text_field |
|
|
|
|
|
def tok_map(batch): |
|
|
return tokenizer(batch[field_to_use], add_special_tokens=False) |
|
|
|
|
|
toks = raw.map(tok_map, batched=True, remove_columns=cols) |
|
|
|
|
|
|
|
|
elif (text_field is None or text_field not in cols) and word_field in cols and article_field in cols: |
|
|
BEGIN_THOUGHT = re.compile(r"<\|begin_of_thought\|>.*?<\|end_of_thought\|>", re.DOTALL) |
|
|
|
|
|
def fmt(batch): |
|
|
out = [] |
|
|
for w, a in zip(batch[word_field], batch[article_field]): |
|
|
w = (w or "").strip() |
|
|
a = re.sub(BEGIN_THOUGHT, "", (a or "")).strip() |
|
|
out.append(w + "\n" + a + "\n\n") |
|
|
return {"text": out} |
|
|
|
|
|
raw = raw.map(fmt, batched=True) |
|
|
raw = DatasetDict({ |
|
|
sp: d.remove_columns([c for c in d.column_names if c != "text"]) |
|
|
for sp, d in raw.items() |
|
|
}) |
|
|
|
|
|
def tok_map(batch): |
|
|
return tokenizer(batch["text"], add_special_tokens=False) |
|
|
|
|
|
toks = raw.map(tok_map, batched=True, remove_columns=["text"]) |
|
|
|
|
|
|
|
|
elif "text" in cols: |
|
|
def tok_map(batch): |
|
|
return tokenizer(batch["text"], add_special_tokens=False) |
|
|
|
|
|
toks = raw.map(tok_map, batched=True, remove_columns=cols) |
|
|
|
|
|
else: |
|
|
raise ValueError( |
|
|
f"[error] Could not find a text source.\n" |
|
|
f" - Requested text_field={text_field!r}\n" |
|
|
f" - Available columns: {cols}\n" |
|
|
f" - Set text_field accordingly, or set text_field=None if your dataset has " |
|
|
f" both '{word_field}' and '{article_field}' to auto-merge." |
|
|
) |
|
|
|
|
|
n_empty = 0 |
|
|
for doc in toks[split]["input_ids"]: |
|
|
if not doc: |
|
|
n_empty += 1 |
|
|
continue |
|
|
token_stream.extend(doc) |
|
|
if eos_id is not None: |
|
|
token_stream.append(eos_id) |
|
|
|
|
|
if n_empty > 0: |
|
|
print(f"[info] Skipped {n_empty} empty documents") |
|
|
|
|
|
elif source.startswith("txt:"): |
|
|
paths = [p for p in source[4:].split(";") if p] |
|
|
if not paths: |
|
|
raise ValueError("[error] No file paths provided after 'txt:'") |
|
|
|
|
|
for p in paths: |
|
|
if not os.path.exists(p): |
|
|
raise FileNotFoundError(f"[error] File not found: {p}") |
|
|
with open(p, "r", encoding="utf-8") as f: |
|
|
text = f.read() |
|
|
if text.strip(): |
|
|
ids = tokenizer(text, add_special_tokens=False)["input_ids"] |
|
|
token_stream.extend(ids) |
|
|
if eos_id is not None: |
|
|
token_stream.append(eos_id) |
|
|
else: |
|
|
raise ValueError("[error] source must start with 'hf:' or 'txt:'") |
|
|
|
|
|
if not token_stream: |
|
|
raise ValueError("[error] No tokens extracted from the source. Check your data.") |
|
|
|
|
|
ds = CausalChunked(token_stream, block_size) |
|
|
if len(ds) == 0: |
|
|
raise ValueError( |
|
|
f"[error] Tokenized corpus ({len(token_stream)} tokens) is too small " |
|
|
f"for block_size={block_size}. No complete blocks produced." |
|
|
) |
|
|
|
|
|
loader = DataLoader( |
|
|
ds, |
|
|
batch_size=batch_size, |
|
|
shuffle=shuffle, |
|
|
drop_last=True, |
|
|
pin_memory=torch.cuda.is_available(), |
|
|
num_workers=num_workers, |
|
|
) |
|
|
|
|
|
meta = { |
|
|
"vocab_size": len(tokenizer), |
|
|
"eos_id": eos_id, |
|
|
"n_blocks": len(ds), |
|
|
"n_tokens": len(token_stream), |
|
|
"tokens_per_block": block_size, |
|
|
} |
|
|
|
|
|
print(f"[info] Dataset ready: {meta['n_blocks']} blocks, {meta['n_tokens']} tokens total") |
|
|
return tokenizer, loader, meta |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainConfig: |
|
|
output_dir: str = "outputs/hrm_run" |
|
|
num_epochs: int = 1 |
|
|
max_steps: Optional[int] = None |
|
|
per_device_train_batch_size: int = 8 |
|
|
gradient_accumulation_steps: int = 1 |
|
|
learning_rate: float = 1e-4 |
|
|
betas: tuple = (0.9, 0.95) |
|
|
eps: float = 1e-8 |
|
|
weight_decay: float = 0.01 |
|
|
warmup_ratio: float = 0.06 |
|
|
max_grad_norm: float = 0.5 |
|
|
log_every: int = 100 |
|
|
save_every: int = 2000 |
|
|
eval_every: int = 2000 |
|
|
save_total_limit: int = 3 |
|
|
fp16: bool = False |
|
|
bf16: bool = True |
|
|
seed: int = 1337 |
|
|
resume_from: Optional[str] = None |
|
|
early_stopping_patience: Optional[int] = None |
|
|
best_metric: str = "eval/loss" |
|
|
greater_is_better: bool = False |
|
|
torch_compile: bool = False |
|
|
|
|
|
|
|
|
wandb_enable: bool = False |
|
|
wandb_entity: Optional[str] = None |
|
|
wandb_project: Optional[str] = None |
|
|
wandb_run_name: Optional[str] = None |
|
|
|
|
|
|
|
|
def _out_get(out: Any, key: str, default=None): |
|
|
if isinstance(out, dict): |
|
|
return out.get(key, default) |
|
|
return getattr(out, key, default) |
|
|
|
|
|
|
|
|
class MiniTRLTrainer: |
|
|
""" |
|
|
TRL-like supervised trainer: |
|
|
|
|
|
Model forward must accept (input_ids, labels) and return something with: |
|
|
- loss (required) |
|
|
- logits (optional but recommended; used for sanity checks) |
|
|
- lm_loss (optional; logged if present) |
|
|
- ponder_loss (optional; logged if present) |
|
|
|
|
|
DataLoader must yield dicts with keys: |
|
|
- "input_ids" and (optionally) "labels". If "labels" missing, labels=input_ids. |
|
|
- If you pad to fixed length externally, also pass "attention_mask" so we can mask pad tokens. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: nn.Module, |
|
|
train_loader: DataLoader, |
|
|
tokenizer: Optional[AutoTokenizer] = None, |
|
|
eval_loader: Optional[DataLoader] = None, |
|
|
config: TrainConfig = TrainConfig(), |
|
|
compute_metrics: Optional[Callable[[Dict[str, float]], Dict[str, float]]] = None, |
|
|
custom_loss_fn: Optional[Callable[[Any], torch.Tensor]] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
): |
|
|
self.model = model |
|
|
self.train_loader = train_loader |
|
|
self.eval_loader = eval_loader |
|
|
self.tok = tokenizer |
|
|
self.cfg = config |
|
|
self.compute_metrics = compute_metrics |
|
|
self.custom_loss_fn = custom_loss_fn |
|
|
self.device = device or auto_device() |
|
|
set_seed(self.cfg.seed) |
|
|
|
|
|
self.model.to(self.device) |
|
|
if self.cfg.torch_compile: |
|
|
try: |
|
|
self.model = torch.compile(self.model) |
|
|
except Exception as e: |
|
|
print("[warn] torch.compile failed:", e) |
|
|
|
|
|
|
|
|
if self.device.type == "cuda": |
|
|
self.amp_dtype = torch.bfloat16 if (self.cfg.bf16 and torch.cuda.is_bf16_supported()) else (torch.float16 if self.cfg.fp16 else None) |
|
|
else: |
|
|
self.amp_dtype = None |
|
|
|
|
|
|
|
|
decay, no_decay = [], [] |
|
|
for n, p in self.model.named_parameters(): |
|
|
if not p.requires_grad: |
|
|
continue |
|
|
nl = n.lower() |
|
|
if p.ndim == 1 or "norm" in nl or "bias" in nl or ("tok_emb.weight" in n): |
|
|
no_decay.append(p) |
|
|
else: |
|
|
decay.append(p) |
|
|
|
|
|
self.optimizer = torch.optim.AdamW( |
|
|
[{"params": decay, "weight_decay": self.cfg.weight_decay}, |
|
|
{"params": no_decay, "weight_decay": 0.0}], |
|
|
lr=self.cfg.learning_rate, betas=self.cfg.betas, eps=self.cfg.eps |
|
|
) |
|
|
|
|
|
|
|
|
steps_per_epoch = math.ceil(len(self.train_loader) / max(1, self.cfg.gradient_accumulation_steps)) |
|
|
total_updates = self.cfg.max_steps if self.cfg.max_steps is not None else self.cfg.num_epochs * max(1, steps_per_epoch) |
|
|
total_updates = max(1, total_updates) |
|
|
warmup_steps = int(self.cfg.warmup_ratio * total_updates) |
|
|
self.scheduler = get_linear_schedule_with_warmup(self.optimizer, warmup_steps, total_updates) |
|
|
|
|
|
|
|
|
self.scaler = torch.cuda.amp.GradScaler(enabled=(self.amp_dtype == torch.float16 and self.device.type == "cuda")) |
|
|
|
|
|
|
|
|
self.global_step = 0 |
|
|
self.best_metric_val = float("-inf") if self.cfg.greater_is_better else float("inf") |
|
|
self.no_improve_steps = 0 |
|
|
|
|
|
os.makedirs(self.cfg.output_dir, exist_ok=True) |
|
|
self._maybe_resume() |
|
|
|
|
|
|
|
|
self._wandb_run = None |
|
|
if self.cfg.wandb_enable: |
|
|
if wandb is None: |
|
|
print("[warn] wandb_enable=True but wandb is not installed; proceeding without W&B.") |
|
|
else: |
|
|
self._wandb_run = wandb.init( |
|
|
entity=self.cfg.wandb_entity, |
|
|
project=self.cfg.wandb_project or "hrm", |
|
|
name=self.cfg.wandb_run_name, |
|
|
config=asdict(self.cfg), |
|
|
) |
|
|
|
|
|
|
|
|
def train(self): |
|
|
self.model.train() |
|
|
log_acc_loss = 0.0 |
|
|
log_acc_tokens = 0 |
|
|
t0 = time.time() |
|
|
|
|
|
max_updates = self.cfg.max_steps |
|
|
if max_updates is None: |
|
|
steps_per_epoch = math.ceil(len(self.train_loader) / max(1, self.cfg.gradient_accumulation_steps)) |
|
|
max_updates = self.cfg.num_epochs * max(1, steps_per_epoch) |
|
|
|
|
|
pbar = tqdm(total=max_updates, initial=self.global_step, desc="Training", dynamic_ncols=True) |
|
|
|
|
|
while self.global_step < max_updates: |
|
|
for batch in self.train_loader: |
|
|
if self.global_step >= max_updates: |
|
|
break |
|
|
|
|
|
input_ids = batch["input_ids"].to(self.device) |
|
|
labels = batch.get("labels", input_ids).to(self.device) |
|
|
|
|
|
|
|
|
pad_id = getattr(self.tok, "pad_token_id", None) if self.tok is not None else ( |
|
|
getattr(getattr(self.model, "config", None), "pad_token_id", None) |
|
|
) |
|
|
attn = batch.get("attention_mask", None) |
|
|
attn = attn.to(self.device) if attn is not None else None |
|
|
labels = mask_pad_labels(input_ids, labels, pad_id=pad_id, attention_mask=attn) |
|
|
|
|
|
ctx = (torch.autocast(device_type=self.device.type, dtype=self.amp_dtype) |
|
|
if (self.amp_dtype is not None and self.device.type in ("cuda", "mps")) |
|
|
else nullcontext()) |
|
|
|
|
|
with ctx: |
|
|
out = self.model(input_ids=input_ids, labels=labels) |
|
|
loss = _out_get(out, "loss") |
|
|
if self.custom_loss_fn is not None: |
|
|
loss = self.custom_loss_fn(out) |
|
|
loss = loss / max(1, self.cfg.gradient_accumulation_steps) |
|
|
|
|
|
logits = _out_get(out, "logits", None) |
|
|
if logits is not None: |
|
|
if not torch.isfinite(logits).all(): |
|
|
mx = logits.detach().float().abs().max().item() |
|
|
raise FloatingPointError(f"logits non-finite (max|logit|={mx})") |
|
|
if not torch.isfinite(loss): |
|
|
lmax = (logits.detach().float().abs().max().item() if logits is not None else float("nan")) |
|
|
print(f"[dbg] non-finite loss; max|logit|={lmax}, lm={_out_get(out,'lm_loss')}, ponder={_out_get(out,'ponder_loss')}") |
|
|
raise FloatingPointError("Loss became non-finite.") |
|
|
|
|
|
if self.scaler.is_enabled(): |
|
|
self.scaler.scale(loss).backward() |
|
|
else: |
|
|
loss.backward() |
|
|
|
|
|
do_step = ((self.global_step + 1) % self.cfg.gradient_accumulation_steps == 0) |
|
|
if do_step: |
|
|
if self.scaler.is_enabled(): |
|
|
self.scaler.unscale_(self.optimizer) |
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.max_grad_norm) |
|
|
|
|
|
if self.scaler.is_enabled(): |
|
|
prev_scale = self.scaler.get_scale() |
|
|
self.scaler.step(self.optimizer) |
|
|
self.scaler.update() |
|
|
if self.scaler.get_scale() >= prev_scale: |
|
|
self.scheduler.step() |
|
|
else: |
|
|
self.optimizer.step() |
|
|
self.scheduler.step() |
|
|
|
|
|
self.optimizer.zero_grad(set_to_none=True) |
|
|
self.global_step += 1 |
|
|
pbar.update(1) |
|
|
|
|
|
|
|
|
tokens = int((labels != -100).sum().item()) |
|
|
lm_for_log = _out_get(out, "lm_loss", loss.detach()) |
|
|
log_acc_loss += float(lm_for_log) * max(1, tokens) |
|
|
log_acc_tokens += max(1, tokens) |
|
|
|
|
|
if self.global_step % max(1, self.cfg.log_every) == 0: |
|
|
avg_loss = log_acc_loss / max(1, log_acc_tokens) |
|
|
msg = { |
|
|
"step": self.global_step, |
|
|
"lr": self.scheduler.get_last_lr()[0], |
|
|
"avg_lm_loss": avg_loss, |
|
|
"ppl~": math.exp(min(20.0, avg_loss)), |
|
|
"ponder": (_out_get(out, "ponder_loss", None)), |
|
|
"elapsed_s": int(time.time() - t0), |
|
|
} |
|
|
tqdm.write("[log] " + ", ".join(f"{k}={format_num(v)}" for k, v in msg.items() if v is not None)) |
|
|
if self._wandb_run is not None: |
|
|
self._wandb_run.log({k: v for k, v in msg.items() if isinstance(v, (int, float))}) |
|
|
log_acc_loss = 0.0 |
|
|
log_acc_tokens = 0 |
|
|
|
|
|
|
|
|
if self.eval_loader and self.global_step % max(1, self.cfg.eval_every) == 0: |
|
|
metrics = self.evaluate() |
|
|
improved = self._check_improve(metrics[self.cfg.best_metric]) |
|
|
if self._wandb_run is not None: |
|
|
self._wandb_run.log(metrics) |
|
|
if self.cfg.early_stopping_patience is not None: |
|
|
if improved: |
|
|
self.no_improve_steps = 0 |
|
|
else: |
|
|
self.no_improve_steps += self.cfg.eval_every |
|
|
if self.no_improve_steps >= self.cfg.early_stopping_patience: |
|
|
tqdm.write("[early-stop] patience exhausted.") |
|
|
self._save_checkpoint(tag="final") |
|
|
pbar.close() |
|
|
return |
|
|
|
|
|
if self.global_step % max(1, self.cfg.save_every) == 0: |
|
|
self._save_checkpoint() |
|
|
|
|
|
pbar.close() |
|
|
self._save_checkpoint(tag="final") |
|
|
|
|
|
@torch.no_grad() |
|
|
def evaluate(self) -> Dict[str, float]: |
|
|
self.model.eval() |
|
|
total_loss = 0.0 |
|
|
total_tokens = 0 |
|
|
total_ponder = 0.0 |
|
|
n_batches = 0 |
|
|
|
|
|
for batch in tqdm(self.eval_loader, desc="Eval", leave=False): |
|
|
input_ids = batch["input_ids"].to(self.device) |
|
|
labels = batch.get("labels", input_ids).to(self.device) |
|
|
|
|
|
pad_id = getattr(self.tok, "pad_token_id", None) if self.tok is not None else ( |
|
|
getattr(getattr(self.model, "config", None), "pad_token_id", None) |
|
|
) |
|
|
attn = batch.get("attention_mask", None) |
|
|
attn = attn.to(self.device) if attn is not None else None |
|
|
labels = mask_pad_labels(input_ids, labels, pad_id=pad_id, attention_mask=attn) |
|
|
|
|
|
out = self.model(input_ids=input_ids, labels=labels) |
|
|
lm = float(_out_get(out, "lm_loss", _out_get(out, "loss"))) |
|
|
tokens = int((labels != -100).sum().item()) |
|
|
|
|
|
total_loss += lm * max(1, tokens) |
|
|
total_tokens += max(1, tokens) |
|
|
pl = _out_get(out, "ponder_loss", None) |
|
|
if pl is not None: |
|
|
total_ponder += float(pl) |
|
|
n_batches += 1 |
|
|
|
|
|
avg_loss = total_loss / max(1, total_tokens) |
|
|
ppl = math.exp(min(20.0, avg_loss)) |
|
|
avg_ponder = (total_ponder / max(1, n_batches)) if n_batches > 0 else float("nan") |
|
|
metrics = {"eval/loss": avg_loss, "eval/ppl": ppl, "eval/ponder": avg_ponder, "step": self.global_step} |
|
|
tqdm.write("[eval] " + ", ".join(f"{k}={format_num(v)}" for k, v in metrics.items())) |
|
|
self.model.train() |
|
|
return metrics |
|
|
|
|
|
|
|
|
def _save_checkpoint(self, tag: Optional[str] = None): |
|
|
tag = tag or f"step{self.global_step}" |
|
|
ckpt_dir = os.path.join(self.cfg.output_dir, f"ckpt-{tag}") |
|
|
os.makedirs(ckpt_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
torch.save({ |
|
|
"model": self.model.state_dict(), |
|
|
"opt": self.optimizer.state_dict(), |
|
|
"sched": self.scheduler.state_dict(), |
|
|
"scaler": (self.scaler.state_dict() if self.scaler.is_enabled() else None), |
|
|
"global_step": self.global_step, |
|
|
"config": asdict(self.cfg), |
|
|
}, os.path.join(ckpt_dir, "trainer_state.pt")) |
|
|
|
|
|
|
|
|
save_safetensors_safe(self.model, os.path.join(ckpt_dir, "model.safetensors"), |
|
|
metadata={"note": "MiniTRLTrainer save", "global_step": str(self.global_step)}) |
|
|
with open(os.path.join(ckpt_dir, "config.json"), "w") as f: |
|
|
json.dump({"global_step": self.global_step, **asdict(self.cfg)}, f, indent=2) |
|
|
|
|
|
self._prune_checkpoints() |
|
|
|
|
|
def _prune_checkpoints(self): |
|
|
if self.cfg.save_total_limit is None: |
|
|
return |
|
|
subs = [d for d in os.listdir(self.cfg.output_dir) if d.startswith("ckpt-")] |
|
|
if len(subs) <= self.cfg.save_total_limit: |
|
|
return |
|
|
subs = sorted(subs, key=lambda s: os.path.getmtime(os.path.join(self.cfg.output_dir, s))) |
|
|
for d in subs[:-self.cfg.save_total_limit]: |
|
|
shutil.rmtree(os.path.join(self.cfg.output_dir, d), ignore_errors=True) |
|
|
|
|
|
def _maybe_resume(self): |
|
|
if not self.cfg.resume_from: |
|
|
return |
|
|
state_path = os.path.join(self.cfg.resume_from, "trainer_state.pt") |
|
|
if not os.path.exists(state_path): |
|
|
print(f"[resume] not found: {state_path}") |
|
|
return |
|
|
ckpt = torch.load(state_path, map_location="cpu") |
|
|
self.model.load_state_dict(ckpt["model"], strict=False) |
|
|
self.optimizer.load_state_dict(ckpt["opt"]) |
|
|
self.scheduler.load_state_dict(ckpt["sched"]) |
|
|
if ckpt.get("scaler") and self.scaler.is_enabled(): |
|
|
self.scaler.load_state_dict(ckpt["scaler"]) |
|
|
self.global_step = int(ckpt.get("global_step", 0)) |
|
|
print(f"[resume] loaded from {self.cfg.resume_from} @ step {self.global_step}") |
|
|
|
|
|
def _check_improve(self, val: float) -> bool: |
|
|
improved = (val > self.best_metric_val) if self.cfg.greater_is_better else (val < self.best_metric_val) |
|
|
if improved: |
|
|
self.best_metric_val = val |
|
|
return improved |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _state_dict_for_safetensors(model): |
|
|
""" |
|
|
Build a CPU state_dict suitable for safetensors. |
|
|
If lm_head.weight is tied to tok_emb.weight, omit lm_head.weight to avoid duplicate storage. |
|
|
""" |
|
|
tied = hasattr(model, "lm_head") and hasattr(model, "tok_emb") and ( |
|
|
getattr(model.lm_head, "weight", None) is getattr(model.tok_emb, "weight", None) |
|
|
) |
|
|
sd_cpu = {k: v.detach().cpu() for k, v in model.state_dict().items()} |
|
|
if tied and "lm_head.weight" in sd_cpu: |
|
|
sd_cpu.pop("lm_head.weight") |
|
|
return sd_cpu, tied |
|
|
|
|
|
|
|
|
def retie_output_embedding(model): |
|
|
""" |
|
|
Re-tie output and input embeddings after loading weights, if model provides get_*_embeddings(). |
|
|
""" |
|
|
if hasattr(model, "get_input_embeddings") and hasattr(model, "get_output_embeddings"): |
|
|
inp = model.get_input_embeddings() |
|
|
out = model.get_output_embeddings() |
|
|
if inp is not None and out is not None and out.weight.data_ptr() != inp.weight.data_ptr(): |
|
|
out.weight = inp.weight |
|
|
|
|
|
|
|
|
def _chain_get(obj, attrs, default=None): |
|
|
""" |
|
|
Safe chained getattr: _chain_get(model, ["L_mod", "attn", "num_heads"], default=None) |
|
|
""" |
|
|
cur = obj |
|
|
for a in attrs: |
|
|
if not hasattr(cur, a): |
|
|
return default |
|
|
cur = getattr(cur, a) |
|
|
return cur |
|
|
|
|
|
|
|
|
def save_model_complete(model, save_dir, tokenizer=None, training_args=None, metadata=None): |
|
|
""" |
|
|
Save model with all details: weights (.pt + .safetensors), config, architecture, |
|
|
parameter summaries, tokenizer (optional), and a README. |
|
|
|
|
|
Returns: save_dir |
|
|
""" |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
from datetime import datetime |
|
|
from collections import OrderedDict |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
print(f"Saving model to: {save_dir}") |
|
|
|
|
|
|
|
|
print("1. Saving model weights (.pt)...") |
|
|
checkpoint = { |
|
|
"model_state_dict": model.state_dict(), |
|
|
"timestamp": timestamp, |
|
|
} |
|
|
if training_args and "optimizer_state" in training_args: |
|
|
checkpoint["optimizer_state_dict"] = training_args["optimizer_state"] |
|
|
if training_args and "scheduler_state" in training_args: |
|
|
checkpoint["scheduler_state_dict"] = training_args["scheduler_state"] |
|
|
if training_args and "epoch" in training_args: |
|
|
checkpoint["epoch"] = training_args["epoch"] |
|
|
if training_args and "global_step" in training_args: |
|
|
checkpoint["global_step"] = training_args["global_step"] |
|
|
torch.save(checkpoint, os.path.join(save_dir, "model.pt")) |
|
|
print(" ✓ Saved: model.pt") |
|
|
|
|
|
|
|
|
print("1b. Saving model weights (.safetensors)...") |
|
|
try: |
|
|
from safetensors.torch import save_file |
|
|
sd_cpu, tied = _state_dict_for_safetensors(model) |
|
|
save_file(sd_cpu, os.path.join(save_dir, "model.safetensors")) |
|
|
if tied: |
|
|
print(" ℹ Weight tying detected: excluded lm_head.weight (re-tie on load).") |
|
|
print(" ✓ Saved: model.safetensors") |
|
|
except ImportError: |
|
|
print(" ⚠ safetensors not installed, skipping .safetensors format") |
|
|
except Exception as e: |
|
|
print(f" ⚠ Could not save safetensors: {e}") |
|
|
|
|
|
|
|
|
print("2. Saving model config...") |
|
|
vocab_size = getattr(model, "vocab_size", None) |
|
|
d_model = getattr(model, "d_model", None) |
|
|
n_heads = _chain_get(model, ["L_mod", "attn", "num_heads"], default=None) |
|
|
d_ff = _chain_get(model, ["L_mod", "mlp", "w1", "out_features"], default=None) |
|
|
dropout = _chain_get(model, ["L_mod", "drop", "p"], default=None) |
|
|
k_l_steps = getattr(model, "k_l_steps", None) |
|
|
max_cycles = getattr(model, "max_cycles", None) |
|
|
ponder_w = getattr(model, "ponder_w", None) |
|
|
has_out_norm = hasattr(model, "out_norm") |
|
|
tied_flag = hasattr(model, "lm_head") and hasattr(model, "tok_emb") and ( |
|
|
getattr(model.lm_head, "weight", None) is getattr(model.tok_emb, "weight", None) |
|
|
) |
|
|
config = { |
|
|
"model_type": type(model).__name__, |
|
|
"vocab_size": vocab_size, |
|
|
"d_model": d_model, |
|
|
"n_heads": n_heads, |
|
|
"d_ff": d_ff, |
|
|
"dropout": dropout, |
|
|
"k_l_steps": k_l_steps, |
|
|
"max_cycles": max_cycles, |
|
|
"ponder_loss_weight": ponder_w, |
|
|
"has_out_norm": has_out_norm, |
|
|
"weight_tying": tied_flag, |
|
|
"tie_word_embeddings": tied_flag, |
|
|
} |
|
|
with open(os.path.join(save_dir, "config.json"), "w") as f: |
|
|
json.dump(config, f, indent=2) |
|
|
print(" ✓ Saved: config.json") |
|
|
|
|
|
|
|
|
print("3. Saving model architecture...") |
|
|
with open(os.path.join(save_dir, "architecture.txt"), "w") as f: |
|
|
f.write(str(model)) |
|
|
print(" ✓ Saved: architecture.txt") |
|
|
|
|
|
|
|
|
print("4. Saving parameter details...") |
|
|
param_info = [] |
|
|
total_params = 0 |
|
|
trainable_params = 0 |
|
|
for name, p in model.named_parameters(): |
|
|
n = p.numel() |
|
|
total_params += n |
|
|
if p.requires_grad: |
|
|
trainable_params += n |
|
|
param_info.append({ |
|
|
"name": name, |
|
|
"shape": list(p.shape), |
|
|
"dtype": str(p.dtype), |
|
|
"requires_grad": p.requires_grad, |
|
|
"num_params": n, |
|
|
"device": str(p.device), |
|
|
}) |
|
|
param_summary = { |
|
|
"total_parameters": total_params, |
|
|
"trainable_parameters": trainable_params, |
|
|
"non_trainable_parameters": total_params - trainable_params, |
|
|
"size_mb": total_params * 4 / (1024 ** 2), |
|
|
"parameters": param_info, |
|
|
} |
|
|
with open(os.path.join(save_dir, "parameters.json"), "w") as f: |
|
|
json.dump(param_summary, f, indent=2) |
|
|
print(" ✓ Saved: parameters.json") |
|
|
print(f" Total parameters: {total_params:,}") |
|
|
print(f" Trainable: {trainable_params:,}") |
|
|
print(f" Model size: {total_params * 4 / (1024**2):.2f} MB") |
|
|
|
|
|
|
|
|
print("5. Saving layer-wise breakdown...") |
|
|
from collections import OrderedDict |
|
|
layer_params = OrderedDict() |
|
|
for name, module in model.named_children(): |
|
|
num_params = sum(p.numel() for p in module.parameters()) |
|
|
layer_params[name] = { |
|
|
"num_params": num_params, |
|
|
"percentage": 100 * num_params / total_params if total_params > 0 else 0, |
|
|
} |
|
|
with open(os.path.join(save_dir, "layer_params.json"), "w") as f: |
|
|
json.dump(layer_params, f, indent=2) |
|
|
print(" ✓ Saved: layer_params.json") |
|
|
|
|
|
|
|
|
if training_args: |
|
|
print("6. Saving training arguments...") |
|
|
serializable_args = {} |
|
|
for k, v in training_args.items(): |
|
|
if isinstance(v, (int, float, str, bool, list, dict, type(None))): |
|
|
serializable_args[k] = v |
|
|
else: |
|
|
serializable_args[k] = str(v) |
|
|
with open(os.path.join(save_dir, "training_args.json"), "w") as f: |
|
|
json.dump(serializable_args, f, indent=2) |
|
|
print(" ✓ Saved: training_args.json") |
|
|
|
|
|
|
|
|
print("7. Saving metadata...") |
|
|
metadata_full = { |
|
|
"timestamp": timestamp, |
|
|
"pytorch_version": torch.__version__, |
|
|
"cuda_available": torch.cuda.is_available(), |
|
|
"cuda_version": torch.version.cuda if torch.cuda.is_available() else None, |
|
|
"device": str(next(model.parameters()).device), |
|
|
"dtype": str(next(model.parameters()).dtype), |
|
|
} |
|
|
if metadata: |
|
|
metadata_full.update(metadata) |
|
|
with open(os.path.join(save_dir, "metadata.json"), "w") as f: |
|
|
json.dump(metadata_full, f, indent=2) |
|
|
print(" ✓ Saved: metadata.json") |
|
|
|
|
|
|
|
|
if tokenizer is not None: |
|
|
print("8. Saving tokenizer...") |
|
|
try: |
|
|
tokenizer.save_pretrained(save_dir) |
|
|
print(" ✓ Saved tokenizer files") |
|
|
except Exception as e: |
|
|
print(f" ⚠ Could not save tokenizer: {e}") |
|
|
|
|
|
|
|
|
print("9. Creating README...") |
|
|
readme_content = f"""# HRM/LM Model Checkpoint |
|
|
|
|
|
## Model Information |
|
|
- **Model Type**: {config['model_type']} |
|
|
- **Timestamp**: {timestamp} |
|
|
- **Total Parameters**: {total_params:,} |
|
|
- **Trainable Parameters**: {trainable_params:,} |
|
|
- **Model Size**: {total_params * 4 / (1024**2):.2f} MB |
|
|
|
|
|
## Architecture (best-effort introspection) |
|
|
- **Vocabulary Size**: {config['vocab_size']} |
|
|
- **Hidden Dimension**: {config['d_model']} |
|
|
- **Attention Heads**: {config['n_heads']} |
|
|
- **FFN Dimension**: {config['d_ff']} |
|
|
- **Dropout**: {config['dropout']} |
|
|
- **L-mod Steps**: {config['k_l_steps']} |
|
|
- **Max Cycles**: {config['max_cycles']} |
|
|
- **Has Output Norm**: {config['has_out_norm']} |
|
|
- **Weight Tying**: {config['weight_tying']} (tok_emb ↔ lm_head) |
|
|
|
|
|
## Files |
|
|
- `model.pt` — Full checkpoint (PyTorch) |
|
|
- `model.safetensors` — Safetensors (excludes lm_head if tied) |
|
|
- `config.json` — Model configuration summary |
|
|
- `architecture.txt` — Stringified architecture |
|
|
- `parameters.json` — Parameter metadata |
|
|
- `layer_params.json` — Layer-wise parameter counts |
|
|
- `training_args.json` — Training hyperparameters (if provided) |
|
|
- `metadata.json` — Environment/device metadata |
|
|
- Tokenizer files (if provided) |
|
|
""" |
|
|
with open(os.path.join(save_dir, 'README.md'), 'w') as f: |
|
|
f.write(readme_content) |
|
|
print(f" ✓ Saved: README.md") |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("SAVE COMPLETE!") |
|
|
print("="*60) |
|
|
print(f"Location: {save_dir}") |
|
|
print(f"Files saved: {len(os.listdir(save_dir))}") |
|
|
print("\nSummary:") |
|
|
print(f" Total parameters: {total_params:,}") |
|
|
print(f" Model size: {total_params * 4 / (1024**2):.2f} MB") |
|
|
print(f" Config saved: ✓") |
|
|
print(f" Weights saved: ✓") |
|
|
print(f" Tokenizer saved: {'✓' if tokenizer else '✗'}") |
|
|
print("="*60) |
|
|
return save_dir |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_via_callable(load_fn: str, **kwargs): |
|
|
""" |
|
|
load_fn: 'module.submodule:function_name' (e.g., 'hrm_utils:load_hrm') |
|
|
kwargs: forwarded to the function |
|
|
""" |
|
|
if ":" not in load_fn: |
|
|
raise ValueError("load_fn must look like 'module.submodule:function_name'") |
|
|
mod_name, fn_name = load_fn.split(":", 1) |
|
|
import importlib |
|
|
mod = importlib.import_module(mod_name) |
|
|
fn = getattr(mod, fn_name) |
|
|
return fn(**kwargs) |
|
|
|
|
|
|
|
|
def main(): |
|
|
import argparse |
|
|
p = argparse.ArgumentParser(description="All-in-one HRM/LM data + trainer + checkpointing") |
|
|
sub = p.add_subparsers(dest="cmd", required=True) |
|
|
|
|
|
|
|
|
sp = sub.add_parser("prepare", help="Tokenize and build a quick dataloader") |
|
|
sp.add_argument("--source", default="hf:lemonilia/wikified_english_dictionary") |
|
|
sp.add_argument("--split", default="train") |
|
|
sp.add_argument("--text-field", default="text") |
|
|
sp.add_argument("--block-size", type=int, default=128) |
|
|
sp.add_argument("--batch-size", type=int, default=8) |
|
|
sp.add_argument("--tokenizer-dir", default=None) |
|
|
|
|
|
|
|
|
st = sub.add_parser("train", help="Train a model via dynamic loader") |
|
|
st.add_argument("--load-fn", required=True, help="module:function (e.g. hrm_utils:load_hrm)") |
|
|
st.add_argument("--load-args", default="{}", help="JSON dict of kwargs to pass to load-fn (e.g. '{\"name\":\"hrm_v0.04\",\"device\":\"cuda\",\"with_tokenizer\":true}')") |
|
|
st.add_argument("--source", default="hf:lemonilia/wikified_english_dictionary") |
|
|
st.add_argument("--split", default="train") |
|
|
st.add_argument("--text-field", default="text") |
|
|
st.add_argument("--block-size", type=int, default=128) |
|
|
st.add_argument("--batch-size", type=int, default=8) |
|
|
st.add_argument("--epochs", type=int, default=1) |
|
|
st.add_argument("--lr", type=float, default=1e-4) |
|
|
st.add_argument("--output-dir", default="outputs/hrm_run") |
|
|
st.add_argument("--wandb", action="store_true") |
|
|
st.add_argument("--wandb-entity", default=None) |
|
|
st.add_argument("--wandb-project", default=None) |
|
|
st.add_argument("--wandb-run-name", default=None) |
|
|
|
|
|
|
|
|
ss = sub.add_parser("save", help="Save a fully-documented checkpoint for an already-loaded model") |
|
|
ss.add_argument("--load-fn", required=True) |
|
|
ss.add_argument("--load-args", default="{}") |
|
|
ss.add_argument("--save-dir", default="saved_models/hrm_export") |
|
|
ss.add_argument("--with-tokenizer", action="store_true") |
|
|
|
|
|
args = p.parse_args() |
|
|
|
|
|
if args.cmd == "prepare": |
|
|
tok, loader, meta = load_dataset_fn( |
|
|
source=args.source, |
|
|
split=args.split, |
|
|
text_field=args.text_field, |
|
|
block_size=args.block_size, |
|
|
batch_size=args.batch_size, |
|
|
checkpoint_or_dir=args.tokenizer_dir, |
|
|
) |
|
|
print("[ok] Prepared one pass through dataloader:") |
|
|
for i, b in enumerate(loader): |
|
|
print(" batch", i, {k: v.shape for k, v in b.items()}) |
|
|
if i > 2: break |
|
|
|
|
|
elif args.cmd == "train": |
|
|
load_kwargs = json.loads(args.load_args or "{}") |
|
|
obj = _load_via_callable(args.load_fn, **load_kwargs) |
|
|
if isinstance(obj, tuple) and len(obj) >= 2: |
|
|
model, tokenizer = obj[0], obj[1] |
|
|
else: |
|
|
|
|
|
model, tokenizer = obj, None |
|
|
|
|
|
tok, train_loader, _ = load_dataset_fn( |
|
|
source=args.source, |
|
|
split=args.split, |
|
|
text_field=args.text_field, |
|
|
block_size=args.block_size, |
|
|
batch_size=args.batch_size, |
|
|
checkpoint_or_dir=(tokenizer.name_or_path if tokenizer is not None else None), |
|
|
) |
|
|
tokenizer = tokenizer or tok |
|
|
|
|
|
cfg = TrainConfig( |
|
|
output_dir=args.output_dir, |
|
|
num_epochs=args.epochs, |
|
|
learning_rate=args.lr, |
|
|
wandb_enable=bool(args.wandb), |
|
|
wandb_entity=args.wandb_entity, |
|
|
wandb_project=args.wandb_project, |
|
|
wandb_run_name=args.wandb_run_name, |
|
|
) |
|
|
trainer = MiniTRLTrainer( |
|
|
model=model, |
|
|
train_loader=train_loader, |
|
|
tokenizer=tokenizer, |
|
|
eval_loader=None, |
|
|
config=cfg, |
|
|
) |
|
|
trainer.train() |
|
|
|
|
|
elif args.cmd == "save": |
|
|
load_kwargs = json.loads(args.load_args or "{}") |
|
|
obj = _load_via_callable(args.load_fn, **load_kwargs) |
|
|
if isinstance(obj, tuple) and len(obj) >= 2: |
|
|
model, tokenizer = obj[0], obj[1] |
|
|
else: |
|
|
model, tokenizer = obj, None |
|
|
save_model_complete(model, args.save_dir, tokenizer=(tokenizer if args.with_tokenizer else None)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|