| """ |
| Stage 4: Steering sweep (Apr 2026 update). |
| |
| Changes vs v1: |
| - Sweep restricted to α ∈ [0, 1] (no over-suppression / amplification) |
| - 2 versions (v1_raw, v_pca_subspace) instead of 4 |
| - --save_texts default True (so we always have CoT texts for inspection) |
| - --joint flag enables anti-leak coupling steering |
| - Robust collapse detection (ngram-based, length-relative) |
| - Real-reflection vs filler distinction in monitoring counts |
| """ |
| import sys |
| import argparse |
| from pathlib import Path |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| import torch |
| from tqdm import tqdm |
|
|
| from configs.paths import ( |
| ensure_dirs, LOGS_DIR, TEST_MATH_PATH, |
| PLAN_V1_RAW, PLAN_V_PCA_SUBSPACE, |
| MON_V1_RAW, MON_V_PCA_SUBSPACE, |
| RESULTS_DIR, |
| ) |
| from configs.model import ( |
| MODEL_CONFIG, ALPHA_SWEEP, GEN_CONFIG_FAST, |
| ANTI_LEAK_BETA, |
| ) |
| from src.utils import setup_logger, read_jsonl, append_jsonl, write_json, cleanup_memory |
| from src.model_io import load_model_and_tokenizer, build_thinking_prompt, generate |
| from src.detectors import ( |
| BehaviorDetector, compute_rr, |
| count_real_monitoring, is_collapsed, |
| ) |
| from src.planning_quality import compute_pqs |
| from src.steering import ( |
| ResidualSteerer, JointResidualSteerer, |
| build_force_prompt, is_neutral_alpha, |
| ) |
| from src.directions import load_directions |
|
|
|
|
| SWEEP_LOG = RESULTS_DIR / "sweep_log.jsonl" |
|
|
|
|
| def get_direction_paths(): |
| return { |
| "planning": {"v1_raw": PLAN_V1_RAW, |
| "v_pca_subspace": PLAN_V_PCA_SUBSPACE}, |
| "monitoring": {"v1_raw": MON_V1_RAW, |
| "v_pca_subspace": MON_V_PCA_SUBSPACE}, |
| } |
|
|
|
|
| def make_config_key(dim, version, alpha, idx, joint=False): |
| """Stable resume key.""" |
| a_str = "NA" if alpha is None else f"{alpha}" |
| j_str = "_J" if joint else "" |
| return f"{dim}|{version}|alpha{a_str}|idx{idx}{j_str}" |
|
|
|
|
| def load_completed_keys(log_path: Path): |
| done = set() |
| if not log_path.exists(): |
| return done |
| import json as _json |
| with open(log_path, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| try: |
| obj = _json.loads(line) |
| done.add(make_config_key( |
| obj.get("dim"), obj.get("version"), |
| obj.get("alpha"), obj.get("idx"), |
| joint=obj.get("joint", False), |
| )) |
| except Exception: |
| pass |
| return done |
|
|
|
|
| def evaluate_cot(text, base_text, mon_det, plan_det): |
| """Compute all metrics for a CoT.""" |
| mon_cnt = mon_det.detect(text)["total"] |
| plan_cnt = plan_det.detect(text)["total"] |
| real_mon = count_real_monitoring(text) |
| pqs = compute_pqs(text) |
| coll = is_collapsed(text, base_text=base_text) |
| return { |
| "mon_count": mon_cnt, |
| "plan_count": plan_cnt, |
| "mon_real": real_mon["real_reflection"], |
| "mon_filler": real_mon["filler_only"], |
| "mon_ambiguous": real_mon["ambiguous"], |
| "pqs": pqs, |
| "collapsed": coll["collapsed"], |
| "collapse_reason": coll["reason"], |
| "ngram_repetition": coll["ngram_repetition"], |
| "length_ratio": coll["length_ratio"], |
| "len_chars": len(text), |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--n_test", type=int, default=30) |
| parser.add_argument("--resume", action="store_true") |
| parser.add_argument("--max_new_tokens", type=int, default=GEN_CONFIG_FAST["max_new_tokens"]) |
| parser.add_argument("--skip_force_prompt", action="store_true") |
| parser.add_argument("--only_dim", choices=["planning", "monitoring", "both"], default="both") |
| parser.add_argument("--only_versions", nargs="+", default=None, |
| help="Subset of direction versions, e.g. v_pca_subspace") |
| parser.add_argument("--save_texts", action="store_true", default=True, |
| help="Save full CoT text in log (default: True)") |
| parser.add_argument("--no_save_texts", dest="save_texts", action="store_false") |
| parser.add_argument("--joint", action="store_true", |
| help="Enable anti-leak joint steering (suppress both dims together)") |
| parser.add_argument("--anti_leak_beta", type=float, default=ANTI_LEAK_BETA) |
| args = parser.parse_args() |
|
|
| ensure_dirs() |
| log = setup_logger("09_sweep", LOGS_DIR / "09_sweep.log") |
|
|
| problems = read_jsonl(TEST_MATH_PATH)[: args.n_test] |
| log.info(f"Test problems: {len(problems)}") |
| log.info(f"Joint steering (anti-leak): {args.joint}, beta={args.anti_leak_beta}") |
| log.info(f"Save texts: {args.save_texts}") |
|
|
| log.info("Loading model...") |
| model, tokenizer = load_model_and_tokenizer() |
|
|
| dir_paths = get_direction_paths() |
| loaded_dirs = {dim: {} for dim in dir_paths} |
| for dim in dir_paths: |
| for version, p in dir_paths[dim].items(): |
| loaded_dirs[dim][version] = load_directions(p) |
| n_layers = len(loaded_dirs[dim][version]) |
| n_nonzero = sum( |
| 1 for w in loaded_dirs[dim][version].values() |
| if (w.dim() == 1 and w.norm() > 1e-8) or |
| (w.dim() == 2 and w.shape[0] > 0) |
| ) |
| log.info(f"Loaded {dim}/{version}: {n_layers} layers, {n_nonzero} non-zero") |
|
|
| mon_det = BehaviorDetector("monitoring") |
| plan_det = BehaviorDetector("planning") |
|
|
| completed = load_completed_keys(SWEEP_LOG) if args.resume else set() |
| log.info(f"Resume: {len(completed)} experiments already logged") |
|
|
| |
| log.info("Computing baselines (alpha=1, NEW semantics: no steering)...") |
| baselines = {} |
| for prob in tqdm(problems, desc="baselines"): |
| prompt = build_thinking_prompt(tokenizer, prob["problem"], enable_thinking=True) |
| try: |
| text = generate(model, tokenizer, prompt, max_new_tokens=args.max_new_tokens) |
| except Exception as e: |
| log.error(f"baseline idx={prob.get('idx')} failed: {e}") |
| continue |
| ev = evaluate_cot(text, base_text=text, mon_det=mon_det, plan_det=plan_det) |
| baselines[prob["idx"]] = {"text": text, "prompt": prompt, **ev} |
| cleanup_memory() |
| log.info(f"Baselines done. {len(baselines)} OK") |
|
|
| |
| dimensions = ["planning", "monitoring"] if args.only_dim == "both" else [args.only_dim] |
| versions_to_use = args.only_versions or ["v1_raw", "v_pca_subspace"] |
|
|
| total_runs = len(dimensions) * len(versions_to_use) * len(ALPHA_SWEEP) * len(problems) |
| log.info(f"Total sweep runs: {total_runs}") |
|
|
| for dim in dimensions: |
| other_dim = "monitoring" if dim == "planning" else "planning" |
| for version in versions_to_use: |
| target_dirs = loaded_dirs[dim][version] |
| other_dirs = loaded_dirs[other_dim][version] |
|
|
| |
| def _has_signal(dirs): |
| for w in dirs.values(): |
| if w.dim() == 1 and w.norm() > 1e-6: |
| return True |
| if w.dim() == 2 and w.shape[0] > 0: |
| return True |
| return False |
| if not _has_signal(target_dirs): |
| log.warning(f"{dim}/{version}: all zero, skipping") |
| continue |
|
|
| for alpha in ALPHA_SWEEP: |
| desc = f"{dim[:4]}/{version}/α={alpha:.2f}{' J' if args.joint else ''}" |
| for prob in tqdm(problems, desc=desc, leave=False): |
| key = make_config_key(dim, version, alpha, prob["idx"], joint=args.joint) |
| if key in completed: |
| continue |
|
|
| base = baselines.get(prob["idx"]) |
| if base is None: |
| continue |
|
|
| if is_neutral_alpha(alpha): |
| steered_text = base["text"] |
| else: |
| if args.joint: |
| steerer = JointResidualSteerer( |
| model, target_dirs, other_dirs, |
| alpha=alpha, beta=args.anti_leak_beta, |
| ) |
| else: |
| steerer = ResidualSteerer(model, target_dirs, alpha=alpha) |
| steerer.start() |
| try: |
| steered_text = generate( |
| model, tokenizer, base["prompt"], |
| max_new_tokens=args.max_new_tokens, |
| ) |
| except Exception as e: |
| log.error(f"{key}: generation failed: {e}") |
| steerer.stop() |
| continue |
| steerer.stop() |
|
|
| ev = evaluate_cot(steered_text, base_text=base["text"], |
| mon_det=mon_det, plan_det=plan_det) |
|
|
| rec = { |
| "dim": dim, |
| "version": version, |
| "alpha": alpha, |
| "joint": args.joint, |
| "beta": args.anti_leak_beta if args.joint else None, |
| "idx": prob["idx"], |
| "base_mon": base["mon_count"], |
| "base_plan": base["plan_count"], |
| "base_mon_real": base["mon_real"], |
| "base_pqs": base["pqs"]["pqs"], |
| "base_len": base["len_chars"], |
| "steered_mon": ev["mon_count"], |
| "steered_plan": ev["plan_count"], |
| "steered_mon_real": ev["mon_real"], |
| "steered_mon_filler": ev["mon_filler"], |
| "steered_pqs": ev["pqs"]["pqs"], |
| "steered_q1": ev["pqs"]["q1_structural_depth"], |
| "steered_q2": ev["pqs"]["q2_strategy_diversity"], |
| "steered_q3": ev["pqs"]["q3_long_range_coherence"], |
| "steered_q4": ev["pqs"]["q4_premature_execution"], |
| "steered_len": ev["len_chars"], |
| "collapsed": ev["collapsed"], |
| "collapse_reason": ev["collapse_reason"], |
| "ngram_repetition": ev["ngram_repetition"], |
| "length_ratio": ev["length_ratio"], |
| "steered_text": steered_text if args.save_texts else None, |
| } |
| |
| if dim == "monitoring": |
| rec["rr"] = compute_rr(rec["base_mon_real"], rec["steered_mon_real"]) |
| rec["rr_total"] = compute_rr(rec["base_mon"], rec["steered_mon"]) |
| else: |
| rec["rr"] = compute_rr(rec["base_plan"], rec["steered_plan"]) |
|
|
| append_jsonl(rec, SWEEP_LOG) |
| cleanup_memory() |
|
|
| |
| if not args.skip_force_prompt: |
| log.info("Force-prompt baselines...") |
| for dim in dimensions: |
| for mode in ["suppress", "enhance"]: |
| desc = f"force_{mode}/{dim[:4]}" |
| version_name = f"force_{mode}" |
| for prob in tqdm(problems, desc=desc, leave=False): |
| key = make_config_key(dim, version_name, None, prob['idx'], joint=False) |
| if key in completed: |
| continue |
| sys_prompt = build_force_prompt( |
| MODEL_CONFIG["default_system_prompt"], dim, mode |
| ) |
| prompt = build_thinking_prompt( |
| tokenizer, prob["problem"], |
| system_prompt=sys_prompt, enable_thinking=True, |
| ) |
| try: |
| text = generate(model, tokenizer, prompt, |
| max_new_tokens=args.max_new_tokens) |
| except Exception as e: |
| log.error(f"{key}: failed: {e}") |
| continue |
| base = baselines.get(prob["idx"]) |
| if base is None: |
| continue |
| ev = evaluate_cot(text, base_text=base["text"], |
| mon_det=mon_det, plan_det=plan_det) |
| rec = { |
| "dim": dim, |
| "version": version_name, |
| "alpha": None, |
| "joint": False, |
| "idx": prob["idx"], |
| "base_mon": base["mon_count"], |
| "base_plan": base["plan_count"], |
| "base_mon_real": base["mon_real"], |
| "base_pqs": base["pqs"]["pqs"], |
| "base_len": base["len_chars"], |
| "steered_mon": ev["mon_count"], |
| "steered_plan": ev["plan_count"], |
| "steered_mon_real": ev["mon_real"], |
| "steered_mon_filler": ev["mon_filler"], |
| "steered_pqs": ev["pqs"]["pqs"], |
| "steered_q1": ev["pqs"]["q1_structural_depth"], |
| "steered_q2": ev["pqs"]["q2_strategy_diversity"], |
| "steered_q3": ev["pqs"]["q3_long_range_coherence"], |
| "steered_q4": ev["pqs"]["q4_premature_execution"], |
| "steered_len": ev["len_chars"], |
| "collapsed": ev["collapsed"], |
| "collapse_reason": ev["collapse_reason"], |
| "ngram_repetition": ev["ngram_repetition"], |
| "length_ratio": ev["length_ratio"], |
| "steered_text": text if args.save_texts else None, |
| } |
| if dim == "monitoring": |
| rec["rr"] = compute_rr(rec["base_mon_real"], rec["steered_mon_real"]) |
| else: |
| rec["rr"] = compute_rr(rec["base_plan"], rec["steered_plan"]) |
| append_jsonl(rec, SWEEP_LOG) |
| cleanup_memory() |
|
|
| log.info(f"Sweep complete. Log: {SWEEP_LOG}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|