# -*- coding: utf-8 -*- from __future__ import annotations import os, re, math, json, time, shutil, random from dataclasses import dataclass, asdict from typing import Optional, Callable, Dict, Any, Iterable, Tuple, List from contextlib import nullcontext # Torch & friends import torch import torch.nn.functional as F from torch import nn from torch.utils.data import Dataset, DataLoader from tqdm import tqdm # Transformers / Datasets from transformers import AutoTokenizer, get_linear_schedule_with_warmup from datasets import load_dataset, DatasetDict # Optional: Weights & Biases try: import wandb # noqa except Exception: wandb = None # ========================================================= # Utils # ========================================================= def set_seed(seed: int = 1337): import numpy as np random.seed(seed); np.random.seed(seed); torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def auto_device(): if torch.cuda.is_available(): return torch.device("cuda") if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") def format_num(x): try: return f"{x:.6g}" except: return str(x) def save_safetensors_safe(model: nn.Module, path: str, metadata: Optional[Dict[str, str]] = None): """ Save weights as .safetensors, handling tied weights (lm_head <- tok_emb) when needed. """ try: from safetensors.torch import save_model # preserves shared storage & avoids duplication save_model(model, path, metadata=metadata or {}) except Exception: # Fallback that copies state_dict and de-duplicates lm_head if needed try: from safetensors.torch import save_file state = model.state_dict() if "lm_head.weight" in state and "tok_emb.weight" in state: state["lm_head.weight"] = state["tok_emb.weight"].clone() save_file(state, path, metadata=metadata or {}) except Exception as e: print("[warn] safetensors not saved:", e) # ========================================================= # Tokenizer helper # ========================================================= def _gpt2_tokenizer_with_specials( additional: Optional[List[str]] = None, checkpoint_or_dir: Optional[str] = None, ) -> AutoTokenizer: """ If `checkpoint_or_dir` is provided, load tokenizer from there; else use 'gpt2'. Ensures PAD exists (PAD→EOS), optionally adds extra specials, sets a huge model_max_length. """ tok = None if checkpoint_or_dir is not None: try: tok = AutoTokenizer.from_pretrained(checkpoint_or_dir, use_fast=True) except Exception as e: print(f"[warn] Failed to load tokenizer from '{checkpoint_or_dir}': {e}") print("[warn] Falling back to 'gpt2' tokenizer.") if tok is None: tok = AutoTokenizer.from_pretrained("gpt2", use_fast=True) if tok.eos_token is None: tok.add_special_tokens({"eos_token": ""}) if tok.pad_token is None: tok.pad_token = tok.eos_token if additional: new_tokens = [t for t in additional if t not in tok.get_vocab()] if new_tokens: tok.add_special_tokens({"additional_special_tokens": new_tokens}) print(f"[info] Added {len(new_tokens)} special tokens to tokenizer") tok.model_max_length = 10_000_000 tok.init_kwargs["model_max_length"] = tok.model_max_length return tok # ========================================================= # Fixed-block causal dataset # ========================================================= class CausalChunked(Dataset): """Flatten tokens then slice into non-overlapping blocks; x == labels.""" def __init__(self, token_ids: Iterable[int], block_size: int): ids = list(token_ids) n_full = (len(ids) // block_size) * block_size n_discarded = len(ids) - n_full if n_discarded > 0 and len(ids) > 0: pct = n_discarded / len(ids) * 100 print(f"[info] Discarded {n_discarded} tokens ({pct:.2f}%) that didn't fit into complete blocks") ids = ids[:n_full] self.blocks = [ids[i:i + block_size] for i in range(0, n_full, block_size)] def __len__(self) -> int: return len(self.blocks) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: x = torch.tensor(self.blocks[idx], dtype=torch.long) return {"input_ids": x, "labels": x.clone()} # ========================================================= # PAD-mask helper (for variable-length batches with padding) # ========================================================= def mask_pad_labels( input_ids: torch.Tensor, labels: torch.Tensor, pad_id: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Clone `labels` and set pad positions to -100 (ignored by CrossEntropyLoss). Prefers `attention_mask` if provided, otherwise uses pad_id to detect padding. Note: CausalChunked produces fixed-length blocks without padding, so this is only needed if you supply your own dataloader with padding. """ lab = labels.clone() if attention_mask is not None: lab[attention_mask == 0] = -100 elif pad_id is not None: lab[input_ids == pad_id] = -100 return lab # ========================================================= # Dataset loader (HF datasets or local txt files) # ========================================================= def load_dataset_fn( source: str = "hf:lemonilia/wikified_english_dictionary", split: str = "train", *, text_field: Optional[str] = "text", word_field: str = "word", article_field: str = "article", block_size: int = 128, batch_size: int = 8, num_workers: int = 0, shuffle: bool = True, checkpoint_or_dir: Optional[str] = None, additional_specials: Optional[List[str]] = None, ) -> Tuple[AutoTokenizer, DataLoader, Dict[str, int]]: """ Load and tokenize a dataset for causal LM training. Returns (tokenizer, DataLoader, meta). source: - 'hf:' to read a HuggingFace dataset - 'txt:/path1;/path2;...' to read local text files (semicolon-separated) Behavior: • If `text_field` is present, uses it. • Else if both `word_field` and `article_field` exist, merges them as: "\\n
\\n\\n" while stripping any <|begin_of_thought|>...<|end_of_thought|> spans. • Else, falls back to a 'text' column if available. • Appends EOS between docs/files to avoid cross-boundary contamination. """ tokenizer = _gpt2_tokenizer_with_specials( additional=additional_specials, checkpoint_or_dir=checkpoint_or_dir, ) eos_id = tokenizer.eos_token_id token_stream: List[int] = [] if source.startswith("hf:"): ds_name = source[3:] raw = load_dataset(ds_name) if split not in raw: raise ValueError(f"[error] Split '{split}' not found. Available: {list(raw.keys())}") cols = raw[split].column_names # A) explicit text field if text_field is not None and text_field in cols: field_to_use = text_field def tok_map(batch): return tokenizer(batch[field_to_use], add_special_tokens=False) toks = raw.map(tok_map, batched=True, remove_columns=cols) # B) merge word+article when requested/needed elif (text_field is None or text_field not in cols) and word_field in cols and article_field in cols: BEGIN_THOUGHT = re.compile(r"<\|begin_of_thought\|>.*?<\|end_of_thought\|>", re.DOTALL) def fmt(batch): out = [] for w, a in zip(batch[word_field], batch[article_field]): w = (w or "").strip() a = re.sub(BEGIN_THOUGHT, "", (a or "")).strip() out.append(w + "\n" + a + "\n\n") return {"text": out} raw = raw.map(fmt, batched=True) raw = DatasetDict({ sp: d.remove_columns([c for c in d.column_names if c != "text"]) for sp, d in raw.items() }) def tok_map(batch): return tokenizer(batch["text"], add_special_tokens=False) toks = raw.map(tok_map, batched=True, remove_columns=["text"]) # C) fallback 'text' elif "text" in cols: def tok_map(batch): return tokenizer(batch["text"], add_special_tokens=False) toks = raw.map(tok_map, batched=True, remove_columns=cols) else: raise ValueError( f"[error] Could not find a text source.\n" f" - Requested text_field={text_field!r}\n" f" - Available columns: {cols}\n" f" - Set text_field accordingly, or set text_field=None if your dataset has " f" both '{word_field}' and '{article_field}' to auto-merge." ) n_empty = 0 for doc in toks[split]["input_ids"]: if not doc: n_empty += 1 continue token_stream.extend(doc) if eos_id is not None: token_stream.append(eos_id) if n_empty > 0: print(f"[info] Skipped {n_empty} empty documents") elif source.startswith("txt:"): paths = [p for p in source[4:].split(";") if p] if not paths: raise ValueError("[error] No file paths provided after 'txt:'") for p in paths: if not os.path.exists(p): raise FileNotFoundError(f"[error] File not found: {p}") with open(p, "r", encoding="utf-8") as f: text = f.read() if text.strip(): ids = tokenizer(text, add_special_tokens=False)["input_ids"] token_stream.extend(ids) if eos_id is not None: token_stream.append(eos_id) else: raise ValueError("[error] source must start with 'hf:' or 'txt:'") if not token_stream: raise ValueError("[error] No tokens extracted from the source. Check your data.") ds = CausalChunked(token_stream, block_size) if len(ds) == 0: raise ValueError( f"[error] Tokenized corpus ({len(token_stream)} tokens) is too small " f"for block_size={block_size}. No complete blocks produced." ) loader = DataLoader( ds, batch_size=batch_size, shuffle=shuffle, drop_last=True, pin_memory=torch.cuda.is_available(), num_workers=num_workers, ) meta = { "vocab_size": len(tokenizer), "eos_id": eos_id, "n_blocks": len(ds), "n_tokens": len(token_stream), "tokens_per_block": block_size, } print(f"[info] Dataset ready: {meta['n_blocks']} blocks, {meta['n_tokens']} tokens total") return tokenizer, loader, meta # ========================================================= # Trainer # ========================================================= @dataclass class TrainConfig: output_dir: str = "outputs/hrm_run" num_epochs: int = 1 max_steps: Optional[int] = None # if set, overrides num_epochs per_device_train_batch_size: int = 8 gradient_accumulation_steps: int = 1 learning_rate: float = 1e-4 betas: tuple = (0.9, 0.95) eps: float = 1e-8 weight_decay: float = 0.01 warmup_ratio: float = 0.06 max_grad_norm: float = 0.5 log_every: int = 100 save_every: int = 2000 eval_every: int = 2000 save_total_limit: int = 3 fp16: bool = False bf16: bool = True # prefer bf16 if supported seed: int = 1337 resume_from: Optional[str] = None # path to checkpoint dir early_stopping_patience: Optional[int] = None # steps without eval improvement best_metric: str = "eval/loss" greater_is_better: bool = False torch_compile: bool = False # Optional W&B wandb_enable: bool = False wandb_entity: Optional[str] = None wandb_project: Optional[str] = None wandb_run_name: Optional[str] = None def _out_get(out: Any, key: str, default=None): if isinstance(out, dict): return out.get(key, default) return getattr(out, key, default) class MiniTRLTrainer: """ TRL-like supervised trainer: Model forward must accept (input_ids, labels) and return something with: - loss (required) - logits (optional but recommended; used for sanity checks) - lm_loss (optional; logged if present) - ponder_loss (optional; logged if present) DataLoader must yield dicts with keys: - "input_ids" and (optionally) "labels". If "labels" missing, labels=input_ids. - If you pad to fixed length externally, also pass "attention_mask" so we can mask pad tokens. """ def __init__( self, model: nn.Module, train_loader: DataLoader, tokenizer: Optional[AutoTokenizer] = None, eval_loader: Optional[DataLoader] = None, config: TrainConfig = TrainConfig(), compute_metrics: Optional[Callable[[Dict[str, float]], Dict[str, float]]] = None, custom_loss_fn: Optional[Callable[[Any], torch.Tensor]] = None, # receives model outputs device: Optional[torch.device] = None, ): self.model = model self.train_loader = train_loader self.eval_loader = eval_loader self.tok = tokenizer self.cfg = config self.compute_metrics = compute_metrics self.custom_loss_fn = custom_loss_fn self.device = device or auto_device() set_seed(self.cfg.seed) self.model.to(self.device) if self.cfg.torch_compile: try: self.model = torch.compile(self.model) except Exception as e: print("[warn] torch.compile failed:", e) # AMP dtype if self.device.type == "cuda": self.amp_dtype = torch.bfloat16 if (self.cfg.bf16 and torch.cuda.is_bf16_supported()) else (torch.float16 if self.cfg.fp16 else None) else: self.amp_dtype = None # Param groups with/without weight decay decay, no_decay = [], [] for n, p in self.model.named_parameters(): if not p.requires_grad: continue nl = n.lower() if p.ndim == 1 or "norm" in nl or "bias" in nl or ("tok_emb.weight" in n): no_decay.append(p) else: decay.append(p) self.optimizer = torch.optim.AdamW( [{"params": decay, "weight_decay": self.cfg.weight_decay}, {"params": no_decay, "weight_decay": 0.0}], lr=self.cfg.learning_rate, betas=self.cfg.betas, eps=self.cfg.eps ) # Scheduler steps_per_epoch = math.ceil(len(self.train_loader) / max(1, self.cfg.gradient_accumulation_steps)) total_updates = self.cfg.max_steps if self.cfg.max_steps is not None else self.cfg.num_epochs * max(1, steps_per_epoch) total_updates = max(1, total_updates) # guard warmup_steps = int(self.cfg.warmup_ratio * total_updates) self.scheduler = get_linear_schedule_with_warmup(self.optimizer, warmup_steps, total_updates) # GradScaler only for fp16 self.scaler = torch.cuda.amp.GradScaler(enabled=(self.amp_dtype == torch.float16 and self.device.type == "cuda")) # State self.global_step = 0 self.best_metric_val = float("-inf") if self.cfg.greater_is_better else float("inf") self.no_improve_steps = 0 os.makedirs(self.cfg.output_dir, exist_ok=True) self._maybe_resume() # W&B self._wandb_run = None if self.cfg.wandb_enable: if wandb is None: print("[warn] wandb_enable=True but wandb is not installed; proceeding without W&B.") else: self._wandb_run = wandb.init( entity=self.cfg.wandb_entity, project=self.cfg.wandb_project or "hrm", name=self.cfg.wandb_run_name, config=asdict(self.cfg), ) # -------------------------- public API -------------------------- def train(self): self.model.train() log_acc_loss = 0.0 log_acc_tokens = 0 t0 = time.time() max_updates = self.cfg.max_steps if max_updates is None: steps_per_epoch = math.ceil(len(self.train_loader) / max(1, self.cfg.gradient_accumulation_steps)) max_updates = self.cfg.num_epochs * max(1, steps_per_epoch) pbar = tqdm(total=max_updates, initial=self.global_step, desc="Training", dynamic_ncols=True) while self.global_step < max_updates: for batch in self.train_loader: if self.global_step >= max_updates: break input_ids = batch["input_ids"].to(self.device) labels = batch.get("labels", input_ids).to(self.device) # Mask pads only if attention_mask/pad present pad_id = getattr(self.tok, "pad_token_id", None) if self.tok is not None else ( getattr(getattr(self.model, "config", None), "pad_token_id", None) ) attn = batch.get("attention_mask", None) attn = attn.to(self.device) if attn is not None else None labels = mask_pad_labels(input_ids, labels, pad_id=pad_id, attention_mask=attn) ctx = (torch.autocast(device_type=self.device.type, dtype=self.amp_dtype) if (self.amp_dtype is not None and self.device.type in ("cuda", "mps")) else nullcontext()) with ctx: out = self.model(input_ids=input_ids, labels=labels) loss = _out_get(out, "loss") if self.custom_loss_fn is not None: loss = self.custom_loss_fn(out) loss = loss / max(1, self.cfg.gradient_accumulation_steps) logits = _out_get(out, "logits", None) if logits is not None: if not torch.isfinite(logits).all(): mx = logits.detach().float().abs().max().item() raise FloatingPointError(f"logits non-finite (max|logit|={mx})") if not torch.isfinite(loss): lmax = (logits.detach().float().abs().max().item() if logits is not None else float("nan")) print(f"[dbg] non-finite loss; max|logit|={lmax}, lm={_out_get(out,'lm_loss')}, ponder={_out_get(out,'ponder_loss')}") raise FloatingPointError("Loss became non-finite.") if self.scaler.is_enabled(): self.scaler.scale(loss).backward() else: loss.backward() do_step = ((self.global_step + 1) % self.cfg.gradient_accumulation_steps == 0) if do_step: if self.scaler.is_enabled(): self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.max_grad_norm) if self.scaler.is_enabled(): prev_scale = self.scaler.get_scale() self.scaler.step(self.optimizer) self.scaler.update() if self.scaler.get_scale() >= prev_scale: self.scheduler.step() else: self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad(set_to_none=True) self.global_step += 1 pbar.update(1) # Logging accumulators (token-weighted). Count only non-pad tokens. tokens = int((labels != -100).sum().item()) lm_for_log = _out_get(out, "lm_loss", loss.detach()) log_acc_loss += float(lm_for_log) * max(1, tokens) log_acc_tokens += max(1, tokens) if self.global_step % max(1, self.cfg.log_every) == 0: avg_loss = log_acc_loss / max(1, log_acc_tokens) msg = { "step": self.global_step, "lr": self.scheduler.get_last_lr()[0], "avg_lm_loss": avg_loss, "ppl~": math.exp(min(20.0, avg_loss)), "ponder": (_out_get(out, "ponder_loss", None)), "elapsed_s": int(time.time() - t0), } tqdm.write("[log] " + ", ".join(f"{k}={format_num(v)}" for k, v in msg.items() if v is not None)) if self._wandb_run is not None: self._wandb_run.log({k: v for k, v in msg.items() if isinstance(v, (int, float))}) log_acc_loss = 0.0 log_acc_tokens = 0 # Eval / early stop if self.eval_loader and self.global_step % max(1, self.cfg.eval_every) == 0: metrics = self.evaluate() improved = self._check_improve(metrics[self.cfg.best_metric]) if self._wandb_run is not None: self._wandb_run.log(metrics) if self.cfg.early_stopping_patience is not None: if improved: self.no_improve_steps = 0 else: self.no_improve_steps += self.cfg.eval_every if self.no_improve_steps >= self.cfg.early_stopping_patience: tqdm.write("[early-stop] patience exhausted.") self._save_checkpoint(tag="final") pbar.close() return if self.global_step % max(1, self.cfg.save_every) == 0: self._save_checkpoint() pbar.close() self._save_checkpoint(tag="final") @torch.no_grad() def evaluate(self) -> Dict[str, float]: self.model.eval() total_loss = 0.0 total_tokens = 0 total_ponder = 0.0 n_batches = 0 for batch in tqdm(self.eval_loader, desc="Eval", leave=False): input_ids = batch["input_ids"].to(self.device) labels = batch.get("labels", input_ids).to(self.device) pad_id = getattr(self.tok, "pad_token_id", None) if self.tok is not None else ( getattr(getattr(self.model, "config", None), "pad_token_id", None) ) attn = batch.get("attention_mask", None) attn = attn.to(self.device) if attn is not None else None labels = mask_pad_labels(input_ids, labels, pad_id=pad_id, attention_mask=attn) out = self.model(input_ids=input_ids, labels=labels) lm = float(_out_get(out, "lm_loss", _out_get(out, "loss"))) tokens = int((labels != -100).sum().item()) total_loss += lm * max(1, tokens) total_tokens += max(1, tokens) pl = _out_get(out, "ponder_loss", None) if pl is not None: total_ponder += float(pl) n_batches += 1 avg_loss = total_loss / max(1, total_tokens) ppl = math.exp(min(20.0, avg_loss)) avg_ponder = (total_ponder / max(1, n_batches)) if n_batches > 0 else float("nan") metrics = {"eval/loss": avg_loss, "eval/ppl": ppl, "eval/ponder": avg_ponder, "step": self.global_step} tqdm.write("[eval] " + ", ".join(f"{k}={format_num(v)}" for k, v in metrics.items())) self.model.train() return metrics # -------------------------- checkpoints -------------------------- def _save_checkpoint(self, tag: Optional[str] = None): tag = tag or f"step{self.global_step}" ckpt_dir = os.path.join(self.cfg.output_dir, f"ckpt-{tag}") os.makedirs(ckpt_dir, exist_ok=True) # trainer state (resumable) torch.save({ "model": self.model.state_dict(), "opt": self.optimizer.state_dict(), "sched": self.scheduler.state_dict(), "scaler": (self.scaler.state_dict() if self.scaler.is_enabled() else None), "global_step": self.global_step, "config": asdict(self.cfg), }, os.path.join(ckpt_dir, "trainer_state.pt")) # weights-only safetensors + minimal config save_safetensors_safe(self.model, os.path.join(ckpt_dir, "model.safetensors"), metadata={"note": "MiniTRLTrainer save", "global_step": str(self.global_step)}) with open(os.path.join(ckpt_dir, "config.json"), "w") as f: json.dump({"global_step": self.global_step, **asdict(self.cfg)}, f, indent=2) self._prune_checkpoints() def _prune_checkpoints(self): if self.cfg.save_total_limit is None: return subs = [d for d in os.listdir(self.cfg.output_dir) if d.startswith("ckpt-")] if len(subs) <= self.cfg.save_total_limit: return subs = sorted(subs, key=lambda s: os.path.getmtime(os.path.join(self.cfg.output_dir, s))) for d in subs[:-self.cfg.save_total_limit]: shutil.rmtree(os.path.join(self.cfg.output_dir, d), ignore_errors=True) def _maybe_resume(self): if not self.cfg.resume_from: return state_path = os.path.join(self.cfg.resume_from, "trainer_state.pt") if not os.path.exists(state_path): print(f"[resume] not found: {state_path}") return ckpt = torch.load(state_path, map_location="cpu") self.model.load_state_dict(ckpt["model"], strict=False) self.optimizer.load_state_dict(ckpt["opt"]) self.scheduler.load_state_dict(ckpt["sched"]) if ckpt.get("scaler") and self.scaler.is_enabled(): self.scaler.load_state_dict(ckpt["scaler"]) self.global_step = int(ckpt.get("global_step", 0)) print(f"[resume] loaded from {self.cfg.resume_from} @ step {self.global_step}") def _check_improve(self, val: float) -> bool: improved = (val > self.best_metric_val) if self.cfg.greater_is_better else (val < self.best_metric_val) if improved: self.best_metric_val = val return improved # ========================================================= # Checkpoint helpers (complete save) # ========================================================= def _state_dict_for_safetensors(model): """ Build a CPU state_dict suitable for safetensors. If lm_head.weight is tied to tok_emb.weight, omit lm_head.weight to avoid duplicate storage. """ tied = hasattr(model, "lm_head") and hasattr(model, "tok_emb") and ( getattr(model.lm_head, "weight", None) is getattr(model.tok_emb, "weight", None) ) sd_cpu = {k: v.detach().cpu() for k, v in model.state_dict().items()} if tied and "lm_head.weight" in sd_cpu: sd_cpu.pop("lm_head.weight") return sd_cpu, tied def retie_output_embedding(model): """ Re-tie output and input embeddings after loading weights, if model provides get_*_embeddings(). """ if hasattr(model, "get_input_embeddings") and hasattr(model, "get_output_embeddings"): inp = model.get_input_embeddings() out = model.get_output_embeddings() if inp is not None and out is not None and out.weight.data_ptr() != inp.weight.data_ptr(): out.weight = inp.weight def _chain_get(obj, attrs, default=None): """ Safe chained getattr: _chain_get(model, ["L_mod", "attn", "num_heads"], default=None) """ cur = obj for a in attrs: if not hasattr(cur, a): return default cur = getattr(cur, a) return cur def save_model_complete(model, save_dir, tokenizer=None, training_args=None, metadata=None): """ Save model with all details: weights (.pt + .safetensors), config, architecture, parameter summaries, tokenizer (optional), and a README. Returns: save_dir """ os.makedirs(save_dir, exist_ok=True) from datetime import datetime from collections import OrderedDict timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") print(f"Saving model to: {save_dir}") # 1) Weights (.pt) print("1. Saving model weights (.pt)...") checkpoint = { "model_state_dict": model.state_dict(), "timestamp": timestamp, } if training_args and "optimizer_state" in training_args: checkpoint["optimizer_state_dict"] = training_args["optimizer_state"] if training_args and "scheduler_state" in training_args: checkpoint["scheduler_state_dict"] = training_args["scheduler_state"] if training_args and "epoch" in training_args: checkpoint["epoch"] = training_args["epoch"] if training_args and "global_step" in training_args: checkpoint["global_step"] = training_args["global_step"] torch.save(checkpoint, os.path.join(save_dir, "model.pt")) print(" ✓ Saved: model.pt") # 1b) Weights (.safetensors) print("1b. Saving model weights (.safetensors)...") try: from safetensors.torch import save_file sd_cpu, tied = _state_dict_for_safetensors(model) save_file(sd_cpu, os.path.join(save_dir, "model.safetensors")) if tied: print(" ℹ Weight tying detected: excluded lm_head.weight (re-tie on load).") print(" ✓ Saved: model.safetensors") except ImportError: print(" ⚠ safetensors not installed, skipping .safetensors format") except Exception as e: print(f" ⚠ Could not save safetensors: {e}") # 2) Save a minimal config (best-effort introspection) print("2. Saving model config...") vocab_size = getattr(model, "vocab_size", None) d_model = getattr(model, "d_model", None) n_heads = _chain_get(model, ["L_mod", "attn", "num_heads"], default=None) d_ff = _chain_get(model, ["L_mod", "mlp", "w1", "out_features"], default=None) dropout = _chain_get(model, ["L_mod", "drop", "p"], default=None) k_l_steps = getattr(model, "k_l_steps", None) max_cycles = getattr(model, "max_cycles", None) ponder_w = getattr(model, "ponder_w", None) has_out_norm = hasattr(model, "out_norm") tied_flag = hasattr(model, "lm_head") and hasattr(model, "tok_emb") and ( getattr(model.lm_head, "weight", None) is getattr(model.tok_emb, "weight", None) ) config = { "model_type": type(model).__name__, "vocab_size": vocab_size, "d_model": d_model, "n_heads": n_heads, "d_ff": d_ff, "dropout": dropout, "k_l_steps": k_l_steps, "max_cycles": max_cycles, "ponder_loss_weight": ponder_w, "has_out_norm": has_out_norm, "weight_tying": tied_flag, "tie_word_embeddings": tied_flag, } with open(os.path.join(save_dir, "config.json"), "w") as f: json.dump(config, f, indent=2) print(" ✓ Saved: config.json") # 3) Architecture string print("3. Saving model architecture...") with open(os.path.join(save_dir, "architecture.txt"), "w") as f: f.write(str(model)) print(" ✓ Saved: architecture.txt") # 4) Parameter details print("4. Saving parameter details...") param_info = [] total_params = 0 trainable_params = 0 for name, p in model.named_parameters(): n = p.numel() total_params += n if p.requires_grad: trainable_params += n param_info.append({ "name": name, "shape": list(p.shape), "dtype": str(p.dtype), "requires_grad": p.requires_grad, "num_params": n, "device": str(p.device), }) param_summary = { "total_parameters": total_params, "trainable_parameters": trainable_params, "non_trainable_parameters": total_params - trainable_params, "size_mb": total_params * 4 / (1024 ** 2), # float32 estimate "parameters": param_info, } with open(os.path.join(save_dir, "parameters.json"), "w") as f: json.dump(param_summary, f, indent=2) print(" ✓ Saved: parameters.json") print(f" Total parameters: {total_params:,}") print(f" Trainable: {trainable_params:,}") print(f" Model size: {total_params * 4 / (1024**2):.2f} MB") # 5) Layer-wise breakdown (top-level children) print("5. Saving layer-wise breakdown...") from collections import OrderedDict layer_params = OrderedDict() for name, module in model.named_children(): num_params = sum(p.numel() for p in module.parameters()) layer_params[name] = { "num_params": num_params, "percentage": 100 * num_params / total_params if total_params > 0 else 0, } with open(os.path.join(save_dir, "layer_params.json"), "w") as f: json.dump(layer_params, f, indent=2) print(" ✓ Saved: layer_params.json") # 6) Training args (if provided) if training_args: print("6. Saving training arguments...") serializable_args = {} for k, v in training_args.items(): if isinstance(v, (int, float, str, bool, list, dict, type(None))): serializable_args[k] = v else: serializable_args[k] = str(v) with open(os.path.join(save_dir, "training_args.json"), "w") as f: json.dump(serializable_args, f, indent=2) print(" ✓ Saved: training_args.json") # 7) Metadata print("7. Saving metadata...") metadata_full = { "timestamp": timestamp, "pytorch_version": torch.__version__, "cuda_available": torch.cuda.is_available(), "cuda_version": torch.version.cuda if torch.cuda.is_available() else None, "device": str(next(model.parameters()).device), "dtype": str(next(model.parameters()).dtype), } if metadata: metadata_full.update(metadata) with open(os.path.join(save_dir, "metadata.json"), "w") as f: json.dump(metadata_full, f, indent=2) print(" ✓ Saved: metadata.json") # 8) Tokenizer (optional) if tokenizer is not None: print("8. Saving tokenizer...") try: tokenizer.save_pretrained(save_dir) print(" ✓ Saved tokenizer files") except Exception as e: print(f" ⚠ Could not save tokenizer: {e}") # 9) README print("9. Creating README...") readme_content = f"""# HRM/LM Model Checkpoint ## Model Information - **Model Type**: {config['model_type']} - **Timestamp**: {timestamp} - **Total Parameters**: {total_params:,} - **Trainable Parameters**: {trainable_params:,} - **Model Size**: {total_params * 4 / (1024**2):.2f} MB ## Architecture (best-effort introspection) - **Vocabulary Size**: {config['vocab_size']} - **Hidden Dimension**: {config['d_model']} - **Attention Heads**: {config['n_heads']} - **FFN Dimension**: {config['d_ff']} - **Dropout**: {config['dropout']} - **L-mod Steps**: {config['k_l_steps']} - **Max Cycles**: {config['max_cycles']} - **Has Output Norm**: {config['has_out_norm']} - **Weight Tying**: {config['weight_tying']} (tok_emb ↔ lm_head) ## Files - `model.pt` — Full checkpoint (PyTorch) - `model.safetensors` — Safetensors (excludes lm_head if tied) - `config.json` — Model configuration summary - `architecture.txt` — Stringified architecture - `parameters.json` — Parameter metadata - `layer_params.json` — Layer-wise parameter counts - `training_args.json` — Training hyperparameters (if provided) - `metadata.json` — Environment/device metadata - Tokenizer files (if provided) """ with open(os.path.join(save_dir, 'README.md'), 'w') as f: f.write(readme_content) print(f" ✓ Saved: README.md") print("\n" + "="*60) print("SAVE COMPLETE!") print("="*60) print(f"Location: {save_dir}") print(f"Files saved: {len(os.listdir(save_dir))}") print("\nSummary:") print(f" Total parameters: {total_params:,}") print(f" Model size: {total_params * 4 / (1024**2):.2f} MB") print(f" Config saved: ✓") print(f" Weights saved: ✓") print(f" Tokenizer saved: {'✓' if tokenizer else '✗'}") print("="*60) return save_dir # ========================================================= # Minimal CLI (dynamic model loading via --load-fn module:function) # ========================================================= def _load_via_callable(load_fn: str, **kwargs): """ load_fn: 'module.submodule:function_name' (e.g., 'hrm_utils:load_hrm') kwargs: forwarded to the function """ if ":" not in load_fn: raise ValueError("load_fn must look like 'module.submodule:function_name'") mod_name, fn_name = load_fn.split(":", 1) import importlib mod = importlib.import_module(mod_name) fn = getattr(mod, fn_name) return fn(**kwargs) def main(): import argparse p = argparse.ArgumentParser(description="All-in-one HRM/LM data + trainer + checkpointing") sub = p.add_subparsers(dest="cmd", required=True) # prepare data sp = sub.add_parser("prepare", help="Tokenize and build a quick dataloader") sp.add_argument("--source", default="hf:lemonilia/wikified_english_dictionary") sp.add_argument("--split", default="train") sp.add_argument("--text-field", default="text") sp.add_argument("--block-size", type=int, default=128) sp.add_argument("--batch-size", type=int, default=8) sp.add_argument("--tokenizer-dir", default=None) # train st = sub.add_parser("train", help="Train a model via dynamic loader") st.add_argument("--load-fn", required=True, help="module:function (e.g. hrm_utils:load_hrm)") st.add_argument("--load-args", default="{}", help="JSON dict of kwargs to pass to load-fn (e.g. '{\"name\":\"hrm_v0.04\",\"device\":\"cuda\",\"with_tokenizer\":true}')") st.add_argument("--source", default="hf:lemonilia/wikified_english_dictionary") st.add_argument("--split", default="train") st.add_argument("--text-field", default="text") st.add_argument("--block-size", type=int, default=128) st.add_argument("--batch-size", type=int, default=8) st.add_argument("--epochs", type=int, default=1) st.add_argument("--lr", type=float, default=1e-4) st.add_argument("--output-dir", default="outputs/hrm_run") st.add_argument("--wandb", action="store_true") st.add_argument("--wandb-entity", default=None) st.add_argument("--wandb-project", default=None) st.add_argument("--wandb-run-name", default=None) # save checkpoint (complete) ss = sub.add_parser("save", help="Save a fully-documented checkpoint for an already-loaded model") ss.add_argument("--load-fn", required=True) ss.add_argument("--load-args", default="{}") ss.add_argument("--save-dir", default="saved_models/hrm_export") ss.add_argument("--with-tokenizer", action="store_true") args = p.parse_args() if args.cmd == "prepare": tok, loader, meta = load_dataset_fn( source=args.source, split=args.split, text_field=args.text_field, block_size=args.block_size, batch_size=args.batch_size, checkpoint_or_dir=args.tokenizer_dir, ) print("[ok] Prepared one pass through dataloader:") for i, b in enumerate(loader): print(" batch", i, {k: v.shape for k, v in b.items()}) if i > 2: break elif args.cmd == "train": load_kwargs = json.loads(args.load_args or "{}") obj = _load_via_callable(args.load_fn, **load_kwargs) if isinstance(obj, tuple) and len(obj) >= 2: model, tokenizer = obj[0], obj[1] else: # assume loader returns just model; tokenizer is optional/None model, tokenizer = obj, None tok, train_loader, _ = load_dataset_fn( source=args.source, split=args.split, text_field=args.text_field, block_size=args.block_size, batch_size=args.batch_size, checkpoint_or_dir=(tokenizer.name_or_path if tokenizer is not None else None), ) tokenizer = tokenizer or tok cfg = TrainConfig( output_dir=args.output_dir, num_epochs=args.epochs, learning_rate=args.lr, wandb_enable=bool(args.wandb), wandb_entity=args.wandb_entity, wandb_project=args.wandb_project, wandb_run_name=args.wandb_run_name, ) trainer = MiniTRLTrainer( model=model, train_loader=train_loader, tokenizer=tokenizer, eval_loader=None, # plug one in if you want config=cfg, ) trainer.train() elif args.cmd == "save": load_kwargs = json.loads(args.load_args or "{}") obj = _load_via_callable(args.load_fn, **load_kwargs) if isinstance(obj, tuple) and len(obj) >= 2: model, tokenizer = obj[0], obj[1] else: model, tokenizer = obj, None save_model_complete(model, args.save_dir, tokenizer=(tokenizer if args.with_tokenizer else None)) if __name__ == "__main__": main()