""" Stage 4: Steering sweep (Apr 2026 update). Changes vs v1: - Sweep restricted to α ∈ [0, 1] (no over-suppression / amplification) - 2 versions (v1_raw, v_pca_subspace) instead of 4 - --save_texts default True (so we always have CoT texts for inspection) - --joint flag enables anti-leak coupling steering - Robust collapse detection (ngram-based, length-relative) - Real-reflection vs filler distinction in monitoring counts """ import sys import argparse from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) import torch from tqdm import tqdm from configs.paths import ( ensure_dirs, LOGS_DIR, TEST_MATH_PATH, PLAN_V1_RAW, PLAN_V_PCA_SUBSPACE, MON_V1_RAW, MON_V_PCA_SUBSPACE, RESULTS_DIR, ) from configs.model import ( MODEL_CONFIG, ALPHA_SWEEP, GEN_CONFIG_FAST, ANTI_LEAK_BETA, ) from src.utils import setup_logger, read_jsonl, append_jsonl, write_json, cleanup_memory from src.model_io import load_model_and_tokenizer, build_thinking_prompt, generate from src.detectors import ( BehaviorDetector, compute_rr, count_real_monitoring, is_collapsed, ) from src.planning_quality import compute_pqs from src.steering import ( ResidualSteerer, JointResidualSteerer, build_force_prompt, is_neutral_alpha, ) from src.directions import load_directions SWEEP_LOG = RESULTS_DIR / "sweep_log.jsonl" def get_direction_paths(): return { "planning": {"v1_raw": PLAN_V1_RAW, "v_pca_subspace": PLAN_V_PCA_SUBSPACE}, "monitoring": {"v1_raw": MON_V1_RAW, "v_pca_subspace": MON_V_PCA_SUBSPACE}, } def make_config_key(dim, version, alpha, idx, joint=False): """Stable resume key.""" a_str = "NA" if alpha is None else f"{alpha}" j_str = "_J" if joint else "" return f"{dim}|{version}|alpha{a_str}|idx{idx}{j_str}" def load_completed_keys(log_path: Path): done = set() if not log_path.exists(): return done import json as _json with open(log_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: obj = _json.loads(line) done.add(make_config_key( obj.get("dim"), obj.get("version"), obj.get("alpha"), obj.get("idx"), joint=obj.get("joint", False), )) except Exception: pass return done def evaluate_cot(text, base_text, mon_det, plan_det): """Compute all metrics for a CoT.""" mon_cnt = mon_det.detect(text)["total"] plan_cnt = plan_det.detect(text)["total"] real_mon = count_real_monitoring(text) pqs = compute_pqs(text) coll = is_collapsed(text, base_text=base_text) return { "mon_count": mon_cnt, "plan_count": plan_cnt, "mon_real": real_mon["real_reflection"], "mon_filler": real_mon["filler_only"], "mon_ambiguous": real_mon["ambiguous"], "pqs": pqs, "collapsed": coll["collapsed"], "collapse_reason": coll["reason"], "ngram_repetition": coll["ngram_repetition"], "length_ratio": coll["length_ratio"], "len_chars": len(text), } def main(): parser = argparse.ArgumentParser() parser.add_argument("--n_test", type=int, default=30) parser.add_argument("--resume", action="store_true") parser.add_argument("--max_new_tokens", type=int, default=GEN_CONFIG_FAST["max_new_tokens"]) parser.add_argument("--skip_force_prompt", action="store_true") parser.add_argument("--only_dim", choices=["planning", "monitoring", "both"], default="both") parser.add_argument("--only_versions", nargs="+", default=None, help="Subset of direction versions, e.g. v_pca_subspace") parser.add_argument("--save_texts", action="store_true", default=True, help="Save full CoT text in log (default: True)") parser.add_argument("--no_save_texts", dest="save_texts", action="store_false") parser.add_argument("--joint", action="store_true", help="Enable anti-leak joint steering (suppress both dims together)") parser.add_argument("--anti_leak_beta", type=float, default=ANTI_LEAK_BETA) args = parser.parse_args() ensure_dirs() log = setup_logger("09_sweep", LOGS_DIR / "09_sweep.log") problems = read_jsonl(TEST_MATH_PATH)[: args.n_test] log.info(f"Test problems: {len(problems)}") log.info(f"Joint steering (anti-leak): {args.joint}, beta={args.anti_leak_beta}") log.info(f"Save texts: {args.save_texts}") log.info("Loading model...") model, tokenizer = load_model_and_tokenizer() dir_paths = get_direction_paths() loaded_dirs = {dim: {} for dim in dir_paths} for dim in dir_paths: for version, p in dir_paths[dim].items(): loaded_dirs[dim][version] = load_directions(p) n_layers = len(loaded_dirs[dim][version]) n_nonzero = sum( 1 for w in loaded_dirs[dim][version].values() if (w.dim() == 1 and w.norm() > 1e-8) or (w.dim() == 2 and w.shape[0] > 0) ) log.info(f"Loaded {dim}/{version}: {n_layers} layers, {n_nonzero} non-zero") mon_det = BehaviorDetector("monitoring") plan_det = BehaviorDetector("planning") completed = load_completed_keys(SWEEP_LOG) if args.resume else set() log.info(f"Resume: {len(completed)} experiments already logged") # Baselines (alpha = 1.0) log.info("Computing baselines (alpha=1, NEW semantics: no steering)...") baselines = {} for prob in tqdm(problems, desc="baselines"): prompt = build_thinking_prompt(tokenizer, prob["problem"], enable_thinking=True) try: text = generate(model, tokenizer, prompt, max_new_tokens=args.max_new_tokens) except Exception as e: log.error(f"baseline idx={prob.get('idx')} failed: {e}") continue ev = evaluate_cot(text, base_text=text, mon_det=mon_det, plan_det=plan_det) baselines[prob["idx"]] = {"text": text, "prompt": prompt, **ev} cleanup_memory() log.info(f"Baselines done. {len(baselines)} OK") # Sweep dimensions = ["planning", "monitoring"] if args.only_dim == "both" else [args.only_dim] versions_to_use = args.only_versions or ["v1_raw", "v_pca_subspace"] total_runs = len(dimensions) * len(versions_to_use) * len(ALPHA_SWEEP) * len(problems) log.info(f"Total sweep runs: {total_runs}") for dim in dimensions: other_dim = "monitoring" if dim == "planning" else "planning" for version in versions_to_use: target_dirs = loaded_dirs[dim][version] other_dirs = loaded_dirs[other_dim][version] # Sanity check: skip if all directions zero def _has_signal(dirs): for w in dirs.values(): if w.dim() == 1 and w.norm() > 1e-6: return True if w.dim() == 2 and w.shape[0] > 0: return True return False if not _has_signal(target_dirs): log.warning(f"{dim}/{version}: all zero, skipping") continue for alpha in ALPHA_SWEEP: desc = f"{dim[:4]}/{version}/α={alpha:.2f}{' J' if args.joint else ''}" for prob in tqdm(problems, desc=desc, leave=False): key = make_config_key(dim, version, alpha, prob["idx"], joint=args.joint) if key in completed: continue base = baselines.get(prob["idx"]) if base is None: continue if is_neutral_alpha(alpha): steered_text = base["text"] else: if args.joint: steerer = JointResidualSteerer( model, target_dirs, other_dirs, alpha=alpha, beta=args.anti_leak_beta, ) else: steerer = ResidualSteerer(model, target_dirs, alpha=alpha) steerer.start() try: steered_text = generate( model, tokenizer, base["prompt"], max_new_tokens=args.max_new_tokens, ) except Exception as e: log.error(f"{key}: generation failed: {e}") steerer.stop() continue steerer.stop() ev = evaluate_cot(steered_text, base_text=base["text"], mon_det=mon_det, plan_det=plan_det) rec = { "dim": dim, "version": version, "alpha": alpha, "joint": args.joint, "beta": args.anti_leak_beta if args.joint else None, "idx": prob["idx"], "base_mon": base["mon_count"], "base_plan": base["plan_count"], "base_mon_real": base["mon_real"], "base_pqs": base["pqs"]["pqs"], "base_len": base["len_chars"], "steered_mon": ev["mon_count"], "steered_plan": ev["plan_count"], "steered_mon_real": ev["mon_real"], "steered_mon_filler": ev["mon_filler"], "steered_pqs": ev["pqs"]["pqs"], "steered_q1": ev["pqs"]["q1_structural_depth"], "steered_q2": ev["pqs"]["q2_strategy_diversity"], "steered_q3": ev["pqs"]["q3_long_range_coherence"], "steered_q4": ev["pqs"]["q4_premature_execution"], "steered_len": ev["len_chars"], "collapsed": ev["collapsed"], "collapse_reason": ev["collapse_reason"], "ngram_repetition": ev["ngram_repetition"], "length_ratio": ev["length_ratio"], "steered_text": steered_text if args.save_texts else None, } # Use real_reflection for monitoring RR (excludes filler) if dim == "monitoring": rec["rr"] = compute_rr(rec["base_mon_real"], rec["steered_mon_real"]) rec["rr_total"] = compute_rr(rec["base_mon"], rec["steered_mon"]) else: rec["rr"] = compute_rr(rec["base_plan"], rec["steered_plan"]) append_jsonl(rec, SWEEP_LOG) cleanup_memory() # Force-prompt baseline if not args.skip_force_prompt: log.info("Force-prompt baselines...") for dim in dimensions: for mode in ["suppress", "enhance"]: desc = f"force_{mode}/{dim[:4]}" version_name = f"force_{mode}" for prob in tqdm(problems, desc=desc, leave=False): key = make_config_key(dim, version_name, None, prob['idx'], joint=False) if key in completed: continue sys_prompt = build_force_prompt( MODEL_CONFIG["default_system_prompt"], dim, mode ) prompt = build_thinking_prompt( tokenizer, prob["problem"], system_prompt=sys_prompt, enable_thinking=True, ) try: text = generate(model, tokenizer, prompt, max_new_tokens=args.max_new_tokens) except Exception as e: log.error(f"{key}: failed: {e}") continue base = baselines.get(prob["idx"]) if base is None: continue ev = evaluate_cot(text, base_text=base["text"], mon_det=mon_det, plan_det=plan_det) rec = { "dim": dim, "version": version_name, "alpha": None, "joint": False, "idx": prob["idx"], "base_mon": base["mon_count"], "base_plan": base["plan_count"], "base_mon_real": base["mon_real"], "base_pqs": base["pqs"]["pqs"], "base_len": base["len_chars"], "steered_mon": ev["mon_count"], "steered_plan": ev["plan_count"], "steered_mon_real": ev["mon_real"], "steered_mon_filler": ev["mon_filler"], "steered_pqs": ev["pqs"]["pqs"], "steered_q1": ev["pqs"]["q1_structural_depth"], "steered_q2": ev["pqs"]["q2_strategy_diversity"], "steered_q3": ev["pqs"]["q3_long_range_coherence"], "steered_q4": ev["pqs"]["q4_premature_execution"], "steered_len": ev["len_chars"], "collapsed": ev["collapsed"], "collapse_reason": ev["collapse_reason"], "ngram_repetition": ev["ngram_repetition"], "length_ratio": ev["length_ratio"], "steered_text": text if args.save_texts else None, } if dim == "monitoring": rec["rr"] = compute_rr(rec["base_mon_real"], rec["steered_mon_real"]) else: rec["rr"] = compute_rr(rec["base_plan"], rec["steered_plan"]) append_jsonl(rec, SWEEP_LOG) cleanup_memory() log.info(f"Sweep complete. Log: {SWEEP_LOG}") if __name__ == "__main__": main()