OpenTransformer's picture
Squash history to reclaim LFS storage after pruning older half of checkpoints
4216271
#!/usr/bin/env python3
# n2e.py — Joint AR+SAT Trainer with M-folded attention (exact, no tradeoff)
#
# THEOREM (Shared-U Score Capacity).
# Let U in R^{dk x r} parameterize the attention metric M(U) := U U^T in R^{dk x dk}.
# The set of expressible metrics is exactly
# S(U) = { M in Sym^{dk} : M >= 0, rank(M) <= min(dk, r) }
# In particular, when r >= dk, S(U) equals the full PSD(dk) cone, so additional
# r beyond dk adds NO representational capacity for attention scores -- only
# parameters and (possibly) optimization-landscape benefits.
# In this file U is shared across heads, so any redundancy count is per attention
# layer, not per head.
# Proof: see compact_attention_U() below; constructive via eigendecomposition.
#
# CONSEQUENCE 1 (M-fold computation, applied in TuneableAttentionMHA.forward).
# (X_q U)(X_k U)^T == X_q (U U^T) X_k^T == X_q M X_k^T.
# When r > dk, computing scores via M in R^{dk x dk} cuts the per-layer FLOPs by
# factor r/dk on the score matmul AND eliminates two N*dk*r projections.
# Numerical equivalence verified to ~1e-12 in float64 and within fp32 matmul
# ordering noise for both forward scores AND gradients on Wq, Wk, Wv, U.
#
# CONSEQUENCE 2 (post-train compression, applied via the `compact` subcommand).
# After training, M has effective rank r* <= dk. We can replace U in R^{dk x r}
# with U' in R^{dk x r*} such that U' U'^T == U U^T (eigendecomposition closed
# form). Checkpoints need one shared cfg rank, so the compact command stores all
# layers at the max effective rank and zero-pads lower-rank layers. For r > dk
# this is a strict parameter reduction with zero loss of attention behaviour.
# Use: python n2e.py compact --src CKPT --dst OUT
#
# CONSEQUENCE 3 (SAT-spec exactness, applied in `infer --mode spec`).
# Let p be the AR distribution and q be the SAT draft distribution for the next
# token. Draw x ~ q, accept it with alpha(x)=min(1, p(x)/q(x)); on rejection,
# sample from r(x) proportional to max(p(x)-q(x), 0). Then the final token is
# distributed exactly as p. Repeating this over a verified SAT draft block gives
# AR-quality output while accepting multiple SAT tokens whenever q is close to p.
# Use: python n2e.py prove --theorem spec_kernel
#
# CONSEQUENCE 4 (exact SDPA backend, applied by default).
# After M-folding, the remaining operation is still ordinary scaled dot-product
# attention with an additive mask/ALiBi bias. Therefore the manual
# score->softmax->value pipeline can be replaced by
# torch.nn.functional.scaled_dot_product_attention with identical outputs up to
# floating-point ordering. Disable with --no_sdpa if a backend misbehaves.
# Use: python n2e.py prove --theorem sdpa_equivalence
#
# TOKENIZER.
# Default tokenizer is now deepseek-ai/DeepSeek-V4-Pro. If Transformers does
# not know the new deepseek_v4 config yet, n2e.py loads tokenizer.json directly.
# OpenAI/tiktoken mode is available with:
# TOKENIZER_BACKEND=tiktoken TIKTOKEN_ENCODING=o200k_base python n2e.py ...
#
# Enhanced inference: checkpoint name, tok/s, UK time
from __future__ import annotations
import argparse, json, math, pathlib, random, time, os, sys, threading, hashlib, re, subprocess, shutil
from pathlib import Path
from contextlib import nullcontext
from typing import Dict, Any, List, Optional, Tuple
from datetime import datetime, timezone
STATUS_SCRIPT_PATH = Path(__file__).resolve()
STATUS_DEFAULT_LOG = STATUS_SCRIPT_PATH.parent / "train.log"
STATUS_DEFAULT_SAVE_DIR = STATUS_SCRIPT_PATH.parent / "ckpts_expansion"
_STATUS_PROGRESS_RE = re.compile(
r"^\[(?P<percent>\d+(?:\.\d+)?)%\]\s+"
r"(?P<seen>[\d,]+)/(?P<target>[\d,]+)\s+tok\s+\|\s+"
r"(?P<tok_s>[\d.]+)\s+tok/s\s+\|\s+"
r"loss=(?P<loss>-?[\d.]+)\s+B=(?P<batch>\d+)\s+L=(?P<block>\d+)\s*$"
)
_STATUS_DELTA_RE = re.compile(r"\[delta\]\s+saved\s+(?P<name>\S+?\.pt)\s+\((?P<sha>[0-9a-f]+)\.\.\.\)")
_STATUS_STEP_RE = re.compile(r"step(?P<step>\d+)")
def _status_iso(ts: Optional[float]) -> Optional[str]:
if ts is None:
return None
return datetime.fromtimestamp(ts, tz=timezone.utc).astimezone().isoformat(timespec="seconds")
def _status_human_duration(seconds: Optional[float]) -> Optional[str]:
if seconds is None:
return None
total = max(0, int(seconds))
days, rem = divmod(total, 86400)
hours, rem = divmod(rem, 3600)
minutes, secs = divmod(rem, 60)
parts = []
if days:
parts.append(f"{days}d")
if hours or parts:
parts.append(f"{hours}h")
if minutes or parts:
parts.append(f"{minutes}m")
parts.append(f"{secs}s")
return " ".join(parts)
def _status_format_int(value: Optional[int]) -> str:
return "?" if value is None else f"{value:,}"
def _status_parse_step(text: str) -> Optional[int]:
match = _STATUS_STEP_RE.search(text)
return int(match.group("step")) if match else None
def _status_resolve_ckpt_path(raw_path: str, base_dir: Path) -> Path:
ckpt_path = Path(raw_path)
return ckpt_path if ckpt_path.is_absolute() else (base_dir / ckpt_path).resolve()
def _status_read_cmdline(proc_dir: Path) -> Optional[List[str]]:
try:
data = (proc_dir / "cmdline").read_bytes().split(b"\0")
return [item.decode("utf-8", errors="ignore") for item in data if item]
except Exception:
return None
def _status_resolve_proc_arg(proc_dir: Path, raw_arg: str) -> Optional[Path]:
try:
arg_path = Path(raw_arg)
if arg_path.is_absolute():
return arg_path.resolve()
cwd = Path(os.readlink(proc_dir / "cwd"))
return (cwd / arg_path).resolve()
except Exception:
return None
def _status_proc_uptime(proc_dir: Path) -> Optional[float]:
try:
proc_uptime = float((Path("/proc") / "uptime").read_text().split()[0])
stat_text = (proc_dir / "stat").read_text()
after = stat_text[stat_text.rfind(")") + 2:].split()
start_ticks = float(after[19])
clock_ticks = os.sysconf(os.sysconf_names["SC_CLK_TCK"])
return max(0.0, proc_uptime - (start_ticks / clock_ticks))
except Exception:
return None
def _status_find_trainers(script_path: Path) -> List[Dict[str, Any]]:
matches: List[Dict[str, Any]] = []
for proc_dir in Path("/proc").iterdir():
if not proc_dir.name.isdigit():
continue
args = _status_read_cmdline(proc_dir)
if not args or "train" not in args:
continue
resolved_script = None
for arg in args:
if Path(arg).name != script_path.name:
continue
candidate = _status_resolve_proc_arg(proc_dir, arg)
if candidate == script_path:
resolved_script = candidate
break
if resolved_script is None:
continue
uptime_seconds = _status_proc_uptime(proc_dir)
try:
cwd = str(Path(os.readlink(proc_dir / "cwd")))
except Exception:
cwd = None
matches.append({
"pid": int(proc_dir.name),
"cmdline": " ".join(args),
"args": args,
"cwd": cwd,
"uptime_seconds": round(uptime_seconds, 3) if uptime_seconds is not None else None,
"uptime_human": _status_human_duration(uptime_seconds),
})
return sorted(matches, key=lambda item: item["pid"])
def _status_parse_progress_line(line: str) -> Optional[Dict[str, Any]]:
match = _STATUS_PROGRESS_RE.match(line.strip())
if not match:
return None
tok_per_sec = float(match.group("tok_s"))
loss = float(match.group("loss"))
return {
"raw_line": line.strip(),
"percent": float(match.group("percent")),
"seen_tokens": int(match.group("seen").replace(",", "")),
"target_tokens": int(match.group("target").replace(",", "")),
"tok_per_sec": int(tok_per_sec) if tok_per_sec.is_integer() else tok_per_sec,
"loss": loss,
"batch": int(match.group("batch")),
"block": int(match.group("block")),
}
def _status_parse_delta_line(line: str) -> Optional[Dict[str, Any]]:
match = _STATUS_DELTA_RE.search(line)
if not match:
return None
name = match.group("name")
return {
"raw_line": line.strip(),
"name": name,
"step": _status_parse_step(name),
"sha_prefix": match.group("sha"),
"source": "log",
}
def _status_scan_log(log_path: Path) -> tuple[Dict[str, Any], Optional[Dict[str, Any]], Optional[Dict[str, Any]], List[str]]:
now = time.time()
info: Dict[str, Any] = {
"path": str(log_path),
"exists": log_path.exists(),
"mtime": None,
"mtime_iso": None,
"age_seconds": None,
"age_human": None,
"size_bytes": None,
}
warnings: List[str] = []
if not log_path.exists():
warnings.append(f"train log missing: {log_path}")
return info, None, None, warnings
try:
st = log_path.stat()
info["mtime"] = st.st_mtime
info["mtime_iso"] = _status_iso(st.st_mtime)
info["age_seconds"] = round(max(0.0, now - st.st_mtime), 3)
info["age_human"] = _status_human_duration(info["age_seconds"])
info["size_bytes"] = st.st_size
except Exception as exc:
warnings.append(f"failed to stat train log: {exc}")
last_progress = None
last_delta = None
try:
with log_path.open("r", encoding="utf-8", errors="ignore") as handle:
for raw_line in handle:
line = raw_line.rstrip("\n")
progress = _status_parse_progress_line(line)
if progress is not None:
last_progress = progress
delta = _status_parse_delta_line(line)
if delta is not None:
last_delta = delta
except Exception as exc:
warnings.append(f"failed to read train log: {exc}")
return info, last_progress, last_delta, warnings
def _status_latest_full_checkpoint(save_dir: Path, base_dir: Path) -> tuple[Dict[str, Any], List[str]]:
latest_path = save_dir / "latest.json"
info: Dict[str, Any] = {
"metadata_path": str(latest_path),
"exists": latest_path.exists(),
"raw_path": None,
"checkpoint_path": None,
"checkpoint_name": None,
"checkpoint_exists": None,
"step": None,
"checkpoint_mtime": None,
"checkpoint_mtime_iso": None,
}
warnings: List[str] = []
if not latest_path.exists():
warnings.append(f"latest.json missing: {latest_path}")
return info, warnings
try:
payload = json.loads(latest_path.read_text(encoding="utf-8"))
except Exception as exc:
warnings.append(f"failed to parse latest.json: {exc}")
return info, warnings
raw_path = payload.get("path")
info["raw_path"] = raw_path
info["step"] = payload.get("step")
if raw_path:
ckpt_path = _status_resolve_ckpt_path(raw_path, base_dir)
info["checkpoint_path"] = str(ckpt_path)
info["checkpoint_name"] = ckpt_path.name
info["checkpoint_exists"] = ckpt_path.exists()
if ckpt_path.exists():
try:
st = ckpt_path.stat()
info["checkpoint_mtime"] = st.st_mtime
info["checkpoint_mtime_iso"] = _status_iso(st.st_mtime)
except Exception as exc:
warnings.append(f"failed to stat full checkpoint: {exc}")
else:
warnings.append(f"latest.json points to missing checkpoint: {ckpt_path}")
return info, warnings
def _status_newest_delta(save_dir: Path) -> tuple[Optional[Dict[str, Any]], List[str]]:
warnings: List[str] = []
if not save_dir.exists():
warnings.append(f"save dir missing: {save_dir}")
return None, warnings
try:
candidates = [item for item in save_dir.glob("*_delta_step*.pt") if item.is_file()]
except Exception as exc:
warnings.append(f"failed to list delta checkpoints: {exc}")
return None, warnings
if not candidates:
warnings.append(f"no delta checkpoints found in {save_dir}")
return None, warnings
newest = max(candidates, key=lambda item: item.stat().st_mtime)
st = newest.stat()
return {
"path": str(newest),
"name": newest.name,
"step": _status_parse_step(newest.name),
"mtime": st.st_mtime,
"mtime_iso": _status_iso(st.st_mtime),
"size_bytes": st.st_size,
"source": "disk",
}, warnings
def _status_gpu_info() -> tuple[Optional[Dict[str, Any]], List[str]]:
warnings: List[str] = []
try:
result = subprocess.run(
[
"nvidia-smi",
"--query-gpu=name,utilization.gpu,memory.used,memory.total,temperature.gpu,power.draw",
"--format=csv,noheader,nounits",
],
capture_output=True,
text=True,
timeout=5,
check=False,
)
except FileNotFoundError:
return None, warnings
except Exception as exc:
warnings.append(f"failed to query GPU status: {exc}")
return None, warnings
if result.returncode != 0:
warnings.append(result.stderr.strip() or "nvidia-smi returned non-zero exit status")
return None, warnings
lines = [line.strip() for line in result.stdout.splitlines() if line.strip()]
if not lines:
return None, warnings
if len(lines) > 1:
warnings.append("multiple GPUs detected; reporting the first GPU only")
parts = [part.strip() for part in lines[0].split(",")]
if len(parts) != 6:
warnings.append(f"unexpected nvidia-smi format: {lines[0]}")
return None, warnings
def _parse_int(raw: str) -> Optional[int]:
try:
return int(float(raw))
except Exception:
return None
def _parse_float(raw: str) -> Optional[float]:
try:
return float(raw)
except Exception:
return None
return {
"name": parts[0],
"utilization_gpu": _parse_int(parts[1]),
"memory_used_mib": _parse_int(parts[2]),
"memory_total_mib": _parse_int(parts[3]),
"temperature_c": _parse_int(parts[4]),
"power_draw_w": _parse_float(parts[5]),
}, warnings
def _status_choose_delta(from_log: Optional[Dict[str, Any]], from_disk: Optional[Dict[str, Any]], warnings: List[str]) -> Optional[Dict[str, Any]]:
if from_log and from_disk:
log_step = from_log.get("step")
disk_step = from_disk.get("step")
if log_step is not None and disk_step is not None:
if log_step != disk_step:
warnings.append(
f"log delta step {log_step} and newest on-disk delta step {disk_step} differ; using the newer step"
)
if disk_step >= log_step:
merged = dict(from_disk)
merged["source"] = "disk+log" if disk_step == log_step else "disk"
if disk_step == log_step:
merged["sha_prefix"] = from_log.get("sha_prefix")
return merged
return dict(from_log)
return dict(from_disk)
if from_disk:
return dict(from_disk)
if from_log:
return dict(from_log)
return None
def _collect_status(log_path: Path, save_dir: Path) -> tuple[Dict[str, Any], int]:
checked_at = time.time()
requested_save_dir = save_dir.expanduser()
log_path = log_path.expanduser()
status: Dict[str, Any] = {
"checked_at": checked_at,
"checked_at_iso": _status_iso(checked_at),
"running": False,
"process": None,
"progress": None,
"delta_checkpoint": None,
"delta_from_log": None,
"delta_on_disk": None,
"latest_full_checkpoint": None,
"log": None,
"gpu": None,
"save_dir": {
"requested_path": str(requested_save_dir),
"path": str(requested_save_dir),
"exists": requested_save_dir.exists(),
"source": "requested",
},
"warnings": [],
}
warnings = status["warnings"]
matches = _status_find_trainers(STATUS_SCRIPT_PATH)
if len(matches) > 1:
status["error"] = "multiple active n.py train processes found"
status["processes"] = matches
return status, 1
if matches:
status["running"] = True
status["process"] = matches[0]
save_dir = requested_save_dir
if status["process"] and status["process"].get("cwd"):
proc_cwd = Path(status["process"]["cwd"])
alt_save_dir = (proc_cwd / requested_save_dir.name).resolve()
if alt_save_dir != requested_save_dir and alt_save_dir.exists():
requested_delta, _ = _status_newest_delta(requested_save_dir)
requested_full, _ = _status_latest_full_checkpoint(requested_save_dir, STATUS_SCRIPT_PATH.parent)
alt_delta, _ = _status_newest_delta(alt_save_dir)
alt_full, _ = _status_latest_full_checkpoint(alt_save_dir, proc_cwd)
requested_score = int(requested_delta is not None) + int(bool(requested_full.get("checkpoint_exists")))
alt_score = int(alt_delta is not None) + int(bool(alt_full.get("checkpoint_exists")))
if alt_score > requested_score:
save_dir = alt_save_dir
status["save_dir"] = {
"requested_path": str(requested_save_dir),
"path": str(save_dir),
"exists": save_dir.exists(),
"source": "process_cwd_fallback",
}
warnings.append(
f"using process cwd save dir fallback: {save_dir} (requested {requested_save_dir})"
)
log_info, progress, delta_from_log, log_warnings = _status_scan_log(log_path)
warnings.extend(log_warnings)
status["log"] = log_info
status["progress"] = progress
status["delta_from_log"] = delta_from_log
latest_base_dir = STATUS_SCRIPT_PATH.parent
if status["save_dir"].get("source") == "process_cwd_fallback" and status["process"] and status["process"].get("cwd"):
latest_base_dir = Path(status["process"]["cwd"])
latest_full, latest_warnings = _status_latest_full_checkpoint(save_dir, latest_base_dir)
warnings.extend(latest_warnings)
status["latest_full_checkpoint"] = latest_full
delta_on_disk, delta_warnings = _status_newest_delta(save_dir)
warnings.extend(delta_warnings)
status["delta_on_disk"] = delta_on_disk
status["delta_checkpoint"] = _status_choose_delta(delta_from_log, delta_on_disk, warnings)
gpu, gpu_warnings = _status_gpu_info()
warnings.extend(gpu_warnings)
status["gpu"] = gpu
if status["running"] and log_info.get("age_seconds") is not None and log_info["age_seconds"] > 600:
warnings.append(f"train log appears stale while trainer is running ({log_info['age_human']} old)")
if log_info.get("exists") and progress is None:
warnings.append("no parseable progress line found in train log")
latest_step = latest_full.get("step") if latest_full else None
delta_step = status["delta_checkpoint"].get("step") if status["delta_checkpoint"] else None
if latest_step is not None and delta_step is not None and latest_step < delta_step:
warnings.append(f"latest.json step {latest_step} lags newest delta step {delta_step}")
if not status["running"] and progress is None:
warnings.append("no active trainer process found")
return status, 0
def _format_status_text(status: Dict[str, Any]) -> str:
lines = [f"AGILLM status @ {status.get('checked_at_iso')}"]
if status.get("error"):
lines.append(f"Error: {status['error']}")
for proc in status.get("processes", []):
lines.append(f"- pid {proc.get('pid')}: {proc.get('cmdline')}")
return "\n".join(lines)
process = status.get("process")
if status.get("running") and process:
lines.append(f"Process: RUNNING | pid {process.get('pid')} | uptime {process.get('uptime_human') or 'unknown'}")
lines.append(f"Cmd: {process.get('cmdline')}")
else:
lines.append("Process: NOT RUNNING")
progress = status.get("progress")
if progress:
lines.append(
"Progress: "
f"{progress['percent']:.1f}% | "
f"{_status_format_int(progress['seen_tokens'])}/{_status_format_int(progress['target_tokens'])} tok | "
f"{progress['tok_per_sec']} tok/s | loss {progress['loss']:.3f} | "
f"B={progress['batch']} L={progress['block']}"
)
else:
lines.append("Progress: unavailable")
log_info = status.get("log") or {}
if log_info.get("exists"):
lines.append(
f"Log: {log_info.get('path')} | updated {log_info.get('age_human') or 'unknown'} ago | "
f"mtime {log_info.get('mtime_iso')}"
)
else:
lines.append(f"Log: missing ({log_info.get('path')})")
delta = status.get("delta_checkpoint")
if delta:
line = f"Delta: {delta.get('name')} | step {delta.get('step')} | source {delta.get('source')}"
if delta.get("path"):
line += f" | {delta['path']}"
lines.append(line)
else:
lines.append("Delta: unavailable")
latest_full = status.get("latest_full_checkpoint") or {}
if latest_full.get("exists"):
lines.append(
f"Latest full: step {latest_full.get('step')} | {latest_full.get('checkpoint_path') or latest_full.get('raw_path')}"
)
else:
lines.append(f"Latest full: unavailable ({latest_full.get('metadata_path')})")
gpu = status.get("gpu")
if gpu:
lines.append(
f"GPU: {gpu.get('name')} | {gpu.get('utilization_gpu')}% | "
f"{gpu.get('memory_used_mib')}/{gpu.get('memory_total_mib')} MiB | "
f"{gpu.get('temperature_c')}C | {gpu.get('power_draw_w')} W"
)
warnings = status.get("warnings") or []
if warnings:
lines.append("Warnings:")
lines.extend(f"- {warning}" for warning in warnings)
return "\n".join(lines)
def _emit_status(log_path: Path, save_dir: Path, as_json: bool) -> int:
status, exit_code = _collect_status(log_path, save_dir)
if as_json:
print(json.dumps(status, indent=2, sort_keys=True))
else:
print(_format_status_text(status))
return exit_code
def _run_status_command(argv: List[str]) -> int:
parser = argparse.ArgumentParser(prog=f"{STATUS_SCRIPT_PATH.name} status", description="Read-only training status")
parser.add_argument("--json", dest="json_output", action="store_true", help="Emit machine-readable JSON")
parser.add_argument("--log", type=Path, default=STATUS_DEFAULT_LOG, help="Path to the training log")
parser.add_argument("--save_dir", type=Path, default=STATUS_DEFAULT_SAVE_DIR, help="Checkpoint directory")
args = parser.parse_args(argv)
return _emit_status(args.log, args.save_dir, args.json_output)
def _maybe_handle_status_fastpath() -> None:
if len(sys.argv) > 1 and sys.argv[1] == "status":
raise SystemExit(_run_status_command(sys.argv[2:]))
_maybe_handle_status_fastpath()
import torch
# SafeProgress - Claude-safe progress (discrete lines, not single growing line)
class SafeProgress:
def __init__(self, total, initial=0, unit="tok", print_every_tokens=1_000_000):
self.total, self.n, self.unit = total, initial, unit
self.initial = initial
self.last_print, self.postfix = initial, {}
self.print_every_tokens = print_every_tokens
self.start_time = time.time()
def update(self, n=1):
self.n += n
if self.n - self.last_print >= self.print_every_tokens:
self._print(); self.last_print = self.n
def set_postfix(self, **kwargs):
# Store values lazily: GPU tensors stay on GPU until _print() needs a scalar.
# This avoids one CUDA sync per training step caused by .item() in the caller.
self.postfix = kwargs
def _resolve(self, v):
if hasattr(v, "item") and not isinstance(v, (int, float, str, bytes)):
try:
return f"{float(v.item()):.3f}"
except Exception:
return str(v)
return str(v)
def _print(self):
elapsed = __import__('time').time() - self.start_time
rate = (self.n - self.initial) / elapsed if elapsed > 0 else 0
pct = 100 * self.n / self.total if self.total > 0 else 0
pf = ' '.join(f"{k}={self._resolve(v)}" for k, v in self.postfix.items())
print(f"[{pct:.1f}%] {self.n:,}/{self.total:,} {self.unit} | {rate:.0f} tok/s | {pf}")
def close(self): self._print(); print("Done.")
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset, DownloadConfig
from transformers import AutoTokenizer, PreTrainedTokenizerFast, logging as hf_log
# from tqdm.auto import tqdm # DISABLED - kills Claude context
# ─────────────────────────────── HOT DATASET LOADING ───────────────────────────────
HOT_CONFIG_PATH = Path("/workspace/hot_config.json")
_hot_config_cache = {"mtime": 0, "data": {}}
def get_hot_config() -> dict:
"""Load hot_config.json with caching, return empty dict if missing"""
try:
if HOT_CONFIG_PATH.exists():
mtime = HOT_CONFIG_PATH.stat().st_mtime
if mtime > _hot_config_cache["mtime"]:
with open(HOT_CONFIG_PATH) as f:
_hot_config_cache["data"] = json.load(f)
_hot_config_cache["mtime"] = mtime
return _hot_config_cache["data"]
except Exception as e:
print(f"[hot_config] Error loading: {e}")
return {}
def get_hot_datasets(default_sources: str) -> str:
"""Get datasets from hot_config if present, else use default"""
cfg = get_hot_config()
if "datasets" in cfg and cfg["datasets"]:
hot_ds = cfg["datasets"]
if isinstance(hot_ds, list):
hot_ds = ",".join(hot_ds)
print(f"[hot_config] Using hot datasets: {hot_ds}")
return hot_ds
return default_sources
# DISABLED: # Auto-rotating log to prevent context-window suicide
# DISABLED: try:
# DISABLED: from rotating_log import install_rotating_log
# DISABLED: install_rotating_log()
# DISABLED: except ImportError:
# pass # Running without rotation
# ───────────────────────── ANSI Colors ─────────────────────────
class Colors:
RESET = "\033[0m"
BOLD = "\033[1m"
PROMPT = "\033[36m"
GEN = "\033[0m"
INFO = "\033[90m"
WARN = "\033[93m"
# ───────────────────────── Globals ─────────────────────────
hf_log.set_verbosity_error()
DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True
try:
torch.set_float32_matmul_precision("high")
except Exception:
pass
TOKENIZER_BACKEND = os.environ.get("TOKENIZER_BACKEND", "hf").strip().lower()
TOKENIZER_ID = os.environ.get("TOKENIZER_ID", "deepseek-ai/DeepSeek-V4-Pro")
TIKTOKEN_ENCODING = os.environ.get("TIKTOKEN_ENCODING", "o200k_base")
class TiktokenTokenizerAdapter:
"""Tiny compatibility wrapper for OpenAI tiktoken encodings."""
backend_name = "tiktoken"
def __init__(self, encoding_name: str):
import tiktoken
self.encoding_name = encoding_name
self.enc = tiktoken.get_encoding(encoding_name)
self.eos_token_id = getattr(self.enc, "eot_token", None)
self.sep_token_id = self.eos_token_id
self.pad_token = "<|endoftext|>"
self.pad_token_id = self.eos_token_id
self.vocab_size = self.enc.n_vocab
def encode(self, text: str, *args, **kwargs):
return self.enc.encode(text, allowed_special="all")
def decode(self, ids, *args, **kwargs):
return self.enc.decode([int(i) for i in ids])
def get_vocab(self):
return {str(i): i for i in range(self.enc.n_vocab)}
def add_special_tokens(self, *_args, **_kwargs):
return 0
def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=False):
rendered = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
rendered.append(f"{role}: {content}")
if add_generation_prompt:
rendered.append("assistant:")
text = "\n".join(rendered)
return self.encode(text) if tokenize else text
def _load_tokenizer():
if TOKENIZER_BACKEND in {"tiktoken", "openai"}:
print(f"[tokenizer] backend=tiktoken encoding={TIKTOKEN_ENCODING}")
return TiktokenTokenizerAdapter(TIKTOKEN_ENCODING)
if TOKENIZER_BACKEND not in {"hf", "huggingface", "transformers"}:
print(f"[tokenizer] WARNING: unknown TOKENIZER_BACKEND={TOKENIZER_BACKEND!r}; falling back to Hugging Face")
print(f"[tokenizer] backend=hf id={TOKENIZER_ID}")
try:
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True)
except Exception as e:
print(f"[tokenizer] AutoTokenizer failed ({type(e).__name__}: {e})")
print("[tokenizer] Falling back to direct tokenizer.json load")
from huggingface_hub import hf_hub_download
tokenizer_file = hf_hub_download(TOKENIZER_ID, "tokenizer.json")
tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file)
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
tokenizer.backend_name = "hf"
return tokenizer
tok = _load_tokenizer()
# ─── Fix tokenizer Ġ/▁ mismatch ───
# Some DeepSeek tokenizer releases use Ä  (U+0120) for space-prefixed tokens,
# but some transformers versions set the Metaspace pre-tokenizer to use
# ▁ (U+2581) instead, causing encode/decode to lose all spaces.
def _fix_tokenizer_space_mismatch(tokenizer):
try:
if not hasattr(tokenizer, "backend_tokenizer"):
return
import json as _json
from tokenizers import Tokenizer as _Tokenizer
bt = tokenizer.backend_tokenizer
tj = _json.loads(bt.to_str())
pre = tj.get("pre_tokenizer", {})
needs_fix = (pre.get("type") == "Metaspace" and pre.get("replacement") == "\u2581")
if not needs_fix:
return
# Check if vocab actually uses Ä  (U+0120) for spaces
vocab = tj.get("model", {}).get("vocab", {})
has_gpt2_space = any(k.startswith("\u0120") for k in list(vocab.keys())[:500])
if not has_gpt2_space:
return
# Patch pre_tokenizer: ▁ -> Ġ
tj["pre_tokenizer"]["replacement"] = "\u0120"
# Patch decoder: ▁ -> Ġ in Replace step
for step in tj.get("decoder", {}).get("decoders", []):
if step.get("type") == "Replace":
pat = step.get("pattern", {})
if pat.get("String") == "\u2581":
pat["String"] = "\u0120"
# Rebuild backend tokenizer
fixed = _Tokenizer.from_str(_json.dumps(tj))
tokenizer.backend_tokenizer = fixed
# Verify fix
test_ids = tokenizer.encode("hello world")
test_dec = tokenizer.decode(test_ids, skip_special_tokens=True)
if "hello world" in test_dec:
print("[tokenizer] Fixed Ġ/▁ space mismatch")
else:
print(f"[tokenizer] WARNING: fix applied but decode test failed: {repr(test_dec)}")
except Exception as e:
print(f"[tokenizer] Could not fix space mismatch: {e}")
_fix_tokenizer_space_mismatch(tok)
# ─── Tokenizer startup health check ───
# Abort early if tokenizer can't roundtrip spaces — prevents silent data corruption
def _tokenizer_health_check(tokenizer):
import transformers as _tf
ver = _tf.__version__
print(f"[tokenizer] transformers={ver}, tokenizers={__import__('tokenizers').__version__}")
# Warn on known-bad versions
try:
from packaging.version import Version
if Version(ver) >= Version('5.0.0'):
print(f'[tokenizer] WARNING: transformers {ver} may have Metaspace bug — verify carefully')
except ImportError:
pass
# Roundtrip tests — must preserve spaces
tests = [
'Water boils at one hundred degrees',
'The quick brown fox jumps over the lazy dog',
'Hello world! This is a test sentence with spaces.',
]
for text in tests:
ids = tokenizer.encode(text)
decoded = tokenizer.decode(ids, skip_special_tokens=True)
if ' ' not in decoded:
print(f'[tokenizer] FATAL: Roundtrip lost all spaces!')
print(f' Input: {repr(text)}')
print(f' Encoded: {ids[:20]}...')
print(f' Decoded: {repr(decoded)}')
print(f'[tokenizer] ABORTING — fix tokenizer before training!')
sys.exit(1)
# Check decoded is reasonably close to input
if text.lower().split()[:3] != decoded.lower().split()[:3]:
print(f'[tokenizer] WARNING: Roundtrip diverged:')
print(f' Input: {repr(text[:60])}')
print(f' Decoded: {repr(decoded[:60])}')
print(f'[tokenizer] Health check PASSED — spaces preserved in roundtrip')
_tokenizer_health_check(tok)
def _tokenizer_vocab_size(tokenizer) -> int:
if isinstance(tokenizer, TiktokenTokenizerAdapter):
return tokenizer.vocab_size
vocab_size = getattr(tokenizer, "vocab_size", None)
try:
actual = max(tokenizer.get_vocab().values()) + 1
except Exception:
actual = 0
if isinstance(vocab_size, int) and vocab_size > 0:
return max(vocab_size, actual)
return actual
def _tokenizer_eos_id(tokenizer) -> int:
eos = getattr(tokenizer, "eos_token_id", None)
if eos is not None:
return eos
sep = getattr(tokenizer, "sep_token_id", None)
if sep is not None:
return sep
pad = getattr(tokenizer, "pad_token_id", None)
if pad is not None:
return pad
return _tokenizer_vocab_size(tokenizer) - 1
VOCAB, EOS = _tokenizer_vocab_size(tok), _tokenizer_eos_id(tok)
print(f"[tokenizer] vocab={VOCAB:,} eos={EOS}")
# ───────────────────────── PRESETS ─────────────────────────
PRESETS: Dict[str, Dict[str, int]] = {
"femto_1x": dict(d=16, layers=1, heads=1, rank=16),
"femto_12x": dict(d=16, layers=1, heads=1, rank=192),
"femto_24x": dict(d=16, layers=1, heads=1, rank=384),
"pico_1x": dict(d=32, layers=1, heads=2, rank=16),
"pico_3x": dict(d=32, layers=1, heads=2, rank=48),
"pico_6x": dict(d=32, layers=1, heads=2, rank=96),
"pico_12x": dict(d=32, layers=1, heads=2, rank=192),
"pico_24x": dict(d=32, layers=1, heads=2, rank=384),
"pico_48x": dict(d=32, layers=1, heads=2, rank=768),
"nano_1x": dict(d=64, layers=2, heads=4, rank=16),
"nano_3x": dict(d=64, layers=2, heads=4, rank=48),
"nano_6x": dict(d=64, layers=2, heads=4, rank=96),
"nano_12x": dict(d=64, layers=2, heads=4, rank=192),
"nano_24x": dict(d=64, layers=2, heads=4, rank=384),
"nano_48x": dict(d=64, layers=2, heads=4, rank=768),
"nano_96x": dict(d=64, layers=2, heads=4, rank=1536),
"micro_3x": dict(d=128, layers=4, heads=8, rank=48),
"micro_6x": dict(d=128, layers=4, heads=8, rank=96),
"micro_12x": dict(d=128, layers=4, heads=8, rank=192),
"micro_24x": dict(d=128, layers=4, heads=8, rank=384),
"small": dict(d=512, layers=8, heads=16, rank=64),
"smallx2": dict(d=512, layers=16, heads=16, rank=64),
"base": dict(d=768, layers=12, heads=24, rank=96),
"base18": dict(d=768, layers=18, heads=24, rank=96),
"large": dict(d=1024, layers=24, heads=16, rank=128),
}
DEFAULT_BLOCK = 1122
DEFAULT_BATCH = 4
SAT_BLOCK = 2
LR_CORE, LR_HEAD = 5e-5, 2e-4
EMIT_LAMBDA = 0.1
DEFAULT_SAVE_SEC = 24 * 3600
DEFAULT_DELTA_STEPS = 500 # lightweight weight-only save every N steps
DEFAULT_MAX_DELTAS = 5 # keep last N deltas (older pruned after full save)
CKDIR = pathlib.Path("ckpts_expansion")
DEFAULT_PRETRAIN_SOURCES = "OpenTransformer/goddess-crawl,OpenTransformer/agillm-crawl-data,OpenTransformer/web-crawl-2026,OpenTransformer/web-crawl-clean-v2,OpenTransformer/scraped-web-data,OpenTransformer/turbo-crawl,OpenTransformer/sft-data-clean,OpenTransformer/web-crawl-v1,HuggingFaceFW/fineweb,wikimedia/wikipedia:20231101.en,allenai/c4:en"
DEFAULT_AFTER_SFT_SOURCES = "mlabonne/opc-sft-stage2-chat,HuggingFaceH4/ultrachat_200k"
DEFAULT_AFTER_SFT_BLOCK = 1122
# ───────────────────────── UK Time Helper ─────────────────────────
def get_uk_time() -> str:
utc_now = datetime.now(timezone.utc)
year = utc_now.year
march_last = datetime(year, 3, 31, 1, 0, tzinfo=timezone.utc)
while march_last.weekday() != 6:
march_last = march_last.replace(day=march_last.day - 1)
oct_last = datetime(year, 10, 31, 1, 0, tzinfo=timezone.utc)
while oct_last.weekday() != 6:
oct_last = oct_last.replace(day=oct_last.day - 1)
if march_last <= utc_now < oct_last:
uk_offset = 1
tz_name = "BST"
else:
uk_offset = 0
tz_name = "GMT"
from datetime import timedelta
uk_time = utc_now + timedelta(hours=uk_offset)
return uk_time.strftime(f'%Y-%m-%d %H:%M:%S {tz_name}')
# ───────────────────────── Utilities ─────────────────────────
def rng_state():
if DEV.type == "cuda":
try:
return torch.cuda.get_rng_state(DEV)
except TypeError:
return torch.cuda.get_rng_state()
return torch.get_rng_state()
def _is_probably_ckpt(path: pathlib.Path) -> bool:
try:
return path.is_file() and path.suffix == ".pt" and not path.name.endswith(".pt.tmp") and path.stat().st_size > (1<<20)
except Exception:
return False
def _resolve_ckpt(path: pathlib.Path) -> pathlib.Path | None:
try:
if path.is_dir():
cands = sorted([p for p in path.glob("*.pt") if _is_probably_ckpt(p)],
key=lambda p: p.stat().st_mtime, reverse=True)
return cands[0] if cands else None
if path.suffix == ".tmp":
solid = path.with_suffix("")
return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent)
return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent)
except Exception:
return None
def _try_load(path: pathlib.Path, map_location="cpu"):
try:
return torch.load(path, map_location="cpu")
except Exception as e:
print(f"[ckpt-skip] {path} not usable: {e}")
return None
def _prune_checkpoints(save_dir: pathlib.Path, phase_name: str, max_ckpts: int):
if max_ckpts is None or max_ckpts <= 0:
return
try:
pattern = f"{phase_name}_step*.pt"
ckpts = sorted(
[p for p in save_dir.glob(pattern) if _is_probably_ckpt(p)],
key=lambda p: p.stat().st_mtime
)
excess = len(ckpts) - max_ckpts
if excess > 0:
for p in ckpts[:excess]:
try:
p.unlink()
print(f" [prune] deleted old {p.name}")
except Exception:
pass
except Exception as e:
print(f"[ckpt-prune] error: {e}")
def print_expansion_info(cfg: dict, tie_weights: bool = False):
d_k = cfg["d"] // cfg["heads"]
rank = cfg["rank"]
ratio = rank / d_k
regime = "COMPRESSION" if ratio < 1 else ("IDENTITY" if ratio == 1 else "EXPANSION")
tie_str = "YES" if tie_weights else "NO"
print(f"┌─────────────────────────────────────────┐")
print(f"│ TUNEABLE ATTENTION CONFIG │")
print(f"├─────────────────────────────────────────┤")
print(f"│ d_model: {cfg['d']:4d} heads: {cfg['heads']:2d} d_k: {d_k:3d} │")
print(f"│ layers: {cfg['layers']:4d} tie_weights: {tie_str:3s} │")
print(f"│ rank: {rank:4d} ratio: {ratio:.1f}x [{regime:11s}] │")
print(f"└─────────────────────────────────────────┘")
# ───────────────────────── AMP helper ─────────────────────────
try:
from torch.amp import autocast as _ac, GradScaler
except ImportError:
from torch.cuda.amp import autocast as _ac, GradScaler
def _auto_amp_dtype():
if DEV.type == "cuda":
try:
if torch.cuda.is_bf16_supported(): return torch.bfloat16
return torch.float16
except Exception: return torch.float16
return torch.float32
def amp(enabled: bool):
return nullcontext() if not (enabled and DEV.type == "cuda") else _ac(device_type="cuda", dtype=_auto_amp_dtype())
# ───────────────────────── Chat & Data Stream ─────────────────────────
def _coerce_role(r: str) -> str:
r = (r or "").lower()
if r in {"user", "human", "customer"}: return "user"
if r in {"assistant", "gpt", "bot"}: return "assistant"
if r in {"system", "context"}: return "system"
return r or "user"
def _render_chat_text_from_ex(ex: dict, messages_key: str, add_generation_prompt: bool) -> Optional[str]:
msgs = ex.get(messages_key)
if msgs is None:
for alt in ("conversations", "dialog", "turns"):
if isinstance(ex.get(alt), list):
msgs = ex[alt]; break
if isinstance(msgs, list) and msgs and isinstance(msgs[0], dict):
try:
norm = []
for m in msgs:
role = _coerce_role(m.get("role", "")); content = m.get("content", m.get("text", ""))
if not isinstance(content, str): continue
norm.append({"role": role, "content": content})
if not norm: return None
return tok.apply_chat_template(norm, tokenize=False, add_generation_prompt=add_generation_prompt)
except Exception: return None
for a, b in (("prompt", "response"), ("instruction", "output"), ("question", "answer")):
if isinstance(ex.get(a), str) and isinstance(ex.get(b), str):
return f"User: {ex[a]}\nAssistant: {ex[b]}"
return None
def _open_stream_one(ds_name: str, seed: int, streaming: bool = True):
dc = DownloadConfig(max_retries=5, use_etag=True, resume_download=True)
if ":" in ds_name: base, config = ds_name.split(":", 1)
else: base, config = ds_name, None
if not streaming:
print(f"[download] Downloading {ds_name} (non-streaming)...")
if base == "json":
data_files = {"train": config}
ds = load_dataset("json", data_files=data_files, split="train", streaming=streaming, download_config=dc)
else:
ds = load_dataset(base, config, split="train", streaming=streaming, download_config=dc) if config else \
load_dataset(base, split="train", streaming=streaming, download_config=dc)
if streaming:
return iter(ds.shuffle(buffer_size=1000, seed=seed))
else:
print(f"[download] Got {len(ds):,} examples. Shuffling...")
ds = ds.shuffle(seed=seed)
return iter(ds)
def token_stream(ds_names: str, target: int, seed: int = 42,
chat: bool = False, chat_messages_key: str = "messages",
sft_add_generation_prompt: bool = False, dataset_field_text: str = "text",
streaming: bool = True):
ds_names = get_hot_datasets(ds_names) # HOT LOAD
sources = [s.strip() for s in ds_names.split(",") if s.strip()]
if not sources: return
src_idx = 0; emitted = 0; it = None; attempts = 0; backoff_base = 2.0
while emitted < target:
try:
if it is None: it = _open_stream_one(sources[src_idx], seed, streaming=streaming)
ex = next(it)
text = None
if isinstance(ex, dict):
if chat:
text = _render_chat_text_from_ex(ex, chat_messages_key, sft_add_generation_prompt)
if text is None:
if dataset_field_text and isinstance(ex.get(dataset_field_text), str):
text = ex[dataset_field_text]
elif isinstance(ex.get("text"), str):
text = ex["text"]
if not isinstance(text, str):
attempts = 0; continue
enc = tok.encode(text)
if EOS is not None and (len(enc) == 0 or enc[-1] != EOS):
enc = enc + [EOS]
for t in enc:
yield t
emitted += 1
if emitted >= target: return
attempts = 0
except StopIteration:
it = None; src_idx = (src_idx + 1) % len(sources)
except Exception as e:
attempts += 1
sleep_s = min(60.0, backoff_base ** min(attempts, 6))
print(f"[stream-retry] {sources[src_idx]} error: {type(e).__name__}, sleeping {sleep_s:.1f}s")
time.sleep(sleep_s); it = None
if attempts % 5 == 0 and len(sources) > 1:
src_idx = (src_idx + 1) % len(sources)
# ───────────────────────── ALiBi ─────────────────────────
_ALIBI_CACHE_MAX = 64
_alibi_slopes_cache: Dict[int, torch.Tensor] = {}
_alibi_bias_cache: Dict[Tuple[int, int], torch.Tensor] = {}
def _alibi_slopes(n_heads: int):
cached = _alibi_slopes_cache.get(n_heads)
if cached is not None:
return cached
def pow2slopes(n):
start = 2 ** (-2 ** -(math.log2(n) - 3))
ratio = start
return [start * (ratio ** i) for i in range(n)]
if math.log2(n_heads).is_integer(): vals = pow2slopes(n_heads)
else:
closest = 2 ** math.floor(math.log2(n_heads))
vals = pow2slopes(closest)
extra = pow2slopes(2 * closest)
vals += extra[0::2][: n_heads - closest]
t = torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
_alibi_slopes_cache[n_heads] = t
return t
def alibi_bias(n_heads: int, n_tokens: int):
key = (n_heads, n_tokens)
cached = _alibi_bias_cache.get(key)
if cached is not None:
return cached
i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
dist = (j - i).clamp_min(0)
bias = -_alibi_slopes(n_heads) * dist
if len(_alibi_bias_cache) < _ALIBI_CACHE_MAX:
_alibi_bias_cache[key] = bias
return bias
# ───────────────────────── Theorem-derived helpers ─────────────────────────
def _effective_psd_rank(M: torch.Tensor, eps_ratio: float = 1e-6) -> int:
"""Numerical rank of a symmetric PSD matrix via eigenvalue threshold.
Used to detect post-train how many of U's r columns are actually exercised.
"""
eigvals = torch.linalg.eigvalsh(M.detach().to(torch.float64))
if eigvals.numel() == 0:
return 0
threshold = eps_ratio * eigvals.abs().max().clamp_min(torch.finfo(eigvals.dtype).tiny)
return int((eigvals > threshold).sum().item())
def compact_attention_U(mha: "TuneableAttentionMHA", eps_ratio: float = 1e-6) -> Dict[str, int]:
"""Constructive proof of CONSEQUENCE 2 (see file header).
Replace mha.U in R^{dk x r} with U' in R^{dk x r*} satisfying U' U'^T == U U^T,
where r* is the numerical rank of U U^T. Attention scores (and therefore the
network output and gradients flowing back through it) are unchanged.
Construction:
M = U U^T # symmetric PSD, dk x dk
M = Q diag(lambda) Q^T # eigendecomposition
keep i where lambda_i > eps_ratio * max(lambda)
U' = Q[:, keep] * sqrt(lambda[keep])
Returns {"r_old": r, "r_new": r*, "params_saved": dk*(r - r*)}.
"""
with torch.no_grad():
U = mha.U.detach()
dk, r_old = U.shape
if r_old <= 1:
return {"r_old": r_old, "r_new": r_old, "params_saved": 0}
# eigendecompose in fp64 for numerical stability
M = U.to(torch.float64) @ U.to(torch.float64).T
eigvals, eigvecs = torch.linalg.eigh(M)
threshold = eps_ratio * eigvals.abs().max().clamp_min(1e-30)
keep_mask = eigvals > threshold
r_new = int(keep_mask.sum().item())
if r_new == 0:
r_new = 1 # don't degenerate to empty
keep_mask = torch.zeros_like(keep_mask, dtype=torch.bool); keep_mask[-1] = True
if r_new >= r_old:
return {"r_old": r_old, "r_new": r_old, "params_saved": 0}
# U' = eigvecs[:, keep] * sqrt(lambda[keep])
kept_vals = eigvals[keep_mask].clamp_min(0.0).sqrt() # [r_new]
kept_vecs = eigvecs[:, keep_mask] # [dk, r_new]
U_new = (kept_vecs * kept_vals).to(U.dtype) # [dk, r_new]
# Replace parameter and update attribute used by forward to choose path
mha.U = nn.Parameter(U_new.contiguous())
mha.r = r_new
return {"r_old": r_old, "r_new": r_new, "params_saved": dk * (r_old - r_new)}
def compact_checkpoint_U(src_path: str, dst_path: str, eps_ratio: float = 1e-6) -> Dict[str, Any]:
"""Load a checkpoint, compact every attention layer's U, save to dst.
The result is functionally identical to the original (within fp32 noise) but
smaller. Use for deploying a trained model where you don't need the over-
parameterized U for further training.
"""
ck = torch.load(src_path, map_location="cpu", weights_only=False)
cfg = ck.get("cfg")
if cfg is None:
raise ValueError(f"checkpoint {src_path} has no 'cfg'; cannot reconstruct model")
tie = ck.get("tie_weights", False)
core = Encoder(cfg, tie_weights=tie)
ar_h = ARHead(cfg["d"], tie_weights=tie, embedding_weight=core.emb.weight if tie else None)
sat_h = SATHead(cfg["d"], mode="var")
core.load_state_dict(ck["core"])
ar_h.load_state_dict(ck["ar"])
sat_h.load_state_dict(ck["sat"])
total_old = 0; effective_new = 0; report = []
for li, blk in enumerate(core.blocks):
before = blk.mha.U.numel()
info = compact_attention_U(blk.mha, eps_ratio=eps_ratio)
after = blk.mha.U.numel()
total_old += before; effective_new += after
report.append({"layer": li, **info})
# Encoder cfg has one rank value shared by every layer. Keep checkpoint loadable
# by padding lower-rank layers up to the max compacted rank with zero columns.
target_rank = max(blk.mha.U.shape[1] for blk in core.blocks)
stored_new = 0
for blk in core.blocks:
U = blk.mha.U.detach()
if U.shape[1] < target_rank:
pad = torch.zeros(U.shape[0], target_rank - U.shape[1], dtype=U.dtype, device=U.device)
blk.mha.U = nn.Parameter(torch.cat([U, pad], dim=1).contiguous())
blk.mha.r = target_rank
stored_new += blk.mha.U.numel()
# Re-package checkpoint with shrunken U tensors
out = dict(ck)
out["cfg"] = dict(cfg)
out["cfg"]["rank"] = target_rank
out["core"] = core.state_dict()
out["compacted"] = True
out["compact_eps_ratio"] = eps_ratio
out["compact_target_rank"] = target_rank
out["compact_report"] = report
# Optimizer state and scaler are no longer dim-compatible; drop to make this
# a pure inference-ready checkpoint.
out.pop("opt", None); out.pop("scaler", None)
torch.save(out, dst_path)
return {
"src": src_path, "dst": dst_path,
"U_params_old": total_old,
"U_params_new": stored_new,
"U_params_effective_new": effective_new,
"U_params_saved": total_old - stored_new,
"target_rank": target_rank,
"by_layer": report,
}
# ───────────────────────── Model components ─────────────────────────
class TuneableAttentionMHA(nn.Module):
def __init__(self, d: int, h: int, r: int, use_relpos: bool = True, use_sdpa: bool = True):
super().__init__()
assert d % h == 0
self.h, self.dk, self.r = h, d // h, r
self.use_relpos = use_relpos
self.use_sdpa = bool(use_sdpa and hasattr(F, "scaled_dot_product_attention"))
self.q = nn.Linear(d, d, bias=False)
self.k = nn.Linear(d, d, bias=False)
self.v = nn.Linear(d, d, bias=False)
self.U = nn.Parameter(torch.randn(self.dk, r))
nn.init.orthogonal_(self.U)
self.proj = nn.Linear(h * self.dk, d, bias=False)
self.drop = nn.Dropout(0.1)
def _proj_qk(self, x):
B, N, _ = x.shape
return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
def _reshape_qk_base(self, x):
B, N, _ = x.shape
return x.view(B, N, self.h, self.dk).transpose(1, 2)
def _reshape_v(self, x):
B, N, _ = x.shape
return x.view(B, N, self.h, self.dk).transpose(1, 2)
def forward(self, x, mask=None, rel_bias_tokens=None, kv_cache=None, use_cache=False):
q_len = x.size(1)
v_new = self._reshape_v(self.v(x))
if self.r > self.dk:
q = self._reshape_qk_base(self.q(x))
k_new = self._reshape_qk_base(self.k(x))
if kv_cache is None:
k, v = k_new, v_new
else:
k_cached, v_cached = kv_cache
if use_cache:
k = torch.cat([k_cached, k_new], dim=2)
v = torch.cat([v_cached, v_new], dim=2)
else:
k, v = k_new, v_new
# Exact identity:
# (q @ U) @ (k @ U).T == q @ (U @ U.T) @ k.T.
# For expansion ranks this keeps the expensive NxN score matmul in dk.
metric = self.U @ self.U.T
q_scores = q @ metric
else:
q = self._proj_qk(self.q(x))
k_new = self._proj_qk(self.k(x))
if kv_cache is None:
k, v = k_new, v_new
else:
k_cached, v_cached = kv_cache
if use_cache:
k = torch.cat([k_cached, k_new], dim=2)
v = torch.cat([v_cached, v_new], dim=2)
else:
k, v = k_new, v_new
q_scores = q
attn_bias = None
if self.use_relpos and rel_bias_tokens is not None:
attn_bias = alibi_bias(self.h, rel_bias_tokens)[:, :, -q_len:, :]
if mask is not None:
attn_bias = mask if attn_bias is None else (attn_bias + mask)
if self.use_sdpa:
z = F.scaled_dot_product_attention(
q_scores, k, v,
attn_mask=attn_bias,
dropout_p=0.0,
is_causal=False,
scale=1.0 / math.sqrt(self.dk),
)
else:
att = (q_scores @ k.transpose(-1, -2)) / math.sqrt(self.dk)
if attn_bias is not None:
att = att + attn_bias
z = att.softmax(-1) @ v
z = z.transpose(1, 2).reshape(x.size(0), x.size(1), -1)
out = self.drop(self.proj(z))
return (out, (k, v)) if use_cache else out
class Block(nn.Module):
def __init__(self, d: int, h: int, r: int, use_sdpa: bool = True):
super().__init__()
self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
self.mha = TuneableAttentionMHA(d, h, r, use_sdpa=use_sdpa)
self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
def forward(self, x, mask, kv=None, use_cache=False, total_seq_len=None):
if use_cache:
y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=total_seq_len, kv_cache=kv, use_cache=True)
x = x + y
x = x + self.ff(self.ln2(x))
return x, new_kv
n = x.size(1)
x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n)
return x + self.ff(self.ln2(x))
class Encoder(nn.Module):
def __init__(self, cfg, tie_weights: bool = False, use_sdpa: bool = True):
super().__init__()
d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"]
self.emb = nn.Embedding(VOCAB, d)
self.blocks = nn.ModuleList([Block(d, h, r, use_sdpa=use_sdpa) for _ in range(l)])
self.ln = nn.LayerNorm(d)
self.tie_weights = tie_weights
self.use_sdpa = bool(use_sdpa and hasattr(F, "scaled_dot_product_attention"))
def forward(self, ids, mask, kv_caches=None, use_cache=False, total_seq_len=None):
x = self.emb(ids)
if not use_cache:
for blk in self.blocks:
x = blk(x, mask)
return self.ln(x)
new_kvs = []
for i, blk in enumerate(self.blocks):
kv = kv_caches[i] if kv_caches else None
x, kv_out = blk(x, mask, kv, use_cache=True, total_seq_len=total_seq_len)
new_kvs.append(kv_out)
return self.ln(x), new_kvs
class ARHead(nn.Module):
def __init__(self, d, tie_weights: bool = False, embedding_weight: nn.Parameter = None):
super().__init__()
self.tie_weights = tie_weights
if tie_weights and embedding_weight is not None:
self.proj = nn.Linear(d, VOCAB, bias=False)
self.proj.weight = embedding_weight
else:
self.proj = nn.Linear(d, VOCAB)
def forward(self, h):
return self.proj(h)
class SATHead(nn.Module):
def __init__(self, d, mode="var"):
super().__init__()
self.proj = nn.Linear(d, VOCAB)
self.gate = nn.Linear(d, 2) if mode == "var" else None
def forward(self, h_last):
return self.proj(h_last), (self.gate(h_last[:, 0]) if self.gate else None)
# ───────────────────────── Masks (cached) ─────────────────────────
_MASK_CACHE_MAX = 64
_causal_mask_cache: Dict[int, torch.Tensor] = {}
_sat_mask_cache: Dict[Tuple[int, int], torch.Tensor] = {}
def causal_mask(n):
if n in _causal_mask_cache:
return _causal_mask_cache[n]
m = torch.triu(torch.full((1, 1, n, n), float("-inf"), device=DEV), 1)
if len(_causal_mask_cache) < _MASK_CACHE_MAX:
_causal_mask_cache[n] = m
return m
def sat_mask(n, block=SAT_BLOCK):
key = (n, block)
if key in _sat_mask_cache:
return _sat_mask_cache[key]
idx = torch.arange(n, device=DEV)
grp = idx.unsqueeze(0) // block
allow = (grp.T == grp) | (grp.T > grp)
m = torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
if len(_sat_mask_cache) < _MASK_CACHE_MAX:
_sat_mask_cache[key] = m
return m
def sat_mask_cached(new_len: int, cached_len: int, block=SAT_BLOCK):
total_len = cached_len + new_len
mask = torch.zeros((1, 1, new_len, total_len), device=DEV)
return mask
def causal_mask_cached(new_len: int, cached_len: int):
total_len = cached_len + new_len
mask = torch.zeros((1, 1, new_len, total_len), device=DEV)
if new_len > 1:
mask[:, :, :, cached_len:] = torch.triu(
torch.full((new_len, new_len), float("-inf"), device=DEV), 1
).view(1, 1, new_len, new_len)
return mask
# ───────────────────────── Checkpoint helpers ─────────────────────────
# ───────────────────────── Delta Checkpoints (weight-only, async) ─────────────────────────
_delta_lock = threading.Lock()
_delta_thread: Optional[threading.Thread] = None
def _sha256_file(path: pathlib.Path) -> str:
"""Compute SHA256 of a file for integrity verification."""
h = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(1 << 20), b""):
h.update(chunk)
return h.hexdigest()
def _do_delta_save(tensors: dict, path: pathlib.Path, meta: dict):
"""Background worker: write weight-only checkpoint + checksum."""
try:
path.parent.mkdir(exist_ok=True, parents=True)
tmp = path.with_suffix(path.suffix + ".dtmp")
torch.save({"weights": tensors, **meta}, tmp, _use_new_zipfile_serialization=False)
digest = _sha256_file(tmp)
tmp.replace(path)
# Write sidecar checksum
path.with_suffix(".sha256").write_text(f"{digest} {path.name}\n")
print(f" [delta] saved {path.name} ({digest[:12]}...)")
except Exception as e:
print(f" [delta] FAILED {path.name}: {e}")
def save_delta(core, ar_h, sat_h, step: int, seen_tok: int, save_dir: pathlib.Path, phase_name: str):
"""Save weight-only delta in background thread. Non-blocking."""
global _delta_thread
# Wait for any previous delta write to finish
if _delta_thread is not None and _delta_thread.is_alive():
_delta_thread.join(timeout=60)
# Snapshot weights to CPU (detach from GPU graph)
with _delta_lock:
tensors = {
"core": {k: v.detach().cpu() for k, v in core.state_dict().items()},
"ar": {k: v.detach().cpu() for k, v in ar_h.state_dict().items()},
"sat": {k: v.detach().cpu() for k, v in sat_h.state_dict().items()},
}
meta = {"step": step, "seen_tok": seen_tok, "wall_time": time.time(), "delta": True, **_tokenizer_checkpoint_meta()}
path = save_dir / f"{phase_name}_delta_step{step:08d}.pt"
_delta_thread = threading.Thread(target=_do_delta_save, args=(tensors, path, meta), daemon=True)
_delta_thread.start()
def _prune_deltas(save_dir: pathlib.Path, phase_name: str, max_deltas: int):
"""Keep only the most recent max_deltas delta files."""
if max_deltas is None or max_deltas <= 0:
return
try:
pattern = f"{phase_name}_delta_step*.pt"
deltas = sorted(
[p for p in save_dir.glob(pattern) if p.stat().st_size > 0],
key=lambda p: p.stat().st_mtime
)
excess = len(deltas) - max_deltas
if excess > 0:
for p in deltas[:excess]:
try:
p.unlink()
sha = p.with_suffix(".sha256")
if sha.exists(): sha.unlink()
print(f" [delta-prune] deleted {p.name}")
except Exception:
pass
except Exception as e:
print(f" [delta-prune] error: {e}")
def load_delta(path: pathlib.Path, core, ar_h, sat_h):
"""Load weight-only delta. Returns (step, seen_tok) or raises."""
# Verify checksum if sidecar exists
sha_path = path.with_suffix(".sha256")
if sha_path.exists():
expected = sha_path.read_text().split()[0]
actual = _sha256_file(path)
if expected != actual:
raise ValueError(f"Checksum mismatch for {path.name}: expected {expected[:12]}... got {actual[:12]}...")
print(f" [delta] checksum OK for {path.name}")
ck = torch.load(path, map_location="cpu", weights_only=False)
if not ck.get("delta"):
raise ValueError(f"{path.name} is not a delta checkpoint")
core.load_state_dict(ck["weights"]["core"])
ar_h.load_state_dict(ck["weights"]["ar"])
sat_h.load_state_dict(ck["weights"]["sat"])
return ck.get("step", 0), ck.get("seen_tok", 0)
def _flush_delta():
"""Wait for any in-flight delta save to complete."""
global _delta_thread
if _delta_thread is not None and _delta_thread.is_alive():
print(" [delta] flushing in-flight write...")
_delta_thread.join(timeout=120)
def _tokenizer_checkpoint_meta() -> Dict[str, Any]:
meta = {
"tokenizer_backend": TOKENIZER_BACKEND,
"tokenizer_id": TOKENIZER_ID,
"tokenizer_vocab": VOCAB,
"tokenizer_eos": EOS,
}
if isinstance(tok, TiktokenTokenizerAdapter):
meta["tiktoken_encoding"] = tok.encoding_name
meta["tiktoken_version"] = __import__("tiktoken").__version__
return meta
if hasattr(tok, "backend_tokenizer"):
try:
meta["tokenizer_json"] = tok.backend_tokenizer.to_str()
except Exception as e:
meta["tokenizer_json_error"] = str(e)
try:
meta["transformers_version"] = __import__("transformers").__version__
meta["tokenizers_version"] = __import__("tokenizers").__version__
except Exception:
pass
return meta
def _warn_tokenizer_checkpoint_mismatch(ck: Dict[str, Any]):
ck_backend = ck.get("tokenizer_backend")
ck_id = ck.get("tokenizer_id")
ck_vocab = ck.get("tokenizer_vocab")
if ck_backend and ck_backend != TOKENIZER_BACKEND:
print(f"[tokenizer] WARNING: checkpoint backend={ck_backend}, runtime backend={TOKENIZER_BACKEND}")
if ck_id and ck_id != TOKENIZER_ID and TOKENIZER_BACKEND not in {"tiktoken", "openai"}:
print(f"[tokenizer] WARNING: checkpoint tokenizer={ck_id}, runtime tokenizer={TOKENIZER_ID}")
if ck_vocab and ck_vocab != VOCAB:
print(f"[tokenizer] WARNING: checkpoint vocab={ck_vocab:,}, runtime vocab={VOCAB:,}; embedding/head shapes may not load")
def _restore_tokenizer_from_checkpoint(ck: Dict[str, Any]):
_warn_tokenizer_checkpoint_mismatch(ck)
if "tokenizer_json" in ck and hasattr(tok, "backend_tokenizer"):
try:
from tokenizers import Tokenizer as _Tokenizer
tok.backend_tokenizer = _Tokenizer.from_str(ck["tokenizer_json"])
print("[tokenizer] Restored Hugging Face tokenizer JSON from checkpoint")
except Exception as e:
print(f"[tokenizer] WARNING: could not restore tokenizer from checkpoint: {e}")
if "transformers_version" in ck:
import transformers as _tf
if ck["transformers_version"] != _tf.__version__:
print(f"[tokenizer] WARNING: checkpoint saved with transformers={ck['transformers_version']}, now running {_tf.__version__}")
def save_ckpt(path: pathlib.Path, core, ar_h, sat_h, opt, scaler, meta):
path.parent.mkdir(exist_ok=True, parents=True)
tmp = path.with_suffix(path.suffix + ".tmp")
state = {
"core": core.state_dict(), "ar": ar_h.state_dict(), "sat": sat_h.state_dict(),
"opt": opt.state_dict(), "scaler": scaler.state_dict(),
"cfg": meta.get("cfg"), **_tokenizer_checkpoint_meta(),
"tie_weights": meta.get("tie_weights", False),
**{k: v for k, v in meta.items() if k not in ("cfg", "tie_weights")}
}
torch.save(state, tmp, _use_new_zipfile_serialization=False)
tmp.replace(path)
(path.parent / "latest.json").write_text(json.dumps({"path": str(path), "step": meta["step"]}))
print(f"\n✓ saved checkpoint {path.name}")
def load_ckpt(path, core, ar_h, sat_h, opt, scaler):
p = _resolve_ckpt(path) or path
ck = _try_load(p, map_location="cpu")
if ck is None: raise FileNotFoundError(f"No valid checkpoint at {p}")
core.load_state_dict(ck["core"])
ar_h.load_state_dict(ck["ar"])
sat_h.load_state_dict(ck["sat"])
opt.load_state_dict(ck["opt"])
scaler.load_state_dict(ck["scaler"])
_restore_tokenizer_from_checkpoint(ck)
return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time())
def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None):
p = _resolve_ckpt(path) or path
if not p.exists(): return 0
ck = _try_load(p, map_location="cpu")
if ck is None: return 0
sd = ck.get(key, ck) if key else ck
if isinstance(sd, dict) and "state_dict" in sd: sd = sd["state_dict"]
tgt_sd = tgt.state_dict()
filt = {k: v for k, v in sd.items() if k in tgt_sd and v.shape == tgt_sd[k].shape}
if filt: tgt.load_state_dict(filt, strict=False)
return len(filt)
def infer_cfg_from_ckpt(path: pathlib.Path):
"""Return (cfg, ckpt_meta) where ckpt_meta exposes side-channel flags like
tie_weights that aren't part of cfg itself. Both may be None on miss."""
p = _resolve_ckpt(path) or path
if not p.exists(): return None, None
sd = _try_load(p, map_location="cpu")
if sd is None: return None, None
cfg = dict(sd["cfg"]) if "cfg" in sd else None
meta = {"tie_weights": bool(sd.get("tie_weights", False))}
return cfg, meta
# ───────────────────────── Training Logic ─────────────────────────
def _parse_grow_plan(s: str) -> List[int]:
return sorted(set([int(x.strip()) for x in s.split(",") if x.strip() and int(x.strip()) >= 128]))
def _count_enabled_params(*modules) -> int:
seen_data_ptrs = set()
total = 0
for m in modules:
if m is None:
continue
for p in m.parameters():
if p.data_ptr() not in seen_data_ptrs:
seen_data_ptrs.add(p.data_ptr())
total += p.numel()
return total
def _params_excluding(params, *exclude_modules):
# With --tie_weights the LM-head weight IS core.emb.weight (same Parameter
# object), so it must not appear in two optimizer groups.
excluded_ids = set()
for m in exclude_modules:
if m is None:
continue
for p in m.parameters():
excluded_ids.add(id(p))
return [p for p in params if id(p) not in excluded_ids]
def _phase_freeze(core: nn.Module, *, freeze_core: bool, unfreeze_ln: bool, train_emb: bool):
for p in core.parameters(): p.requires_grad = not freeze_core
if freeze_core:
if unfreeze_ln:
for blk in core.blocks:
for p in blk.ln1.parameters(): p.requires_grad = True
for p in blk.ln2.parameters(): p.requires_grad = True
for p in core.ln.parameters(): p.requires_grad = True
if train_emb:
for p in core.emb.parameters(): p.requires_grad = True
def _sat_window_starts(n_tokens: int, window: int, stride: int, max_windows: int, device=None) -> torch.Tensor:
"""Start offsets for SAT future-window training.
A source window [s, s+window) predicts the next window
[s+window, s+2*window), matching SAT inference's "last window drafts next
window" contract.
"""
stride = max(1, int(stride))
max_start = n_tokens - (2 * window)
if max_start < 0:
return torch.empty(0, device=device, dtype=torch.long)
starts = torch.arange(0, max_start + 1, stride, device=device, dtype=torch.long)
if max_windows and max_windows > 0 and starts.numel() > max_windows:
pick = torch.linspace(0, starts.numel() - 1, max_windows, device=device).round().long()
starts = starts[pick].unique(sorted=True)
return starts
def _sat_future_window_loss(args, h_sat, ids, sat_h, ce_tok, ce_gate):
B, N, D = h_sat.shape
starts = _sat_window_starts(
N, SAT_BLOCK,
getattr(args, "sat_loss_stride", 1),
getattr(args, "sat_loss_max_windows", 0),
device=ids.device,
)
if starts.numel() == 0:
return None, 0
offset = torch.arange(SAT_BLOCK, device=ids.device)
ctx_idx = starts[:, None] + offset[None, :]
tgt_idx = starts[:, None] + SAT_BLOCK + offset[None, :]
ctx = h_sat[:, ctx_idx, :].reshape(B * starts.numel(), SAT_BLOCK, D).contiguous()
tgt = ids[:, tgt_idx].reshape(B * starts.numel(), SAT_BLOCK).contiguous()
logits_sat, gate = sat_h(ctx)
loss_sat = ce_tok(logits_sat.reshape(-1, VOCAB), tgt.reshape(-1))
if gate is not None:
emit_full_window = torch.ones(ctx.size(0), device=ids.device, dtype=torch.long)
loss_sat = loss_sat + EMIT_LAMBDA * ce_gate(gate, emit_full_window)
return loss_sat, int(starts.numel())
def _train_phase(
args, phase_name: str,
core, ar_h, sat_h, opt, scaler,
start_step, seen_tok, resume_wall_time,
cfg, source, steps, block_size, batch_size,
chat_cfg: dict,
max_ckpts: int,
target_tokens_override: Optional[int] = None,
tie_weights: bool = False,
streaming: bool = True
):
BLOCK = block_size
BATCH = batch_size
if target_tokens_override is not None:
target_tokens = target_tokens_override
else:
ratio = 51.2 if args.chilla_max_double else 25
param_count = _count_enabled_params(core, ar_h, sat_h)
target_tokens = int(ratio * param_count)
if steps:
phase_target_tokens = steps * BLOCK * BATCH
total_tokens_needed = seen_tok + phase_target_tokens
else:
total_tokens_needed = target_tokens
if total_tokens_needed <= seen_tok:
print(f"[{phase_name}] target {total_tokens_needed} already reached.")
return start_step, seen_tok, resume_wall_time
stream = token_stream(
source, total_tokens_needed, seed=42,
chat=chat_cfg.get("chat", False),
chat_messages_key=chat_cfg.get("key", "messages"),
sft_add_generation_prompt=chat_cfg.get("gen_prompt", False),
dataset_field_text=chat_cfg.get("text_field", "text"),
streaming=streaming
)
ce_tok = nn.CrossEntropyLoss(label_smoothing=0.1)
ce_gate = nn.CrossEntropyLoss()
pbar = SafeProgress(total=total_tokens_needed, initial=seen_tok, unit="tok")
grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else []
buf: list[int] = []
batch_accum: list[list[int]] = []
step = start_step
steps_since_last_grow = 0
oom_retries = 0
MAX_OOM_RETRIES = 2
now_wall = time.time()
last_save_mono = time.monotonic() - (now_wall - (resume_wall_time or now_wall))
last_delta_step = start_step
print(f"[{phase_name}] Starting. Goal: {total_tokens_needed:,} tokens. Batch={BATCH}, Block={BLOCK}")
print(f"[{phase_name}] AR_ONLY={args.ar_only}, TIE_WEIGHTS={tie_weights}, STREAMING={streaming}")
if not args.ar_only:
sat_stride = getattr(args, "sat_loss_stride", 1)
max_windows = getattr(args, "sat_loss_max_windows", 0)
max_windows_s = "all" if not max_windows else str(max_windows)
print(f"[{phase_name}] SAT future-window loss: window={SAT_BLOCK}, stride={sat_stride}, max_windows={max_windows_s}")
while seen_tok < total_tokens_needed:
try:
while len(buf) < BLOCK:
buf.append(next(stream))
except StopIteration:
break
seq = buf[:BLOCK]
buf = buf[BLOCK:]
batch_accum.append(seq)
if len(batch_accum) < BATCH:
continue
ids = torch.tensor(batch_accum, device=DEV)
batch_accum = []
tgt_ar = ids.clone()
try:
with amp(args.amp):
h_ar = core(ids, causal_mask(ids.size(1)))
logits_ar = ar_h(h_ar)[:, :-1]
loss_ar = ce_tok(logits_ar.reshape(-1, VOCAB), tgt_ar[:, 1:].reshape(-1))
if args.ar_only:
loss = loss_ar
else:
h_sat = core(ids, sat_mask(ids.size(1)))
loss_sat, sat_windows = _sat_future_window_loss(args, h_sat, ids, sat_h, ce_tok, ce_gate)
if loss_sat is None:
loss_sat = h_sat.sum() * 0.0
sat_windows = 0
loss = loss_ar + loss_sat
scaler.scale(loss).backward()
scaler.unscale_(opt)
nn.utils.clip_grad_norm_(core.parameters(), 1.0)
scaler.step(opt)
scaler.update()
opt.zero_grad(set_to_none=True)
except RuntimeError as e:
msg = str(e).lower()
if "out of memory" in msg or "cuda error" in msg:
batch_accum = []
opt.zero_grad(set_to_none=True)
if DEV.type == "cuda":
torch.cuda.empty_cache()
torch.cuda.synchronize()
oom_retries += 1
if oom_retries <= MAX_OOM_RETRIES:
print(f"\n[{phase_name} OOM] Retry {oom_retries}/{MAX_OOM_RETRIES} at Batch={BATCH}, clearing VRAM...")
time.sleep(2)
continue
oom_retries = 0
if BATCH > 1:
print(f"\n[{phase_name} OOM] Reducing Batch: {BATCH} -> {BATCH - 1} (after {MAX_OOM_RETRIES} retries)")
BATCH -= 1
time.sleep(2)
else:
new_block = max(128, BLOCK // 2)
print(f"\n[{phase_name} OOM] Reducing Block: {BLOCK} -> {new_block}")
BLOCK = new_block
time.sleep(2)
steps_since_last_grow = 0
continue
raise
step += 1
# Periodic tokenizer spot-check: verify training data has spaces
if step % 1000 == 0:
try:
sample_text = tok.decode(ids[0][:50].tolist(), skip_special_tokens=True)
if len(sample_text) > 20 and " " not in sample_text:
print(f"\n[tokenizer] ALERT step {step}: decoded batch has NO SPACES!")
print(f" Sample: {repr(sample_text[:80])}")
print(" Check transformers version!")
except Exception:
pass
oom_retries = 0
toks_processed = BLOCK * BATCH
seen_tok += toks_processed
pbar.update(toks_processed)
# Pass detached GPU tensor; SafeProgress.item()s it lazily only when it actually
# prints (~once per ~1M tokens), avoiding a CUDA sync every training step.
if args.ar_only:
pbar.set_postfix(loss=loss.detach(), B=BATCH, L=BLOCK)
else:
pbar.set_postfix(loss=loss.detach(), B=BATCH, L=BLOCK, satw=sat_windows)
if args.save_every_sec > 0:
now_mono = time.monotonic()
if now_mono - last_save_mono >= args.save_every_sec:
ck_name = f"{phase_name}_step{step:08d}.pt"
_flush_delta() # wait for any in-flight delta before full save
_prune_checkpoints(pathlib.Path(args.save_dir), phase_name, max_ckpts)
save_ckpt(pathlib.Path(args.save_dir) / ck_name, core, ar_h, sat_h, opt, scaler,
meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(), "tie_weights": tie_weights})
last_save_mono = now_mono
# Prune old deltas after a full save (they're superseded)
_prune_deltas(pathlib.Path(args.save_dir), phase_name, args.delta_max_keep)
last_delta_step = step # reset delta counter after full save
# ── Delta checkpoint (step-based, weight-only, async) ──
if args.delta_every_steps > 0 and (step - last_delta_step) >= args.delta_every_steps:
_prune_deltas(pathlib.Path(args.save_dir), phase_name, args.delta_max_keep)
save_delta(core, ar_h, sat_h, step, seen_tok, pathlib.Path(args.save_dir), phase_name)
last_delta_step = step
if args.auto_grow:
steps_since_last_grow += 1
if steps_since_last_grow >= args.grow_every_steps:
steps_since_last_grow = 0
try:
idx = grow_plan.index(BLOCK)
if idx + 1 < len(grow_plan):
BLOCK = grow_plan[idx + 1]
print(f"[{phase_name} Grow] Block -> {BLOCK}")
if DEV.type == "cuda": torch.cuda.empty_cache()
except ValueError:
grow_plan = sorted(set(grow_plan + [BLOCK]))
pbar.close()
_flush_delta() # ensure any in-flight delta completes before final save
save_ckpt(pathlib.Path(args.save_dir) / f"{phase_name}_final.pt", core, ar_h, sat_h, opt, scaler,
meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(), "tie_weights": tie_weights})
return step, seen_tok, time.time()
# ───────────────────────── Main Orchestrator ─────────────────────────
def train(args):
cfg = PRESETS[args.preset].copy()
tie_weights = args.tie_weights
if not args.fresh:
src_probe = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
prev_cfg, prev_meta = infer_cfg_from_ckpt(src_probe)
else:
prev_cfg, prev_meta = None, None
if prev_meta and prev_meta.get("tie_weights") and not args.tie_weights:
# Warm-starting from a tied checkpoint without --tie_weights would
# silently un-tie the LM head and re-train from scratch on it. Inherit.
tie_weights = True
print(f"[warmstart] inheriting tie_weights=True from checkpoint")
print_expansion_info(cfg, tie_weights)
if prev_cfg:
cfg.update({k: v for k, v in prev_cfg.items() if k in cfg})
if args.x2 and prev_cfg.get("layers"): cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2)
if args.rank: cfg["rank"] = args.rank
if args.x2 and not prev_cfg: cfg["layers"] *= 2
print(f"Config: {cfg}")
dk = cfg["d"] // cfg["heads"]
if cfg["rank"] > dk:
print(f"[qk-fold] exact score fold active: rank {cfg['rank']} -> d_k {dk}")
use_sdpa = not getattr(args, "no_sdpa", False)
if use_sdpa:
print("[sdpa] exact fused scaled-dot-product attention enabled")
core = Encoder(cfg, tie_weights=tie_weights, use_sdpa=use_sdpa).to(DEV)
ar_h = ARHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(DEV)
sat_h = SATHead(cfg["d"], mode="var").to(DEV)
total_params = _count_enabled_params(core, ar_h, sat_h)
print(f"Total parameters: {total_params:,}")
if tie_weights:
print(f"{Colors.WARN}[weight-tying] Embedding and LM head share weights{Colors.RESET}")
if not args.fresh:
src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
src = _resolve_ckpt(src)
if src:
loaded = _safe_load_any(src, core, key="core")
_safe_load_any(src, ar_h, key="ar")
_safe_load_any(src, sat_h, key="sat")
if loaded: print(f"Warm-start loaded from {src}")
_phase_freeze(core, freeze_core=args.freeze_core, unfreeze_ln=args.unfreeze_ln, train_emb=args.train_emb)
opt = torch.optim.AdamW([
{"params": [p for p in core.parameters() if p.requires_grad], "lr": args.lr_core},
{"params": _params_excluding(ar_h.parameters(), core), "lr": args.lr_head},
{"params": _params_excluding(sat_h.parameters(), core), "lr": args.lr_head},
])
scaler = GradScaler(enabled=(args.amp and DEV.type == "cuda"))
start_step, seen_tok, last_wall = 0, 0, None
if args.resume_delta and not args.fresh:
delta_step, delta_tok = load_delta(pathlib.Path(args.resume_delta), core, ar_h, sat_h)
start_step, seen_tok, last_wall = delta_step, delta_tok, None
print(f"Resumed from DELTA at step {start_step} (optimizer state reset — momentum rebuilds in ~100 steps)")
elif args.resume and not args.fresh:
start_step, seen_tok, last_wall = load_ckpt(pathlib.Path(args.resume), core, ar_h, sat_h, opt, scaler)
print(f"Resumed from step {start_step}")
# torch.compile AFTER loading checkpoint (key names differ)
if args.compile:
print("[torch.compile] Compiling model...")
core = torch.compile(core, mode="reduce-overhead")
ar_h = torch.compile(ar_h, mode="reduce-overhead")
sat_h = torch.compile(sat_h, mode="reduce-overhead")
print("[torch.compile] Done.")
step, seen_tok, last_wall = _train_phase(
args, "pretrain", core, ar_h, sat_h, opt, scaler,
start_step, seen_tok, last_wall, cfg,
args.source, args.steps,
args.block or DEFAULT_BLOCK,
args.batch_size or DEFAULT_BATCH,
chat_cfg={"chat": args.chat, "key": args.chat_messages_key, "gen_prompt": args.sft_add_generation_prompt, "text_field": args.dataset_field_text},
max_ckpts=args.max_ckpts,
target_tokens_override=args.target_tokens,
tie_weights=tie_weights
)
if (not args.after_sft_source) and (args.after_sft_steps and args.after_sft_steps > 0):
args.after_sft_source = DEFAULT_AFTER_SFT_SOURCES
args.after_sft_chat = True
if args.after_sft_add_generation_prompt is None: args.after_sft_add_generation_prompt = True
if not args.after_sft_block: args.after_sft_block = DEFAULT_AFTER_SFT_BLOCK
if args.after_sft_source and args.after_sft_steps and args.after_sft_steps > 0:
print("\n[Orchestrator] Starting Post-Pretraining SFT Phase...")
_phase_freeze(core,
freeze_core=args.after_sft_freeze_core,
unfreeze_ln=args.after_sft_unfreeze_ln,
train_emb=args.after_sft_train_emb)
opt = torch.optim.AdamW([
{"params": [p for p in core.parameters() if p.requires_grad], "lr": args.after_sft_lr_core or args.lr_core},
{"params": _params_excluding(ar_h.parameters(), core), "lr": args.after_sft_lr_head or args.lr_head},
{"params": _params_excluding(sat_h.parameters(), core), "lr": args.after_sft_lr_head or args.lr_head},
])
step, seen_tok, last_wall = _train_phase(
args, "sft", core, ar_h, sat_h, opt, scaler,
step, seen_tok, last_wall, cfg,
args.after_sft_source, args.after_sft_steps,
args.after_sft_block or DEFAULT_AFTER_SFT_BLOCK,
args.batch_size or DEFAULT_BATCH,
chat_cfg={
"chat": args.after_sft_chat,
"key": args.after_sft_chat_messages_key,
"gen_prompt": args.after_sft_add_generation_prompt if args.after_sft_add_generation_prompt is not None else args.sft_add_generation_prompt,
"text_field": args.after_sft_dataset_field_text
},
max_ckpts=args.max_ckpts,
target_tokens_override=None,
tie_weights=tie_weights,
streaming=False
)
save_ckpt(pathlib.Path(args.save_dir) / "final.pt", core, ar_h, sat_h, opt, scaler,
meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(), "tie_weights": tie_weights})
print("🎉 All Training Complete")
# ───────────────────────── Sampling ─────────────────────────
def _apply_penalties(logits, ids, n, rep_p, pres_p, freq_p):
if ids.numel() == 0: return logits
hist = ids[0, -n:].long() if n > 0 else ids[0].long()
uniq, counts = torch.unique(hist, return_counts=True)
if pres_p or freq_p:
logits[..., uniq] -= (pres_p + freq_p * counts.float())
if rep_p != 1.0:
sel = logits[..., uniq]
logits[..., uniq] = torch.where(sel > 0, sel / rep_p, sel * rep_p)
return logits
def _logits_to_probs(logits, T, top_k, top_p, min_p, greedy):
if greedy:
idx = logits.argmax(-1, keepdim=True)
return torch.zeros_like(logits).scatter_(-1, idx, 1.0)
probs = (logits / max(T, 1e-8)).softmax(-1)
if top_k:
v, i = torch.topk(probs, min(top_k, probs.size(-1)))
probs = torch.zeros_like(probs).scatter_(-1, i, v)
if top_p < 1.0:
s_probs, s_idx = torch.sort(probs, descending=True, dim=-1)
probs = torch.zeros_like(probs).scatter_(-1, s_idx, s_probs * (torch.cumsum(s_probs, -1) <= top_p).float())
if min_p > 0: probs[probs < min_p] = 0
denom = probs.sum(-1, keepdim=True)
if bool((denom <= 0).any()):
idx = logits.argmax(-1, keepdim=True)
fallback = torch.zeros_like(logits).scatter_(-1, idx, 1.0)
probs = torch.where(denom > 0, probs / denom.clamp_min(1e-30), fallback)
else:
probs = probs / denom
return probs
def _sample_from_probs(probs):
return probs.multinomial(1)
def _sample(logits, T, top_k, top_p, min_p, greedy):
return _sample_from_probs(_logits_to_probs(logits, T, top_k, top_p, min_p, greedy))
def _slice_kv_caches(kvs, keep_len: int):
return [(k[:, :, :keep_len, :].contiguous(), v[:, :, :keep_len, :].contiguous()) for k, v in kvs]
def _proof_ok(name: str, value: float, tol: float) -> Dict[str, Any]:
return {"name": name, "value": float(value), "tol": float(tol), "ok": bool(value <= tol)}
def _print_proof_report(title: str, rows: List[Dict[str, Any]], notes: Optional[List[str]] = None) -> bool:
print(f"[prove] {title}")
all_ok = True
for row in rows:
status = "PASS" if row["ok"] else "FAIL"
all_ok = all_ok and row["ok"]
print(f" {status:4s} {row['name']}: {row['value']:.3e} <= {row['tol']:.3e}")
if notes:
for note in notes:
print(f" note: {note}")
print(f"[prove] {title}: {'PASS' if all_ok else 'FAIL'}")
return all_ok
def _old_expanded_mha_forward(mha: TuneableAttentionMHA, x: torch.Tensor):
B, N, _ = x.shape
q = mha.q(x).view(B, N, mha.h, mha.dk).transpose(1, 2) @ mha.U
k = mha.k(x).view(B, N, mha.h, mha.dk).transpose(1, 2) @ mha.U
v = mha.v(x).view(B, N, mha.h, mha.dk).transpose(1, 2)
att = (q @ k.transpose(-1, -2)) / math.sqrt(mha.dk)
z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
return mha.proj(z)
def prove_m_fold(args) -> bool:
torch.manual_seed(args.seed)
cfg = PRESETS[args.preset].copy()
d, h, r = cfg["d"], cfg["heads"], cfg["rank"]
mha = TuneableAttentionMHA(d, h, r).to(DEV)
mha.eval()
x = torch.randn(2, args.cached_len + args.new_len, d, device=DEV, requires_grad=True)
y_new = mha(x)
y_old = _old_expanded_mha_forward(mha, x)
loss_new = y_new.square().mean()
loss_old = y_old.square().mean()
loss_new.backward(retain_graph=True)
new_grads = {name: p.grad.detach().clone() for name, p in mha.named_parameters() if p.grad is not None}
x_grad_new = x.grad.detach().clone()
mha.zero_grad(set_to_none=True); x.grad.zero_()
loss_old.backward()
old_grads = {name: p.grad.detach().clone() for name, p in mha.named_parameters() if p.grad is not None}
x_grad_old = x.grad.detach().clone()
max_param_grad = max((new_grads[k] - old_grads[k]).abs().max().item() for k in old_grads)
rows = [
_proof_ok("forward_output_max_abs", (y_new - y_old).abs().max().item(), args.tol),
_proof_ok("loss_abs", abs(loss_new.item() - loss_old.item()), args.tol),
_proof_ok("input_grad_max_abs", (x_grad_new - x_grad_old).abs().max().item(), args.tol),
_proof_ok("param_grad_max_abs", max_param_grad, args.tol),
]
notes = [
"(Q U)(K U)^T rewrites to Q(U U^T)K^T by transpose-of-product and associativity.",
f"preset={args.preset}, dk={d // h}, rank={r}, metric_path={'on' if r > d // h else 'off'}",
]
return _print_proof_report("m_fold", rows, notes)
def prove_sdpa_equivalence(args) -> bool:
torch.manual_seed(args.seed)
cfg = PRESETS[args.preset].copy()
d, h, r = cfg["d"], cfg["heads"], cfg["rank"]
n = args.cached_len + args.new_len
manual = TuneableAttentionMHA(d, h, r, use_sdpa=False).to(DEV).eval()
sdpa = TuneableAttentionMHA(d, h, r, use_sdpa=True).to(DEV).eval()
sdpa.load_state_dict(manual.state_dict())
rows = []
case_specs = [
("none", None, None),
("causal_alibi", causal_mask(n), n),
("sat_alibi", sat_mask(n), n),
]
max_grad = 0.0
for name, mask, rel_tokens in case_specs:
x_m = torch.randn(2, n, d, device=DEV, requires_grad=True)
x_s = x_m.detach().clone().requires_grad_(True)
y_m = manual(x_m, mask, rel_bias_tokens=rel_tokens)
y_s = sdpa(x_s, mask, rel_bias_tokens=rel_tokens)
loss_m = y_m.square().mean()
loss_s = y_s.square().mean()
loss_m.backward()
grads_m = {k: p.grad.detach().clone() for k, p in manual.named_parameters() if p.grad is not None}
x_grad_m = x_m.grad.detach().clone()
manual.zero_grad(set_to_none=True)
loss_s.backward()
grads_s = {k: p.grad.detach().clone() for k, p in sdpa.named_parameters() if p.grad is not None}
x_grad_s = x_s.grad.detach().clone()
sdpa.zero_grad(set_to_none=True)
rows.append(_proof_ok(f"{name}_forward_max_abs", (y_m - y_s).abs().max().item(), args.tol))
rows.append(_proof_ok(f"{name}_loss_abs", abs(loss_m.item() - loss_s.item()), args.tol))
max_grad = max(max_grad, (x_grad_m - x_grad_s).abs().max().item())
max_grad = max(max_grad, max((grads_m[k] - grads_s[k]).abs().max().item() for k in grads_m))
cached_len, new_len = args.cached_len, args.new_len
with torch.no_grad():
prefix = torch.randn(1, cached_len, d, device=DEV)
append = torch.randn(1, new_len, d, device=DEV)
_, kv_m = manual(prefix, causal_mask(cached_len), rel_bias_tokens=cached_len, use_cache=True)
_, kv_s = sdpa(prefix, causal_mask(cached_len), rel_bias_tokens=cached_len, use_cache=True)
y_m, _ = manual(
append, causal_mask_cached(new_len, cached_len),
rel_bias_tokens=cached_len + new_len, kv_cache=kv_m, use_cache=True
)
y_s, _ = sdpa(
append, causal_mask_cached(new_len, cached_len),
rel_bias_tokens=cached_len + new_len, kv_cache=kv_s, use_cache=True
)
rows.append(_proof_ok("cached_append_forward_max_abs", (y_m - y_s).abs().max().item(), args.tol))
rows.append(_proof_ok("max_grad_abs", max_grad, args.tol))
notes = [
"Manual path and SDPA path compute identical folded logits plus the same additive mask/ALiBi bias.",
f"preset={args.preset}, dk={d // h}, rank={r}, sdpa_available={hasattr(F, 'scaled_dot_product_attention')}",
]
return _print_proof_report("sdpa_equivalence", rows, notes)
def prove_cache_equivalence(args) -> bool:
torch.manual_seed(args.seed)
cfg = PRESETS[args.preset].copy()
core = Encoder(cfg).to(DEV).eval()
cached_len, new_len = args.cached_len, args.new_len
ids = torch.randint(0, VOCAB, (1, cached_len + new_len), device=DEV)
with torch.no_grad():
h_full = core(ids, causal_mask(ids.size(1)))[:, cached_len:]
h_prefix, kvs = core(
ids[:, :cached_len], causal_mask(cached_len),
use_cache=True, total_seq_len=cached_len
)
mask = causal_mask_cached(new_len, cached_len)
h_cached, kvs2 = core(
ids[:, cached_len:], mask,
kv_caches=kvs, use_cache=True, total_seq_len=cached_len + new_len
)
rows = [
_proof_ok("hidden_state_max_abs", (h_full - h_cached).abs().max().item(), args.tol),
_proof_ok("cache_key_len", abs(kvs2[0][0].size(2) - (cached_len + new_len)), 0.0),
_proof_ok("cache_value_len", abs(kvs2[0][1].size(2) - (cached_len + new_len)), 0.0),
]
notes = [
"Full causal forward and cached append forward should agree on the appended token hidden states.",
f"preset={args.preset}, cached_len={cached_len}, new_len={new_len}",
]
return _print_proof_report("cache_equivalence", rows, notes)
def prove_spec_kernel(args) -> bool:
torch.manual_seed(args.seed)
max_diff = 0.0
min_mass = 1.0
for _ in range(args.proof_trials):
p = torch.rand(1, args.proof_vocab, device=DEV)
q = torch.rand(1, args.proof_vocab, device=DEV)
p = p / p.sum(-1, keepdim=True)
q = q / q.sum(-1, keepdim=True)
accept_mass_by_token = torch.minimum(p, q)
reject_mass = 1.0 - accept_mass_by_token.sum(-1, keepdim=True)
residual = (p - q).clamp_min(0)
residual_denom = residual.sum(-1, keepdim=True).clamp_min(1e-30)
output_dist = accept_mass_by_token + reject_mass * (residual / residual_denom)
max_diff = max(max_diff, (output_dist - p).abs().max().item())
min_mass = min(min_mass, output_dist.sum().item())
rows = [
_proof_ok("output_distribution_max_abs", max_diff, args.tol),
_proof_ok("output_mass_abs_error", abs(min_mass - 1.0), args.tol),
]
notes = [
"Spec kernel proof: q(x)min(1,p(x)/q(x)) + reject_mass*max(p-q,0)/sum(max(p-q,0)) = p(x).",
"This proves the SAT draft step preserves AR distribution when p is AR and q is SAT.",
]
return _print_proof_report("spec_kernel", rows, notes)
def prove_sat_alignment(args) -> bool:
"""Verify SAT training windows match the SAT inference contract.
Inference uses SATHead(h[:, -SAT_BLOCK:]) to draft the next SAT_BLOCK tokens.
Training must therefore use each hidden window [s, s+B) to predict tokens
[s+B, s+2B), not unrelated positions in the sequence.
"""
N, B = 10, SAT_BLOCK
full = _sat_window_starts(N, B, stride=1, max_windows=0, device=DEV)
aligned = _sat_window_starts(N, B, stride=B, max_windows=0, device=DEV)
capped = _sat_window_starts(N, B, stride=1, max_windows=3, device=DEV)
short = _sat_window_starts(2 * B - 1, B, stride=1, max_windows=0, device=DEV)
ids = torch.arange(N, device=DEV).view(1, N)
off = torch.arange(B, device=DEV)
ctx_idx = full[:, None] + off[None, :]
tgt_idx = full[:, None] + B + off[None, :]
ctx = ids[:, ctx_idx]
tgt = ids[:, tgt_idx]
expected_ctx0 = torch.tensor([[0, 1]], device=DEV)
expected_tgt0 = torch.tensor([[2, 3]], device=DEV)
expected_last_ctx = torch.tensor([[6, 7]], device=DEV)
expected_last_tgt = torch.tensor([[8, 9]], device=DEV)
rows = [
_proof_ok("rolling_window_count", abs(full.numel() - 7), 0.0),
_proof_ok("aligned_window_count", abs(aligned.numel() - 4), 0.0),
_proof_ok("capped_window_count", abs(capped.numel() - 3), 0.0),
_proof_ok("short_sequence_has_no_windows", abs(short.numel()), 0.0),
_proof_ok("first_ctx_matches_last-window_contract", (ctx[0, 0] - expected_ctx0).abs().max().item(), 0.0),
_proof_ok("first_target_is_next_window", (tgt[0, 0] - expected_tgt0).abs().max().item(), 0.0),
_proof_ok("last_ctx_matches_last-window_contract", (ctx[0, -1] - expected_last_ctx).abs().max().item(), 0.0),
_proof_ok("last_target_is_next_window", (tgt[0, -1] - expected_last_tgt).abs().max().item(), 0.0),
]
notes = [
f"SAT_BLOCK={SAT_BLOCK}; rolling starts={full.tolist()}; aligned starts={aligned.tolist()}; capped starts={capped.tolist()}",
"This guards against training SAT on a hidden window while targeting unrelated early-sequence tokens.",
]
return _print_proof_report("sat_alignment", rows, notes)
def prove_compact_u(args) -> bool:
torch.manual_seed(args.seed)
cfg = PRESETS["nano_3x"].copy()
core = Encoder(cfg).to(DEV).eval()
with torch.no_grad():
core.blocks[0].mha.U[:, 3:] = 0
core.blocks[1].mha.U[:, 5:] = 0
ids = torch.randint(0, VOCAB, (1, args.cached_len + args.new_len), device=DEV)
with torch.no_grad():
before = core(ids, causal_mask(ids.size(1)))
infos = [compact_attention_U(blk.mha, eps_ratio=args.eps_ratio) for blk in core.blocks]
target_rank = max(blk.mha.U.shape[1] for blk in core.blocks)
for blk in core.blocks:
U = blk.mha.U.detach()
if U.shape[1] < target_rank:
pad = torch.zeros(U.shape[0], target_rank - U.shape[1], dtype=U.dtype, device=U.device)
blk.mha.U = nn.Parameter(torch.cat([U, pad], dim=1).contiguous())
blk.mha.r = target_rank
after = core(ids, causal_mask(ids.size(1)))
rows = [
_proof_ok("compacted_output_max_abs", (before - after).abs().max().item(), args.tol),
_proof_ok("target_rank", abs(target_rank - 5), 0.0),
]
notes = [
f"layer ranks: {[(i['r_old'], i['r_new']) for i in infos]}, padded target_rank={target_rank}",
"Zero-padding compacted U columns keeps U U^T unchanged and preserves checkpoint loadability.",
]
return _print_proof_report("compact_u", rows, notes)
# ────────────────────────────────────────────────────────────────────────
# Tiny algebraic-equivalence engine over tensor matmul / transpose.
# Goal: given a symbolic expression, search for equivalent rewrites with
# lower FLOP count. Demonstrated on the M-fold rule (rediscovered from
# the rules below without being told to apply it).
# ────────────────────────────────────────────────────────────────────────
def _find_lean_exe() -> Optional[str]:
found = shutil.which("lean")
if found:
return found
local = Path.home() / ".elan" / "bin" / ("lean.exe" if os.name == "nt" else "lean")
return str(local) if local.exists() else None
def prove_lean_m_fold(args) -> bool:
lean = _find_lean_exe()
if lean is None:
return _print_proof_report(
"lean_m_fold",
[_proof_ok("lean_executable_found", 1.0, 0.0)],
["Lean was not found on PATH or in ~/.elan/bin."]
)
proof_path = Path(args.lean_file).expanduser().resolve()
proof_src = """import Std
theorem scalar_m_fold_nat (q k u : Nat) : (q * u) * (k * u) = q * (u * u) * k := by
simp [Nat.mul_comm, Nat.mul_left_comm]
"""
proof_path.write_text(proof_src, encoding="ascii")
proc = subprocess.run([lean, str(proof_path)], capture_output=True, text=True, timeout=120)
notes = [
f"lean={lean}",
f"proof_file={proof_path}",
"Lean proves the commutative scalar skeleton; n2e.py's tensor checks catch shape/cache/gradient bugs.",
]
if proc.stdout.strip():
notes.append("stdout: " + proc.stdout.strip().splitlines()[0][:160])
if proc.stderr.strip():
notes.append("stderr: " + proc.stderr.strip().splitlines()[0][:160])
return _print_proof_report(
"lean_m_fold",
[_proof_ok("lean_exit_code", float(proc.returncode), 0.0)],
notes
)
class _Te:
"""Minimal tensor-expression AST.
Each node has a `shape` tuple (last 2 dims are the matrix dims; leading
dims are batch). Two equal expressions are not necessarily equal *trees* —
we compare canonical strings to dedupe.
"""
__slots__ = ("op", "args", "shape", "name")
def __init__(self, op: str, args: tuple, shape: tuple, name: Optional[str] = None):
self.op = op
self.args = args
self.shape = tuple(shape)
self.name = name
@staticmethod
def var(name: str, shape: tuple) -> "_Te":
return _Te("var", (), shape, name=name)
@staticmethod
def matmul(a: "_Te", b: "_Te") -> "_Te":
# Last dim of a equals second-to-last dim of b. Batch dims broadcast.
assert a.shape[-1] == b.shape[-2], f"matmul shape mismatch: {a.shape} @ {b.shape}"
out_shape = a.shape[:-1] + (b.shape[-1],)
return _Te("matmul", (a, b), out_shape)
@staticmethod
def transpose(a: "_Te") -> "_Te":
assert len(a.shape) >= 2
ns = list(a.shape)
ns[-1], ns[-2] = ns[-2], ns[-1]
return _Te("transpose", (a,), tuple(ns))
def canonical(self) -> str:
if self.op == "var":
return f"V({self.name},{list(self.shape)})"
return f"{self.op}({','.join(c.canonical() for c in self.args)})"
def __repr__(self) -> str:
if self.op == "var":
return f"{self.name}{list(self.shape)}"
if self.op == "matmul":
return f"({self.args[0]!r} @ {self.args[1]!r})"
if self.op == "transpose":
return f"{self.args[0]!r}^T"
return f"{self.op}({self.args})"
def _te_flops(e: "_Te", precomputed: Optional[set] = None) -> int:
"""Approximate FLOPs to evaluate `e` from leaves. precomputed names are free."""
if precomputed is None:
precomputed = set()
if e.op == "var":
return 0
if e.op == "transpose":
return _te_flops(e.args[0], precomputed) # a view, ~0 FLOPs (model as 0)
if e.op == "matmul":
a, b = e.args
# If e itself is a precomputed constant (e.g. M = U U^T saved at init), free.
ckey = e.canonical()
if ckey in precomputed:
return 0
# Else: cost of children + cost of this matmul.
# Use product of all leading dims for batch.
lead = 1
for d in e.shape[:-2]:
lead *= d
m, k, n = a.shape[-2], a.shape[-1], b.shape[-1]
return _te_flops(a, precomputed) + _te_flops(b, precomputed) + 2 * lead * m * k * n
return 0
# ── Rewrite rules: each takes an _Te and returns a (possibly equivalent) _Te or None ──
def _rule_assoc_left(e: "_Te") -> Optional["_Te"]:
"""(A @ B) @ C -> A @ (B @ C) [shapes must allow]"""
if e.op != "matmul":
return None
a_b, c = e.args
if a_b.op != "matmul":
return None
a, b = a_b.args
if b.shape[-1] != c.shape[-2]:
return None
try:
return _Te.matmul(a, _Te.matmul(b, c))
except AssertionError:
return None
def _rule_assoc_right(e: "_Te") -> Optional["_Te"]:
"""A @ (B @ C) -> (A @ B) @ C"""
if e.op != "matmul":
return None
a, b_c = e.args
if b_c.op != "matmul":
return None
b, c = b_c.args
if a.shape[-1] != b.shape[-2]:
return None
try:
return _Te.matmul(_Te.matmul(a, b), c)
except AssertionError:
return None
def _rule_transpose_product(e: "_Te") -> Optional["_Te"]:
"""(A @ B)^T -> B^T @ A^T"""
if e.op != "transpose":
return None
inner = e.args[0]
if inner.op != "matmul":
return None
a, b = inner.args
return _Te.matmul(_Te.transpose(b), _Te.transpose(a))
def _rule_double_transpose(e: "_Te") -> Optional["_Te"]:
"""(A^T)^T -> A"""
if e.op != "transpose":
return None
inner = e.args[0]
if inner.op != "transpose":
return None
return inner.args[0]
# ── Search: BFS over rewrites, find min-FLOPs equivalent ──
def _all_subexpr_rewrites(e: "_Te", rules) -> list:
"""Return every expression reachable by applying one rule at any subexpression."""
out = []
# Apply rules at the root.
for rule in rules:
new = rule(e)
if new is not None and new.canonical() != e.canonical():
out.append(new)
# Recurse into children: rebuild parent with rewritten child.
for i, child in enumerate(e.args):
for rewritten_child in _all_subexpr_rewrites(child, rules):
new_args = list(e.args)
new_args[i] = rewritten_child
try:
if e.op == "matmul":
out.append(_Te.matmul(new_args[0], new_args[1]))
elif e.op == "transpose":
out.append(_Te.transpose(new_args[0]))
except AssertionError:
continue
return out
def find_min_flops_equivalent(e: "_Te", precomputed=None, max_steps: int = 50) -> Tuple["_Te", int, list]:
"""BFS for the lowest-FLOP equivalent expression. Returns (best, flops, history)."""
rules = [_rule_assoc_left, _rule_assoc_right, _rule_transpose_product, _rule_double_transpose]
seen = {e.canonical(): _te_flops(e, precomputed)}
best, best_flops = e, seen[e.canonical()]
frontier = [e]
history = [(e.canonical(), best_flops)]
for _ in range(max_steps):
if not frontier:
break
nxt = []
for cur in frontier:
for rew in _all_subexpr_rewrites(cur, rules):
key = rew.canonical()
if key in seen:
continue
f = _te_flops(rew, precomputed)
seen[key] = f
nxt.append(rew)
history.append((key, f))
if f < best_flops:
best, best_flops = rew, f
frontier = nxt
return best, best_flops, history
def prove_rewriter(args) -> bool:
"""Validate the algebraic-rewriter engine on the M-fold case + sanity tests.
Concrete demo: build the symbolic expression
(X_q @ U) @ (X_k @ U)^T
where X_q, X_k are [B, h, N, dk] and U is [dk, r], with r > dk.
The rewriter should discover the equivalent
X_q @ (U @ U^T) @ X_k^T
(with U@U^T as a precomputed [dk,dk] constant) and report its lower FLOPs.
We also check:
- associativity is FLOPS-aware (search picks the cheaper grouping)
- transpose-of-product is found
- the engine never outputs a shape-incompatible expression
"""
rows = []
notes = []
# ── 1. M-fold rediscovery ──
B, h, N, dk, r = 1, 16, 1122, 64, 128 # production "large" preset
Xq = _Te.var("Xq", (B, h, N, dk))
Xk = _Te.var("Xk", (B, h, N, dk))
U = _Te.var("U", (dk, r))
# original: (Xq @ U) @ (Xk @ U)^T
orig = _Te.matmul(_Te.matmul(Xq, U), _Te.transpose(_Te.matmul(Xk, U)))
# mark M = U @ U^T as precomputed-once-per-forward (cached at top of attn fwd)
M = _Te.matmul(U, _Te.transpose(U))
precomputed = {M.canonical()}
best, best_flops, history = find_min_flops_equivalent(orig, precomputed=precomputed, max_steps=40)
orig_flops = _te_flops(orig, precomputed)
speedup = orig_flops / max(best_flops, 1)
# Verify the rewriter found a strictly cheaper equivalent
rows.append(_proof_ok("rewriter_found_cheaper", float(best_flops >= orig_flops), 0.0))
rows.append(_proof_ok(f"rewriter_speedup_>=_1.5x_at_r{r}_dk{dk}", float(speedup < 1.5), 0.0))
# ── 2. Associativity: (A@B)@C with B narrow should swap to A@(B@C) ──
A2 = _Te.var("A2", (1000, 4)) # tall
B2 = _Te.var("B2", (4, 1000)) # short
C2 = _Te.var("C2", (1000, 4))
chain = _Te.matmul(_Te.matmul(A2, B2), C2) # (1000x4 @ 4x1000) @ 1000x4
cheap, cf, _ = find_min_flops_equivalent(chain, max_steps=10)
chain_flops = _te_flops(chain)
rows.append(_proof_ok("assoc_found_cheaper_chain", float(cf >= chain_flops), 0.0))
# ── 3. Engine never produces a shape-invalid expression ──
invalid_count = 0
for canon, _ in history:
# all in `seen` were built via _Te constructors which assert shapes
pass # if any rule produced an invalid expr it would have been skipped
rows.append(_proof_ok("rewriter_no_invalid_shapes", float(invalid_count), 0.0))
notes.append(f"M-fold target: orig={orig_flops:,} flops -> best={best_flops:,} flops ({speedup:.2f}x)")
notes.append(f"best expression: {best!r}")
notes.append(f"explored {len(history)} candidates across BFS")
notes.append(f"assoc chain: orig={chain_flops:,} flops -> best={cf:,} flops ({chain_flops/max(cf,1):.2f}x)")
notes.append(f"best assoc: {cheap!r}")
return _print_proof_report("rewriter", rows, notes)
def prove_scaling_preview(args) -> bool:
"""Run the algebraic rewriter across every PRESET and report predicted M-fold
speedups. For each architectural config, constructs the symbolic attention
forward and asks find_min_flops_equivalent() what the cheapest equivalent
expression is. Useful for sanity-checking which presets actually benefit
from the M-fold rewrite (only r > dk wins).
"""
rows = []
notes = []
notes.append(f"{'preset':<14} {'d':>5} {'h':>3} {'dk':>4} {'r':>5} {'ratio':>6} "
f"{'orig_GFLOPs':>12} {'best_GFLOPs':>12} {'speedup':>8} {'regime'}")
notes.append("-" * 100)
B = 1
N = 1122 # production block
expansion_speedups = []
for preset_name, cfg in PRESETS.items():
d, h, r = cfg["d"], cfg["heads"], cfg["rank"]
dk = d // h
ratio = r / dk
regime = "compress" if ratio < 1 else ("identity" if ratio == 1 else "expand")
Xq = _Te.var("Xq", (B, h, N, dk))
Xk = _Te.var("Xk", (B, h, N, dk))
U = _Te.var("U", (dk, r))
orig = _Te.matmul(_Te.matmul(Xq, U), _Te.transpose(_Te.matmul(Xk, U)))
M = _Te.matmul(U, _Te.transpose(U))
precomputed = {M.canonical()}
best, best_flops, _hist = find_min_flops_equivalent(orig, precomputed=precomputed, max_steps=20)
orig_flops = _te_flops(orig, precomputed)
speedup = orig_flops / max(best_flops, 1)
notes.append(
f"{preset_name:<14} {d:>5} {h:>3} {dk:>4} {r:>5} {ratio:>5.1f}x "
f"{orig_flops/1e9:>11.2f}G {best_flops/1e9:>11.2f}G {speedup:>7.2f}x {regime}"
)
if regime == "expand":
expansion_speedups.append(speedup)
rows.append(_proof_ok(f"speedup_{preset_name}_>=_1.4x", float(speedup < 1.4), 0.0))
notes.append("-" * 100)
if expansion_speedups:
notes.append(
f"expansion presets: min={min(expansion_speedups):.2f}x, "
f"max={max(expansion_speedups):.2f}x, "
f"mean={sum(expansion_speedups)/len(expansion_speedups):.2f}x "
f"(n={len(expansion_speedups)})"
)
notes.append("compress/identity presets: M-fold not applied (skipped, original path is cheaper).")
notes.append("Speedups are theoretical FLOP reductions on the score matmul only.")
return _print_proof_report("scaling_preview", rows, notes)
def prove_alibi(args) -> bool:
"""Verify ALiBi slope and bias construction satisfies its math properties.
For all valid n_heads (powers of two AND non-powers-of-two), check:
1. _alibi_slopes(h) returns exactly h positive values in (0, 1)
2. Power-of-two head slopes are strictly monotonically decreasing
3. alibi_bias(h, n) has shape [1, h, n, n]
4. alibi_bias is non-positive everywhere (it's a penalty)
5. alibi_bias is zero on the diagonal (zero self-distance penalty)
6. alibi_bias[i,j] depends only on (j-i) for j>=i: translation-invariant
on the future side before the causal mask is added
7. alibi_bias lower triangle is zero (clamp_min on j-i)
"""
rows = []
notes = []
head_counts = [1, 2, 4, 8, 16, 24, 32]
for h in head_counts:
slopes = _alibi_slopes(h).flatten()
rows.append(_proof_ok(f"slopes_count_h{h}", abs(slopes.numel() - h), 0.0))
rows.append(_proof_ok(f"slopes_all_positive_h{h}", float((slopes <= 0).sum().item()), 0.0))
rows.append(_proof_ok(f"slopes_below_one_h{h}", float((slopes >= 1).sum().item()), 0.0))
if h in {1, 2, 4, 8, 16, 32, 64}:
diff = slopes[1:] - slopes[:-1]
rows.append(_proof_ok(f"slopes_monotone_h{h}", float((diff > 1e-7).sum().item()), 0.0))
# ── Bias-tensor checks ──
for h, n in [(4, 8), (8, 16), (16, 32)]:
bias = alibi_bias(h, n)
rows.append(_proof_ok(f"bias_shape_h{h}_n{n}", float(bias.shape != (1, h, n, n)), 0.0))
rows.append(_proof_ok(f"bias_nonpositive_h{h}_n{n}", float((bias > 1e-7).sum().item()), 0.0))
diag_max = bias[0, :, range(n), range(n)].abs().max().item()
rows.append(_proof_ok(f"bias_diagonal_zero_h{h}_n{n}", diag_max, 1e-7))
# 6. translation-invariance on future side: bias[i,j] == bias[i+1,j+1] for j>=i
b0 = bias[0, 0]
max_drift = 0.0
for i in range(n - 1):
for j in range(i, n - 1):
d = (b0[i, j] - b0[i + 1, j + 1]).abs().item()
max_drift = max(max_drift, d)
rows.append(_proof_ok(f"bias_translation_invariant_h{h}_n{n}", max_drift, 1e-6))
# 7. lower triangle zero
lower_sum = bias[0, 0].tril(-1).abs().sum().item()
rows.append(_proof_ok(f"bias_lower_tri_zero_h{h}_n{n}", lower_sum, 1e-6))
notes.append("ALiBi spec: bias[h, i, j] = -slope[h] * max(0, j - i)")
notes.append("Slopes formula valid for any n_heads via pow2-base + extra-from-2N pattern.")
notes.append("Translation-invariance is what makes ALiBi a *relative* position bias.")
return _print_proof_report("alibi", rows, notes)
def prove_spec_live(args) -> bool:
"""End-to-end verifier for the --mode spec control flow.
prove_spec_kernel validates the MATH IDENTITY in isolation:
accept_mass + reject_mass * residual_normalized == p
This proof exercises the actual implementation logic that realizes that
identity, including:
- the accept criterion `rand < min(1, p/q)`
- first-reject residual sampling `r = max(p-q, 0); r /= sum(r)`
- the `denom <= 0` fallback to p (degenerate case where q dominates p
on every accepted-mass token)
- KV cache slicing on partial acceptance (`_slice_kv_caches`)
For each of N random (p,q) pairs we run the live-style step T times and
record the output token. The empirical distribution must match p within
a chi-square-like tolerance ~1/sqrt(T). A bug like flipping the reject
direction or off-by-one in the residual will produce a measurable bias.
"""
torch.manual_seed(args.seed)
def _live_spec_step(p: torch.Tensor, q: torch.Tensor, rng: torch.Generator) -> int:
"""One mirror of the per-token accept/reject step from infer(--mode spec).
Mirrors lines that begin at 'accept_prob = (p_tok / q_tok).clamp(max=1.0)'
in n2e.py infer(). Returns the chosen token id.
"""
# Sample draft from q
draft_tok = torch.multinomial(q, num_samples=1, generator=rng)
p_tok = p.gather(-1, draft_tok)
q_tok = q.gather(-1, draft_tok).clamp_min(1e-30)
accept_prob = (p_tok / q_tok).clamp(max=1.0)
u = torch.rand((), generator=rng)
if u.item() < accept_prob.item():
return int(draft_tok.item())
# Reject: sample from residual
residual = (p - q).clamp_min(0)
denom = residual.sum(-1, keepdim=True)
if bool((denom <= 0).any()):
residual = p
else:
residual = residual / denom
out = torch.multinomial(residual, num_samples=1, generator=rng)
return int(out.item())
# ── 1. Cache-slice direct test: shape, dtype, contiguity, value ──
# Build a fake KV cache and slice; verify content matches a hand-truncate.
h, N, dk = 4, 16, 8
fake_k = torch.arange(h * N * dk, dtype=torch.float32).reshape(1, h, N, dk)
fake_v = (fake_k + 1000.0).clone()
sliced = _slice_kv_caches([(fake_k, fake_v)], keep_len=10)
k_s, v_s = sliced[0]
cache_rows = [
_proof_ok("cache_k_keeps_len", abs(k_s.size(2) - 10), 0.0),
_proof_ok("cache_v_keeps_len", abs(v_s.size(2) - 10), 0.0),
_proof_ok("cache_k_value_match", (k_s - fake_k[:, :, :10, :]).abs().max().item(), 0.0),
_proof_ok("cache_v_value_match", (v_s - fake_v[:, :, :10, :]).abs().max().item(), 0.0),
_proof_ok("cache_k_contiguous", float(not k_s.is_contiguous()), 0.0),
_proof_ok("cache_v_contiguous", float(not v_s.is_contiguous()), 0.0),
]
# ── 2. Live-step distribution test on random (p, q) pairs ──
# For each pair, run the real accept/residual logic many times and check
# that the empirical token distribution matches p within sqrt(T) noise.
V = max(8, args.proof_vocab // 8) # smaller vocab for tighter sampling
T = max(4000, args.proof_trials * 64) # enough trials for convergence
NUM_PAIRS = 6
rng = torch.Generator().manual_seed(args.seed)
max_tv = 0.0
max_kl = 0.0
pairs_summary = []
for pair_idx in range(NUM_PAIRS):
p_unn = torch.rand(V, generator=rng)
q_unn = torch.rand(V, generator=rng)
p = p_unn / p_unn.sum()
q = q_unn / q_unn.sum()
counts = torch.zeros(V, dtype=torch.float64)
for _ in range(T):
tok = _live_spec_step(p, q, rng)
counts[tok] += 1
emp = counts / T
# Total variation distance |emp - p|/2 (max ~ 1/sqrt(T))
tv = 0.5 * (emp - p.double()).abs().sum().item()
# KL(emp || p) — tiny if matched, blows up if any token is over-emitted
eps = 1e-8
kl = (emp * (emp.clamp_min(eps).log() - p.double().clamp_min(eps).log())).sum().item()
max_tv = max(max_tv, tv)
max_kl = max(max_kl, kl)
pairs_summary.append((pair_idx, tv, kl))
# 1/sqrt(T) is the noise floor for empirical distribution; allow 6x slack.
tv_tol = 6.0 / (T ** 0.5)
kl_tol = 0.05 # fairly loose; KL from sampling is small
dist_rows = [
_proof_ok(f"spec_live_max_TV_distance", max_tv, tv_tol),
_proof_ok(f"spec_live_max_KL_div", max_kl, kl_tol),
]
# ── 3. Edge case: q dominates p (residual denom == 0) ──
# When q[x] >= p[x] for every x, the reject-residual is all-zero and the
# implementation falls back to sampling from p. Test this branch is hit.
p_dom = torch.tensor([0.5, 0.5])
q_dom = torch.tensor([0.5, 0.5]) # q == p, accept always
counts = torch.zeros(2, dtype=torch.float64)
T_edge = 4000
rng2 = torch.Generator().manual_seed(args.seed + 1)
for _ in range(T_edge):
counts[_live_spec_step(p_dom, q_dom, rng2)] += 1
emp = counts / T_edge
edge_tv = 0.5 * (emp - p_dom.double()).abs().sum().item()
edge_rows = [_proof_ok("spec_live_edge_q_eq_p", edge_tv, 6.0 / (T_edge ** 0.5))]
# ── 4. Edge case: q has zero mass on some tokens (clamp_min) ──
p_z = torch.tensor([0.4, 0.3, 0.3])
q_z = torch.tensor([0.5, 0.5, 0.0])
counts = torch.zeros(3, dtype=torch.float64)
rng3 = torch.Generator().manual_seed(args.seed + 2)
for _ in range(T_edge):
counts[_live_spec_step(p_z, q_z, rng3)] += 1
emp = counts / T_edge
edge2_tv = 0.5 * (emp - p_z.double()).abs().sum().item()
edge_rows.append(_proof_ok("spec_live_edge_q_zero_at_token", edge2_tv, 6.0 / (T_edge ** 0.5)))
rows = cache_rows + dist_rows + edge_rows
notes = [
f"per-pair: " + ", ".join(f"#{i}:tv={t:.4f},kl={k:.4f}" for i, t, k in pairs_summary),
f"trials per (p,q): {T} | vocab: {V} | pairs: {NUM_PAIRS} | tv_tol: {tv_tol:.4f}",
"Live verifier exercises actual accept/reject control flow + KV cache slicing.",
"If the implementation deviates from the math, max_TV will exceed sqrt(T) noise floor.",
]
return _print_proof_report("spec_live", rows, notes)
def prove_masks(args) -> bool:
"""Verify mask construction satisfies its specification.
Specs:
causal_mask(n)[i,j] : 0 if j <= i else -inf
sat_mask(n, B)[i,j] : 0 if floor(j/B) <= floor(i/B) else -inf
causal_mask_cached(new, cached) : 0 if j <= cached + i else -inf (incremental decode)
For each implementation we generate a from-scratch reference using the spec, then
require element-wise equality. A bug in the implementation (off-by-one, wrong
direction of the inequality, future-token leak, etc.) would produce a non-zero
diff. Checks span small + medium n and several block sizes including the edge
cases B=1 (degenerates to causal-bidirectional), B=n (everything one block).
"""
NEG = float("-inf")
def causal_spec(n):
# j <= i -> 0, else -inf
m = torch.full((1, 1, n, n), NEG, device=DEV)
for i in range(n):
for j in range(i + 1):
m[0, 0, i, j] = 0.0
return m
def sat_spec(n, B):
m = torch.full((1, 1, n, n), NEG, device=DEV)
for i in range(n):
for j in range(n):
if (j // B) <= (i // B):
m[0, 0, i, j] = 0.0
return m
def cached_spec(new_len, cached_len):
# query at offset i sees keys [0 .. cached_len + i]
n_q, n_k = new_len, cached_len + new_len
m = torch.full((1, 1, n_q, n_k), NEG, device=DEV)
for i in range(n_q):
for j in range(cached_len + i + 1):
m[0, 0, i, j] = 0.0
return m
def diff(a, b):
# equality on values where neither is -inf, plus matching -inf locations
a_inf = torch.isinf(a) & (a < 0)
b_inf = torch.isinf(b) & (b < 0)
if not torch.equal(a_inf, b_inf):
return float("inf")
finite = ~a_inf
if finite.sum() == 0:
return 0.0
return (a[finite] - b[finite]).abs().max().item()
rows = []
notes = []
# causal_mask across several sizes
for n in [1, 2, 4, 8, 16, 64]:
_causal_mask_cache.clear()
impl = causal_mask(n)
spec = causal_spec(n)
rows.append(_proof_ok(f"causal_mask_n{n}", diff(impl, spec), 0.0))
# sat_mask across (n, B) including edge cases B=1, B=n, B>n
for n, B in [(8, 1), (8, 2), (8, 4), (8, 8), (16, 3), (32, 5), (8, 16)]:
_sat_mask_cache.clear()
impl = sat_mask(n, B)
spec = sat_spec(n, B)
rows.append(_proof_ok(f"sat_mask_n{n}_B{B}", diff(impl, spec), 0.0))
# causal_mask_cached for incremental decode
for cached, new in [(0, 4), (3, 1), (5, 3), (10, 7)]:
impl = causal_mask_cached(new, cached)
spec = cached_spec(new, cached)
rows.append(_proof_ok(f"causal_mask_cached_c{cached}_n{new}", diff(impl, spec), 0.0))
# Property check: causal mask must never allow a future-token attention
for n in [16, 32]:
m = causal_mask(n)
future_allowed = (~torch.isinf(m[0, 0])) & (torch.arange(n, device=DEV).unsqueeze(0) > torch.arange(n, device=DEV).unsqueeze(1))
rows.append(_proof_ok(f"causal_no_future_leak_n{n}", future_allowed.float().sum().item(), 0.0))
# Property check: SAT mask block-causal and within-block bidirectional
for n, B in [(16, 4), (32, 8)]:
m = sat_mask(n, B)
finite = ~torch.isinf(m[0, 0])
idx = torch.arange(n, device=DEV)
grp = idx // B
# No attending to a later block:
violations = (finite & (grp.unsqueeze(0) > grp.unsqueeze(1))).float().sum().item()
rows.append(_proof_ok(f"sat_no_future_block_n{n}_B{B}", violations, 0.0))
# Within same block: must be all-allowed
same_block = (grp.unsqueeze(0) == grp.unsqueeze(1))
within_disallowed = (same_block & ~finite).float().sum().item()
rows.append(_proof_ok(f"sat_within_block_open_n{n}_B{B}", within_disallowed, 0.0))
notes.append("Each n,B is generated from spec (math) and compared element-wise to the implementation.")
notes.append("Property checks: causal forbids future, SAT forbids future-block and is bidirectional within block.")
return _print_proof_report("masks", rows, notes)
def run_proofs(args) -> int:
theorem_fns = {
"m_fold": prove_m_fold,
"sdpa_equivalence": prove_sdpa_equivalence,
"cache_equivalence": prove_cache_equivalence,
"spec_kernel": prove_spec_kernel,
"sat_alignment": prove_sat_alignment,
"compact_u": prove_compact_u,
"masks": prove_masks,
"spec_live": prove_spec_live,
"rewriter": prove_rewriter,
"alibi": prove_alibi,
"lean_m_fold": prove_lean_m_fold,
"scaling_preview": prove_scaling_preview,
}
names = list(theorem_fns) if args.theorem == "all" else [args.theorem]
ok = True
for name in names:
ok = theorem_fns[name](args) and ok
return 0 if ok else 1
@torch.no_grad()
def infer(args):
if args.mode in {"ar", "spec"}:
if args.temperature is None: args.temperature = 0.7
if args.top_k is None: args.top_k = 0
if args.repetition_penalty is None: args.repetition_penalty = 1.3
if args.presence_penalty is None: args.presence_penalty = 0.0
if args.frequency_penalty is None: args.frequency_penalty = 0.3
if args.penalty_last_n is None: args.penalty_last_n = 128
if args.var is None: args.var = False
else:
if args.temperature is None: args.temperature = 0.5
if args.top_k is None: args.top_k = 30
if args.repetition_penalty is None: args.repetition_penalty = 2.0
if args.presence_penalty is None: args.presence_penalty = 0.6
if args.frequency_penalty is None: args.frequency_penalty = 1.0
if args.penalty_last_n is None: args.penalty_last_n = 200
if args.var is None: args.var = True
path = _resolve_ckpt(pathlib.Path(args.ckpt)) or pathlib.Path(args.ckpt)
sd = torch.load(path, map_location="cpu")
_restore_tokenizer_from_checkpoint(sd)
# Handle delta checkpoints (weight-only, no cfg)
if sd.get("delta"):
print("[infer] Delta checkpoint detected, using large preset cfg")
cfg = PRESETS["large"].copy()
tie_weights = False
# Remap: delta stores under sd["weights"]["core"/"ar"/"sat"]
sd["core"] = sd["weights"]["core"]
sd["ar"] = sd["weights"]["ar"]
sd["sat"] = sd["weights"]["sat"]
else:
cfg = sd["cfg"]
tie_weights = sd.get("tie_weights", False)
uk_time = get_uk_time()
ckpt_name = path.name
print(f"┌─────────────────────────────────────────────────┐")
print(f"│ INFERENCE @ {uk_time:<35s} │")
print(f"├─────────────────────────────────────────────────┤")
print(f"│ Checkpoint: {ckpt_name:<35s} │")
print(f"└─────────────────────────────────────────────────┘")
print_expansion_info(cfg, tie_weights)
dk = cfg["d"] // cfg["heads"]
if cfg["rank"] > dk:
print(f"[qk-fold] exact score fold active: rank {cfg['rank']} -> d_k {dk}")
use_sdpa = not getattr(args, "no_sdpa", False)
if use_sdpa:
print("[sdpa] exact fused scaled-dot-product attention enabled")
core = Encoder(cfg, tie_weights=tie_weights, use_sdpa=use_sdpa).to(DEV)
ar_h = ARHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(DEV)
sat_h = SATHead(cfg["d"]).to(DEV)
core.load_state_dict(sd["core"])
ar_h.load_state_dict(sd["ar"])
sat_h.load_state_dict(sd["sat"])
core.eval()
ar_h.eval()
sat_h.eval()
total_params = _count_enabled_params(core, ar_h, sat_h)
if total_params >= 1_000_000_000:
param_str = f"{total_params / 1_000_000_000:.2f}B"
elif total_params >= 1_000_000:
param_str = f"{total_params / 1_000_000:.2f}M"
elif total_params >= 1_000:
param_str = f"{total_params / 1_000:.2f}K"
else:
param_str = f"{total_params}"
print(f"Model size: {param_str} parameters ({total_params:,})")
prompt_tokens = tok.encode(args.prompt)
prompt_len = len(prompt_tokens)
ids = torch.tensor([prompt_tokens], device=DEV)
if ids.size(1) == 0:
ids = torch.tensor([[EOS]], device=DEV)
prompt_len = 1
mode_str = args.mode
if args.mode == "sat":
mode_str = f"sat-{'var' if args.var else 'fixed'}"
elif args.mode == "spec":
mode_str = f"sat-spec-ar-{'var' if args.var else 'fixed'}"
print(f"{Colors.INFO}Generating ({mode_str})...{Colors.RESET}")
start = time.time()
spec_stats = None
if args.mode == "ar":
h, kvs = core(ids, causal_mask(ids.size(1)), use_cache=True, total_seq_len=ids.size(1))
for _ in range(args.max_new):
logits = ar_h(h)[:, -1]
logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty)
nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
ids = torch.cat([ids, nxt], 1)
h, kvs = core(ids[:, -1:], None, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
elif args.mode == "spec":
ar_hid, ar_kvs = core(ids, causal_mask(ids.size(1)), use_cache=True, total_seq_len=ids.size(1))
sat_cached_len = ids.size(1)
sat_hid, sat_kvs = core(ids, sat_mask(ids.size(1)), use_cache=True, total_seq_len=sat_cached_len)
added = 0
drafted_total = accepted_total = rejected_total = 0
while added < args.max_new:
old_len = ids.size(1)
logits_all, gate = sat_h(sat_hid[:, -SAT_BLOCK:])
stride = SAT_BLOCK if (not args.var or gate is None) else (gate.softmax(-1).multinomial(1).item() + 1)
stride = min(int(stride), logits_all.size(1), args.max_new - added)
if stride <= 0:
break
draft_tokens, draft_probs = [], []
draft_context = ids
for i in range(stride):
q_logits = logits_all[:, i].clone()
q_logits = _apply_penalties(q_logits, draft_context, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty)
q_probs = _logits_to_probs(q_logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
nxt = _sample_from_probs(q_probs)
draft_tokens.append(nxt)
draft_probs.append(q_probs)
draft_context = torch.cat([draft_context, nxt], 1)
draft_ids = torch.cat(draft_tokens, dim=1)
draft_len = draft_ids.size(1)
drafted_total += draft_len
# One AR pass verifies every SAT draft token under the exact AR target
# distribution. This is speculative decoding with the SAT head as q.
p0_logits = ar_h(ar_hid)[:, -1].clone()
verify_mask = causal_mask_cached(draft_len, old_len)
verify_hid, verify_kvs = core(
draft_ids, verify_mask, kv_caches=ar_kvs,
use_cache=True, total_seq_len=old_len + draft_len
)
verify_logits = ar_h(verify_hid)
accepted = 0
reject_token = None
target_context = ids
for i in range(draft_len):
p_logits = p0_logits if i == 0 else verify_logits[:, i - 1].clone()
p_logits = _apply_penalties(p_logits, target_context, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty)
p_probs = _logits_to_probs(p_logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
q_probs = draft_probs[i]
draft_tok = draft_tokens[i]
p_tok = p_probs.gather(-1, draft_tok)
q_tok = q_probs.gather(-1, draft_tok).clamp_min(1e-30)
accept_prob = (p_tok / q_tok).clamp(max=1.0)
# One device->host sync per drafted token instead of two.
if torch.rand_like(accept_prob).lt(accept_prob).item():
accepted += 1
target_context = torch.cat([target_context, draft_tok], 1)
continue
residual = (p_probs - q_probs).clamp_min(0)
denom = residual.sum(-1, keepdim=True)
residual = p_probs if bool((denom <= 0).any()) else residual / denom
reject_token = _sample_from_probs(residual)
rejected_total += 1
break
if reject_token is None:
ids = draft_context
ar_hid, ar_kvs = verify_hid[:, -1:], verify_kvs
actual_new = draft_ids
accepted_total += draft_len
else:
if accepted > 0:
accepted_ids = draft_ids[:, :accepted]
ids = torch.cat([ids, accepted_ids], 1)
ar_kvs = _slice_kv_caches(verify_kvs, old_len + accepted)
ar_hid = verify_hid[:, accepted - 1:accepted]
actual_new = torch.cat([draft_ids[:, :accepted], reject_token], dim=1) if accepted > 0 else reject_token
ids = torch.cat([ids, reject_token], 1)
ar_hid, ar_kvs = core(
reject_token, None, kv_caches=ar_kvs,
use_cache=True, total_seq_len=ids.size(1)
)
accepted_total += accepted
added += actual_new.size(1)
mask = sat_mask_cached(actual_new.size(1), sat_cached_len)
sat_hid, sat_kvs = core(
actual_new, mask, kv_caches=sat_kvs,
use_cache=True, total_seq_len=ids.size(1)
)
sat_cached_len = ids.size(1)
spec_stats = {
"drafted": drafted_total,
"accepted": accepted_total,
"rejected": rejected_total,
"accept_rate": (accepted_total / drafted_total) if drafted_total else 0.0,
}
else:
cached_len = ids.size(1)
h, kvs = core(ids, sat_mask(ids.size(1)), use_cache=True, total_seq_len=cached_len)
added = 0
while added < args.max_new:
logits_all, gate = sat_h(h[:, -SAT_BLOCK:])
stride = SAT_BLOCK if (not args.var or gate is None) else (gate.softmax(-1).multinomial(1).item() + 1)
stride = min(int(stride), logits_all.size(1))
new_tokens = []
for i in range(int(stride)):
logits = logits_all[:, i]
logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty)
nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
new_tokens.append(nxt)
ids = torch.cat([ids, nxt], 1)
added += 1
if added >= args.max_new: break
if added >= args.max_new: break
new_ids = torch.cat(new_tokens, dim=1)
mask = sat_mask_cached(new_ids.size(1), cached_len)
h, kvs = core(new_ids, mask, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
cached_len = ids.size(1)
elapsed = time.time() - start
gen_tokens = len(ids[0]) - prompt_len
tok_per_sec = gen_tokens / elapsed if elapsed > 0 else 0
all_tokens = ids[0].tolist()
prompt_text = tok.decode(all_tokens[:prompt_len], skip_special_tokens=True)
gen_text = tok.decode(all_tokens[prompt_len:], skip_special_tokens=True)
print(f"{Colors.PROMPT}{prompt_text}{Colors.RESET}{gen_text}")
print(f"{Colors.INFO}[{elapsed:.2f}s | {gen_tokens} tokens | {tok_per_sec:.1f} tok/s]{Colors.RESET}")
if spec_stats is not None:
print(f"{Colors.INFO}[spec drafted={spec_stats['drafted']} accepted={spec_stats['accepted']} rejected={spec_stats['rejected']} accept={100*spec_stats['accept_rate']:.1f}%]{Colors.RESET}")
if getattr(args, "claude_friendly", False):
print("[CLAUDE_FRIENDLY_START]")
print(f"[mode={mode_str}]")
print("[prompt_input]")
print(prompt_text)
print("[completion]")
print(gen_text)
print(f"[stats] {elapsed:.2f}s | {gen_tokens} tokens | {tok_per_sec:.1f} tok/s")
if spec_stats is not None:
print(f"[spec] drafted={spec_stats['drafted']} accepted={spec_stats['accepted']} rejected={spec_stats['rejected']} accept={100*spec_stats['accept_rate']:.1f}%")
print("[CLAUDE_FRIENDLY_END]")
# ───────────────────────── CLI ─────────────────────────
def main():
ap = argparse.ArgumentParser(description="AGILLM Expansion Ratio Testing")
sub = ap.add_subparsers(dest="cmd", required=True)
tr = sub.add_parser("train")
tr.add_argument("--preset", choices=PRESETS.keys(), default="nano_3x")
tr.add_argument("--rank", type=int)
tr.add_argument("--block", type=int, default=DEFAULT_BLOCK)
tr.add_argument("--batch_size", type=int, default=DEFAULT_BATCH)
tr.add_argument("--source", default=DEFAULT_PRETRAIN_SOURCES)
tr.add_argument("--target_tokens", type=int)
tr.add_argument("--steps", type=int)
tr.add_argument("--amp", action="store_true")
tr.add_argument("--compile", action="store_true", help="Use torch.compile for speedup")
tr.add_argument("--no_sdpa", action="store_true", help="Disable exact SDPA attention backend and use manual softmax attention")
tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC)
tr.add_argument("--delta_every_steps", type=int, default=DEFAULT_DELTA_STEPS, help="Weight-only delta save every N steps (0=off)")
tr.add_argument("--delta_max_keep", type=int, default=DEFAULT_MAX_DELTAS, help="Max delta checkpoints to keep")
tr.add_argument("--resume_delta", type=str, help="Resume from a delta (weight-only, no optimizer state)")
tr.add_argument("--save_dir", default=str(CKDIR))
tr.add_argument("--resume", type=str)
tr.add_argument("--x2", action="store_true")
tr.add_argument("--warmstart_from", type=str)
tr.add_argument("--fresh", action="store_true")
tr.add_argument("--max_ckpts", type=int, default=None)
tr.add_argument("--chilla_max_double", action="store_true")
tr.add_argument("--tie_weights", action="store_true")
tr.add_argument("--ar_only", action="store_true")
tr.add_argument("--sat_loss_stride", type=int, default=1, help="SAT future-window training stride (1=all rolling windows, 2=block-aligned for SAT_BLOCK=2)")
tr.add_argument("--sat_loss_max_windows", type=int, default=0, help="Cap SAT training windows per sequence (0=all)")
tr.add_argument("--freeze_core", action="store_true")
tr.add_argument("--unfreeze_ln", action="store_true")
tr.add_argument("--train_emb", action="store_true")
tr.add_argument("--lr_core", type=float, default=LR_CORE)
tr.add_argument("--lr_head", type=float, default=LR_HEAD)
tr.add_argument("--chat", action="store_true")
tr.add_argument("--chat_messages_key", default="messages")
tr.add_argument("--dataset_field_text", default="text")
tr.add_argument("--sft_add_generation_prompt", action="store_true")
tr.add_argument("--auto_grow", action="store_true")
tr.add_argument("--grow_plan", default="576,640,768,896,1024,1122")
tr.add_argument("--grow_every_steps", type=int, default=50000)
tr.add_argument("--after_sft_source", default="")
tr.add_argument("--after_sft_steps", type=int, default=0)
tr.add_argument("--after_sft_chat", action="store_true")
tr.add_argument("--after_sft_chat_messages_key", default="messages")
tr.add_argument("--after_sft_dataset_field_text", default="text")
tr.add_argument("--after_sft_add_generation_prompt", type=bool, default=None)
tr.add_argument("--after_sft_block", type=int, default=0)
tr.add_argument("--after_sft_freeze_core", action="store_true")
tr.add_argument("--after_sft_unfreeze_ln", action="store_true")
tr.add_argument("--after_sft_train_emb", action="store_true")
tr.add_argument("--after_sft_lr_core", type=float, default=0.0)
tr.add_argument("--after_sft_lr_head", type=float, default=0.0)
inf = sub.add_parser("infer")
inf.add_argument("--mode", choices=["ar", "sat", "spec"], required=True)
inf.add_argument("--ckpt", required=True)
inf.add_argument("--prompt", required=True)
inf.add_argument("--max_new", type=int, default=120)
inf.add_argument("--temperature", type=float, default=None)
inf.add_argument("--greedy", action="store_true")
inf.add_argument("--top_k", type=int, default=None)
inf.add_argument("--top_p", type=float, default=0.9)
inf.add_argument("--min_p", type=float, default=0.0)
inf.add_argument("--repetition_penalty", type=float, default=None)
inf.add_argument("--presence_penalty", type=float, default=None)
inf.add_argument("--frequency_penalty", type=float, default=None)
inf.add_argument("--penalty_last_n", type=int, default=None)
inf.add_argument("--var", action="store_true", default=None)
inf.add_argument("--no-var", dest="var", action="store_false")
inf.add_argument("--no_sdpa", action="store_true", help="Disable exact SDPA attention backend and use manual softmax attention")
inf.add_argument("--claude-friendly", action="store_true", help="Also print an artifact-free prompt/completion block for downstream JSON consumers")
st = sub.add_parser("status", help="Read-only training status")
st.add_argument("--json", dest="json_output", action="store_true")
st.add_argument("--log", type=str, default=str(STATUS_DEFAULT_LOG))
st.add_argument("--save_dir", type=str, default=str(STATUS_DEFAULT_SAVE_DIR))
cp = sub.add_parser("compact",
help="Compact each attention block's U parameter via eigendecomposition. "
"Output checkpoint is functionally identical (CONSEQUENCE 2 of "
"Shared-U Score Capacity theorem; see file header).")
cp.add_argument("--src", type=str, required=True, help="Source checkpoint")
cp.add_argument("--dst", type=str, required=True, help="Destination checkpoint")
cp.add_argument("--eps_ratio", type=float, default=1e-6,
help="Eigenvalue threshold relative to max eigenvalue")
pr = sub.add_parser("prove",
help="Run n2e.py-specific proof/checker tests for exact rewrites, caches, and SAT-spec sampling.")
pr.add_argument("--theorem",
choices=["all", "m_fold", "sdpa_equivalence", "cache_equivalence", "spec_kernel", "sat_alignment", "compact_u", "masks", "spec_live", "rewriter", "alibi", "lean_m_fold", "scaling_preview"],
default="all")
pr.add_argument("--preset", choices=PRESETS.keys(), default="femto_12x")
pr.add_argument("--seed", type=int, default=1234)
pr.add_argument("--cached_len", type=int, default=13)
pr.add_argument("--new_len", type=int, default=5)
pr.add_argument("--tol", type=float, default=1e-5)
pr.add_argument("--eps_ratio", type=float, default=1e-6)
pr.add_argument("--proof_trials", type=int, default=64)
pr.add_argument("--proof_vocab", type=int, default=97)
pr.add_argument("--lean_file", type=str, default=str(Path(__file__).with_name("n2e_m_fold.lean")))
args = ap.parse_args()
if args.cmd == "train": train(args)
elif args.cmd == "infer": infer(args)
elif args.cmd == "compact":
info = compact_checkpoint_U(args.src, args.dst, eps_ratio=args.eps_ratio)
print(f"Compacted {args.src} -> {args.dst}")
print(f" target cfg rank: {info['target_rank']}")
print(f" U params: {info['U_params_old']:,} -> {info['U_params_new']:,} "
f"(saved {info['U_params_saved']:,})")
if info["U_params_effective_new"] != info["U_params_new"]:
print(f" effective U params before cfg padding: {info['U_params_effective_new']:,}")
for row in info["by_layer"]:
print(f" layer {row['layer']:>2}: r {row['r_old']:>4} -> {row['r_new']:>4}")
elif args.cmd == "prove":
raise SystemExit(run_proofs(args))
else: raise SystemExit(_emit_status(Path(args.log), Path(args.save_dir), args.json_output))
if __name__ == "__main__":
main()