CausalGrok / code /experiments /mechinterp_m6_neuron_ablation.py
nileshsarkar-ai's picture
Upload code/experiments
50fa85c verified
"""M6 — Neuron-level Ablation (the textbook reviewer-asked intervention).
Pipeline:
1. At a chosen checkpoint (default: peak_ood_epoch), extract avgpool
features for train (H0-H2) and OOD (H4) splits.
2. Score each of the 512 avgpool channels by *how predictive its activation
is of hospital ID*: we use a one-vs-rest logistic-regression coefficient
per channel × class as the per-neuron shortcut score:
score_c = max_h |β_{h,c}| (β = coefficients of LR fit per channel)
↑ score_c → channel c is more strongly stain-shortcut-aligned.
3. Sweep top-K ∈ {0, 8, 16, 32, 64, 128} ablated neurons (zero out their
activations) and measure:
- head OOD acc (raw vs ablated)
- hospital-probe acc on raw vs ablated features
- tumor-probe acc on raw vs ablated features
4. Strong mechanistic claim:
- hospital-probe acc collapses sharply with K (these neurons are
carrying hospital info)
- head OOD acc *improves* (or at least preserves) at small K (the
model was using shortcut neurons to harm OOD)
- tumor-probe acc stays flat (causal info is distributed elsewhere)
Usage
-----
python -m experiments.mechinterp_m6_neuron_ablation \\
--run_dir experiments/runs/<id> \\
--data_root data/wilds \\
[--epoch 50] [--max_samples 1000] \\
[--ks "0,4,8,16,32,64,128,256"]
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Dict, List
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from experiments.mechinterp_m1 import (
register_hooks, extract_features, load_model_from_checkpoint,
)
from experiments.mechinterp_m4_ablation import (
_select_epoch, _TransformWrapper,
_classifier_logits_from_features, _accuracy,
)
from utils.camelyon_data import get_camelyon_subsets
def _build_loaders_with_id(data_root: str, max_samples: int, seed: int = 42):
"""Like M4's _build_loaders but also returns an ID validation loader so
we can track ID acc and compute the OOD/ID degradation ratio."""
transform = transforms.Compose([
transforms.Resize((96, 96)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
train_ds, id_val_ds, ood_test_ds, _ = get_camelyon_subsets(
root_dir=data_root, download=False
)
train_t = _TransformWrapper(train_ds, transform)
id_t = _TransformWrapper(id_val_ds, transform)
ood_t = _TransformWrapper(ood_test_ds, transform)
torch.manual_seed(seed)
train_idx = torch.randperm(len(train_t))[:max_samples]
id_idx = torch.randperm(len(id_t))[:max_samples // 2]
ood_idx = torch.randperm(len(ood_t))[:max_samples // 2]
train_loader = DataLoader(Subset(train_t, train_idx), batch_size=128,
shuffle=False, num_workers=0)
id_loader = DataLoader(Subset(id_t, id_idx), batch_size=128,
shuffle=False, num_workers=0)
ood_loader = DataLoader(Subset(ood_t, ood_idx), batch_size=128,
shuffle=False, num_workers=0)
return train_loader, id_loader, ood_loader
def _per_neuron_shortcut_scores(X_n: np.ndarray, hosp: np.ndarray) -> np.ndarray:
"""Return a (D,) array — score per channel c, larger = more hospital-predictive.
Uses a 1-feature-at-a-time log-reg fit's |coef| would be dominated by feature
scale; instead we fit a single multiclass LR over all features and use the
L2 norm of (β_{:,c}) — the column norm of the LR coefficient matrix —
as channel c's hospital-discrimination score.
"""
clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs",
multi_class="auto", n_jobs=-1).fit(X_n, hosp)
W = clf.coef_ # (n_classes, D)
# column norms — large means many class-discriminations rely on this neuron
return np.linalg.norm(W, axis=0) # (D,)
def _ablate_and_eval(
X_n, mask, scaler, head_target, model, layer, device,
hosp_clf, tumor_clf, hosp_target, tumor_target,
):
"""Apply mask to normalized features, unscale, evaluate everything."""
X_ablated_n = X_n * mask[None, :]
X_ablated = scaler.inverse_transform(X_ablated_n)
logits = _classifier_logits_from_features(model, X_ablated, layer, device)
head_acc = _accuracy(logits, head_target)
hosp_acc = hosp_clf.score(X_ablated_n, hosp_target) if hosp_clf is not None and len(np.unique(hosp_target)) > 1 else float("nan")
tumor_acc = tumor_clf.score(X_ablated_n, tumor_target)
return head_acc, hosp_acc, tumor_acc
def run_neuron_ablation(
run_dir: Path,
data_root: str,
epoch: int | None = None,
max_samples: int = 1000,
device: str = "cuda",
ks: List[int] = None,
n_random_samples: int = 5,
include_morphology: bool = True,
include_id: bool = True,
) -> Dict:
if ks is None:
# Dose-response curve emphasizing small K (per reviewer guidance)
ks = [0, 4, 8, 16, 32, 64, 128, 256]
epoch, ckpt_path = _select_epoch(run_dir, epoch)
print(f"\n M6 — Neuron Ablation (with random + morphology controls)")
print(f" run_dir : {run_dir.name}")
print(f" epoch : {epoch} ({ckpt_path.name})")
print(f" ks : {ks}")
print(f" random ablation: {n_random_samples} samplings per K")
model = load_model_from_checkpoint(str(ckpt_path), n_classes=2, device=device)
model.eval()
register_hooks(model)
cfg_path = run_dir / "config.json"
seed = 42
if cfg_path.exists():
seed = json.loads(cfg_path.read_text()).get("seed", 42)
if include_id:
train_loader, id_loader, ood_loader = _build_loaders_with_id(data_root, max_samples, seed=seed)
else:
from experiments.mechinterp_m4_ablation import _build_loaders as _bl
train_loader, ood_loader = _bl(data_root, max_samples, seed=seed)
id_loader = None
print(f" Extracting features (train + id + ood)...")
feats_train, hosp_train, tumor_train = extract_features(
model, train_loader, device, max_samples=max_samples
)
feats_ood, hosp_ood, tumor_ood = extract_features(
model, ood_loader, device, max_samples=max_samples // 2
)
feats_id, hosp_id, tumor_id = (None, None, None)
if id_loader is not None:
feats_id, hosp_id, tumor_id = extract_features(
model, id_loader, device, max_samples=max_samples // 2
)
layer = "avgpool"
def _to_2d(arr):
a = np.asarray(arr); return a.reshape(a.shape[0], -1)
X_tr = _to_2d(feats_train[layer])
X_ood = _to_2d(feats_ood[layer])
X_id = _to_2d(feats_id[layer]) if feats_id is not None else None
hosp_train = np.asarray(hosp_train); hosp_ood = np.asarray(hosp_ood)
tumor_train = np.asarray(tumor_train); tumor_ood = np.asarray(tumor_ood)
if X_id is not None:
hosp_id = np.asarray(hosp_id); tumor_id = np.asarray(tumor_id)
scaler = StandardScaler().fit(X_tr)
X_tr_n = scaler.transform(X_tr)
X_ood_n = scaler.transform(X_ood)
X_id_n = scaler.transform(X_id) if X_id is not None else None
# 1. Per-neuron scores: shortcut (hospital) and morphology (tumor)
print(f" Scoring {X_tr.shape[1]} avgpool channels...")
shortcut_scores = _per_neuron_shortcut_scores(X_tr_n, hosp_train)
morphology_scores = _per_neuron_shortcut_scores(X_tr_n, tumor_train) if include_morphology else None
rank_shortcut = np.argsort(-shortcut_scores)
rank_morphology = np.argsort(-morphology_scores) if morphology_scores is not None else None
hosp_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs",
multi_class="auto", n_jobs=-1).fit(X_tr_n, hosp_train)
tumor_clf = LogisticRegression(max_iter=500, C=1.0, solver="lbfgs",
multi_class="auto", n_jobs=-1).fit(X_tr_n, tumor_train)
rng = np.random.default_rng(seed)
D = X_tr.shape[1]
sweep = []
for k in ks:
row = {"k": int(k)}
# Mask helpers
def make_mask(indices):
m = np.ones(D)
if k > 0:
m[indices[:k]] = 0.0
return m
# ── A: top-K SHORTCUT neurons (the targeted ablation) ──
mask_s = make_mask(rank_shortcut)
h_ood, hp_ood, tp_ood = _ablate_and_eval(
X_ood_n, mask_s, scaler, tumor_ood, model, layer, device,
hosp_clf, tumor_clf, hosp_ood, tumor_ood,
)
row["shortcut_head_ood"] = float(h_ood)
row["shortcut_hosp_probe"] = float(hp_ood)
row["shortcut_tumor_probe"] = float(tp_ood)
if X_id_n is not None:
h_id, _, _ = _ablate_and_eval(
X_id_n, mask_s, scaler, tumor_id, model, layer, device,
None, tumor_clf, hosp_id, tumor_id,
)
row["shortcut_head_id"] = float(h_id)
# ── B: top-K MORPHOLOGY neurons (control: ablate the causal neurons) ──
if include_morphology and rank_morphology is not None:
mask_m = make_mask(rank_morphology)
h_ood_m, _, _ = _ablate_and_eval(
X_ood_n, mask_m, scaler, tumor_ood, model, layer, device,
None, tumor_clf, hosp_ood, tumor_ood,
)
row["morphology_head_ood"] = float(h_ood_m)
if X_id_n is not None:
h_id_m, _, _ = _ablate_and_eval(
X_id_n, mask_m, scaler, tumor_id, model, layer, device,
None, tumor_clf, hosp_id, tumor_id,
)
row["morphology_head_id"] = float(h_id_m)
# ── C: K RANDOM neurons (control: damage uniformly) ──
if k > 0:
r_oods, r_ids = [], []
for s_ in range(n_random_samples):
idx = rng.permutation(D)[:k]
m = np.ones(D); m[idx] = 0.0
h_ood_r, _, _ = _ablate_and_eval(
X_ood_n, m, scaler, tumor_ood, model, layer, device,
None, tumor_clf, hosp_ood, tumor_ood,
)
r_oods.append(h_ood_r)
if X_id_n is not None:
h_id_r, _, _ = _ablate_and_eval(
X_id_n, m, scaler, tumor_id, model, layer, device,
None, tumor_clf, hosp_id, tumor_id,
)
r_ids.append(h_id_r)
row["random_head_ood_mean"] = float(np.mean(r_oods))
row["random_head_ood_std"] = float(np.std(r_oods))
if r_ids:
row["random_head_id_mean"] = float(np.mean(r_ids))
row["random_head_id_std"] = float(np.std(r_ids))
else:
row["random_head_ood_mean"] = row["shortcut_head_ood"] # K=0 same as baseline
row["random_head_ood_std"] = 0.0
if X_id_n is not None:
row["random_head_id_mean"] = row.get("shortcut_head_id", float("nan"))
row["random_head_id_std"] = 0.0
sweep.append(row)
# Concise log line
print(f" K={k:>4} shortcut={row['shortcut_head_ood']:.3f} "
f"random={row.get('random_head_ood_mean', float('nan')):.3f}±"
f"{row.get('random_head_ood_std', 0):.3f} "
+ (f"morphology={row.get('morphology_head_ood', float('nan')):.3f}"
if include_morphology else ""))
return {
"run_id": run_dir.name,
"epoch": epoch,
"layer": layer,
"max_samples": max_samples,
"feature_dim": int(X_tr.shape[1]),
"shortcut_scores_top10": [int(i) for i in rank_shortcut[:10]],
"morphology_scores_top10": ([int(i) for i in rank_morphology[:10]]
if rank_morphology is not None else []),
"n_random_samples": n_random_samples,
"include_id": include_id,
"include_morphology": include_morphology,
"sweep": sweep,
}
def plot_neuron_ablation(result: Dict, out_path: Path):
sweep = result["sweep"]
ks = [r["k"] for r in sweep]
has_id = result.get("include_id", False)
has_morph = result.get("include_morphology", False)
fig, axes = plt.subplots(1, 2 if has_id else 1, figsize=(13, 5)) if has_id else \
plt.subplots(1, 1, figsize=(8, 5))
if not has_id:
axes = [axes]
# Panel A — Head OOD: shortcut vs random (vs morphology)
ax = axes[0]
shortcut_ood = [r.get("shortcut_head_ood") for r in sweep]
random_ood_mu = [r.get("random_head_ood_mean") for r in sweep]
random_ood_sd = [r.get("random_head_ood_std", 0) for r in sweep]
morphology_ood = [r.get("morphology_head_ood") for r in sweep] if has_morph else None
ax.plot(ks, shortcut_ood, "r-o", lw=2.2, ms=7, label="top-K shortcut neurons (targeted)")
ax.plot(ks, random_ood_mu, "k-s", lw=1.8, ms=6, label="K random neurons (control)")
ax.fill_between(ks,
[m - s for m, s in zip(random_ood_mu, random_ood_sd)],
[m + s for m, s in zip(random_ood_mu, random_ood_sd)],
color="black", alpha=0.15)
if has_morph and morphology_ood is not None:
ax.plot(ks, morphology_ood, "g-^", lw=1.8, ms=6,
label="top-K morphology neurons (control)")
base = shortcut_ood[0]
ax.axhline(base, color="gray", ls=":", lw=1, alpha=0.5,
label=f"K=0 baseline ({base:.3f})")
ax.set_xlabel("K (neurons zeroed at avgpool)")
ax.set_ylabel("Head OOD (H4) accuracy")
ax.set_xscale("symlog", linthresh=4)
ax.set_title("Targeted vs random ablation — OOD effect\n"
"(separation = shortcut neurons selectively hurt OOD)",
fontweight="bold", fontsize=10)
ax.legend(loc="best", fontsize=8); ax.grid(alpha=0.3)
# Panel B — ID/OOD tradeoff
if has_id:
ax = axes[1]
shortcut_id = [r.get("shortcut_head_id") for r in sweep]
random_id_mu = [r.get("random_head_id_mean") for r in sweep]
random_id_sd = [r.get("random_head_id_std", 0) for r in sweep]
ax.plot(ks, shortcut_id, "r--o", lw=2, ms=7, alpha=0.85, label="ID (shortcut ablation)")
ax.plot(ks, shortcut_ood, "r-o", lw=2, ms=7, label="OOD (shortcut ablation)")
ax.plot(ks, random_id_mu, "k--s", lw=1.6, ms=5, alpha=0.7, label="ID (random ablation)")
ax.plot(ks, random_ood_mu, "k-s", lw=1.6, ms=5, alpha=0.7, label="OOD (random ablation)")
ax.set_xlabel("K (neurons zeroed at avgpool)")
ax.set_ylabel("Head accuracy")
ax.set_xscale("symlog", linthresh=4)
ax.set_title("ID vs OOD degradation tradeoff\n"
"(targeted: OOD steady or ↑ while ID slowly ↓ = good)",
fontweight="bold", fontsize=10)
ax.legend(fontsize=8, loc="best"); ax.grid(alpha=0.3)
fig.suptitle(f"M6 — Targeted Neuron Ablation vs Random Control: {result['run_id']} "
f"• ep{result['epoch']}",
fontsize=11, fontweight="bold")
plt.tight_layout()
fig.savefig(out_path, dpi=180, bbox_inches="tight")
plt.close(fig)
def main():
p = argparse.ArgumentParser()
p.add_argument("--run_dir", required=True)
p.add_argument("--data_root", default="data/wilds")
p.add_argument("--epoch", type=int, default=None)
p.add_argument("--max_samples", type=int, default=1000)
p.add_argument("--device", default="cuda")
p.add_argument("--ks", default=None,
help="Comma-separated K values, e.g. '0,4,8,16,32,64,128,256'")
p.add_argument("--n_random_samples", type=int, default=5,
help="Random ablation: averages over this many random K-subsets")
p.add_argument("--no_morphology", action="store_true",
help="Skip the morphology-targeted ablation control")
p.add_argument("--no_id", action="store_true",
help="Skip ID accuracy evaluation (faster but loses ID/OOD ratio)")
args = p.parse_args()
ks = None
if args.ks is not None:
ks = [int(x) for x in args.ks.split(",")]
run_dir = Path(args.run_dir)
out_dir = run_dir / "mechinterp"
out_dir.mkdir(parents=True, exist_ok=True)
result = run_neuron_ablation(
run_dir=run_dir, data_root=args.data_root, epoch=args.epoch,
max_samples=args.max_samples, device=args.device, ks=ks,
n_random_samples=args.n_random_samples,
include_morphology=not args.no_morphology,
include_id=not args.no_id,
)
base = out_dir / f"m6_neuron_ablation_ep{result['epoch']:05d}"
base.with_suffix(".json").write_text(json.dumps(result, indent=2))
plot_neuron_ablation(result, base.with_suffix(".png"))
print(f"\n → {base.with_suffix('.json')}")
print(f" → {base.with_suffix('.png')}")
if __name__ == "__main__":
main()