| """M5 — Activation Steering: causally manipulate the shortcut direction. |
| |
| For one checkpoint (default: peak_ood_epoch from summary.json), we: |
| 1. Extract avgpool features for train (H0-H2) + OOD (H4) splits. |
| 2. Identify the dominant shortcut direction `v_s` as the top eigenvector |
| of the between-hospital covariance (LDA's first projection direction). |
| 3. Sweep α ∈ {-3, -2, -1, 0, +1, +2, +3} and apply |
| h' = h + α · σ_align · v_s |
| where σ_align is the std of features projected onto v_s (so α counts |
| in 'standard deviations of shortcut activation'). |
| 4. Re-classify OOD with the original head. |
| 5. Re-fit hospital + tumor probes on the steered features and report |
| accuracy curves. |
| |
| Strong mechanistic claim if: |
| - tumor-head OOD acc declines monotonically as |α| grows |
| - hospital-probe acc on steered features rises with |α| |
| - tumor-probe acc on steered features stays approximately flat (the |
| *causal* feature isn't aligned with the shortcut direction) |
| |
| Usage |
| ----- |
| python -m experiments.mechinterp_m5_steering \\ |
| --run_dir experiments/runs/<id> \\ |
| --data_root data/wilds \\ |
| [--epoch 50] # default: peak_ood_epoch from summary.json |
| [--max_samples 1000] [--alphas " -3,-2,-1,0,1,2,3 "] |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
| from typing import Dict, List, Tuple |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from sklearn.linear_model import LogisticRegression |
| from sklearn.preprocessing import StandardScaler |
|
|
| from experiments.mechinterp_m1 import ( |
| register_hooks, extract_features, load_model_from_checkpoint, |
| find_checkpoints, |
| ) |
| from experiments.mechinterp_m4_ablation import ( |
| _select_epoch, _build_loaders, |
| _classifier_logits_from_features, _accuracy, |
| ) |
|
|
|
|
| def _top_lda_direction(X: np.ndarray, hospital_ids: np.ndarray) -> np.ndarray: |
| """Return a unit vector aligned with the dominant between-hospital direction |
| in feature space (LDA-1).""" |
| classes = np.unique(hospital_ids) |
| global_mean = X.mean(axis=0, keepdims=True) |
| means = np.vstack([ |
| X[hospital_ids == c].mean(axis=0, keepdims=True) - global_mean |
| for c in classes |
| ]) |
| |
| |
| U, s, Vt = np.linalg.svd(means, full_matrices=False) |
| return Vt[0] |
|
|
|
|
| def run_steering( |
| run_dir: Path, |
| data_root: str, |
| epoch: int | None = None, |
| max_samples: int = 1000, |
| device: str = "cuda", |
| alphas: List[float] = None, |
| ) -> Dict: |
| if alphas is None: |
| alphas = [-3.0, -2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0, 3.0] |
|
|
| epoch, ckpt_path = _select_epoch(run_dir, epoch) |
|
|
| print(f"\n M5 — Activation Steering") |
| print(f" run_dir : {run_dir.name}") |
| print(f" epoch : {epoch} ({ckpt_path.name})") |
| print(f" alphas : {alphas}") |
|
|
| model = load_model_from_checkpoint(str(ckpt_path), n_classes=2, device=device) |
| model.eval() |
| register_hooks(model) |
|
|
| cfg_path = run_dir / "config.json" |
| seed = 42 |
| if cfg_path.exists(): |
| seed = json.loads(cfg_path.read_text()).get("seed", 42) |
| train_loader, ood_loader = _build_loaders(data_root, max_samples, seed=seed) |
|
|
| print(f" Extracting features...") |
| feats_train, hosp_train, tumor_train = extract_features( |
| model, train_loader, device, max_samples=max_samples |
| ) |
| feats_ood, hosp_ood, tumor_ood = extract_features( |
| model, ood_loader, device, max_samples=max_samples // 2 |
| ) |
|
|
| layer = "avgpool" |
| X_tr = np.asarray(feats_train[layer]); X_tr = X_tr.reshape(X_tr.shape[0], -1) |
| X_ood = np.asarray(feats_ood[layer]); X_ood = X_ood.reshape(X_ood.shape[0], -1) |
| hosp_train = np.asarray(hosp_train) |
| hosp_ood = np.asarray(hosp_ood) |
| tumor_train = np.asarray(tumor_train) |
| tumor_ood = np.asarray(tumor_ood) |
|
|
| |
| scaler = StandardScaler().fit(X_tr) |
| X_tr_n = scaler.transform(X_tr) |
| X_ood_n = scaler.transform(X_ood) |
|
|
| |
| v = _top_lda_direction(X_tr_n, hosp_train) |
| |
| sigma = float(np.std(X_tr_n @ v)) |
| print(f" Top hospital direction v_s : ‖v‖={np.linalg.norm(v):.3f}, " |
| f"σ(X_tr·v)={sigma:.3f}") |
|
|
| |
| hosp_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", |
| multi_class="auto", n_jobs=-1).fit(X_tr_n, hosp_train) |
| tumor_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", |
| multi_class="auto", n_jobs=-1).fit(X_tr_n, tumor_train) |
|
|
| |
| sweep = [] |
| for alpha in alphas: |
| |
| X_ood_steered_n = X_ood_n + alpha * sigma * v[None, :] |
| X_ood_steered = scaler.inverse_transform(X_ood_steered_n) |
|
|
| |
| logits = _classifier_logits_from_features(model, X_ood_steered, layer, device) |
| head_acc = _accuracy(logits, tumor_ood) |
|
|
| |
| if len(np.unique(hosp_ood)) > 1: |
| hosp_acc = hosp_clf.score(X_ood_steered_n, hosp_ood) |
| else: |
| hosp_acc = float("nan") |
| tumor_acc = tumor_clf.score(X_ood_steered_n, tumor_ood) |
|
|
| sweep.append({ |
| "alpha": float(alpha), |
| "head_ood_acc": head_acc, |
| "hospital_probe": hosp_acc, |
| "tumor_probe": tumor_acc, |
| }) |
| print(f" α={alpha:+.2f} head_ood={head_acc:.3f} " |
| f"hosp_probe={hosp_acc if not np.isnan(hosp_acc) else 'nan':<5} " |
| f"tumor_probe={tumor_acc:.3f}") |
|
|
| return { |
| "run_id": run_dir.name, |
| "epoch": epoch, |
| "layer": layer, |
| "max_samples": max_samples, |
| "v_norm": float(np.linalg.norm(v)), |
| "sigma": sigma, |
| "sweep": sweep, |
| } |
|
|
|
|
| def plot_steering(result: Dict, out_path: Path): |
| sweep = result["sweep"] |
| a = [r["alpha"] for r in sweep] |
| head = [r["head_ood_acc"] for r in sweep] |
| hosp = [r["hospital_probe"] for r in sweep] |
| tumor = [r["tumor_probe"] for r in sweep] |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(13, 5)) |
|
|
| |
| ax = axes[0] |
| ax.plot(a, head, "k-o", lw=2, ms=7) |
| ax.axvline(0, color="gray", ls=":", lw=1, alpha=0.5) |
| ax.set_xlabel("Steering coefficient α (in σ-units of shortcut direction)") |
| ax.set_ylabel("Head OOD (H4) accuracy") |
| ax.set_title("Causal effect of steering activations along v_s\n" |
| "(monotonic decline as |α| grows = causal evidence)", |
| fontweight="bold", fontsize=10) |
| ax.grid(alpha=0.3) |
| ax.set_ylim(0.4, max(0.85, max(head) + 0.05)) |
|
|
| |
| ax = axes[1] |
| ax.plot(a, hosp, "r-s", lw=2, ms=7, label="Hospital probe (↑ with |α| = good)") |
| ax.plot(a, tumor, "g-^", lw=2, ms=7, label="Tumor probe (flat = causal disjoint)") |
| ax.axvline(0, color="gray", ls=":", lw=1, alpha=0.5) |
| ax.set_xlabel("Steering coefficient α") |
| ax.set_ylabel("Probe accuracy") |
| ax.set_title("Probe responses to steering", fontweight="bold", fontsize=10) |
| ax.legend(loc="best", fontsize=9); ax.grid(alpha=0.3) |
| ax.set_ylim(0, 1.05) |
|
|
| fig.suptitle(f"M5 — Activation Steering: {result['run_id']} " |
| f"• ep{result['epoch']} • layer={result['layer']}", |
| fontsize=11, fontweight="bold") |
| plt.tight_layout() |
| fig.savefig(out_path, dpi=180, bbox_inches="tight") |
| plt.close(fig) |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--run_dir", required=True) |
| p.add_argument("--data_root", default="data/wilds") |
| p.add_argument("--epoch", type=int, default=None) |
| p.add_argument("--max_samples", type=int, default=1000) |
| p.add_argument("--device", default="cuda") |
| p.add_argument("--alphas", default=None, |
| help="Comma-separated α values, e.g. ' -3,-2,-1,0,1,2,3 '") |
| p.add_argument("--all_epochs", action="store_true", |
| help="Sweep across all periodic checkpoints; output a trajectory") |
| args = p.parse_args() |
|
|
| alphas = None |
| if args.alphas is not None: |
| alphas = [float(x) for x in args.alphas.split(",")] |
|
|
| run_dir = Path(args.run_dir) |
| out_dir = run_dir / "mechinterp" |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| if args.all_epochs: |
| |
| ckpts = find_checkpoints(str(run_dir)) |
| seen = set(); uniq = [] |
| for ep, p in ckpts: |
| if ep in seen: |
| continue |
| seen.add(ep); uniq.append((ep, p)) |
| traj = [] |
| for ep, _ in uniq: |
| try: |
| r = run_steering( |
| run_dir=run_dir, data_root=args.data_root, epoch=ep, |
| max_samples=args.max_samples, device=args.device, alphas=alphas, |
| ) |
| traj.append(r) |
| except Exception as e: |
| print(f" [skip ep{ep}] {e}") |
| out = out_dir / "m5_steering_trajectory.json" |
| out.write_text(json.dumps(traj, indent=2)) |
| print(f"\n → {out}") |
| return |
|
|
| result = run_steering( |
| run_dir=run_dir, data_root=args.data_root, epoch=args.epoch, |
| max_samples=args.max_samples, device=args.device, alphas=alphas, |
| ) |
| base = out_dir / f"m5_steering_ep{result['epoch']:05d}" |
| base.with_suffix(".json").write_text(json.dumps(result, indent=2)) |
| plot_steering(result, base.with_suffix(".png")) |
| print(f"\n → {base.with_suffix('.json')}") |
| print(f" → {base.with_suffix('.png')}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|