""" Standalone inference tool (Apr 2026 update). Given a math problem (or default), produce N outputs at different alphas. Default alphas: 1.0, 0.5, 0.25, 0.0 (NEW SEMANTICS: 1=baseline, 0=full suppress). Default direction version: v_pca_subspace (k-D subspace, more robust than v1). Usage: # Single dim, multi-alpha python scripts/10_infer.py --dim planning --alphas 1.0 0.5 0.0 # With anti-leak joint steering python scripts/10_infer.py --dim planning --joint --alphas 1.0 0.5 0.0 # Use v1_raw for comparison python scripts/10_infer.py --dim planning --version v1_raw --alphas 1.0 0.0 This script is also called by runall.sh as a sanity-check on 2 sample problems × {planning, monitoring} × {alpha=1, alpha=0}. """ import sys import argparse from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) import torch from configs.paths import ( ensure_dirs, LOGS_DIR, PLAN_V1_RAW, PLAN_V_PCA_SUBSPACE, MON_V1_RAW, MON_V_PCA_SUBSPACE, RESULTS_DIR, ) from configs.model import GEN_CONFIG_FAST, ANTI_LEAK_BETA from src.utils import setup_logger, write_json, cleanup_memory from src.model_io import load_model_and_tokenizer, build_thinking_prompt, generate from src.detectors import BehaviorDetector, count_real_monitoring, is_collapsed from src.planning_quality import compute_pqs from src.steering import ( ResidualSteerer, JointResidualSteerer, is_neutral_alpha, ) from src.directions import load_directions DIRECTION_PATHS = { "planning": { "v1_raw": PLAN_V1_RAW, "v_pca_subspace": PLAN_V_PCA_SUBSPACE, }, "monitoring": { "v1_raw": MON_V1_RAW, "v_pca_subspace": MON_V_PCA_SUBSPACE, }, } DEFAULT_PROBLEMS = [ "Find the smallest positive integer n such that n^2 + n + 41 is composite.", "If $\\sin x + \\cos x = \\frac{1}{5}$ and $0 \\le x < \\pi$, find $\\tan x$.", ] def run_inference(model, tokenizer, prompt, target_dirs, other_dirs, alpha, max_new_tokens, joint=False, beta=ANTI_LEAK_BETA): if is_neutral_alpha(alpha): return generate(model, tokenizer, prompt, max_new_tokens=max_new_tokens) if joint and other_dirs is not None: steerer = JointResidualSteerer(model, target_dirs, other_dirs, alpha=alpha, beta=beta) else: steerer = ResidualSteerer(model, target_dirs, alpha=alpha) steerer.start() try: text = generate(model, tokenizer, prompt, max_new_tokens=max_new_tokens) finally: steerer.stop() return text def format_report(problem, alpha, text, base_text, mon_det, plan_det): mon_full = mon_det.detect(text) plan_full = plan_det.detect(text) real_mon = count_real_monitoring(text) pqs = compute_pqs(text) coll = is_collapsed(text, base_text=base_text) if alpha is None: profile = "force-prompt baseline" elif abs(alpha - 1.0) < 1e-5: profile = "BASELINE (no steering, full ability)" elif abs(alpha - 0.0) < 1e-5: profile = "ZERO ABILITY (full suppression)" elif 0.0 < alpha < 1.0: profile = f"PARTIAL ABILITY ({alpha*100:.0f}% of native)" elif alpha > 1.0: profile = f"OVERTHINKER (amplified by {alpha-1.0:+.1f})" else: profile = f"OVER-SUPPRESSED ({alpha:.1f})" lines = [ "=" * 70, f"[α = {alpha}] profile: {profile}", "=" * 70, f"Length: {len(text)} chars Length-ratio vs base: {coll['length_ratio']}", f"Collapsed: {coll['collapsed']} reason={coll['reason']} " f"ngram_rep={coll['ngram_repetition']:.3f}", "", f"Planning triggers (total): {plan_full['total']}", f" by_subtype: {plan_full['by_type']}", f"Monitoring triggers (total): {mon_full['total']}", f" REAL reflection: {real_mon['real_reflection']}", f" Filler ('wait, 5+3=...'): {real_mon['filler_only']}", f" Ambiguous: {real_mon['ambiguous']}", f" by_subtype: {mon_full['by_type']}", "", f"Planning Quality Score (PQS): {pqs['pqs']:.3f}", f" Q1 structural_depth: {pqs['q1_structural_depth']:.3f}", f" Q2 strategy_diversity: {pqs['q2_strategy_diversity']:.3f}", f" Q3 long_range_coherence: {pqs['q3_long_range_coherence']:.3f}", f" Q4 premature_execution: {pqs['q4_premature_execution']:.3f} (lower=better)", "", "=== Generated CoT (first 800 chars) ===", text[:800], "...", "=== last 200 chars ===", text[-200:], "", ] return "\n".join(lines), { "alpha": alpha, "len": len(text), "mon_total": mon_full["total"], "mon_real": real_mon["real_reflection"], "plan_total": plan_full["total"], "pqs": pqs["pqs"], "collapsed": coll["collapsed"], } def main(): parser = argparse.ArgumentParser() parser.add_argument("--dim", choices=["planning", "monitoring"], default="planning") parser.add_argument("--version", choices=["v1_raw", "v_pca_subspace"], default="v_pca_subspace") parser.add_argument("--problem", type=str, default=None) parser.add_argument("--problem_file", type=str, default=None) parser.add_argument("--alphas", nargs="+", type=float, default=[1.0, 0.5, 0.25, 0.0]) parser.add_argument("--max_new_tokens", type=int, default=4096) parser.add_argument("--save_to", type=str, default=None) parser.add_argument("--joint", action="store_true", help="Enable anti-leak coupling steering") parser.add_argument("--beta", type=float, default=ANTI_LEAK_BETA) parser.add_argument("--auto_problems", action="store_true", help="Run on built-in default problems (used by runall sanity)") args = parser.parse_args() ensure_dirs() log = setup_logger("10_infer", LOGS_DIR / "10_infer.log") # Determine problems if args.auto_problems: problems_to_run = DEFAULT_PROBLEMS elif args.problem: problems_to_run = [args.problem] elif args.problem_file: problems_to_run = [Path(args.problem_file).read_text(encoding="utf-8").strip()] else: problems_to_run = [DEFAULT_PROBLEMS[0]] log.info(f"Dimension: {args.dim} Version: {args.version}") log.info(f"Alphas: {args.alphas}") log.info(f"Joint anti-leak: {args.joint} beta={args.beta}") log.info(f"Problems: {len(problems_to_run)}") log.info("Loading model...") model, tokenizer = load_model_and_tokenizer() log.info("Loading directions...") target_dirs = load_directions(DIRECTION_PATHS[args.dim][args.version]) other_dim = "monitoring" if args.dim == "planning" else "planning" other_dirs = load_directions(DIRECTION_PATHS[other_dim][args.version]) if args.joint else None mon_det = BehaviorDetector("monitoring") plan_det = BehaviorDetector("planning") full_outputs = [] for prob_idx, problem in enumerate(problems_to_run): log.info(f"\n========== Problem {prob_idx+1}/{len(problems_to_run)} ==========") log.info(f"Q: {problem[:120]}") prompt = build_thinking_prompt(tokenizer, problem, enable_thinking=True) # First pass at alpha=1 (baseline) for length comparison baseline_text = None prob_outputs = {} for a in args.alphas: log.info(f"-- α={a} --") text = run_inference(model, tokenizer, prompt, target_dirs, other_dirs, a, args.max_new_tokens, joint=args.joint, beta=args.beta) if is_neutral_alpha(a): baseline_text = text base_for_eval = baseline_text if baseline_text else text report, summary = format_report(problem, a, text, base_for_eval, mon_det, plan_det) print(report) prob_outputs[str(a)] = { "text": text, "summary": summary, } cleanup_memory() full_outputs.append({ "problem": problem, "outputs": prob_outputs, }) if args.save_to: write_json({ "dim": args.dim, "version": args.version, "alphas": args.alphas, "joint": args.joint, "beta": args.beta, "problems": full_outputs, }, Path(args.save_to)) log.info(f"Saved to {args.save_to}") # Print summary table print("\n" + "=" * 70) print("SUMMARY TABLE (across all problems)") print("=" * 70) print(f"{'α':>6} {'mon_total':>10} {'mon_real':>10} {'plan':>6} {'pqs':>6} {'len':>8} {'collapse':>10}") for a in args.alphas: # Aggregate over problems rows = [] for po in full_outputs: rows.append(po["outputs"][str(a)]["summary"]) mt = sum(r["mon_total"] for r in rows) / len(rows) mr = sum(r["mon_real"] for r in rows) / len(rows) pl = sum(r["plan_total"] for r in rows) / len(rows) pq = sum(r["pqs"] for r in rows) / len(rows) ln = sum(r["len"] for r in rows) / len(rows) coll_pct = sum(1 for r in rows if r["collapsed"]) / len(rows) * 100 print(f"{a:>6.2f} {mt:>10.1f} {mr:>10.1f} {pl:>6.1f} {pq:>6.3f} {ln:>8.0f} {coll_pct:>9.0f}%") print() if __name__ == "__main__": main()