File size: 22,769 Bytes
50fa85c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 | """M4 β Representation Ablation: causal intervention on the shortcut subspace.
Pipeline:
1. Pick a checkpoint (peak-OOD epoch by default).
2. Extract features at avgpool (or `--layer`) for train (H0-H2) + OOD (H4) splits.
3. Fit a hospital-classification logistic-regression probe on train features.
The probe's weight rows define the *shortcut subspace* in feature space.
4. Build the projector P = W^T (W W^T)^-1 W onto that subspace and define
`ablate(h) = h - P h`.
5. Re-classify OOD images with the *same* trained classifier head, fed:
(a) raw features h β baseline OOD accuracy
(b) ablated features h - Ph β post-intervention OOD accuracy
6. Also report:
(c) shortcut accuracy (probe.score on h vs h-Ph)
(d) tumor probe accuracy on h vs h-Ph (sanity: the causal feature
should survive the intervention)
(e) head's tumor classification accuracy on H4 with raw vs ablated features
If the intervention is causal:
- shortcut probe accuracy: collapses
- OOD accuracy: improves (or at least doesn't decay as much)
- tumor probe accuracy: largely preserved
Usage
-----
python -m experiments.mechinterp_m4_ablation \\
--run_dir experiments/runs/<id> \\
--data_root data/wilds \\
--layer avgpool \\
[--epoch 50] # default: peak_ood_epoch from summary.json
[--max_samples 1000]
Output:
<run_dir>/mechinterp/m4_ablation_<layer>_ep<E>.json
<run_dir>/mechinterp/m4_ablation_<layer>_ep<E>.png
"""
from __future__ import annotations
import argparse
import json
import os
import sys
from pathlib import Path
from typing import Dict, Tuple
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
# Re-use M1 helpers β hooks, model loader, feature extraction, ckpt discovery.
from experiments.mechinterp_m1 import (
register_hooks,
extract_features,
load_model_from_checkpoint,
find_checkpoints,
)
from utils.camelyon_data import get_camelyon_subsets
class _TransformWrapper:
def __init__(self, dataset, transform):
self.dataset = dataset
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
img, label, metadata = self.dataset[idx]
return self.transform(img), label, metadata
def _build_loaders(data_root: str, max_samples: int, seed: int = 42):
transform = transforms.Compose([
transforms.Resize((96, 96)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
train_ds, id_val_ds, ood_test_ds, _ = get_camelyon_subsets(
root_dir=data_root, download=False
)
train_t = _TransformWrapper(train_ds, transform)
ood_t = _TransformWrapper(ood_test_ds, transform)
torch.manual_seed(seed)
train_idx = torch.randperm(len(train_t))[:max_samples]
ood_idx = torch.randperm(len(ood_t))[:max_samples // 2]
train_loader = DataLoader(Subset(train_t, train_idx), batch_size=128,
shuffle=False, num_workers=0)
ood_loader = DataLoader(Subset(ood_t, ood_idx), batch_size=128,
shuffle=False, num_workers=0)
return train_loader, ood_loader
def _select_epoch(run_dir: Path, requested: int | None) -> Tuple[int, Path]:
ckpts = find_checkpoints(str(run_dir))
if not ckpts:
raise FileNotFoundError(f"No checkpoints in {run_dir}/checkpoints/")
if requested is not None:
for ep, p in ckpts:
if ep == requested:
return ep, Path(p)
raise ValueError(f"Requested epoch {requested} not in checkpoints "
f"({[ep for ep, _ in ckpts]})")
# default: peak OOD epoch from summary.json
summary_path = run_dir / "results" / "summary.json"
peak = None
if summary_path.exists():
s = json.loads(summary_path.read_text())
peak = s.get("peak_ood_epoch", None)
if peak is not None and peak > 0:
# nearest periodic checkpoint
nearest = min(ckpts, key=lambda x: abs(x[0] - peak))
return nearest[0], Path(nearest[1])
# fall back to last checkpoint
return ckpts[-1][0], Path(ckpts[-1][1])
def _build_projector(W: np.ndarray) -> np.ndarray:
"""W has shape (k, d). Returns P (d, d) projecting onto rowspace(W)."""
# Use SVD for a stable orthonormal basis of rowspace
U, s, Vt = np.linalg.svd(W, full_matrices=False)
# rowspace basis = Vt rows where singular values > tol
tol = max(W.shape) * np.finfo(s.dtype).eps * (s.max() if s.size else 0.0)
keep = s > tol
basis = Vt[keep] # (k', d)
return basis.T @ basis # (d, d) projector onto rowspace
def _build_shortcut_subspace(
X: np.ndarray, hospital_ids: np.ndarray,
method: str = "lda", subspace_dim: int = 32
) -> np.ndarray:
"""Return a (k, d) basis whose row-span is the 'shortcut subspace'.
method='probe' β k = (n_classes - 1) probe weight rows (small subspace).
method='lda' β k = subspace_dim top between-class directions: take
per-hospital means in feature space, center them,
and run SVD. This gives a rank-bounded but data-driven
subspace that captures hospital-discriminating variance.
method='pca-class' β top-PCs of features colored by hospital (mean-removed
per class), giving us the variance directions that
mostly reflect within-hospital structure Γ class.
"""
if method == "probe":
clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs",
multi_class="auto", n_jobs=-1)
clf.fit(X, hospital_ids)
return clf.coef_
if method == "lda":
classes = np.unique(hospital_ids)
global_mean = X.mean(axis=0, keepdims=True)
between = []
for c in classes:
mu_c = X[hospital_ids == c].mean(axis=0, keepdims=True)
between.append(mu_c - global_mean)
between = np.vstack(between) # (n_classes, d)
# Augment with random hospital-correlated directions to grow rank up
# to subspace_dim β use top PCs of *centered-by-hospital-mean* features.
if subspace_dim > between.shape[0]:
# within-hospital residuals
residuals = []
for c in classes:
mu_c = X[hospital_ids == c].mean(axis=0, keepdims=True)
residuals.append(X[hospital_ids == c] - mu_c)
R = np.vstack(residuals)
# PCA on residuals β these are within-hospital directions; remove
# them from the shortcut subspace by KEEPING only the between-class
# directions. So we just return between as-is, plus the top PCs of
# the *original* features projected onto the orthogonal complement
# of `between` IF the user wants more dims.
U, s, Vt = np.linalg.svd(X - global_mean, full_matrices=False)
top = Vt[:subspace_dim]
# Score each PC by how much it correlates with hospital-id variance
# (one-hot expansion); keep top by that correlation.
one_hot = np.eye(len(classes))[
np.searchsorted(classes, hospital_ids)
] # (N, n_classes)
proj = (X - global_mean) @ top.T # (N, subspace_dim)
corrs = np.array([
np.max(np.abs([np.corrcoef(proj[:, k], one_hot[:, c])[0, 1]
for c in range(len(classes))]))
for k in range(subspace_dim)
])
# take the top-k most-hospital-correlated PCs
order = np.argsort(-np.nan_to_num(corrs))
top_hosp = top[order[:subspace_dim]]
# combine: between-class means + top-hospital-correlated PCs
return np.vstack([between, top_hosp])
return between
raise ValueError(f"Unknown method: {method}")
def _classifier_logits_from_features(
model: nn.Module, features: np.ndarray, layer: str, device: str
) -> np.ndarray:
"""Apply the *post-`layer`* part of the network to the (modified) features
and return the model's binary-classification logits.
For ResNet, `avgpool` features have shape (N, C). The classifier head
`model.fc` (timm: `model.get_classifier()`) maps C β 2. For non-avgpool
layers we do not currently support full propagation β caller should use
layer='avgpool' for OOD-accuracy interventions."""
if layer != "avgpool":
raise NotImplementedError(
"Re-applying the classifier head from intermediate spatial layers "
"is not yet supported. Use --layer avgpool for the head-level "
"ablation."
)
# Find the classifier head (timm convention: model.fc or model.get_classifier())
if hasattr(model, "get_classifier"):
head = model.get_classifier()
elif hasattr(model, "fc"):
head = model.fc
elif hasattr(model, "classifier"):
head = model.classifier
else:
raise RuntimeError("Could not locate classifier head on the model.")
head = head.to(device).eval()
with torch.no_grad():
x = torch.tensor(features, dtype=torch.float32, device=device)
logits = head(x).cpu().numpy()
return logits
def _accuracy(logits: np.ndarray, labels: np.ndarray) -> float:
if logits.ndim == 1 or logits.shape[1] == 1:
pred = (logits.flatten() > 0).astype(int)
else:
pred = logits.argmax(axis=1)
return float((pred == labels).mean())
def run_ablation(
run_dir: Path,
data_root: str,
layer: str = "avgpool",
epoch: int | None = None,
max_samples: int = 1000,
device: str = "cuda",
subspace_method: str = "lda",
subspace_dim: int = 32,
) -> Dict:
epoch, ckpt_path = _select_epoch(run_dir, epoch)
print(f"\n M4 β Representation Ablation")
print(f" run_dir : {run_dir.name}")
print(f" epoch : {epoch} ({ckpt_path.name})")
print(f" layer : {layer}")
# Load model and dataloaders
model = load_model_from_checkpoint(str(ckpt_path), n_classes=2, device=device)
model.eval()
register_hooks(model)
cfg_path = run_dir / "config.json"
seed = 42
if cfg_path.exists():
seed = json.loads(cfg_path.read_text()).get("seed", 42)
train_loader, ood_loader = _build_loaders(data_root, max_samples, seed=seed)
# Extract features
print(f" Extracting features ({max_samples} samples per split)...")
feats_train, hosp_train, tumor_train = extract_features(
model, train_loader, device, max_samples=max_samples
)
feats_ood, hosp_ood, tumor_ood = extract_features(
model, ood_loader, device, max_samples=max_samples // 2
)
if layer not in feats_train:
raise KeyError(f"Layer '{layer}' not in extracted features "
f"({list(feats_train.keys())})")
X_tr = np.asarray(feats_train[layer]) # (N_tr, D)
X_ood = np.asarray(feats_ood[layer]) # (N_ood, D)
if X_tr.ndim > 2: # spatial map; flatten
X_tr = X_tr.reshape(X_tr.shape[0], -1)
X_ood = X_ood.reshape(X_ood.shape[0], -1)
# Normalize features (probe is sensitive to scale; classifier head was
# trained on un-normalized features so we keep two parallel pipelines).
scaler = StandardScaler().fit(X_tr)
X_tr_n = scaler.transform(X_tr)
X_ood_n = scaler.transform(X_ood)
# ββββββββββββ 1. Fit hospital probe + build shortcut subspace
print(f" Fitting hospital probe on H0/H1/H2 train features...")
hosp_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs",
multi_class="auto", n_jobs=-1)
hosp_clf.fit(X_tr_n, hosp_train)
hosp_acc_train = hosp_clf.score(X_tr_n, hosp_train)
# Build a richer shortcut subspace via LDA-style between-class +
# hospital-correlated top PCs. This catches more shortcut variance than
# the (n_classes - 1)-D probe-rowspace alone.
W = _build_shortcut_subspace(X_tr_n, np.asarray(hosp_train),
method=subspace_method,
subspace_dim=subspace_dim)
P = _build_projector(W) # (D, D)
rank_subspace = int(np.linalg.matrix_rank(P, tol=1e-8))
print(f" Shortcut subspace: dim={rank_subspace} method={subspace_method} "
f"(probe train acc {hosp_acc_train:.3f})")
# ββββββββββββ 2. Build ablated versions of features
# Apply the projection in the *normalized* feature space, then un-scale
# for re-feeding to the classifier head (which was trained on raw features).
def ablate_norm(X_n):
return X_n - X_n @ P.T
X_ood_ablated_n = ablate_norm(X_ood_n)
# un-scale
X_ood_ablated = scaler.inverse_transform(X_ood_ablated_n)
# Sanity probe metrics
print(f" Re-fitting tumor probe on train features...")
tumor_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs",
multi_class="auto", n_jobs=-1)
tumor_clf.fit(X_tr_n, tumor_train)
tumor_acc_train = tumor_clf.score(X_tr_n, tumor_train)
# Probe accuracies on raw vs ablated OOD features
hosp_acc_ood_raw = hosp_clf.score(X_ood_n, hosp_ood) if len(np.unique(hosp_ood)) > 1 else float("nan")
hosp_acc_ood_ablated = hosp_clf.score(X_ood_ablated_n, hosp_ood) if len(np.unique(hosp_ood)) > 1 else float("nan")
tumor_acc_ood_raw = tumor_clf.score(X_ood_n, tumor_ood)
tumor_acc_ood_ablated = tumor_clf.score(X_ood_ablated_n, tumor_ood)
# ββββββββββββ 3. Head-level OOD classification accuracy
print(f" Re-classifying OOD with model head (raw vs ablated features)...")
logits_raw = _classifier_logits_from_features(model, X_ood, layer, device)
logits_ablated = _classifier_logits_from_features(model, X_ood_ablated, layer, device)
head_acc_raw = _accuracy(logits_raw, tumor_ood)
head_acc_ablated = _accuracy(logits_ablated, tumor_ood)
# ββββββββββββ 4. Pack + report
result = {
"run_id": run_dir.name,
"epoch": epoch,
"layer": layer,
"max_samples": max_samples,
"shortcut_subspace_dim": rank_subspace,
"hospital_probe_train_acc": hosp_acc_train,
"tumor_probe_train_acc": tumor_acc_train,
"hospital_probe_ood_raw": hosp_acc_ood_raw,
"hospital_probe_ood_ablated": hosp_acc_ood_ablated,
"tumor_probe_ood_raw": tumor_acc_ood_raw,
"tumor_probe_ood_ablated": tumor_acc_ood_ablated,
"head_ood_acc_raw": head_acc_raw,
"head_ood_acc_ablated": head_acc_ablated,
"intervention_effect": {
"shortcut_collapse": hosp_acc_ood_raw - hosp_acc_ood_ablated,
"ood_improvement": head_acc_ablated - head_acc_raw,
"tumor_preservation": tumor_acc_ood_ablated - tumor_acc_ood_raw,
},
}
print(f"\n RESULTS")
print(f" hospital probe (OOD): {hosp_acc_ood_raw:.3f} β {hosp_acc_ood_ablated:.3f} "
f"(Ξ {result['intervention_effect']['shortcut_collapse']:+.3f})")
print(f" tumor probe (OOD) : {tumor_acc_ood_raw:.3f} β {tumor_acc_ood_ablated:.3f} "
f"(Ξ {result['intervention_effect']['tumor_preservation']:+.3f})")
print(f" head OOD acc : {head_acc_raw:.3f} β {head_acc_ablated:.3f} "
f"(Ξ {result['intervention_effect']['ood_improvement']:+.3f})")
return result
def plot_ablation(result: Dict, out_path: Path):
metrics = ["hospital_probe_ood", "tumor_probe_ood", "head_ood_acc"]
raw_keys = ["hospital_probe_ood_raw", "tumor_probe_ood_raw", "head_ood_acc_raw"]
ablated_keys = ["hospital_probe_ood_ablated", "tumor_probe_ood_ablated", "head_ood_acc_ablated"]
labels = ["Hospital probe\n(β = causal effect)",
"Tumor probe\n(stable = good)",
"Head OOD acc\n(β = causal effect)"]
raws = [result[k] for k in raw_keys]
ablateds = [result[k] for k in ablated_keys]
fig, ax = plt.subplots(figsize=(9, 5))
x = np.arange(len(metrics))
w = 0.35
b1 = ax.bar(x - w / 2, raws, w, label="raw features", color="#444")
b2 = ax.bar(x + w / 2, ablateds, w, label="shortcut-ablated", color="#c33")
for bars in (b1, b2):
for b in bars:
ax.text(b.get_x() + b.get_width() / 2, b.get_height() + 0.005,
f"{b.get_height():.3f}", ha="center", va="bottom", fontsize=9)
ax.set_xticks(x); ax.set_xticklabels(labels, fontsize=9)
ax.set_ylim(0, 1.05); ax.set_ylabel("Accuracy")
ax.set_title(f"M4 β Causal Ablation of Shortcut Subspace\n"
f"{result['run_id']} β’ ep{result['epoch']} β’ layer={result['layer']} "
f"β’ subspace dim={result['shortcut_subspace_dim']}",
fontsize=10, fontweight="bold")
ax.legend(loc="upper right")
ax.grid(alpha=0.3, axis="y")
plt.tight_layout()
fig.savefig(out_path, dpi=180, bbox_inches="tight")
plt.close(fig)
def main():
p = argparse.ArgumentParser()
p.add_argument("--run_dir", required=True)
p.add_argument("--data_root", default="data/wilds")
p.add_argument("--layer", default="avgpool",
choices=["avgpool"]) # head-level intervention only at avgpool
p.add_argument("--epoch", type=int, default=None,
help="Specific checkpoint epoch; default = peak_ood_epoch from summary.json")
p.add_argument("--max_samples", type=int, default=1000)
p.add_argument("--device", default="cuda")
p.add_argument("--subspace_method", default="lda",
choices=["lda", "probe"],
help="lda = LDA-style between-class + hospital-correlated PCs; "
"probe = LR probe row-space (small, often only 2-D)")
p.add_argument("--subspace_dim", type=int, default=32,
help="Target subspace dim for lda method")
p.add_argument("--all_epochs", action="store_true",
help="Sweep across all periodic checkpoints")
args = p.parse_args()
run_dir = Path(args.run_dir)
out_dir = run_dir / "mechinterp"
out_dir.mkdir(parents=True, exist_ok=True)
if args.all_epochs:
# Sweep across every periodic checkpoint, build a trajectory.
ckpts = find_checkpoints(str(run_dir))
# de-duplicate (final.pt may share epoch with last ep*.pt)
seen = set(); uniq = []
for ep, p in ckpts:
if ep in seen:
continue
seen.add(ep); uniq.append((ep, p))
traj = []
for ep, _ in uniq:
try:
r = run_ablation(
run_dir=run_dir, data_root=args.data_root, layer=args.layer,
epoch=ep, max_samples=args.max_samples, device=args.device,
subspace_method=args.subspace_method,
subspace_dim=args.subspace_dim,
)
traj.append(r)
except Exception as e:
print(f" [skip ep{ep}] {e}")
out = out_dir / f"m4_ablation_{args.layer}_trajectory.json"
out.write_text(json.dumps(traj, indent=2))
plot_trajectory(traj, out.with_suffix(".png"))
print(f"\n β {out}")
print(f" β {out.with_suffix('.png')}")
return
result = run_ablation(
run_dir=run_dir,
data_root=args.data_root,
layer=args.layer,
epoch=args.epoch,
max_samples=args.max_samples,
device=args.device,
subspace_method=args.subspace_method,
subspace_dim=args.subspace_dim,
)
base = out_dir / f"m4_ablation_{args.layer}_ep{result['epoch']:05d}"
(base.with_suffix(".json")).write_text(json.dumps(result, indent=2))
plot_ablation(result, base.with_suffix(".png"))
print(f"\n β {base.with_suffix('.json')}")
print(f" β {base.with_suffix('.png')}")
def plot_trajectory(traj, out_path: Path):
"""Plot the intervention effect across training epochs."""
eps = [r["epoch"] for r in traj]
head_raw = [r["head_ood_acc_raw"] for r in traj]
head_abl = [r["head_ood_acc_ablated"] for r in traj]
tum_raw = [r["tumor_probe_ood_raw"] for r in traj]
tum_abl = [r["tumor_probe_ood_ablated"] for r in traj]
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Panel A: head OOD acc raw vs ablated
ax = axes[0]
ax.plot(eps, head_raw, "k-o", lw=2, label="raw features")
ax.plot(eps, head_abl, "r-s", lw=2, label="shortcut-ablated features")
ax.fill_between(eps, head_raw, head_abl,
where=[a > b for a, b in zip(head_abl, head_raw)],
color="seagreen", alpha=0.3, label="ablation helps")
ax.fill_between(eps, head_raw, head_abl,
where=[a < b for a, b in zip(head_abl, head_raw)],
color="salmon", alpha=0.3, label="ablation hurts")
ax.set_xlabel("Training epoch"); ax.set_ylabel("OOD (H4) head accuracy")
ax.set_title("Head OOD accuracy: raw vs shortcut-ablated", fontweight="bold")
ax.legend(fontsize=9); ax.grid(alpha=0.3)
# Panel B: tumor probe survival
ax = axes[1]
ax.plot(eps, tum_raw, "k-o", lw=2, label="raw features")
ax.plot(eps, tum_abl, "g-s", lw=2, label="shortcut-ablated features")
ax.set_xlabel("Training epoch"); ax.set_ylabel("Tumor probe OOD accuracy")
ax.set_title("Tumor probe survival under ablation\n(stable line = causal feature preserved)",
fontweight="bold")
ax.legend(fontsize=9); ax.grid(alpha=0.3); ax.set_ylim(0.4, 1.0)
rid = traj[0]["run_id"] if traj else "?"
layer = traj[0]["layer"] if traj else "?"
fig.suptitle(f"M4 β Causal Ablation Trajectory: {rid} β’ layer={layer}",
fontsize=11, fontweight="bold")
plt.tight_layout()
fig.savefig(out_path, dpi=180, bbox_inches="tight")
plt.close(fig)
if __name__ == "__main__":
main()
|