| """BLT-Reasoner training loop. |
| |
| Usage: |
| python -m experiments.blt_reasoner.train --config configs/pilot_qwen15b_gsm8k.json |
| |
| Pre-registered success criterion (do not evaluate raw GSM8K accuracy first): |
| Δ(normal-z − random-z) ≥ 15pp AND Δ(normal-z − zero-z) ≥ 25pp |
| on the held-out GSM8K mini-eval. If both hold, H1 (z carries information) |
| is supported and we proceed to scale. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import os |
| import random |
| import sys |
| import time |
| from dataclasses import asdict |
| from pathlib import Path |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
|
|
| from .data import GSM8KDataset, MATHDataset, collate_batch |
|
|
|
|
| def _build_dataset(name, split, max_examples): |
| """Dispatch to either GSM8K or MATH based on cfg['dataset'] (default gsm8k).""" |
| name = (name or "gsm8k").lower() |
| if name == "math": |
| return MATHDataset(split=split, max_examples=max_examples) |
| return GSM8KDataset(split=split, max_examples=max_examples) |
| from .losses import ( |
| InfoNCEHead, LossWeights, encode_answer_for_infonce, infonce_loss, |
| kl_to_gaussian, lm_loss_on_y, |
| ) |
| from .model import BLTConfig, LatentProjector, build_base, forward_with_latent |
|
|
|
|
| def set_seed(s: int): |
| random.seed(s) |
| torch.manual_seed(s) |
| torch.cuda.manual_seed_all(s) |
|
|
|
|
| def make_optimizer(model, projector, head, cfg): |
| """Three param groups: LoRA / projector / InfoNCE head.""" |
| try: |
| from bitsandbytes.optim import PagedAdamW8bit as AdamW |
| use_8bit = True |
| except Exception: |
| from torch.optim import AdamW |
| use_8bit = False |
| lora_params = [p for p in model.parameters() if p.requires_grad] |
| groups = [ |
| {"params": lora_params, "lr": cfg["lr_lora"]}, |
| {"params": projector.parameters(), "lr": cfg["lr_proj"]}, |
| {"params": head.parameters(), "lr": cfg["lr_head"]}, |
| ] |
| opt = AdamW(groups, weight_decay=cfg["weight_decay"]) |
| return opt, use_8bit |
|
|
|
|
| def get_K_for_step(step: int, curriculum) -> int: |
| """curriculum = [[step_threshold, K], ...] (ascending).""" |
| K = curriculum[0][1] |
| for thr, k in curriculum: |
| if step >= thr: |
| K = k |
| return K |
|
|
|
|
| def cosine_lr(step, warmup, total, base_lr): |
| if step < warmup: |
| return base_lr * step / max(1, warmup) |
| progress = (step - warmup) / max(1, total - warmup) |
| return base_lr * 0.5 * (1.0 + math.cos(math.pi * min(progress, 1.0))) |
|
|
|
|
| def evaluate_quick(model, projector, tokenizer, val_loader, device, K, block_z_to_x: bool = False) -> dict: |
| """Quick eval: per-token LM perplexity + InfoNCE accuracy on the val set.""" |
| model.eval() |
| total_tok, total_loss, total_correct, total_total = 0, 0.0, 0, 0 |
| with torch.no_grad(): |
| for batch in val_loader: |
| x_ids = batch.x_ids.to(device) |
| x_attn = batch.x_attn.to(device) |
| y_ids = batch.y_ids.to(device) |
| y_attn = batch.y_attn.to(device) |
| logits_y, z, _ = forward_with_latent( |
| model, x_ids, x_attn, y_ids, projector, K, |
| block_y_to_x=True, block_z_to_x=block_z_to_x, |
| ) |
| B, L_y, V = logits_y.shape |
| ce = torch.nn.functional.cross_entropy( |
| logits_y.reshape(-1, V), y_ids.reshape(-1), reduction="none" |
| ).reshape(B, L_y) |
| mask = y_attn.float() |
| total_loss += (ce * mask).sum().item() |
| total_tok += mask.sum().item() |
| preds = logits_y.argmax(dim=-1) |
| total_correct += ((preds == y_ids) * mask).sum().item() |
| total_total += mask.sum().item() |
| model.train() |
| return { |
| "val_lm_ppl": math.exp(total_loss / max(total_tok, 1)), |
| "val_tok_acc": total_correct / max(total_total, 1), |
| } |
|
|
|
|
| def _load_state_from_ckpt(model, projector, head, ckpt_dir, device): |
| """Restore LoRA adapter, projector, and InfoNCE head from a ckpt dir. |
| |
| Optimizer state is NOT saved by save_every (see train loop), so resuming |
| re-initializes Adam moments. Loss curves will spike for a few hundred |
| steps but the latent geometry survives — adequate for instance-failure |
| recovery. |
| |
| ckpt_dir may be a local path or a `<hf-namespace>/<repo>[:subfolder]` ref. |
| |
| NB: peft serializes LoRA weights with keys like |
| `...lora_A.weight` but the wrapped PeftModel's state_dict uses |
| `...lora_A.default.weight` (the `.default` is the adapter name). Naive |
| `model.load_state_dict(sd, strict=False)` silently discards all 224 LoRA |
| matrices. We rewrite the keys here and verify nothing is "unexpected". |
| """ |
| from pathlib import Path |
| if "/" in ckpt_dir and not Path(ckpt_dir).exists(): |
| |
| from huggingface_hub import snapshot_download |
| repo, _, sub = ckpt_dir.partition(":") |
| local = snapshot_download(repo_id=repo, allow_patterns=(f"{sub}/*" if sub else None)) |
| ckpt_dir = str(Path(local) / sub) if sub else local |
| ckpt = Path(ckpt_dir) |
| from safetensors.torch import load_file |
| adapter_file = ckpt / "model" / "adapter_model.safetensors" |
| if adapter_file.exists(): |
| sd_raw = load_file(str(adapter_file)) |
| |
| sd = {} |
| for k, v in sd_raw.items(): |
| |
| nk = k |
| for tag in (".lora_A.weight", ".lora_B.weight", |
| ".lora_A.bias", ".lora_B.bias"): |
| if nk.endswith(tag): |
| nk = nk[: -len(tag)] + tag.replace(".weight", ".default.weight") \ |
| .replace(".bias", ".default.bias") |
| break |
| sd[nk] = v |
| missing, unexpected = model.load_state_dict(sd, strict=False) |
| |
| |
| lora_missing = [m for m in missing if "lora" in m.lower()] |
| print(f"[resume] adapter: total_missing={len(missing)} unexpected={len(unexpected)} " |
| f"lora_missing={len(lora_missing)} (lora_missing>0 ⇒ LoRA load failed)", flush=True) |
| if lora_missing: |
| print(f"[resume] WARNING: first 3 lora_missing keys: {lora_missing[:3]}", flush=True) |
| if unexpected: |
| print(f"[resume] WARNING: first 3 unexpected keys: {unexpected[:3]}", flush=True) |
| proj_file = ckpt / "projector.pt" |
| if proj_file.exists(): |
| projector.load_state_dict(torch.load(proj_file, map_location=device)) |
| print(f"[resume] projector loaded", flush=True) |
| head_file = ckpt / "head.pt" |
| if head_file.exists(): |
| head.load_state_dict(torch.load(head_file, map_location=device)) |
| print(f"[resume] head loaded", flush=True) |
| |
| m = ckpt.name |
| import re |
| mm = re.match(r"ckpt-step(\d+)", m) |
| return int(mm.group(1)) if mm else 0 |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", required=True) |
| parser.add_argument("--resume_from", default=None, |
| help="Local ckpt dir OR 'hf-namespace/repo[:subfolder]'. " |
| "Restores LoRA + projector + InfoNCE head; opt state is re-initialized.") |
| args = parser.parse_args() |
| with open(args.config) as f: |
| cfg = json.load(f) |
| set_seed(cfg["seed"]) |
|
|
| out_dir = Path(cfg["output_dir"]) |
| out_dir.mkdir(parents=True, exist_ok=True) |
| log_path = out_dir / "train.log" |
| metrics_path = out_dir / "metrics.jsonl" |
| log_f = open(log_path, "a", buffering=1) |
| met_f = open(metrics_path, "a", buffering=1) |
|
|
| def log(msg): |
| line = f"[{time.strftime('%H:%M:%S')}] {msg}" |
| print(line, flush=True) |
| log_f.write(line + "\n") |
|
|
| log(f"config={json.dumps(cfg)}") |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| log(f"device={device}") |
|
|
| |
| bcfg = BLTConfig( |
| base_model=cfg["base_model"], |
| use_lora=cfg["use_lora"], |
| lora_r=cfg["lora_r"], lora_alpha=cfg["lora_alpha"], |
| lora_dropout=cfg["lora_dropout"], |
| lora_target_modules=tuple(cfg["lora_target_modules"]), |
| K_latents=cfg["K_latents"], block_y_to_x=cfg["block_y_to_x"], |
| block_z_to_x=cfg.get("block_z_to_x", False), |
| proj_init_scale=cfg["proj_init_scale"], |
| dtype=cfg["dtype"], attn_impl=cfg["attn_impl"], |
| gradient_checkpointing=cfg.get("gradient_checkpointing", False), |
| ) |
| model, tokenizer = build_base(bcfg) |
| model.to(device) |
|
|
| inner = model.get_base_model() if hasattr(model, "get_base_model") else model |
| d_model = inner.config.hidden_size |
| model_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[cfg["dtype"]] |
| projector = LatentProjector( |
| d_model, init_scale=cfg["proj_init_scale"], |
| use_mlp=cfg.get("proj_mlp", False), |
| hidden_mult=cfg.get("proj_hidden_mult", 4), |
| ).to(device).to(model_dtype) |
| |
| head = InfoNCEHead(d_z=d_model, d_y=d_model, d_out=256).to(device) |
|
|
| n_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) \ |
| + sum(p.numel() for p in projector.parameters()) \ |
| + sum(p.numel() for p in head.parameters()) |
| log(f"trainable params (LoRA+proj+head) = {n_params_trainable/1e6:.2f}M") |
|
|
| |
| ds_name = cfg.get("dataset", "gsm8k") |
| train_ds = _build_dataset(ds_name, "train", cfg.get("data_train_size")) |
| val_ds = _build_dataset(ds_name, "test", cfg.get("data_eval_size") or 200) |
| log(f"dataset={ds_name} train={len(train_ds)} val={len(val_ds)}") |
|
|
| def make_loader(ds, bs, shuffle): |
| return DataLoader(ds, batch_size=bs, shuffle=shuffle, drop_last=shuffle, |
| collate_fn=lambda b: collate_batch( |
| b, tokenizer, |
| max_prompt_len=cfg["max_prompt_len"], |
| max_answer_len=cfg["max_answer_len"], |
| )) |
| train_loader = make_loader(train_ds, cfg["batch_size"], True) |
| val_loader = make_loader(val_ds, cfg["batch_size"], False) |
|
|
| opt, used_8bit = make_optimizer(model, projector, head, cfg) |
| log(f"optimizer 8bit={used_8bit}") |
|
|
| weights = LossWeights( |
| lambda_lm=cfg["lambda_lm"], lambda_id=cfg["lambda_id"], |
| lambda_kl=cfg["lambda_kl"], tau_infonce=cfg["tau_infonce"], |
| ) |
|
|
| step = 0 |
| if args.resume_from: |
| step = _load_state_from_ckpt(model, projector, head, args.resume_from, device) |
| log(f"[resume] restored from {args.resume_from} at step={step}") |
| accum_idx = 0 |
| t0 = time.time() |
| running = {"loss": 0.0, "lm": 0.0, "id": 0.0, "kl": 0.0, "z_norm": 0.0, "n": 0} |
|
|
| model.train() |
| while step < cfg["max_steps"]: |
| for batch in train_loader: |
| if step >= cfg["max_steps"]: |
| break |
| K = get_K_for_step(step, cfg["K_curriculum"]) |
| x_ids = batch.x_ids.to(device) |
| x_attn = batch.x_attn.to(device) |
| y_ids = batch.y_ids.to(device) |
| y_attn = batch.y_attn.to(device) |
|
|
| |
| |
| |
| |
| |
| |
| |
| use_per_slot = cfg.get("infonce_per_slot", False) |
| if use_per_slot: |
| from .data import split_y_into_chunks |
| from .losses import encode_chunks_per_slot, infonce_per_slot_loss |
| src = batch.full_answer_strs if batch.full_answer_strs is not None else batch.final_strs |
| chunks_per_problem = [split_y_into_chunks(s, K) for s in src] |
| _chunk_max_len = cfg.get("infonce_chunk_max_len", 32) |
| f_y_chunks = encode_chunks_per_slot( |
| model, tokenizer, chunks_per_problem, device=device, max_len=_chunk_max_len, |
| ) |
| else: |
| if cfg.get("infonce_full_answer", False): |
| _target_text = batch.full_answer_strs if batch.full_answer_strs is not None \ |
| else batch.final_strs |
| _target_max_len = cfg.get("infonce_target_max_len", 128) |
| else: |
| _target_text = batch.final_strs |
| _target_max_len = cfg.get("infonce_target_max_len", 16) |
| f_y = encode_answer_for_infonce( |
| model, tokenizer, _target_text, device=device, max_len=_target_max_len, |
| ) |
|
|
| logits_y, z, _ = forward_with_latent( |
| model, x_ids, x_attn, y_ids, projector, K, |
| block_y_to_x=cfg["block_y_to_x"], |
| block_z_to_x=cfg.get("block_z_to_x", False), |
| ) |
| L_lm = lm_loss_on_y(logits_y, y_ids, y_attn) |
|
|
| if use_per_slot: |
| _ps_info = infonce_per_slot_loss(z, f_y_chunks, head, tau=weights.tau_infonce) |
| L_id = _ps_info["loss"] |
| |
| running.setdefault("ps_acc_full", 0.0) |
| running.setdefault("ps_acc_within", 0.0) |
| running["ps_acc_full"] += float(_ps_info["acc_z2y"].item()) |
| running["ps_acc_within"] += float(_ps_info["acc_within_problem"].item()) |
| else: |
| |
| z_pool = z.mean(dim=1) |
| z_emb, y_emb = head(z_pool.float(), f_y.float()) |
| L_id = infonce_loss(z_emb, y_emb, tau=weights.tau_infonce) |
|
|
| L_kl = kl_to_gaussian(z.float()) |
|
|
| |
| lambda_decorr = cfg.get("lambda_decorr", 0.0) |
| if lambda_decorr > 0: |
| from .losses import slot_decorrelation_loss |
| L_decorr = slot_decorrelation_loss(z) |
| else: |
| L_decorr = torch.zeros((), device=device) |
|
|
| loss = (weights.lambda_lm * L_lm |
| + weights.lambda_id * L_id |
| + weights.lambda_kl * L_kl |
| + lambda_decorr * L_decorr) |
| (loss / cfg["grad_accum"]).backward() |
|
|
| running["loss"] += loss.item() |
| running["lm"] += L_lm.item() |
| running["id"] += L_id.item() |
| running["kl"] += L_kl.item() |
| running.setdefault("decorr", 0.0) |
| running["decorr"] += float(L_decorr.item()) |
| running["z_norm"] += z.float().pow(2).sum(dim=-1).mean().sqrt().item() |
| running["n"] += 1 |
|
|
| accum_idx += 1 |
| if accum_idx % cfg["grad_accum"] == 0: |
| lr_now = cosine_lr(step, cfg["warmup_steps"], cfg["max_steps"], cfg["lr_lora"]) |
| for pg, base in zip(opt.param_groups, |
| [cfg["lr_lora"], cfg["lr_proj"], cfg["lr_head"]]): |
| pg["lr"] = base * lr_now / cfg["lr_lora"] |
| torch.nn.utils.clip_grad_norm_( |
| [p for p in model.parameters() if p.requires_grad] |
| + list(projector.parameters()) + list(head.parameters()), |
| cfg["max_grad_norm"], |
| ) |
| opt.step() |
| opt.zero_grad(set_to_none=True) |
| step += 1 |
|
|
| if step % cfg["log_every"] == 0: |
| n = max(running["n"], 1) |
| log(f"step={step} K={K} loss={running['loss']/n:.4f} " |
| f"lm={running['lm']/n:.4f} id={running['id']/n:.4f} " |
| f"kl={running['kl']/n:.4f} z_norm={running['z_norm']/n:.3f} " |
| f"elapsed={time.time()-t0:.0f}s") |
| met_f.write(json.dumps({ |
| "step": step, "K": K, |
| "loss": running['loss']/n, |
| "lm": running['lm']/n, |
| "id": running['id']/n, |
| "kl": running['kl']/n, |
| "z_norm": running['z_norm']/n, |
| "elapsed_s": time.time()-t0, |
| }) + "\n") |
| running = {k: (0.0 if k != "n" else 0) for k in running} |
|
|
| if cfg["eval_every"] > 0 and step % cfg["eval_every"] == 0: |
| quick = evaluate_quick(model, projector, tokenizer, val_loader, device, K, |
| block_z_to_x=cfg.get("block_z_to_x", False)) |
| log(f"[eval] step={step} {quick}") |
| met_f.write(json.dumps({"step": step, "eval": quick}) + "\n") |
|
|
| if cfg["save_every"] > 0 and step % cfg["save_every"] == 0: |
| save_dir = out_dir / f"ckpt-step{step}" |
| save_dir.mkdir(exist_ok=True) |
| model.save_pretrained(save_dir / "model") |
| torch.save(projector.state_dict(), save_dir / "projector.pt") |
| torch.save(head.state_dict(), save_dir / "head.pt") |
| log(f"[save] {save_dir}") |
|
|
| |
| save_dir = out_dir / "final" |
| save_dir.mkdir(exist_ok=True) |
| model.save_pretrained(save_dir / "model") |
| torch.save(projector.state_dict(), save_dir / "projector.pt") |
| torch.save(head.state_dict(), save_dir / "head.pt") |
| log(f"[done] final saved at {save_dir}") |
| log_f.close() |
| met_f.close() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|