v2 / scripts /10_infer.py
JulianHJR's picture
Upload folder using huggingface_hub
e53f10b verified
"""
Standalone inference tool (Apr 2026 update).
Given a math problem (or default), produce N outputs at different alphas.
Default alphas: 1.0, 0.5, 0.25, 0.0 (NEW SEMANTICS: 1=baseline, 0=full suppress).
Default direction version: v_pca_subspace (k-D subspace, more robust than v1).
Usage:
# Single dim, multi-alpha
python scripts/10_infer.py --dim planning --alphas 1.0 0.5 0.0
# With anti-leak joint steering
python scripts/10_infer.py --dim planning --joint --alphas 1.0 0.5 0.0
# Use v1_raw for comparison
python scripts/10_infer.py --dim planning --version v1_raw --alphas 1.0 0.0
This script is also called by runall.sh as a sanity-check on
2 sample problems × {planning, monitoring} × {alpha=1, alpha=0}.
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import torch
from configs.paths import (
ensure_dirs, LOGS_DIR,
PLAN_V1_RAW, PLAN_V_PCA_SUBSPACE,
MON_V1_RAW, MON_V_PCA_SUBSPACE,
RESULTS_DIR,
)
from configs.model import GEN_CONFIG_FAST, ANTI_LEAK_BETA
from src.utils import setup_logger, write_json, cleanup_memory
from src.model_io import load_model_and_tokenizer, build_thinking_prompt, generate
from src.detectors import BehaviorDetector, count_real_monitoring, is_collapsed
from src.planning_quality import compute_pqs
from src.steering import (
ResidualSteerer, JointResidualSteerer,
is_neutral_alpha,
)
from src.directions import load_directions
DIRECTION_PATHS = {
"planning": {
"v1_raw": PLAN_V1_RAW,
"v_pca_subspace": PLAN_V_PCA_SUBSPACE,
},
"monitoring": {
"v1_raw": MON_V1_RAW,
"v_pca_subspace": MON_V_PCA_SUBSPACE,
},
}
DEFAULT_PROBLEMS = [
"Find the smallest positive integer n such that n^2 + n + 41 is composite.",
"If $\\sin x + \\cos x = \\frac{1}{5}$ and $0 \\le x < \\pi$, find $\\tan x$.",
]
def run_inference(model, tokenizer, prompt, target_dirs, other_dirs,
alpha, max_new_tokens, joint=False, beta=ANTI_LEAK_BETA):
if is_neutral_alpha(alpha):
return generate(model, tokenizer, prompt, max_new_tokens=max_new_tokens)
if joint and other_dirs is not None:
steerer = JointResidualSteerer(model, target_dirs, other_dirs,
alpha=alpha, beta=beta)
else:
steerer = ResidualSteerer(model, target_dirs, alpha=alpha)
steerer.start()
try:
text = generate(model, tokenizer, prompt, max_new_tokens=max_new_tokens)
finally:
steerer.stop()
return text
def format_report(problem, alpha, text, base_text, mon_det, plan_det):
mon_full = mon_det.detect(text)
plan_full = plan_det.detect(text)
real_mon = count_real_monitoring(text)
pqs = compute_pqs(text)
coll = is_collapsed(text, base_text=base_text)
if alpha is None:
profile = "force-prompt baseline"
elif abs(alpha - 1.0) < 1e-5:
profile = "BASELINE (no steering, full ability)"
elif abs(alpha - 0.0) < 1e-5:
profile = "ZERO ABILITY (full suppression)"
elif 0.0 < alpha < 1.0:
profile = f"PARTIAL ABILITY ({alpha*100:.0f}% of native)"
elif alpha > 1.0:
profile = f"OVERTHINKER (amplified by {alpha-1.0:+.1f})"
else:
profile = f"OVER-SUPPRESSED ({alpha:.1f})"
lines = [
"=" * 70,
f"[α = {alpha}] profile: {profile}",
"=" * 70,
f"Length: {len(text)} chars Length-ratio vs base: {coll['length_ratio']}",
f"Collapsed: {coll['collapsed']} reason={coll['reason']} "
f"ngram_rep={coll['ngram_repetition']:.3f}",
"",
f"Planning triggers (total): {plan_full['total']}",
f" by_subtype: {plan_full['by_type']}",
f"Monitoring triggers (total): {mon_full['total']}",
f" REAL reflection: {real_mon['real_reflection']}",
f" Filler ('wait, 5+3=...'): {real_mon['filler_only']}",
f" Ambiguous: {real_mon['ambiguous']}",
f" by_subtype: {mon_full['by_type']}",
"",
f"Planning Quality Score (PQS): {pqs['pqs']:.3f}",
f" Q1 structural_depth: {pqs['q1_structural_depth']:.3f}",
f" Q2 strategy_diversity: {pqs['q2_strategy_diversity']:.3f}",
f" Q3 long_range_coherence: {pqs['q3_long_range_coherence']:.3f}",
f" Q4 premature_execution: {pqs['q4_premature_execution']:.3f} (lower=better)",
"",
"=== Generated CoT (first 800 chars) ===",
text[:800],
"...",
"=== last 200 chars ===",
text[-200:],
"",
]
return "\n".join(lines), {
"alpha": alpha,
"len": len(text),
"mon_total": mon_full["total"],
"mon_real": real_mon["real_reflection"],
"plan_total": plan_full["total"],
"pqs": pqs["pqs"],
"collapsed": coll["collapsed"],
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dim", choices=["planning", "monitoring"], default="planning")
parser.add_argument("--version", choices=["v1_raw", "v_pca_subspace"],
default="v_pca_subspace")
parser.add_argument("--problem", type=str, default=None)
parser.add_argument("--problem_file", type=str, default=None)
parser.add_argument("--alphas", nargs="+", type=float,
default=[1.0, 0.5, 0.25, 0.0])
parser.add_argument("--max_new_tokens", type=int, default=4096)
parser.add_argument("--save_to", type=str, default=None)
parser.add_argument("--joint", action="store_true",
help="Enable anti-leak coupling steering")
parser.add_argument("--beta", type=float, default=ANTI_LEAK_BETA)
parser.add_argument("--auto_problems", action="store_true",
help="Run on built-in default problems (used by runall sanity)")
args = parser.parse_args()
ensure_dirs()
log = setup_logger("10_infer", LOGS_DIR / "10_infer.log")
# Determine problems
if args.auto_problems:
problems_to_run = DEFAULT_PROBLEMS
elif args.problem:
problems_to_run = [args.problem]
elif args.problem_file:
problems_to_run = [Path(args.problem_file).read_text(encoding="utf-8").strip()]
else:
problems_to_run = [DEFAULT_PROBLEMS[0]]
log.info(f"Dimension: {args.dim} Version: {args.version}")
log.info(f"Alphas: {args.alphas}")
log.info(f"Joint anti-leak: {args.joint} beta={args.beta}")
log.info(f"Problems: {len(problems_to_run)}")
log.info("Loading model...")
model, tokenizer = load_model_and_tokenizer()
log.info("Loading directions...")
target_dirs = load_directions(DIRECTION_PATHS[args.dim][args.version])
other_dim = "monitoring" if args.dim == "planning" else "planning"
other_dirs = load_directions(DIRECTION_PATHS[other_dim][args.version]) if args.joint else None
mon_det = BehaviorDetector("monitoring")
plan_det = BehaviorDetector("planning")
full_outputs = []
for prob_idx, problem in enumerate(problems_to_run):
log.info(f"\n========== Problem {prob_idx+1}/{len(problems_to_run)} ==========")
log.info(f"Q: {problem[:120]}")
prompt = build_thinking_prompt(tokenizer, problem, enable_thinking=True)
# First pass at alpha=1 (baseline) for length comparison
baseline_text = None
prob_outputs = {}
for a in args.alphas:
log.info(f"-- α={a} --")
text = run_inference(model, tokenizer, prompt,
target_dirs, other_dirs,
a, args.max_new_tokens,
joint=args.joint, beta=args.beta)
if is_neutral_alpha(a):
baseline_text = text
base_for_eval = baseline_text if baseline_text else text
report, summary = format_report(problem, a, text, base_for_eval, mon_det, plan_det)
print(report)
prob_outputs[str(a)] = {
"text": text,
"summary": summary,
}
cleanup_memory()
full_outputs.append({
"problem": problem,
"outputs": prob_outputs,
})
if args.save_to:
write_json({
"dim": args.dim, "version": args.version,
"alphas": args.alphas, "joint": args.joint, "beta": args.beta,
"problems": full_outputs,
}, Path(args.save_to))
log.info(f"Saved to {args.save_to}")
# Print summary table
print("\n" + "=" * 70)
print("SUMMARY TABLE (across all problems)")
print("=" * 70)
print(f"{'α':>6} {'mon_total':>10} {'mon_real':>10} {'plan':>6} {'pqs':>6} {'len':>8} {'collapse':>10}")
for a in args.alphas:
# Aggregate over problems
rows = []
for po in full_outputs:
rows.append(po["outputs"][str(a)]["summary"])
mt = sum(r["mon_total"] for r in rows) / len(rows)
mr = sum(r["mon_real"] for r in rows) / len(rows)
pl = sum(r["plan_total"] for r in rows) / len(rows)
pq = sum(r["pqs"] for r in rows) / len(rows)
ln = sum(r["len"] for r in rows) / len(rows)
coll_pct = sum(1 for r in rows if r["collapsed"]) / len(rows) * 100
print(f"{a:>6.2f} {mt:>10.1f} {mr:>10.1f} {pl:>6.1f} {pq:>6.3f} {ln:>8.0f} {coll_pct:>9.0f}%")
print()
if __name__ == "__main__":
main()