Tilelli-llm / src /tilelli /eval /metacog_probe.py
TilelliLab's picture
Mirror small files (code, paper, results)
f86dc09 verified
Raw
History Blame Contribute Delete
9.01 kB
"""Metacognition probe — one forward pass per prompt, records every
confidence signal under test.
Pre-registered claim (see `Tilelli LLM Research/METACOGNITION_STUDY_SCOPE_2026-05-23.md`):
router entropy is a competitive uncertainty signal against output-side
baselines, and better on OOD / gibberish / factual-misleading / long-input
regimes.
Reads a prompt-set JSONL and writes a signals JSONL with one row per
prompt. Scoring (AUROC + bootstrap CI) lives in `metacog_score.py`.
"""
from __future__ import annotations
import argparse
import json
import math
import os
import time
from pathlib import Path
import torch
from tilelli.core.tilelli_lite import TilelliLiteLM
from tilelli.distillery.tokenize import ByteTokenizer
from tilelli.utils import safe_load_checkpoint
MAX_NEW_TOKENS = 48
DEFAULT_MAX_SEQ = 256
ABSTAIN_KEYS = ("weight", "bias")
def load_bridge(ckpt_path: str):
"""Re-create the deployed bridge's model + abstain head without the
sessioning overhead. Returns (model, abstain_head_or_None, tokenizer)."""
ckpt = safe_load_checkpoint(ckpt_path, trusted=True)
cfg = (ckpt.get("base_model_cfg") or ckpt.get("model_cfg")
or ckpt.get("config") or {})
model = TilelliLiteLM(
vocab_size=cfg.get("vocab_size", 256),
d_model=cfg.get("d_model", 256),
n_layers=cfg.get("n_layers", 8),
n_heads=cfg.get("n_heads", 8),
top_k=cfg.get("top_k", 16),
ffn_expand=cfg.get("dense_expand", 4),
max_seq_len=cfg.get("max_seq_len", DEFAULT_MAX_SEQ),
quantize=cfg.get("quantize", False),
)
raw = ckpt.get("model", ckpt)
base_state, abstain_state = {}, {}
for k, v in raw.items():
if k.startswith("abstain."):
abstain_state[k[len("abstain."):]] = v
else:
base_state[k.replace("base.", "", 1)] = v
model.load_state_dict(base_state, strict=False)
model.eval()
abstain_head = None
if all(k in abstain_state for k in ABSTAIN_KEYS):
out_dim, in_dim = abstain_state["weight"].shape
abstain_head = torch.nn.Linear(in_dim, out_dim)
abstain_head.weight.data.copy_(abstain_state["weight"])
abstain_head.bias.data.copy_(abstain_state["bias"])
abstain_head.eval()
return model, abstain_head, ByteTokenizer()
@torch.no_grad()
def _features_at(model: TilelliLiteLM, ids: torch.Tensor) -> torch.Tensor:
"""Post-norm hidden state for every position; mirrors tilelli_bridge._features."""
x = model.embed(ids)
pos = torch.arange(ids.size(1), device=ids.device)
x = x + model.pos_embed(pos)
for blk in model.blocks:
x = blk(x)
return model.final_norm(x)
def _format_prompt(message: str, max_ctx: int, framing_overhead: int = 20) -> str:
"""Match the bridge's USER:/TILELLI: framing exactly."""
budget = max_ctx - framing_overhead - MAX_NEW_TOKENS
if budget < 32:
budget = 32
if len(message) > budget:
half = max(8, budget // 2 - 3)
message = message[:half] + " ... " + message[-half:]
return ("\nUSER: " + message + "\nTILELLI:").lstrip()
@torch.no_grad()
def probe_one(
model: TilelliLiteLM,
abstain_head: torch.nn.Linear | None,
tokenizer: ByteTokenizer,
message: str,
max_new_tokens: int = MAX_NEW_TOKENS,
) -> dict:
"""Run prompt through the model, return per-prompt signal dict."""
max_ctx = getattr(model, "max_seq_len", DEFAULT_MAX_SEQ)
prompt = _format_prompt(message, max_ctx)
ids = tokenizer.encode(prompt).long().unsqueeze(0)
if ids.shape[1] > max_ctx:
ids = ids[:, -max_ctx:]
prompt_len = ids.shape[1]
# Greedy generate with KV cache; collect per-step logits via probs.max.
full_ids, generated, conf_list = model.generate_with_cache(
ids, n_new_tokens=max_new_tokens, stop_ids=(10, 0),
)
# Trim at fake-USER boundary (matches bridge behaviour)
for i in range(6, len(generated)):
tail = bytes(b & 0xff for b in generated[i-5:i+1]).decode("latin-1", errors="ignore")
if "\nUSER:" in tail or tail.endswith("USER:"):
generated = generated[:i+1]
conf_list = conf_list[:i+1]
break
# Rebuild full_ids from prompt + actually-emitted generated (mirrors bridge fix).
if generated:
gen_tensor = torch.tensor([generated], device=ids.device, dtype=ids.dtype)
full_ids = torch.cat([ids, gen_tensor], dim=1)
else:
full_ids = ids
text = tokenizer.decode(generated).split("\n")[0].split("USER:")[0].strip()
# Router entropies over full sequence — shape (L, B, T).
ents = model.router_entropies(full_ids)
n_layers = ents.shape[0]
max_ent = math.log(3.0) # 3 pathways in TilelliLite
# Gen-position slice; aggregate per-layer mean + variance across layers.
if generated:
gen_ents = ents[:, :, prompt_len:] # (L, B, n_new)
else:
# Empty generation — fall back to last prompt position.
gen_ents = ents[:, :, -1:]
per_layer_mean = gen_ents.mean(dim=(1, 2)) # (L,)
router_entropy_mean = float(per_layer_mean.mean())
router_entropy_var = float(per_layer_mean.var(unbiased=False))
# Normalised confidence (1 = sure, 0 = uniform).
router_conf = max(0.0, min(1.0, 1.0 - router_entropy_mean / max_ent))
# Output-side baselines: mean and last max-softmax over generated tokens.
if conf_list:
max_softmax_mean = sum(conf_list) / len(conf_list)
max_softmax_last = conf_list[-1]
# T-scaling pre-record: store raw logits at the final generated position
# so the scorer can sweep temperatures on the val set.
# Re-derive last logits cheaply by feeding final prompt position.
# (already paid in generate; just store the empirical max-softmax)
else:
max_softmax_mean = float("nan")
max_softmax_last = float("nan")
# Abstain head at last position of full sequence (matches bridge fix).
abstain_p = float("nan")
if abstain_head is not None:
h = _features_at(model, full_ids)
ab_logit = abstain_head(h[:, -1, :])
abstain_p = float(torch.sigmoid(ab_logit).item())
return {
"prompt": message,
"text": text or "(empty)",
"n_generated": len(generated),
"prompt_len_bytes": len(prompt),
"signals": {
"max_softmax_mean": max_softmax_mean,
"max_softmax_last": max_softmax_last,
"router_conf": router_conf,
"router_entropy_mean": router_entropy_mean,
"router_entropy_var": router_entropy_var,
"router_entropy_per_layer": per_layer_mean.tolist(),
"abstain_p": abstain_p,
},
}
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", required=True, type=str,
help="path to a Tilelli chat .pt checkpoint")
ap.add_argument("--in", dest="input_path", required=True, type=str,
help="prompt-set JSONL (one row per prompt: {regime, prompt, label})")
ap.add_argument("--out", required=True, type=str,
help="output JSONL with one row per prompt (carries signals)")
ap.add_argument("--limit", type=int, default=0,
help="cap prompts processed (0 = no cap)")
ap.add_argument("--max-new-tokens", type=int, default=MAX_NEW_TOKENS)
args = ap.parse_args()
t0 = time.time()
model, abstain_head, tokenizer = load_bridge(args.ckpt)
print(f"[probe] ckpt loaded in {time.time()-t0:.1f}s "
f"({sum(p.numel() for p in model.parameters()):,} params, "
f"abstain={'on' if abstain_head is not None else 'off'})")
in_path = Path(args.input_path)
out_path = Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True)
n = 0
t_probe = time.time()
with in_path.open() as fin, out_path.open("w") as fout:
for line in fin:
line = line.strip()
if not line:
continue
row = json.loads(line)
res = probe_one(model, abstain_head, tokenizer,
row["prompt"], max_new_tokens=args.max_new_tokens)
res["regime"] = row.get("regime", "unknown")
res["label"] = row.get("label")
res["meta"] = row.get("meta", {})
fout.write(json.dumps(res) + "\n")
fout.flush() # see progress in real time; cost is negligible at ~0.1/s
n += 1
if args.limit and n >= args.limit:
break
if n % 10 == 0:
rate = n / (time.time() - t_probe + 1e-6)
eta = (args.limit or 10**9) - n
eta_s = eta / max(rate, 1e-6)
print(f"[probe] {n} prompts, {rate:.2f}/s, ETA {eta_s:.0f}s", flush=True)
dt = time.time() - t_probe
print(f"[probe] done — {n} prompts in {dt:.1f}s ({n/dt:.2f}/s) → {out_path}")
if __name__ == "__main__":
main()