v2 / scripts /09_steering_sweep.py
JulianHJR's picture
Upload folder using huggingface_hub
e53f10b verified
"""
Stage 4: Steering sweep (Apr 2026 update).
Changes vs v1:
- Sweep restricted to α ∈ [0, 1] (no over-suppression / amplification)
- 2 versions (v1_raw, v_pca_subspace) instead of 4
- --save_texts default True (so we always have CoT texts for inspection)
- --joint flag enables anti-leak coupling steering
- Robust collapse detection (ngram-based, length-relative)
- Real-reflection vs filler distinction in monitoring counts
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import torch
from tqdm import tqdm
from configs.paths import (
ensure_dirs, LOGS_DIR, TEST_MATH_PATH,
PLAN_V1_RAW, PLAN_V_PCA_SUBSPACE,
MON_V1_RAW, MON_V_PCA_SUBSPACE,
RESULTS_DIR,
)
from configs.model import (
MODEL_CONFIG, ALPHA_SWEEP, GEN_CONFIG_FAST,
ANTI_LEAK_BETA,
)
from src.utils import setup_logger, read_jsonl, append_jsonl, write_json, cleanup_memory
from src.model_io import load_model_and_tokenizer, build_thinking_prompt, generate
from src.detectors import (
BehaviorDetector, compute_rr,
count_real_monitoring, is_collapsed,
)
from src.planning_quality import compute_pqs
from src.steering import (
ResidualSteerer, JointResidualSteerer,
build_force_prompt, is_neutral_alpha,
)
from src.directions import load_directions
SWEEP_LOG = RESULTS_DIR / "sweep_log.jsonl"
def get_direction_paths():
return {
"planning": {"v1_raw": PLAN_V1_RAW,
"v_pca_subspace": PLAN_V_PCA_SUBSPACE},
"monitoring": {"v1_raw": MON_V1_RAW,
"v_pca_subspace": MON_V_PCA_SUBSPACE},
}
def make_config_key(dim, version, alpha, idx, joint=False):
"""Stable resume key."""
a_str = "NA" if alpha is None else f"{alpha}"
j_str = "_J" if joint else ""
return f"{dim}|{version}|alpha{a_str}|idx{idx}{j_str}"
def load_completed_keys(log_path: Path):
done = set()
if not log_path.exists():
return done
import json as _json
with open(log_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
obj = _json.loads(line)
done.add(make_config_key(
obj.get("dim"), obj.get("version"),
obj.get("alpha"), obj.get("idx"),
joint=obj.get("joint", False),
))
except Exception:
pass
return done
def evaluate_cot(text, base_text, mon_det, plan_det):
"""Compute all metrics for a CoT."""
mon_cnt = mon_det.detect(text)["total"]
plan_cnt = plan_det.detect(text)["total"]
real_mon = count_real_monitoring(text)
pqs = compute_pqs(text)
coll = is_collapsed(text, base_text=base_text)
return {
"mon_count": mon_cnt,
"plan_count": plan_cnt,
"mon_real": real_mon["real_reflection"],
"mon_filler": real_mon["filler_only"],
"mon_ambiguous": real_mon["ambiguous"],
"pqs": pqs,
"collapsed": coll["collapsed"],
"collapse_reason": coll["reason"],
"ngram_repetition": coll["ngram_repetition"],
"length_ratio": coll["length_ratio"],
"len_chars": len(text),
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--n_test", type=int, default=30)
parser.add_argument("--resume", action="store_true")
parser.add_argument("--max_new_tokens", type=int, default=GEN_CONFIG_FAST["max_new_tokens"])
parser.add_argument("--skip_force_prompt", action="store_true")
parser.add_argument("--only_dim", choices=["planning", "monitoring", "both"], default="both")
parser.add_argument("--only_versions", nargs="+", default=None,
help="Subset of direction versions, e.g. v_pca_subspace")
parser.add_argument("--save_texts", action="store_true", default=True,
help="Save full CoT text in log (default: True)")
parser.add_argument("--no_save_texts", dest="save_texts", action="store_false")
parser.add_argument("--joint", action="store_true",
help="Enable anti-leak joint steering (suppress both dims together)")
parser.add_argument("--anti_leak_beta", type=float, default=ANTI_LEAK_BETA)
args = parser.parse_args()
ensure_dirs()
log = setup_logger("09_sweep", LOGS_DIR / "09_sweep.log")
problems = read_jsonl(TEST_MATH_PATH)[: args.n_test]
log.info(f"Test problems: {len(problems)}")
log.info(f"Joint steering (anti-leak): {args.joint}, beta={args.anti_leak_beta}")
log.info(f"Save texts: {args.save_texts}")
log.info("Loading model...")
model, tokenizer = load_model_and_tokenizer()
dir_paths = get_direction_paths()
loaded_dirs = {dim: {} for dim in dir_paths}
for dim in dir_paths:
for version, p in dir_paths[dim].items():
loaded_dirs[dim][version] = load_directions(p)
n_layers = len(loaded_dirs[dim][version])
n_nonzero = sum(
1 for w in loaded_dirs[dim][version].values()
if (w.dim() == 1 and w.norm() > 1e-8) or
(w.dim() == 2 and w.shape[0] > 0)
)
log.info(f"Loaded {dim}/{version}: {n_layers} layers, {n_nonzero} non-zero")
mon_det = BehaviorDetector("monitoring")
plan_det = BehaviorDetector("planning")
completed = load_completed_keys(SWEEP_LOG) if args.resume else set()
log.info(f"Resume: {len(completed)} experiments already logged")
# Baselines (alpha = 1.0)
log.info("Computing baselines (alpha=1, NEW semantics: no steering)...")
baselines = {}
for prob in tqdm(problems, desc="baselines"):
prompt = build_thinking_prompt(tokenizer, prob["problem"], enable_thinking=True)
try:
text = generate(model, tokenizer, prompt, max_new_tokens=args.max_new_tokens)
except Exception as e:
log.error(f"baseline idx={prob.get('idx')} failed: {e}")
continue
ev = evaluate_cot(text, base_text=text, mon_det=mon_det, plan_det=plan_det)
baselines[prob["idx"]] = {"text": text, "prompt": prompt, **ev}
cleanup_memory()
log.info(f"Baselines done. {len(baselines)} OK")
# Sweep
dimensions = ["planning", "monitoring"] if args.only_dim == "both" else [args.only_dim]
versions_to_use = args.only_versions or ["v1_raw", "v_pca_subspace"]
total_runs = len(dimensions) * len(versions_to_use) * len(ALPHA_SWEEP) * len(problems)
log.info(f"Total sweep runs: {total_runs}")
for dim in dimensions:
other_dim = "monitoring" if dim == "planning" else "planning"
for version in versions_to_use:
target_dirs = loaded_dirs[dim][version]
other_dirs = loaded_dirs[other_dim][version]
# Sanity check: skip if all directions zero
def _has_signal(dirs):
for w in dirs.values():
if w.dim() == 1 and w.norm() > 1e-6:
return True
if w.dim() == 2 and w.shape[0] > 0:
return True
return False
if not _has_signal(target_dirs):
log.warning(f"{dim}/{version}: all zero, skipping")
continue
for alpha in ALPHA_SWEEP:
desc = f"{dim[:4]}/{version}/α={alpha:.2f}{' J' if args.joint else ''}"
for prob in tqdm(problems, desc=desc, leave=False):
key = make_config_key(dim, version, alpha, prob["idx"], joint=args.joint)
if key in completed:
continue
base = baselines.get(prob["idx"])
if base is None:
continue
if is_neutral_alpha(alpha):
steered_text = base["text"]
else:
if args.joint:
steerer = JointResidualSteerer(
model, target_dirs, other_dirs,
alpha=alpha, beta=args.anti_leak_beta,
)
else:
steerer = ResidualSteerer(model, target_dirs, alpha=alpha)
steerer.start()
try:
steered_text = generate(
model, tokenizer, base["prompt"],
max_new_tokens=args.max_new_tokens,
)
except Exception as e:
log.error(f"{key}: generation failed: {e}")
steerer.stop()
continue
steerer.stop()
ev = evaluate_cot(steered_text, base_text=base["text"],
mon_det=mon_det, plan_det=plan_det)
rec = {
"dim": dim,
"version": version,
"alpha": alpha,
"joint": args.joint,
"beta": args.anti_leak_beta if args.joint else None,
"idx": prob["idx"],
"base_mon": base["mon_count"],
"base_plan": base["plan_count"],
"base_mon_real": base["mon_real"],
"base_pqs": base["pqs"]["pqs"],
"base_len": base["len_chars"],
"steered_mon": ev["mon_count"],
"steered_plan": ev["plan_count"],
"steered_mon_real": ev["mon_real"],
"steered_mon_filler": ev["mon_filler"],
"steered_pqs": ev["pqs"]["pqs"],
"steered_q1": ev["pqs"]["q1_structural_depth"],
"steered_q2": ev["pqs"]["q2_strategy_diversity"],
"steered_q3": ev["pqs"]["q3_long_range_coherence"],
"steered_q4": ev["pqs"]["q4_premature_execution"],
"steered_len": ev["len_chars"],
"collapsed": ev["collapsed"],
"collapse_reason": ev["collapse_reason"],
"ngram_repetition": ev["ngram_repetition"],
"length_ratio": ev["length_ratio"],
"steered_text": steered_text if args.save_texts else None,
}
# Use real_reflection for monitoring RR (excludes filler)
if dim == "monitoring":
rec["rr"] = compute_rr(rec["base_mon_real"], rec["steered_mon_real"])
rec["rr_total"] = compute_rr(rec["base_mon"], rec["steered_mon"])
else:
rec["rr"] = compute_rr(rec["base_plan"], rec["steered_plan"])
append_jsonl(rec, SWEEP_LOG)
cleanup_memory()
# Force-prompt baseline
if not args.skip_force_prompt:
log.info("Force-prompt baselines...")
for dim in dimensions:
for mode in ["suppress", "enhance"]:
desc = f"force_{mode}/{dim[:4]}"
version_name = f"force_{mode}"
for prob in tqdm(problems, desc=desc, leave=False):
key = make_config_key(dim, version_name, None, prob['idx'], joint=False)
if key in completed:
continue
sys_prompt = build_force_prompt(
MODEL_CONFIG["default_system_prompt"], dim, mode
)
prompt = build_thinking_prompt(
tokenizer, prob["problem"],
system_prompt=sys_prompt, enable_thinking=True,
)
try:
text = generate(model, tokenizer, prompt,
max_new_tokens=args.max_new_tokens)
except Exception as e:
log.error(f"{key}: failed: {e}")
continue
base = baselines.get(prob["idx"])
if base is None:
continue
ev = evaluate_cot(text, base_text=base["text"],
mon_det=mon_det, plan_det=plan_det)
rec = {
"dim": dim,
"version": version_name,
"alpha": None,
"joint": False,
"idx": prob["idx"],
"base_mon": base["mon_count"],
"base_plan": base["plan_count"],
"base_mon_real": base["mon_real"],
"base_pqs": base["pqs"]["pqs"],
"base_len": base["len_chars"],
"steered_mon": ev["mon_count"],
"steered_plan": ev["plan_count"],
"steered_mon_real": ev["mon_real"],
"steered_mon_filler": ev["mon_filler"],
"steered_pqs": ev["pqs"]["pqs"],
"steered_q1": ev["pqs"]["q1_structural_depth"],
"steered_q2": ev["pqs"]["q2_strategy_diversity"],
"steered_q3": ev["pqs"]["q3_long_range_coherence"],
"steered_q4": ev["pqs"]["q4_premature_execution"],
"steered_len": ev["len_chars"],
"collapsed": ev["collapsed"],
"collapse_reason": ev["collapse_reason"],
"ngram_repetition": ev["ngram_repetition"],
"length_ratio": ev["length_ratio"],
"steered_text": text if args.save_texts else None,
}
if dim == "monitoring":
rec["rr"] = compute_rr(rec["base_mon_real"], rec["steered_mon_real"])
else:
rec["rr"] = compute_rr(rec["base_plan"], rec["steered_plan"])
append_jsonl(rec, SWEEP_LOG)
cleanup_memory()
log.info(f"Sweep complete. Log: {SWEEP_LOG}")
if __name__ == "__main__":
main()