| """ |
| Standalone inference tool (Apr 2026 update). |
| |
| Given a math problem (or default), produce N outputs at different alphas. |
| Default alphas: 1.0, 0.5, 0.25, 0.0 (NEW SEMANTICS: 1=baseline, 0=full suppress). |
| |
| Default direction version: v_pca_subspace (k-D subspace, more robust than v1). |
| |
| Usage: |
| # Single dim, multi-alpha |
| python scripts/10_infer.py --dim planning --alphas 1.0 0.5 0.0 |
| |
| # With anti-leak joint steering |
| python scripts/10_infer.py --dim planning --joint --alphas 1.0 0.5 0.0 |
| |
| # Use v1_raw for comparison |
| python scripts/10_infer.py --dim planning --version v1_raw --alphas 1.0 0.0 |
| |
| This script is also called by runall.sh as a sanity-check on |
| 2 sample problems × {planning, monitoring} × {alpha=1, alpha=0}. |
| """ |
| import sys |
| import argparse |
| from pathlib import Path |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| import torch |
|
|
| from configs.paths import ( |
| ensure_dirs, LOGS_DIR, |
| PLAN_V1_RAW, PLAN_V_PCA_SUBSPACE, |
| MON_V1_RAW, MON_V_PCA_SUBSPACE, |
| RESULTS_DIR, |
| ) |
| from configs.model import GEN_CONFIG_FAST, ANTI_LEAK_BETA |
| from src.utils import setup_logger, write_json, cleanup_memory |
| from src.model_io import load_model_and_tokenizer, build_thinking_prompt, generate |
| from src.detectors import BehaviorDetector, count_real_monitoring, is_collapsed |
| from src.planning_quality import compute_pqs |
| from src.steering import ( |
| ResidualSteerer, JointResidualSteerer, |
| is_neutral_alpha, |
| ) |
| from src.directions import load_directions |
|
|
|
|
| DIRECTION_PATHS = { |
| "planning": { |
| "v1_raw": PLAN_V1_RAW, |
| "v_pca_subspace": PLAN_V_PCA_SUBSPACE, |
| }, |
| "monitoring": { |
| "v1_raw": MON_V1_RAW, |
| "v_pca_subspace": MON_V_PCA_SUBSPACE, |
| }, |
| } |
|
|
| DEFAULT_PROBLEMS = [ |
| "Find the smallest positive integer n such that n^2 + n + 41 is composite.", |
| "If $\\sin x + \\cos x = \\frac{1}{5}$ and $0 \\le x < \\pi$, find $\\tan x$.", |
| ] |
|
|
|
|
| def run_inference(model, tokenizer, prompt, target_dirs, other_dirs, |
| alpha, max_new_tokens, joint=False, beta=ANTI_LEAK_BETA): |
| if is_neutral_alpha(alpha): |
| return generate(model, tokenizer, prompt, max_new_tokens=max_new_tokens) |
| if joint and other_dirs is not None: |
| steerer = JointResidualSteerer(model, target_dirs, other_dirs, |
| alpha=alpha, beta=beta) |
| else: |
| steerer = ResidualSteerer(model, target_dirs, alpha=alpha) |
| steerer.start() |
| try: |
| text = generate(model, tokenizer, prompt, max_new_tokens=max_new_tokens) |
| finally: |
| steerer.stop() |
| return text |
|
|
|
|
| def format_report(problem, alpha, text, base_text, mon_det, plan_det): |
| mon_full = mon_det.detect(text) |
| plan_full = plan_det.detect(text) |
| real_mon = count_real_monitoring(text) |
| pqs = compute_pqs(text) |
| coll = is_collapsed(text, base_text=base_text) |
|
|
| if alpha is None: |
| profile = "force-prompt baseline" |
| elif abs(alpha - 1.0) < 1e-5: |
| profile = "BASELINE (no steering, full ability)" |
| elif abs(alpha - 0.0) < 1e-5: |
| profile = "ZERO ABILITY (full suppression)" |
| elif 0.0 < alpha < 1.0: |
| profile = f"PARTIAL ABILITY ({alpha*100:.0f}% of native)" |
| elif alpha > 1.0: |
| profile = f"OVERTHINKER (amplified by {alpha-1.0:+.1f})" |
| else: |
| profile = f"OVER-SUPPRESSED ({alpha:.1f})" |
|
|
| lines = [ |
| "=" * 70, |
| f"[α = {alpha}] profile: {profile}", |
| "=" * 70, |
| f"Length: {len(text)} chars Length-ratio vs base: {coll['length_ratio']}", |
| f"Collapsed: {coll['collapsed']} reason={coll['reason']} " |
| f"ngram_rep={coll['ngram_repetition']:.3f}", |
| "", |
| f"Planning triggers (total): {plan_full['total']}", |
| f" by_subtype: {plan_full['by_type']}", |
| f"Monitoring triggers (total): {mon_full['total']}", |
| f" REAL reflection: {real_mon['real_reflection']}", |
| f" Filler ('wait, 5+3=...'): {real_mon['filler_only']}", |
| f" Ambiguous: {real_mon['ambiguous']}", |
| f" by_subtype: {mon_full['by_type']}", |
| "", |
| f"Planning Quality Score (PQS): {pqs['pqs']:.3f}", |
| f" Q1 structural_depth: {pqs['q1_structural_depth']:.3f}", |
| f" Q2 strategy_diversity: {pqs['q2_strategy_diversity']:.3f}", |
| f" Q3 long_range_coherence: {pqs['q3_long_range_coherence']:.3f}", |
| f" Q4 premature_execution: {pqs['q4_premature_execution']:.3f} (lower=better)", |
| "", |
| "=== Generated CoT (first 800 chars) ===", |
| text[:800], |
| "...", |
| "=== last 200 chars ===", |
| text[-200:], |
| "", |
| ] |
| return "\n".join(lines), { |
| "alpha": alpha, |
| "len": len(text), |
| "mon_total": mon_full["total"], |
| "mon_real": real_mon["real_reflection"], |
| "plan_total": plan_full["total"], |
| "pqs": pqs["pqs"], |
| "collapsed": coll["collapsed"], |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--dim", choices=["planning", "monitoring"], default="planning") |
| parser.add_argument("--version", choices=["v1_raw", "v_pca_subspace"], |
| default="v_pca_subspace") |
| parser.add_argument("--problem", type=str, default=None) |
| parser.add_argument("--problem_file", type=str, default=None) |
| parser.add_argument("--alphas", nargs="+", type=float, |
| default=[1.0, 0.5, 0.25, 0.0]) |
| parser.add_argument("--max_new_tokens", type=int, default=4096) |
| parser.add_argument("--save_to", type=str, default=None) |
| parser.add_argument("--joint", action="store_true", |
| help="Enable anti-leak coupling steering") |
| parser.add_argument("--beta", type=float, default=ANTI_LEAK_BETA) |
| parser.add_argument("--auto_problems", action="store_true", |
| help="Run on built-in default problems (used by runall sanity)") |
| args = parser.parse_args() |
|
|
| ensure_dirs() |
| log = setup_logger("10_infer", LOGS_DIR / "10_infer.log") |
|
|
| |
| if args.auto_problems: |
| problems_to_run = DEFAULT_PROBLEMS |
| elif args.problem: |
| problems_to_run = [args.problem] |
| elif args.problem_file: |
| problems_to_run = [Path(args.problem_file).read_text(encoding="utf-8").strip()] |
| else: |
| problems_to_run = [DEFAULT_PROBLEMS[0]] |
|
|
| log.info(f"Dimension: {args.dim} Version: {args.version}") |
| log.info(f"Alphas: {args.alphas}") |
| log.info(f"Joint anti-leak: {args.joint} beta={args.beta}") |
| log.info(f"Problems: {len(problems_to_run)}") |
|
|
| log.info("Loading model...") |
| model, tokenizer = load_model_and_tokenizer() |
| log.info("Loading directions...") |
| target_dirs = load_directions(DIRECTION_PATHS[args.dim][args.version]) |
| other_dim = "monitoring" if args.dim == "planning" else "planning" |
| other_dirs = load_directions(DIRECTION_PATHS[other_dim][args.version]) if args.joint else None |
|
|
| mon_det = BehaviorDetector("monitoring") |
| plan_det = BehaviorDetector("planning") |
|
|
| full_outputs = [] |
| for prob_idx, problem in enumerate(problems_to_run): |
| log.info(f"\n========== Problem {prob_idx+1}/{len(problems_to_run)} ==========") |
| log.info(f"Q: {problem[:120]}") |
| prompt = build_thinking_prompt(tokenizer, problem, enable_thinking=True) |
|
|
| |
| baseline_text = None |
| prob_outputs = {} |
|
|
| for a in args.alphas: |
| log.info(f"-- α={a} --") |
| text = run_inference(model, tokenizer, prompt, |
| target_dirs, other_dirs, |
| a, args.max_new_tokens, |
| joint=args.joint, beta=args.beta) |
| if is_neutral_alpha(a): |
| baseline_text = text |
| base_for_eval = baseline_text if baseline_text else text |
| report, summary = format_report(problem, a, text, base_for_eval, mon_det, plan_det) |
| print(report) |
| prob_outputs[str(a)] = { |
| "text": text, |
| "summary": summary, |
| } |
| cleanup_memory() |
| full_outputs.append({ |
| "problem": problem, |
| "outputs": prob_outputs, |
| }) |
|
|
| if args.save_to: |
| write_json({ |
| "dim": args.dim, "version": args.version, |
| "alphas": args.alphas, "joint": args.joint, "beta": args.beta, |
| "problems": full_outputs, |
| }, Path(args.save_to)) |
| log.info(f"Saved to {args.save_to}") |
|
|
| |
| print("\n" + "=" * 70) |
| print("SUMMARY TABLE (across all problems)") |
| print("=" * 70) |
| print(f"{'α':>6} {'mon_total':>10} {'mon_real':>10} {'plan':>6} {'pqs':>6} {'len':>8} {'collapse':>10}") |
| for a in args.alphas: |
| |
| rows = [] |
| for po in full_outputs: |
| rows.append(po["outputs"][str(a)]["summary"]) |
| mt = sum(r["mon_total"] for r in rows) / len(rows) |
| mr = sum(r["mon_real"] for r in rows) / len(rows) |
| pl = sum(r["plan_total"] for r in rows) / len(rows) |
| pq = sum(r["pqs"] for r in rows) / len(rows) |
| ln = sum(r["len"] for r in rows) / len(rows) |
| coll_pct = sum(1 for r in rows if r["collapsed"]) / len(rows) * 100 |
| print(f"{a:>6.2f} {mt:>10.1f} {mr:>10.1f} {pl:>6.1f} {pq:>6.3f} {ln:>8.0f} {coll_pct:>9.0f}%") |
| print() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|