""" Downstream evaluation: pass@1 on MATH-500 holdout, AIME-24, GPQA-D. For each of: - baseline (alpha=1, no steering — NEW SEMANTICS) - plan_alpha_0 (full planning suppression) - mon_alpha_0 (full monitoring suppression) generate answers and grade. Grading is lenient numeric match for MATH / AIME; substring for GPQA. Output: results/downstream_accuracy.json """ import sys import argparse import re from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) import torch from tqdm import tqdm from configs.paths import ( ensure_dirs, LOGS_DIR, TEST_MATH_PATH, TEST_AIME_PATH, TEST_GPQA_PATH, PLAN_V_PCA_SUBSPACE, MON_V_PCA_SUBSPACE, RESULTS_DIR, DOWNSTREAM_ACC_JSON, ) from configs.model import GEN_CONFIG_FAST from src.utils import setup_logger, read_jsonl, write_json, cleanup_memory from src.model_io import load_model_and_tokenizer, build_thinking_prompt, generate from src.steering import ResidualSteerer, is_neutral_alpha from src.directions import load_directions def extract_boxed(text): """Extract LaTeX \\boxed{...} answer.""" match = re.search(r"\\boxed\{([^}]*)\}", text) if match: return match.group(1).strip() return None def extract_final_answer(text): """Fallback: last number-like token.""" boxed = extract_boxed(text) if boxed: return boxed # Try "Final Answer: X" or "answer is X" m = re.search(r"(?i)final\s+answer[:\s]+(.+?)(?:\n|\Z)", text) if m: return m.group(1).strip().rstrip(".") # Try last number m = re.findall(r"-?\d+\.?\d*", text[-500:]) if m: return m[-1] return "" def normalize_numeric(s): s = s.strip().replace(",", "") s = re.sub(r"\\frac\{(\-?\d+)\}\{(\-?\d+)\}", r"\1/\2", s) return s def grade_numeric(pred, gold): pred_n = normalize_numeric(str(pred)) gold_n = normalize_numeric(str(gold)) if pred_n == gold_n: return True try: return abs(float(pred_n) - float(gold_n)) < 1e-4 except Exception: return False def grade_substring(pred, gold): pred = str(pred).strip().lower() gold = str(gold).strip().lower() return gold in pred or pred in gold def _mcnemar_pvalue(b: int, c: int): """ Exact two-sided McNemar test p-value. Tests H0: P(baseline correct, steered wrong) = P(baseline wrong, steered correct) i.e. that the two configs have equal accuracy on this paired test. b = # samples where baseline is RIGHT but steered is WRONG (regressions) c = # samples where baseline is WRONG but steered is RIGHT (recoveries) Returns float p-value, or None if no discordant pairs (test undefined). Implementation: under H0, b ~ Binomial(n=b+c, p=0.5). Two-sided exact test. """ n = b + c if n == 0: return None # no discordant pairs — test is undefined # Two-sided exact: p = 2 * P(X <= min(b, c) | n, 0.5), capped at 1.0 # Compute Binomial CDF without scipy (small n typical): from math import comb k = min(b, c) cdf = sum(comb(n, i) for i in range(k + 1)) / (2 ** n) p = min(2 * cdf, 1.0) return float(p) def run_config(model, tokenizer, test_set, config_name, directions=None, alpha=1.0, grader="numeric", max_new_tokens=2048): """Run one eval config over test_set. Returns {accuracy, n, per_sample}. NEW SEMANTICS: alpha=1.0 is baseline (no steering applied). """ per_sample = [] correct = 0 for prob in tqdm(test_set, desc=config_name, leave=False): prompt = build_thinking_prompt(tokenizer, prob["problem"], enable_thinking=True) # Apply steering only if alpha is NOT the neutral value if directions is not None and not is_neutral_alpha(alpha): steerer = ResidualSteerer(model, directions, alpha=alpha) steerer.start() try: text = generate(model, tokenizer, prompt, max_new_tokens=max_new_tokens) finally: steerer.stop() else: text = generate(model, tokenizer, prompt, max_new_tokens=max_new_tokens) pred = extract_final_answer(text) gold = prob.get("answer", "") if grader == "numeric": ok = grade_numeric(pred, gold) else: ok = grade_substring(pred, gold) if ok: correct += 1 per_sample.append({ "idx": prob["idx"], "pred": pred, "gold": gold, "correct": bool(ok), }) cleanup_memory() return { "accuracy": correct / max(len(test_set), 1), "correct": correct, "n": len(test_set), "per_sample": per_sample, } def main(): parser = argparse.ArgumentParser() parser.add_argument("--configs", nargs="+", default=["baseline", "plan_alpha_0", "mon_alpha_0"], help="NEW SEMANTICS: alpha=1 baseline; alpha=0 max suppress; " "alpha=2 amplify. e.g. plan_alpha_0 = max planning suppress.") parser.add_argument("--max_new_tokens", type=int, default=2048) parser.add_argument("--resume", action="store_true") args = parser.parse_args() ensure_dirs() log = setup_logger("12_downstream", LOGS_DIR / "12_downstream.log") if args.resume and DOWNSTREAM_ACC_JSON.exists(): log.info(f"Downstream results already exist: {DOWNSTREAM_ACC_JSON}") return # Load test sets test_sets = {} if TEST_MATH_PATH.exists(): test_sets["MATH-500-holdout"] = (read_jsonl(TEST_MATH_PATH), "numeric") if TEST_AIME_PATH.exists(): test_sets["AIME-24"] = (read_jsonl(TEST_AIME_PATH), "numeric") if TEST_GPQA_PATH.exists(): test_sets["GPQA-D"] = (read_jsonl(TEST_GPQA_PATH), "substring") log.info(f"Test sets: {list(test_sets.keys())}") for name, (ds, _) in test_sets.items(): log.info(f" {name}: {len(ds)} problems") log.info("Loading model...") model, tokenizer = load_model_and_tokenizer() plan_dirs = load_directions(PLAN_V_PCA_SUBSPACE) mon_dirs = load_directions(MON_V_PCA_SUBSPACE) results = {} for config_name in args.configs: log.info(f"=== Config: {config_name} ===") if config_name == "baseline": directions, alpha = None, 1.0 # NEW SEMANTICS: alpha=1 means no steering elif config_name.startswith("plan_alpha_"): alpha = float(config_name.replace("plan_alpha_", "")) directions = plan_dirs elif config_name.startswith("mon_alpha_"): alpha = float(config_name.replace("mon_alpha_", "")) directions = mon_dirs else: log.warning(f"Unknown config: {config_name}, skipping") continue results[config_name] = {} for ts_name, (ts, grader) in test_sets.items(): r = run_config( model, tokenizer, ts, f"{config_name}/{ts_name}", directions=directions, alpha=alpha, grader=grader, max_new_tokens=args.max_new_tokens, ) log.info(f" {ts_name}: {r['correct']}/{r['n']} = {r['accuracy']:.3f}") results[config_name][ts_name] = { "accuracy": r["accuracy"], "correct": r["correct"], "n": r["n"], "per_sample": r["per_sample"], } # ========================================================= # Compute accuracy-drop statistics vs baseline # ========================================================= if "baseline" in results: log.info("=" * 60) log.info("Computing per-config accuracy drop vs baseline...") for config_name in list(results.keys()): if config_name == "baseline": continue for ts_name in list(results[config_name].keys()): base = results["baseline"].get(ts_name) cur = results[config_name].get(ts_name) if base is None or cur is None: continue # Build per-sample correctness lookup keyed by problem idx base_map = {p["idx"]: p["correct"] for p in base["per_sample"]} cur_map = {p["idx"]: p["correct"] for p in cur["per_sample"]} common_idx = set(base_map) & set(cur_map) n_common = len(common_idx) # Discordant pairs for McNemar # b: baseline correct, steered wrong ("regressions") # c: baseline wrong, steered correct ("recoveries") b = sum(1 for i in common_idx if base_map[i] and not cur_map[i]) c = sum(1 for i in common_idx if (not base_map[i]) and cur_map[i]) # McNemar p-value (exact binomial under H0: b ~ Bin(b+c, 0.5)) mcnemar_p = _mcnemar_pvalue(b, c) base_acc = base["accuracy"] cur_acc = cur["accuracy"] delta = base_acc - cur_acc # positive => accuracy DROPPED rel_drop = (delta / base_acc) if base_acc > 0 else 0.0 cur["vs_baseline"] = { "baseline_accuracy": base_acc, "steered_accuracy": cur_acc, "absolute_drop": delta, "relative_drop": rel_drop, "n_common": n_common, "n_regressions": b, "n_recoveries": c, "mcnemar_p_value": mcnemar_p, "significant_at_0_05": (mcnemar_p is not None and mcnemar_p < 0.05), } log.info( f" {config_name}/{ts_name}: " f"acc {base_acc:.3f} -> {cur_acc:.3f} " f"(Δ={delta:+.3f}, rel={rel_drop:+.1%}) " f"regressions={b} recoveries={c} " f"McNemar p={mcnemar_p if mcnemar_p is None else f'{mcnemar_p:.3g}'}" ) write_json(results, DOWNSTREAM_ACC_JSON) log.info(f"Saved: {DOWNSTREAM_ACC_JSON}") if __name__ == "__main__": main()