"""Metacognition probe — one forward pass per prompt, records every confidence signal under test. Pre-registered claim (see `Tilelli LLM Research/METACOGNITION_STUDY_SCOPE_2026-05-23.md`): router entropy is a competitive uncertainty signal against output-side baselines, and better on OOD / gibberish / factual-misleading / long-input regimes. Reads a prompt-set JSONL and writes a signals JSONL with one row per prompt. Scoring (AUROC + bootstrap CI) lives in `metacog_score.py`. """ from __future__ import annotations import argparse import json import math import os import time from pathlib import Path import torch from tilelli.core.tilelli_lite import TilelliLiteLM from tilelli.distillery.tokenize import ByteTokenizer from tilelli.utils import safe_load_checkpoint MAX_NEW_TOKENS = 48 DEFAULT_MAX_SEQ = 256 ABSTAIN_KEYS = ("weight", "bias") def load_bridge(ckpt_path: str): """Re-create the deployed bridge's model + abstain head without the sessioning overhead. Returns (model, abstain_head_or_None, tokenizer).""" ckpt = safe_load_checkpoint(ckpt_path, trusted=True) cfg = (ckpt.get("base_model_cfg") or ckpt.get("model_cfg") or ckpt.get("config") or {}) model = TilelliLiteLM( vocab_size=cfg.get("vocab_size", 256), d_model=cfg.get("d_model", 256), n_layers=cfg.get("n_layers", 8), n_heads=cfg.get("n_heads", 8), top_k=cfg.get("top_k", 16), ffn_expand=cfg.get("dense_expand", 4), max_seq_len=cfg.get("max_seq_len", DEFAULT_MAX_SEQ), quantize=cfg.get("quantize", False), ) raw = ckpt.get("model", ckpt) base_state, abstain_state = {}, {} for k, v in raw.items(): if k.startswith("abstain."): abstain_state[k[len("abstain."):]] = v else: base_state[k.replace("base.", "", 1)] = v model.load_state_dict(base_state, strict=False) model.eval() abstain_head = None if all(k in abstain_state for k in ABSTAIN_KEYS): out_dim, in_dim = abstain_state["weight"].shape abstain_head = torch.nn.Linear(in_dim, out_dim) abstain_head.weight.data.copy_(abstain_state["weight"]) abstain_head.bias.data.copy_(abstain_state["bias"]) abstain_head.eval() return model, abstain_head, ByteTokenizer() @torch.no_grad() def _features_at(model: TilelliLiteLM, ids: torch.Tensor) -> torch.Tensor: """Post-norm hidden state for every position; mirrors tilelli_bridge._features.""" x = model.embed(ids) pos = torch.arange(ids.size(1), device=ids.device) x = x + model.pos_embed(pos) for blk in model.blocks: x = blk(x) return model.final_norm(x) def _format_prompt(message: str, max_ctx: int, framing_overhead: int = 20) -> str: """Match the bridge's USER:/TILELLI: framing exactly.""" budget = max_ctx - framing_overhead - MAX_NEW_TOKENS if budget < 32: budget = 32 if len(message) > budget: half = max(8, budget // 2 - 3) message = message[:half] + " ... " + message[-half:] return ("\nUSER: " + message + "\nTILELLI:").lstrip() @torch.no_grad() def probe_one( model: TilelliLiteLM, abstain_head: torch.nn.Linear | None, tokenizer: ByteTokenizer, message: str, max_new_tokens: int = MAX_NEW_TOKENS, ) -> dict: """Run prompt through the model, return per-prompt signal dict.""" max_ctx = getattr(model, "max_seq_len", DEFAULT_MAX_SEQ) prompt = _format_prompt(message, max_ctx) ids = tokenizer.encode(prompt).long().unsqueeze(0) if ids.shape[1] > max_ctx: ids = ids[:, -max_ctx:] prompt_len = ids.shape[1] # Greedy generate with KV cache; collect per-step logits via probs.max. full_ids, generated, conf_list = model.generate_with_cache( ids, n_new_tokens=max_new_tokens, stop_ids=(10, 0), ) # Trim at fake-USER boundary (matches bridge behaviour) for i in range(6, len(generated)): tail = bytes(b & 0xff for b in generated[i-5:i+1]).decode("latin-1", errors="ignore") if "\nUSER:" in tail or tail.endswith("USER:"): generated = generated[:i+1] conf_list = conf_list[:i+1] break # Rebuild full_ids from prompt + actually-emitted generated (mirrors bridge fix). if generated: gen_tensor = torch.tensor([generated], device=ids.device, dtype=ids.dtype) full_ids = torch.cat([ids, gen_tensor], dim=1) else: full_ids = ids text = tokenizer.decode(generated).split("\n")[0].split("USER:")[0].strip() # Router entropies over full sequence — shape (L, B, T). ents = model.router_entropies(full_ids) n_layers = ents.shape[0] max_ent = math.log(3.0) # 3 pathways in TilelliLite # Gen-position slice; aggregate per-layer mean + variance across layers. if generated: gen_ents = ents[:, :, prompt_len:] # (L, B, n_new) else: # Empty generation — fall back to last prompt position. gen_ents = ents[:, :, -1:] per_layer_mean = gen_ents.mean(dim=(1, 2)) # (L,) router_entropy_mean = float(per_layer_mean.mean()) router_entropy_var = float(per_layer_mean.var(unbiased=False)) # Normalised confidence (1 = sure, 0 = uniform). router_conf = max(0.0, min(1.0, 1.0 - router_entropy_mean / max_ent)) # Output-side baselines: mean and last max-softmax over generated tokens. if conf_list: max_softmax_mean = sum(conf_list) / len(conf_list) max_softmax_last = conf_list[-1] # T-scaling pre-record: store raw logits at the final generated position # so the scorer can sweep temperatures on the val set. # Re-derive last logits cheaply by feeding final prompt position. # (already paid in generate; just store the empirical max-softmax) else: max_softmax_mean = float("nan") max_softmax_last = float("nan") # Abstain head at last position of full sequence (matches bridge fix). abstain_p = float("nan") if abstain_head is not None: h = _features_at(model, full_ids) ab_logit = abstain_head(h[:, -1, :]) abstain_p = float(torch.sigmoid(ab_logit).item()) return { "prompt": message, "text": text or "(empty)", "n_generated": len(generated), "prompt_len_bytes": len(prompt), "signals": { "max_softmax_mean": max_softmax_mean, "max_softmax_last": max_softmax_last, "router_conf": router_conf, "router_entropy_mean": router_entropy_mean, "router_entropy_var": router_entropy_var, "router_entropy_per_layer": per_layer_mean.tolist(), "abstain_p": abstain_p, }, } def main(): ap = argparse.ArgumentParser() ap.add_argument("--ckpt", required=True, type=str, help="path to a Tilelli chat .pt checkpoint") ap.add_argument("--in", dest="input_path", required=True, type=str, help="prompt-set JSONL (one row per prompt: {regime, prompt, label})") ap.add_argument("--out", required=True, type=str, help="output JSONL with one row per prompt (carries signals)") ap.add_argument("--limit", type=int, default=0, help="cap prompts processed (0 = no cap)") ap.add_argument("--max-new-tokens", type=int, default=MAX_NEW_TOKENS) args = ap.parse_args() t0 = time.time() model, abstain_head, tokenizer = load_bridge(args.ckpt) print(f"[probe] ckpt loaded in {time.time()-t0:.1f}s " f"({sum(p.numel() for p in model.parameters()):,} params, " f"abstain={'on' if abstain_head is not None else 'off'})") in_path = Path(args.input_path) out_path = Path(args.out) out_path.parent.mkdir(parents=True, exist_ok=True) n = 0 t_probe = time.time() with in_path.open() as fin, out_path.open("w") as fout: for line in fin: line = line.strip() if not line: continue row = json.loads(line) res = probe_one(model, abstain_head, tokenizer, row["prompt"], max_new_tokens=args.max_new_tokens) res["regime"] = row.get("regime", "unknown") res["label"] = row.get("label") res["meta"] = row.get("meta", {}) fout.write(json.dumps(res) + "\n") fout.flush() # see progress in real time; cost is negligible at ~0.1/s n += 1 if args.limit and n >= args.limit: break if n % 10 == 0: rate = n / (time.time() - t_probe + 1e-6) eta = (args.limit or 10**9) - n eta_s = eta / max(rate, 1e-6) print(f"[probe] {n} prompts, {rate:.2f}/s, ETA {eta_s:.0f}s", flush=True) dt = time.time() - t_probe print(f"[probe] done — {n} prompts in {dt:.1f}s ({n/dt:.2f}/s) → {out_path}") if __name__ == "__main__": main()