| """M6 — Neuron-level Ablation (the textbook reviewer-asked intervention). |
| |
| Pipeline: |
| 1. At a chosen checkpoint (default: peak_ood_epoch), extract avgpool |
| features for train (H0-H2) and OOD (H4) splits. |
| 2. Score each of the 512 avgpool channels by *how predictive its activation |
| is of hospital ID*: we use a one-vs-rest logistic-regression coefficient |
| per channel × class as the per-neuron shortcut score: |
| score_c = max_h |β_{h,c}| (β = coefficients of LR fit per channel) |
| ↑ score_c → channel c is more strongly stain-shortcut-aligned. |
| 3. Sweep top-K ∈ {0, 8, 16, 32, 64, 128} ablated neurons (zero out their |
| activations) and measure: |
| - head OOD acc (raw vs ablated) |
| - hospital-probe acc on raw vs ablated features |
| - tumor-probe acc on raw vs ablated features |
| 4. Strong mechanistic claim: |
| - hospital-probe acc collapses sharply with K (these neurons are |
| carrying hospital info) |
| - head OOD acc *improves* (or at least preserves) at small K (the |
| model was using shortcut neurons to harm OOD) |
| - tumor-probe acc stays flat (causal info is distributed elsewhere) |
| |
| Usage |
| ----- |
| python -m experiments.mechinterp_m6_neuron_ablation \\ |
| --run_dir experiments/runs/<id> \\ |
| --data_root data/wilds \\ |
| [--epoch 50] [--max_samples 1000] \\ |
| [--ks "0,4,8,16,32,64,128,256"] |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
| from typing import Dict, List |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
| from sklearn.linear_model import LogisticRegression |
| from sklearn.preprocessing import StandardScaler |
|
|
| from torch.utils.data import DataLoader, Subset |
| from torchvision import transforms |
|
|
| from experiments.mechinterp_m1 import ( |
| register_hooks, extract_features, load_model_from_checkpoint, |
| ) |
| from experiments.mechinterp_m4_ablation import ( |
| _select_epoch, _TransformWrapper, |
| _classifier_logits_from_features, _accuracy, |
| ) |
| from utils.camelyon_data import get_camelyon_subsets |
|
|
|
|
| def _build_loaders_with_id(data_root: str, max_samples: int, seed: int = 42): |
| """Like M4's _build_loaders but also returns an ID validation loader so |
| we can track ID acc and compute the OOD/ID degradation ratio.""" |
| transform = transforms.Compose([ |
| transforms.Resize((96, 96)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]), |
| ]) |
| train_ds, id_val_ds, ood_test_ds, _ = get_camelyon_subsets( |
| root_dir=data_root, download=False |
| ) |
| train_t = _TransformWrapper(train_ds, transform) |
| id_t = _TransformWrapper(id_val_ds, transform) |
| ood_t = _TransformWrapper(ood_test_ds, transform) |
|
|
| torch.manual_seed(seed) |
| train_idx = torch.randperm(len(train_t))[:max_samples] |
| id_idx = torch.randperm(len(id_t))[:max_samples // 2] |
| ood_idx = torch.randperm(len(ood_t))[:max_samples // 2] |
|
|
| train_loader = DataLoader(Subset(train_t, train_idx), batch_size=128, |
| shuffle=False, num_workers=0) |
| id_loader = DataLoader(Subset(id_t, id_idx), batch_size=128, |
| shuffle=False, num_workers=0) |
| ood_loader = DataLoader(Subset(ood_t, ood_idx), batch_size=128, |
| shuffle=False, num_workers=0) |
| return train_loader, id_loader, ood_loader |
|
|
|
|
| def _per_neuron_shortcut_scores(X_n: np.ndarray, hosp: np.ndarray) -> np.ndarray: |
| """Return a (D,) array — score per channel c, larger = more hospital-predictive. |
| |
| Uses a 1-feature-at-a-time log-reg fit's |coef| would be dominated by feature |
| scale; instead we fit a single multiclass LR over all features and use the |
| L2 norm of (β_{:,c}) — the column norm of the LR coefficient matrix — |
| as channel c's hospital-discrimination score. |
| """ |
| clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", |
| multi_class="auto", n_jobs=-1).fit(X_n, hosp) |
| W = clf.coef_ |
| |
| return np.linalg.norm(W, axis=0) |
|
|
|
|
| def _ablate_and_eval( |
| X_n, mask, scaler, head_target, model, layer, device, |
| hosp_clf, tumor_clf, hosp_target, tumor_target, |
| ): |
| """Apply mask to normalized features, unscale, evaluate everything.""" |
| X_ablated_n = X_n * mask[None, :] |
| X_ablated = scaler.inverse_transform(X_ablated_n) |
| logits = _classifier_logits_from_features(model, X_ablated, layer, device) |
| head_acc = _accuracy(logits, head_target) |
| hosp_acc = hosp_clf.score(X_ablated_n, hosp_target) if hosp_clf is not None and len(np.unique(hosp_target)) > 1 else float("nan") |
| tumor_acc = tumor_clf.score(X_ablated_n, tumor_target) |
| return head_acc, hosp_acc, tumor_acc |
|
|
|
|
| def run_neuron_ablation( |
| run_dir: Path, |
| data_root: str, |
| epoch: int | None = None, |
| max_samples: int = 1000, |
| device: str = "cuda", |
| ks: List[int] = None, |
| n_random_samples: int = 5, |
| include_morphology: bool = True, |
| include_id: bool = True, |
| ) -> Dict: |
| if ks is None: |
| |
| ks = [0, 4, 8, 16, 32, 64, 128, 256] |
|
|
| epoch, ckpt_path = _select_epoch(run_dir, epoch) |
|
|
| print(f"\n M6 — Neuron Ablation (with random + morphology controls)") |
| print(f" run_dir : {run_dir.name}") |
| print(f" epoch : {epoch} ({ckpt_path.name})") |
| print(f" ks : {ks}") |
| print(f" random ablation: {n_random_samples} samplings per K") |
|
|
| model = load_model_from_checkpoint(str(ckpt_path), n_classes=2, device=device) |
| model.eval() |
| register_hooks(model) |
|
|
| cfg_path = run_dir / "config.json" |
| seed = 42 |
| if cfg_path.exists(): |
| seed = json.loads(cfg_path.read_text()).get("seed", 42) |
|
|
| if include_id: |
| train_loader, id_loader, ood_loader = _build_loaders_with_id(data_root, max_samples, seed=seed) |
| else: |
| from experiments.mechinterp_m4_ablation import _build_loaders as _bl |
| train_loader, ood_loader = _bl(data_root, max_samples, seed=seed) |
| id_loader = None |
|
|
| print(f" Extracting features (train + id + ood)...") |
| feats_train, hosp_train, tumor_train = extract_features( |
| model, train_loader, device, max_samples=max_samples |
| ) |
| feats_ood, hosp_ood, tumor_ood = extract_features( |
| model, ood_loader, device, max_samples=max_samples // 2 |
| ) |
| feats_id, hosp_id, tumor_id = (None, None, None) |
| if id_loader is not None: |
| feats_id, hosp_id, tumor_id = extract_features( |
| model, id_loader, device, max_samples=max_samples // 2 |
| ) |
|
|
| layer = "avgpool" |
| def _to_2d(arr): |
| a = np.asarray(arr); return a.reshape(a.shape[0], -1) |
| X_tr = _to_2d(feats_train[layer]) |
| X_ood = _to_2d(feats_ood[layer]) |
| X_id = _to_2d(feats_id[layer]) if feats_id is not None else None |
| hosp_train = np.asarray(hosp_train); hosp_ood = np.asarray(hosp_ood) |
| tumor_train = np.asarray(tumor_train); tumor_ood = np.asarray(tumor_ood) |
| if X_id is not None: |
| hosp_id = np.asarray(hosp_id); tumor_id = np.asarray(tumor_id) |
|
|
| scaler = StandardScaler().fit(X_tr) |
| X_tr_n = scaler.transform(X_tr) |
| X_ood_n = scaler.transform(X_ood) |
| X_id_n = scaler.transform(X_id) if X_id is not None else None |
|
|
| |
| print(f" Scoring {X_tr.shape[1]} avgpool channels...") |
| shortcut_scores = _per_neuron_shortcut_scores(X_tr_n, hosp_train) |
| morphology_scores = _per_neuron_shortcut_scores(X_tr_n, tumor_train) if include_morphology else None |
| rank_shortcut = np.argsort(-shortcut_scores) |
| rank_morphology = np.argsort(-morphology_scores) if morphology_scores is not None else None |
|
|
| hosp_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", |
| multi_class="auto", n_jobs=-1).fit(X_tr_n, hosp_train) |
| tumor_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", |
| multi_class="auto", n_jobs=-1).fit(X_tr_n, tumor_train) |
|
|
| rng = np.random.default_rng(seed) |
| D = X_tr.shape[1] |
|
|
| sweep = [] |
| for k in ks: |
| row = {"k": int(k)} |
|
|
| |
| def make_mask(indices): |
| m = np.ones(D) |
| if k > 0: |
| m[indices[:k]] = 0.0 |
| return m |
|
|
| |
| mask_s = make_mask(rank_shortcut) |
| h_ood, hp_ood, tp_ood = _ablate_and_eval( |
| X_ood_n, mask_s, scaler, tumor_ood, model, layer, device, |
| hosp_clf, tumor_clf, hosp_ood, tumor_ood, |
| ) |
| row["shortcut_head_ood"] = float(h_ood) |
| row["shortcut_hosp_probe"] = float(hp_ood) |
| row["shortcut_tumor_probe"] = float(tp_ood) |
| if X_id_n is not None: |
| h_id, _, _ = _ablate_and_eval( |
| X_id_n, mask_s, scaler, tumor_id, model, layer, device, |
| None, tumor_clf, hosp_id, tumor_id, |
| ) |
| row["shortcut_head_id"] = float(h_id) |
|
|
| |
| if include_morphology and rank_morphology is not None: |
| mask_m = make_mask(rank_morphology) |
| h_ood_m, _, _ = _ablate_and_eval( |
| X_ood_n, mask_m, scaler, tumor_ood, model, layer, device, |
| None, tumor_clf, hosp_ood, tumor_ood, |
| ) |
| row["morphology_head_ood"] = float(h_ood_m) |
| if X_id_n is not None: |
| h_id_m, _, _ = _ablate_and_eval( |
| X_id_n, mask_m, scaler, tumor_id, model, layer, device, |
| None, tumor_clf, hosp_id, tumor_id, |
| ) |
| row["morphology_head_id"] = float(h_id_m) |
|
|
| |
| if k > 0: |
| r_oods, r_ids = [], [] |
| for s_ in range(n_random_samples): |
| idx = rng.permutation(D)[:k] |
| m = np.ones(D); m[idx] = 0.0 |
| h_ood_r, _, _ = _ablate_and_eval( |
| X_ood_n, m, scaler, tumor_ood, model, layer, device, |
| None, tumor_clf, hosp_ood, tumor_ood, |
| ) |
| r_oods.append(h_ood_r) |
| if X_id_n is not None: |
| h_id_r, _, _ = _ablate_and_eval( |
| X_id_n, m, scaler, tumor_id, model, layer, device, |
| None, tumor_clf, hosp_id, tumor_id, |
| ) |
| r_ids.append(h_id_r) |
| row["random_head_ood_mean"] = float(np.mean(r_oods)) |
| row["random_head_ood_std"] = float(np.std(r_oods)) |
| if r_ids: |
| row["random_head_id_mean"] = float(np.mean(r_ids)) |
| row["random_head_id_std"] = float(np.std(r_ids)) |
| else: |
| row["random_head_ood_mean"] = row["shortcut_head_ood"] |
| row["random_head_ood_std"] = 0.0 |
| if X_id_n is not None: |
| row["random_head_id_mean"] = row.get("shortcut_head_id", float("nan")) |
| row["random_head_id_std"] = 0.0 |
|
|
| sweep.append(row) |
| |
| print(f" K={k:>4} shortcut={row['shortcut_head_ood']:.3f} " |
| f"random={row.get('random_head_ood_mean', float('nan')):.3f}±" |
| f"{row.get('random_head_ood_std', 0):.3f} " |
| + (f"morphology={row.get('morphology_head_ood', float('nan')):.3f}" |
| if include_morphology else "")) |
|
|
| return { |
| "run_id": run_dir.name, |
| "epoch": epoch, |
| "layer": layer, |
| "max_samples": max_samples, |
| "feature_dim": int(X_tr.shape[1]), |
| "shortcut_scores_top10": [int(i) for i in rank_shortcut[:10]], |
| "morphology_scores_top10": ([int(i) for i in rank_morphology[:10]] |
| if rank_morphology is not None else []), |
| "n_random_samples": n_random_samples, |
| "include_id": include_id, |
| "include_morphology": include_morphology, |
| "sweep": sweep, |
| } |
|
|
|
|
| def plot_neuron_ablation(result: Dict, out_path: Path): |
| sweep = result["sweep"] |
| ks = [r["k"] for r in sweep] |
|
|
| has_id = result.get("include_id", False) |
| has_morph = result.get("include_morphology", False) |
|
|
| fig, axes = plt.subplots(1, 2 if has_id else 1, figsize=(13, 5)) if has_id else \ |
| plt.subplots(1, 1, figsize=(8, 5)) |
| if not has_id: |
| axes = [axes] |
|
|
| |
| ax = axes[0] |
| shortcut_ood = [r.get("shortcut_head_ood") for r in sweep] |
| random_ood_mu = [r.get("random_head_ood_mean") for r in sweep] |
| random_ood_sd = [r.get("random_head_ood_std", 0) for r in sweep] |
| morphology_ood = [r.get("morphology_head_ood") for r in sweep] if has_morph else None |
|
|
| ax.plot(ks, shortcut_ood, "r-o", lw=2.2, ms=7, label="top-K shortcut neurons (targeted)") |
| ax.plot(ks, random_ood_mu, "k-s", lw=1.8, ms=6, label="K random neurons (control)") |
| ax.fill_between(ks, |
| [m - s for m, s in zip(random_ood_mu, random_ood_sd)], |
| [m + s for m, s in zip(random_ood_mu, random_ood_sd)], |
| color="black", alpha=0.15) |
| if has_morph and morphology_ood is not None: |
| ax.plot(ks, morphology_ood, "g-^", lw=1.8, ms=6, |
| label="top-K morphology neurons (control)") |
|
|
| base = shortcut_ood[0] |
| ax.axhline(base, color="gray", ls=":", lw=1, alpha=0.5, |
| label=f"K=0 baseline ({base:.3f})") |
| ax.set_xlabel("K (neurons zeroed at avgpool)") |
| ax.set_ylabel("Head OOD (H4) accuracy") |
| ax.set_xscale("symlog", linthresh=4) |
| ax.set_title("Targeted vs random ablation — OOD effect\n" |
| "(separation = shortcut neurons selectively hurt OOD)", |
| fontweight="bold", fontsize=10) |
| ax.legend(loc="best", fontsize=8); ax.grid(alpha=0.3) |
|
|
| |
| if has_id: |
| ax = axes[1] |
| shortcut_id = [r.get("shortcut_head_id") for r in sweep] |
| random_id_mu = [r.get("random_head_id_mean") for r in sweep] |
| random_id_sd = [r.get("random_head_id_std", 0) for r in sweep] |
| ax.plot(ks, shortcut_id, "r--o", lw=2, ms=7, alpha=0.85, label="ID (shortcut ablation)") |
| ax.plot(ks, shortcut_ood, "r-o", lw=2, ms=7, label="OOD (shortcut ablation)") |
| ax.plot(ks, random_id_mu, "k--s", lw=1.6, ms=5, alpha=0.7, label="ID (random ablation)") |
| ax.plot(ks, random_ood_mu, "k-s", lw=1.6, ms=5, alpha=0.7, label="OOD (random ablation)") |
| ax.set_xlabel("K (neurons zeroed at avgpool)") |
| ax.set_ylabel("Head accuracy") |
| ax.set_xscale("symlog", linthresh=4) |
| ax.set_title("ID vs OOD degradation tradeoff\n" |
| "(targeted: OOD steady or ↑ while ID slowly ↓ = good)", |
| fontweight="bold", fontsize=10) |
| ax.legend(fontsize=8, loc="best"); ax.grid(alpha=0.3) |
|
|
| fig.suptitle(f"M6 — Targeted Neuron Ablation vs Random Control: {result['run_id']} " |
| f"• ep{result['epoch']}", |
| fontsize=11, fontweight="bold") |
| plt.tight_layout() |
| fig.savefig(out_path, dpi=180, bbox_inches="tight") |
| plt.close(fig) |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--run_dir", required=True) |
| p.add_argument("--data_root", default="data/wilds") |
| p.add_argument("--epoch", type=int, default=None) |
| p.add_argument("--max_samples", type=int, default=1000) |
| p.add_argument("--device", default="cuda") |
| p.add_argument("--ks", default=None, |
| help="Comma-separated K values, e.g. '0,4,8,16,32,64,128,256'") |
| p.add_argument("--n_random_samples", type=int, default=5, |
| help="Random ablation: averages over this many random K-subsets") |
| p.add_argument("--no_morphology", action="store_true", |
| help="Skip the morphology-targeted ablation control") |
| p.add_argument("--no_id", action="store_true", |
| help="Skip ID accuracy evaluation (faster but loses ID/OOD ratio)") |
| args = p.parse_args() |
|
|
| ks = None |
| if args.ks is not None: |
| ks = [int(x) for x in args.ks.split(",")] |
|
|
| run_dir = Path(args.run_dir) |
| out_dir = run_dir / "mechinterp" |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| result = run_neuron_ablation( |
| run_dir=run_dir, data_root=args.data_root, epoch=args.epoch, |
| max_samples=args.max_samples, device=args.device, ks=ks, |
| n_random_samples=args.n_random_samples, |
| include_morphology=not args.no_morphology, |
| include_id=not args.no_id, |
| ) |
| base = out_dir / f"m6_neuron_ablation_ep{result['epoch']:05d}" |
| base.with_suffix(".json").write_text(json.dumps(result, indent=2)) |
| plot_neuron_ablation(result, base.with_suffix(".png")) |
| print(f"\n → {base.with_suffix('.json')}") |
| print(f" → {base.with_suffix('.png')}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|