| """Metacognition probe — one forward pass per prompt, records every |
| confidence signal under test. |
| |
| Pre-registered claim (see `Tilelli LLM Research/METACOGNITION_STUDY_SCOPE_2026-05-23.md`): |
| router entropy is a competitive uncertainty signal against output-side |
| baselines, and better on OOD / gibberish / factual-misleading / long-input |
| regimes. |
| |
| Reads a prompt-set JSONL and writes a signals JSONL with one row per |
| prompt. Scoring (AUROC + bootstrap CI) lives in `metacog_score.py`. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import os |
| import time |
| from pathlib import Path |
|
|
| import torch |
|
|
| from tilelli.core.tilelli_lite import TilelliLiteLM |
| from tilelli.distillery.tokenize import ByteTokenizer |
| from tilelli.utils import safe_load_checkpoint |
|
|
|
|
| MAX_NEW_TOKENS = 48 |
| DEFAULT_MAX_SEQ = 256 |
| ABSTAIN_KEYS = ("weight", "bias") |
|
|
|
|
| def load_bridge(ckpt_path: str): |
| """Re-create the deployed bridge's model + abstain head without the |
| sessioning overhead. Returns (model, abstain_head_or_None, tokenizer).""" |
| ckpt = safe_load_checkpoint(ckpt_path, trusted=True) |
| cfg = (ckpt.get("base_model_cfg") or ckpt.get("model_cfg") |
| or ckpt.get("config") or {}) |
| model = TilelliLiteLM( |
| vocab_size=cfg.get("vocab_size", 256), |
| d_model=cfg.get("d_model", 256), |
| n_layers=cfg.get("n_layers", 8), |
| n_heads=cfg.get("n_heads", 8), |
| top_k=cfg.get("top_k", 16), |
| ffn_expand=cfg.get("dense_expand", 4), |
| max_seq_len=cfg.get("max_seq_len", DEFAULT_MAX_SEQ), |
| quantize=cfg.get("quantize", False), |
| ) |
| raw = ckpt.get("model", ckpt) |
| base_state, abstain_state = {}, {} |
| for k, v in raw.items(): |
| if k.startswith("abstain."): |
| abstain_state[k[len("abstain."):]] = v |
| else: |
| base_state[k.replace("base.", "", 1)] = v |
| model.load_state_dict(base_state, strict=False) |
| model.eval() |
|
|
| abstain_head = None |
| if all(k in abstain_state for k in ABSTAIN_KEYS): |
| out_dim, in_dim = abstain_state["weight"].shape |
| abstain_head = torch.nn.Linear(in_dim, out_dim) |
| abstain_head.weight.data.copy_(abstain_state["weight"]) |
| abstain_head.bias.data.copy_(abstain_state["bias"]) |
| abstain_head.eval() |
|
|
| return model, abstain_head, ByteTokenizer() |
|
|
|
|
| @torch.no_grad() |
| def _features_at(model: TilelliLiteLM, ids: torch.Tensor) -> torch.Tensor: |
| """Post-norm hidden state for every position; mirrors tilelli_bridge._features.""" |
| x = model.embed(ids) |
| pos = torch.arange(ids.size(1), device=ids.device) |
| x = x + model.pos_embed(pos) |
| for blk in model.blocks: |
| x = blk(x) |
| return model.final_norm(x) |
|
|
|
|
| def _format_prompt(message: str, max_ctx: int, framing_overhead: int = 20) -> str: |
| """Match the bridge's USER:/TILELLI: framing exactly.""" |
| budget = max_ctx - framing_overhead - MAX_NEW_TOKENS |
| if budget < 32: |
| budget = 32 |
| if len(message) > budget: |
| half = max(8, budget // 2 - 3) |
| message = message[:half] + " ... " + message[-half:] |
| return ("\nUSER: " + message + "\nTILELLI:").lstrip() |
|
|
|
|
| @torch.no_grad() |
| def probe_one( |
| model: TilelliLiteLM, |
| abstain_head: torch.nn.Linear | None, |
| tokenizer: ByteTokenizer, |
| message: str, |
| max_new_tokens: int = MAX_NEW_TOKENS, |
| ) -> dict: |
| """Run prompt through the model, return per-prompt signal dict.""" |
| max_ctx = getattr(model, "max_seq_len", DEFAULT_MAX_SEQ) |
| prompt = _format_prompt(message, max_ctx) |
| ids = tokenizer.encode(prompt).long().unsqueeze(0) |
| if ids.shape[1] > max_ctx: |
| ids = ids[:, -max_ctx:] |
|
|
| prompt_len = ids.shape[1] |
|
|
| |
| full_ids, generated, conf_list = model.generate_with_cache( |
| ids, n_new_tokens=max_new_tokens, stop_ids=(10, 0), |
| ) |
|
|
| |
| for i in range(6, len(generated)): |
| tail = bytes(b & 0xff for b in generated[i-5:i+1]).decode("latin-1", errors="ignore") |
| if "\nUSER:" in tail or tail.endswith("USER:"): |
| generated = generated[:i+1] |
| conf_list = conf_list[:i+1] |
| break |
|
|
| |
| if generated: |
| gen_tensor = torch.tensor([generated], device=ids.device, dtype=ids.dtype) |
| full_ids = torch.cat([ids, gen_tensor], dim=1) |
| else: |
| full_ids = ids |
|
|
| text = tokenizer.decode(generated).split("\n")[0].split("USER:")[0].strip() |
|
|
| |
| ents = model.router_entropies(full_ids) |
| n_layers = ents.shape[0] |
| max_ent = math.log(3.0) |
|
|
| |
| if generated: |
| gen_ents = ents[:, :, prompt_len:] |
| else: |
| |
| gen_ents = ents[:, :, -1:] |
| per_layer_mean = gen_ents.mean(dim=(1, 2)) |
| router_entropy_mean = float(per_layer_mean.mean()) |
| router_entropy_var = float(per_layer_mean.var(unbiased=False)) |
| |
| router_conf = max(0.0, min(1.0, 1.0 - router_entropy_mean / max_ent)) |
|
|
| |
| if conf_list: |
| max_softmax_mean = sum(conf_list) / len(conf_list) |
| max_softmax_last = conf_list[-1] |
| |
| |
| |
| |
| else: |
| max_softmax_mean = float("nan") |
| max_softmax_last = float("nan") |
|
|
| |
| abstain_p = float("nan") |
| if abstain_head is not None: |
| h = _features_at(model, full_ids) |
| ab_logit = abstain_head(h[:, -1, :]) |
| abstain_p = float(torch.sigmoid(ab_logit).item()) |
|
|
| return { |
| "prompt": message, |
| "text": text or "(empty)", |
| "n_generated": len(generated), |
| "prompt_len_bytes": len(prompt), |
| "signals": { |
| "max_softmax_mean": max_softmax_mean, |
| "max_softmax_last": max_softmax_last, |
| "router_conf": router_conf, |
| "router_entropy_mean": router_entropy_mean, |
| "router_entropy_var": router_entropy_var, |
| "router_entropy_per_layer": per_layer_mean.tolist(), |
| "abstain_p": abstain_p, |
| }, |
| } |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--ckpt", required=True, type=str, |
| help="path to a Tilelli chat .pt checkpoint") |
| ap.add_argument("--in", dest="input_path", required=True, type=str, |
| help="prompt-set JSONL (one row per prompt: {regime, prompt, label})") |
| ap.add_argument("--out", required=True, type=str, |
| help="output JSONL with one row per prompt (carries signals)") |
| ap.add_argument("--limit", type=int, default=0, |
| help="cap prompts processed (0 = no cap)") |
| ap.add_argument("--max-new-tokens", type=int, default=MAX_NEW_TOKENS) |
| args = ap.parse_args() |
|
|
| t0 = time.time() |
| model, abstain_head, tokenizer = load_bridge(args.ckpt) |
| print(f"[probe] ckpt loaded in {time.time()-t0:.1f}s " |
| f"({sum(p.numel() for p in model.parameters()):,} params, " |
| f"abstain={'on' if abstain_head is not None else 'off'})") |
|
|
| in_path = Path(args.input_path) |
| out_path = Path(args.out) |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| n = 0 |
| t_probe = time.time() |
| with in_path.open() as fin, out_path.open("w") as fout: |
| for line in fin: |
| line = line.strip() |
| if not line: |
| continue |
| row = json.loads(line) |
| res = probe_one(model, abstain_head, tokenizer, |
| row["prompt"], max_new_tokens=args.max_new_tokens) |
| res["regime"] = row.get("regime", "unknown") |
| res["label"] = row.get("label") |
| res["meta"] = row.get("meta", {}) |
| fout.write(json.dumps(res) + "\n") |
| fout.flush() |
| n += 1 |
| if args.limit and n >= args.limit: |
| break |
| if n % 10 == 0: |
| rate = n / (time.time() - t_probe + 1e-6) |
| eta = (args.limit or 10**9) - n |
| eta_s = eta / max(rate, 1e-6) |
| print(f"[probe] {n} prompts, {rate:.2f}/s, ETA {eta_s:.0f}s", flush=True) |
| dt = time.time() - t_probe |
| print(f"[probe] done — {n} prompts in {dt:.1f}s ({n/dt:.2f}/s) → {out_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|