| """M4 — Representation Ablation: causal intervention on the shortcut subspace. |
| |
| Pipeline: |
| 1. Pick a checkpoint (peak-OOD epoch by default). |
| 2. Extract features at avgpool (or `--layer`) for train (H0-H2) + OOD (H4) splits. |
| 3. Fit a hospital-classification logistic-regression probe on train features. |
| The probe's weight rows define the *shortcut subspace* in feature space. |
| 4. Build the projector P = W^T (W W^T)^-1 W onto that subspace and define |
| `ablate(h) = h - P h`. |
| 5. Re-classify OOD images with the *same* trained classifier head, fed: |
| (a) raw features h — baseline OOD accuracy |
| (b) ablated features h - Ph — post-intervention OOD accuracy |
| 6. Also report: |
| (c) shortcut accuracy (probe.score on h vs h-Ph) |
| (d) tumor probe accuracy on h vs h-Ph (sanity: the causal feature |
| should survive the intervention) |
| (e) head's tumor classification accuracy on H4 with raw vs ablated features |
| |
| If the intervention is causal: |
| - shortcut probe accuracy: collapses |
| - OOD accuracy: improves (or at least doesn't decay as much) |
| - tumor probe accuracy: largely preserved |
| |
| Usage |
| ----- |
| python -m experiments.mechinterp_m4_ablation \\ |
| --run_dir experiments/runs/<id> \\ |
| --data_root data/wilds \\ |
| --layer avgpool \\ |
| [--epoch 50] # default: peak_ood_epoch from summary.json |
| [--max_samples 1000] |
| |
| Output: |
| <run_dir>/mechinterp/m4_ablation_<layer>_ep<E>.json |
| <run_dir>/mechinterp/m4_ablation_<layer>_ep<E>.png |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| from pathlib import Path |
| from typing import Dict, Tuple |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from sklearn.linear_model import LogisticRegression |
| from sklearn.preprocessing import StandardScaler |
| from torch.utils.data import DataLoader, Subset |
| from torchvision import transforms |
|
|
| ROOT = Path(__file__).resolve().parent.parent |
| sys.path.insert(0, str(ROOT)) |
|
|
| |
| from experiments.mechinterp_m1 import ( |
| register_hooks, |
| extract_features, |
| load_model_from_checkpoint, |
| find_checkpoints, |
| ) |
| from utils.camelyon_data import get_camelyon_subsets |
|
|
|
|
| class _TransformWrapper: |
| def __init__(self, dataset, transform): |
| self.dataset = dataset |
| self.transform = transform |
| def __len__(self): |
| return len(self.dataset) |
| def __getitem__(self, idx): |
| img, label, metadata = self.dataset[idx] |
| return self.transform(img), label, metadata |
|
|
|
|
| def _build_loaders(data_root: str, max_samples: int, seed: int = 42): |
| transform = transforms.Compose([ |
| transforms.Resize((96, 96)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]), |
| ]) |
| train_ds, id_val_ds, ood_test_ds, _ = get_camelyon_subsets( |
| root_dir=data_root, download=False |
| ) |
| train_t = _TransformWrapper(train_ds, transform) |
| ood_t = _TransformWrapper(ood_test_ds, transform) |
|
|
| torch.manual_seed(seed) |
| train_idx = torch.randperm(len(train_t))[:max_samples] |
| ood_idx = torch.randperm(len(ood_t))[:max_samples // 2] |
|
|
| train_loader = DataLoader(Subset(train_t, train_idx), batch_size=128, |
| shuffle=False, num_workers=0) |
| ood_loader = DataLoader(Subset(ood_t, ood_idx), batch_size=128, |
| shuffle=False, num_workers=0) |
| return train_loader, ood_loader |
|
|
|
|
| def _select_epoch(run_dir: Path, requested: int | None) -> Tuple[int, Path]: |
| ckpts = find_checkpoints(str(run_dir)) |
| if not ckpts: |
| raise FileNotFoundError(f"No checkpoints in {run_dir}/checkpoints/") |
|
|
| if requested is not None: |
| for ep, p in ckpts: |
| if ep == requested: |
| return ep, Path(p) |
| raise ValueError(f"Requested epoch {requested} not in checkpoints " |
| f"({[ep for ep, _ in ckpts]})") |
|
|
| |
| summary_path = run_dir / "results" / "summary.json" |
| peak = None |
| if summary_path.exists(): |
| s = json.loads(summary_path.read_text()) |
| peak = s.get("peak_ood_epoch", None) |
|
|
| if peak is not None and peak > 0: |
| |
| nearest = min(ckpts, key=lambda x: abs(x[0] - peak)) |
| return nearest[0], Path(nearest[1]) |
|
|
| |
| return ckpts[-1][0], Path(ckpts[-1][1]) |
|
|
|
|
| def _build_projector(W: np.ndarray) -> np.ndarray: |
| """W has shape (k, d). Returns P (d, d) projecting onto rowspace(W).""" |
| |
| U, s, Vt = np.linalg.svd(W, full_matrices=False) |
| |
| tol = max(W.shape) * np.finfo(s.dtype).eps * (s.max() if s.size else 0.0) |
| keep = s > tol |
| basis = Vt[keep] |
| return basis.T @ basis |
|
|
|
|
| def _build_shortcut_subspace( |
| X: np.ndarray, hospital_ids: np.ndarray, |
| method: str = "lda", subspace_dim: int = 32 |
| ) -> np.ndarray: |
| """Return a (k, d) basis whose row-span is the 'shortcut subspace'. |
| |
| method='probe' — k = (n_classes - 1) probe weight rows (small subspace). |
| method='lda' — k = subspace_dim top between-class directions: take |
| per-hospital means in feature space, center them, |
| and run SVD. This gives a rank-bounded but data-driven |
| subspace that captures hospital-discriminating variance. |
| method='pca-class' — top-PCs of features colored by hospital (mean-removed |
| per class), giving us the variance directions that |
| mostly reflect within-hospital structure × class. |
| """ |
| if method == "probe": |
| clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", |
| multi_class="auto", n_jobs=-1) |
| clf.fit(X, hospital_ids) |
| return clf.coef_ |
|
|
| if method == "lda": |
| classes = np.unique(hospital_ids) |
| global_mean = X.mean(axis=0, keepdims=True) |
| between = [] |
| for c in classes: |
| mu_c = X[hospital_ids == c].mean(axis=0, keepdims=True) |
| between.append(mu_c - global_mean) |
| between = np.vstack(between) |
| |
| |
| if subspace_dim > between.shape[0]: |
| |
| residuals = [] |
| for c in classes: |
| mu_c = X[hospital_ids == c].mean(axis=0, keepdims=True) |
| residuals.append(X[hospital_ids == c] - mu_c) |
| R = np.vstack(residuals) |
| |
| |
| |
| |
| |
| U, s, Vt = np.linalg.svd(X - global_mean, full_matrices=False) |
| top = Vt[:subspace_dim] |
| |
| |
| one_hot = np.eye(len(classes))[ |
| np.searchsorted(classes, hospital_ids) |
| ] |
| proj = (X - global_mean) @ top.T |
| corrs = np.array([ |
| np.max(np.abs([np.corrcoef(proj[:, k], one_hot[:, c])[0, 1] |
| for c in range(len(classes))])) |
| for k in range(subspace_dim) |
| ]) |
| |
| order = np.argsort(-np.nan_to_num(corrs)) |
| top_hosp = top[order[:subspace_dim]] |
| |
| return np.vstack([between, top_hosp]) |
|
|
| return between |
|
|
| raise ValueError(f"Unknown method: {method}") |
|
|
|
|
| def _classifier_logits_from_features( |
| model: nn.Module, features: np.ndarray, layer: str, device: str |
| ) -> np.ndarray: |
| """Apply the *post-`layer`* part of the network to the (modified) features |
| and return the model's binary-classification logits. |
| |
| For ResNet, `avgpool` features have shape (N, C). The classifier head |
| `model.fc` (timm: `model.get_classifier()`) maps C → 2. For non-avgpool |
| layers we do not currently support full propagation — caller should use |
| layer='avgpool' for OOD-accuracy interventions.""" |
| if layer != "avgpool": |
| raise NotImplementedError( |
| "Re-applying the classifier head from intermediate spatial layers " |
| "is not yet supported. Use --layer avgpool for the head-level " |
| "ablation." |
| ) |
|
|
| |
| if hasattr(model, "get_classifier"): |
| head = model.get_classifier() |
| elif hasattr(model, "fc"): |
| head = model.fc |
| elif hasattr(model, "classifier"): |
| head = model.classifier |
| else: |
| raise RuntimeError("Could not locate classifier head on the model.") |
|
|
| head = head.to(device).eval() |
| with torch.no_grad(): |
| x = torch.tensor(features, dtype=torch.float32, device=device) |
| logits = head(x).cpu().numpy() |
| return logits |
|
|
|
|
| def _accuracy(logits: np.ndarray, labels: np.ndarray) -> float: |
| if logits.ndim == 1 or logits.shape[1] == 1: |
| pred = (logits.flatten() > 0).astype(int) |
| else: |
| pred = logits.argmax(axis=1) |
| return float((pred == labels).mean()) |
|
|
|
|
| def run_ablation( |
| run_dir: Path, |
| data_root: str, |
| layer: str = "avgpool", |
| epoch: int | None = None, |
| max_samples: int = 1000, |
| device: str = "cuda", |
| subspace_method: str = "lda", |
| subspace_dim: int = 32, |
| ) -> Dict: |
| epoch, ckpt_path = _select_epoch(run_dir, epoch) |
|
|
| print(f"\n M4 — Representation Ablation") |
| print(f" run_dir : {run_dir.name}") |
| print(f" epoch : {epoch} ({ckpt_path.name})") |
| print(f" layer : {layer}") |
|
|
| |
| model = load_model_from_checkpoint(str(ckpt_path), n_classes=2, device=device) |
| model.eval() |
| register_hooks(model) |
|
|
| cfg_path = run_dir / "config.json" |
| seed = 42 |
| if cfg_path.exists(): |
| seed = json.loads(cfg_path.read_text()).get("seed", 42) |
| train_loader, ood_loader = _build_loaders(data_root, max_samples, seed=seed) |
|
|
| |
| print(f" Extracting features ({max_samples} samples per split)...") |
| feats_train, hosp_train, tumor_train = extract_features( |
| model, train_loader, device, max_samples=max_samples |
| ) |
| feats_ood, hosp_ood, tumor_ood = extract_features( |
| model, ood_loader, device, max_samples=max_samples // 2 |
| ) |
|
|
| if layer not in feats_train: |
| raise KeyError(f"Layer '{layer}' not in extracted features " |
| f"({list(feats_train.keys())})") |
|
|
| X_tr = np.asarray(feats_train[layer]) |
| X_ood = np.asarray(feats_ood[layer]) |
| if X_tr.ndim > 2: |
| X_tr = X_tr.reshape(X_tr.shape[0], -1) |
| X_ood = X_ood.reshape(X_ood.shape[0], -1) |
|
|
| |
| |
| scaler = StandardScaler().fit(X_tr) |
| X_tr_n = scaler.transform(X_tr) |
| X_ood_n = scaler.transform(X_ood) |
|
|
| |
| print(f" Fitting hospital probe on H0/H1/H2 train features...") |
| hosp_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", |
| multi_class="auto", n_jobs=-1) |
| hosp_clf.fit(X_tr_n, hosp_train) |
| hosp_acc_train = hosp_clf.score(X_tr_n, hosp_train) |
|
|
| |
| |
| |
| W = _build_shortcut_subspace(X_tr_n, np.asarray(hosp_train), |
| method=subspace_method, |
| subspace_dim=subspace_dim) |
| P = _build_projector(W) |
| rank_subspace = int(np.linalg.matrix_rank(P, tol=1e-8)) |
| print(f" Shortcut subspace: dim={rank_subspace} method={subspace_method} " |
| f"(probe train acc {hosp_acc_train:.3f})") |
|
|
| |
| |
| |
| def ablate_norm(X_n): |
| return X_n - X_n @ P.T |
|
|
| X_ood_ablated_n = ablate_norm(X_ood_n) |
| |
| X_ood_ablated = scaler.inverse_transform(X_ood_ablated_n) |
|
|
| |
| print(f" Re-fitting tumor probe on train features...") |
| tumor_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs", |
| multi_class="auto", n_jobs=-1) |
| tumor_clf.fit(X_tr_n, tumor_train) |
| tumor_acc_train = tumor_clf.score(X_tr_n, tumor_train) |
|
|
| |
| hosp_acc_ood_raw = hosp_clf.score(X_ood_n, hosp_ood) if len(np.unique(hosp_ood)) > 1 else float("nan") |
| hosp_acc_ood_ablated = hosp_clf.score(X_ood_ablated_n, hosp_ood) if len(np.unique(hosp_ood)) > 1 else float("nan") |
| tumor_acc_ood_raw = tumor_clf.score(X_ood_n, tumor_ood) |
| tumor_acc_ood_ablated = tumor_clf.score(X_ood_ablated_n, tumor_ood) |
|
|
| |
| print(f" Re-classifying OOD with model head (raw vs ablated features)...") |
| logits_raw = _classifier_logits_from_features(model, X_ood, layer, device) |
| logits_ablated = _classifier_logits_from_features(model, X_ood_ablated, layer, device) |
|
|
| head_acc_raw = _accuracy(logits_raw, tumor_ood) |
| head_acc_ablated = _accuracy(logits_ablated, tumor_ood) |
|
|
| |
| result = { |
| "run_id": run_dir.name, |
| "epoch": epoch, |
| "layer": layer, |
| "max_samples": max_samples, |
| "shortcut_subspace_dim": rank_subspace, |
| "hospital_probe_train_acc": hosp_acc_train, |
| "tumor_probe_train_acc": tumor_acc_train, |
| "hospital_probe_ood_raw": hosp_acc_ood_raw, |
| "hospital_probe_ood_ablated": hosp_acc_ood_ablated, |
| "tumor_probe_ood_raw": tumor_acc_ood_raw, |
| "tumor_probe_ood_ablated": tumor_acc_ood_ablated, |
| "head_ood_acc_raw": head_acc_raw, |
| "head_ood_acc_ablated": head_acc_ablated, |
| "intervention_effect": { |
| "shortcut_collapse": hosp_acc_ood_raw - hosp_acc_ood_ablated, |
| "ood_improvement": head_acc_ablated - head_acc_raw, |
| "tumor_preservation": tumor_acc_ood_ablated - tumor_acc_ood_raw, |
| }, |
| } |
|
|
| print(f"\n RESULTS") |
| print(f" hospital probe (OOD): {hosp_acc_ood_raw:.3f} → {hosp_acc_ood_ablated:.3f} " |
| f"(Δ {result['intervention_effect']['shortcut_collapse']:+.3f})") |
| print(f" tumor probe (OOD) : {tumor_acc_ood_raw:.3f} → {tumor_acc_ood_ablated:.3f} " |
| f"(Δ {result['intervention_effect']['tumor_preservation']:+.3f})") |
| print(f" head OOD acc : {head_acc_raw:.3f} → {head_acc_ablated:.3f} " |
| f"(Δ {result['intervention_effect']['ood_improvement']:+.3f})") |
|
|
| return result |
|
|
|
|
| def plot_ablation(result: Dict, out_path: Path): |
| metrics = ["hospital_probe_ood", "tumor_probe_ood", "head_ood_acc"] |
| raw_keys = ["hospital_probe_ood_raw", "tumor_probe_ood_raw", "head_ood_acc_raw"] |
| ablated_keys = ["hospital_probe_ood_ablated", "tumor_probe_ood_ablated", "head_ood_acc_ablated"] |
| labels = ["Hospital probe\n(↓ = causal effect)", |
| "Tumor probe\n(stable = good)", |
| "Head OOD acc\n(↑ = causal effect)"] |
| raws = [result[k] for k in raw_keys] |
| ablateds = [result[k] for k in ablated_keys] |
|
|
| fig, ax = plt.subplots(figsize=(9, 5)) |
| x = np.arange(len(metrics)) |
| w = 0.35 |
| b1 = ax.bar(x - w / 2, raws, w, label="raw features", color="#444") |
| b2 = ax.bar(x + w / 2, ablateds, w, label="shortcut-ablated", color="#c33") |
| for bars in (b1, b2): |
| for b in bars: |
| ax.text(b.get_x() + b.get_width() / 2, b.get_height() + 0.005, |
| f"{b.get_height():.3f}", ha="center", va="bottom", fontsize=9) |
| ax.set_xticks(x); ax.set_xticklabels(labels, fontsize=9) |
| ax.set_ylim(0, 1.05); ax.set_ylabel("Accuracy") |
| ax.set_title(f"M4 — Causal Ablation of Shortcut Subspace\n" |
| f"{result['run_id']} • ep{result['epoch']} • layer={result['layer']} " |
| f"• subspace dim={result['shortcut_subspace_dim']}", |
| fontsize=10, fontweight="bold") |
| ax.legend(loc="upper right") |
| ax.grid(alpha=0.3, axis="y") |
| plt.tight_layout() |
| fig.savefig(out_path, dpi=180, bbox_inches="tight") |
| plt.close(fig) |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--run_dir", required=True) |
| p.add_argument("--data_root", default="data/wilds") |
| p.add_argument("--layer", default="avgpool", |
| choices=["avgpool"]) |
| p.add_argument("--epoch", type=int, default=None, |
| help="Specific checkpoint epoch; default = peak_ood_epoch from summary.json") |
| p.add_argument("--max_samples", type=int, default=1000) |
| p.add_argument("--device", default="cuda") |
| p.add_argument("--subspace_method", default="lda", |
| choices=["lda", "probe"], |
| help="lda = LDA-style between-class + hospital-correlated PCs; " |
| "probe = LR probe row-space (small, often only 2-D)") |
| p.add_argument("--subspace_dim", type=int, default=32, |
| help="Target subspace dim for lda method") |
| p.add_argument("--all_epochs", action="store_true", |
| help="Sweep across all periodic checkpoints") |
| args = p.parse_args() |
|
|
| run_dir = Path(args.run_dir) |
| out_dir = run_dir / "mechinterp" |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| if args.all_epochs: |
| |
| ckpts = find_checkpoints(str(run_dir)) |
| |
| seen = set(); uniq = [] |
| for ep, p in ckpts: |
| if ep in seen: |
| continue |
| seen.add(ep); uniq.append((ep, p)) |
|
|
| traj = [] |
| for ep, _ in uniq: |
| try: |
| r = run_ablation( |
| run_dir=run_dir, data_root=args.data_root, layer=args.layer, |
| epoch=ep, max_samples=args.max_samples, device=args.device, |
| subspace_method=args.subspace_method, |
| subspace_dim=args.subspace_dim, |
| ) |
| traj.append(r) |
| except Exception as e: |
| print(f" [skip ep{ep}] {e}") |
|
|
| out = out_dir / f"m4_ablation_{args.layer}_trajectory.json" |
| out.write_text(json.dumps(traj, indent=2)) |
| plot_trajectory(traj, out.with_suffix(".png")) |
| print(f"\n → {out}") |
| print(f" → {out.with_suffix('.png')}") |
| return |
|
|
| result = run_ablation( |
| run_dir=run_dir, |
| data_root=args.data_root, |
| layer=args.layer, |
| epoch=args.epoch, |
| max_samples=args.max_samples, |
| device=args.device, |
| subspace_method=args.subspace_method, |
| subspace_dim=args.subspace_dim, |
| ) |
|
|
| base = out_dir / f"m4_ablation_{args.layer}_ep{result['epoch']:05d}" |
| (base.with_suffix(".json")).write_text(json.dumps(result, indent=2)) |
| plot_ablation(result, base.with_suffix(".png")) |
| print(f"\n → {base.with_suffix('.json')}") |
| print(f" → {base.with_suffix('.png')}") |
|
|
|
|
| def plot_trajectory(traj, out_path: Path): |
| """Plot the intervention effect across training epochs.""" |
| eps = [r["epoch"] for r in traj] |
| head_raw = [r["head_ood_acc_raw"] for r in traj] |
| head_abl = [r["head_ood_acc_ablated"] for r in traj] |
| tum_raw = [r["tumor_probe_ood_raw"] for r in traj] |
| tum_abl = [r["tumor_probe_ood_ablated"] for r in traj] |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
|
|
| |
| ax = axes[0] |
| ax.plot(eps, head_raw, "k-o", lw=2, label="raw features") |
| ax.plot(eps, head_abl, "r-s", lw=2, label="shortcut-ablated features") |
| ax.fill_between(eps, head_raw, head_abl, |
| where=[a > b for a, b in zip(head_abl, head_raw)], |
| color="seagreen", alpha=0.3, label="ablation helps") |
| ax.fill_between(eps, head_raw, head_abl, |
| where=[a < b for a, b in zip(head_abl, head_raw)], |
| color="salmon", alpha=0.3, label="ablation hurts") |
| ax.set_xlabel("Training epoch"); ax.set_ylabel("OOD (H4) head accuracy") |
| ax.set_title("Head OOD accuracy: raw vs shortcut-ablated", fontweight="bold") |
| ax.legend(fontsize=9); ax.grid(alpha=0.3) |
|
|
| |
| ax = axes[1] |
| ax.plot(eps, tum_raw, "k-o", lw=2, label="raw features") |
| ax.plot(eps, tum_abl, "g-s", lw=2, label="shortcut-ablated features") |
| ax.set_xlabel("Training epoch"); ax.set_ylabel("Tumor probe OOD accuracy") |
| ax.set_title("Tumor probe survival under ablation\n(stable line = causal feature preserved)", |
| fontweight="bold") |
| ax.legend(fontsize=9); ax.grid(alpha=0.3); ax.set_ylim(0.4, 1.0) |
|
|
| rid = traj[0]["run_id"] if traj else "?" |
| layer = traj[0]["layer"] if traj else "?" |
| fig.suptitle(f"M4 — Causal Ablation Trajectory: {rid} • layer={layer}", |
| fontsize=11, fontweight="bold") |
| plt.tight_layout() |
| fig.savefig(out_path, dpi=180, bbox_inches="tight") |
| plt.close(fig) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|