| """ |
| CausalGrok β M1: Layer-wise Linear Probing |
| Nilesh |
| |
| The mechanistic claim: |
| Before grokking: hospital probe HIGH (model uses stain shortcut), |
| tumor probe LOW in early layers |
| At transition: hospital probe DROPS, tumor probe RISES |
| After grokking: inverted β tumor high, hospital low |
| |
| If OOD acc jump + hospital probe drop + tumor probe rise |
| all happen at the same epoch β mechanistic claim proven. |
| That's Figure 2 of the paper. |
| |
| Usage: |
| # Run on all saved checkpoints from a run |
| python -m experiments.mechinterp_m1 \ |
| --run_dir experiments/runs/<run_id> \ |
| --data_root data/wilds |
| |
| # Run on latest checkpoint only (quick check while training) |
| python -m experiments.mechinterp_m1 \ |
| --run_dir experiments/runs/<run_id> \ |
| --data_root data/wilds \ |
| --latest_only |
| |
| # Run on ALL camelyon_v2 grokking runs |
| python -m experiments.mechinterp_m1 \ |
| --all_runs \ |
| --data_root data/wilds |
| |
| Output per run: |
| experiments/runs/<run_id>/mechinterp/ |
| m1_probe_heatmap.png β epoch Γ layer, hospital probe acc |
| m1_tumor_heatmap.png β epoch Γ layer, tumor probe acc |
| m1_probe_curves.png β hospital vs tumor probe over epochs (layer 4) |
| m1_probe_data.json β raw numbers for paper tables |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import glob |
| import json |
| import os |
| from typing import Dict, List, Optional |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torchvision.transforms as transforms |
| from torch.utils.data import DataLoader, Subset |
| from sklearn.linear_model import LogisticRegression |
| from sklearn.preprocessing import StandardScaler |
| import matplotlib |
| import matplotlib.pyplot as plt |
| import timm |
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| matplotlib.rcParams.update({"font.size": 11, "figure.dpi": 150}) |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| LAYER_NAMES = [ |
| "stem", |
| "layer1", |
| "layer2", |
| "layer3", |
| "layer4", |
| "avgpool", |
| ] |
|
|
|
|
| def register_hooks(model): |
| """ |
| Register forward hooks on all 6 extraction points. |
| Returns (hooks, features_dict). |
| """ |
| features = {name: [] for name in LAYER_NAMES} |
| hooks = [] |
|
|
| def make_hook(name): |
| def hook_fn(module, input, output): |
| if output.dim() == 4: |
| feat = output.mean(dim=[2, 3]) |
| else: |
| feat = output.view(output.size(0), -1) |
| features[name].append(feat.detach().cpu()) |
| return hook_fn |
|
|
| hooks.append(model.maxpool.register_forward_hook(make_hook("stem"))) |
| hooks.append(model.layer1.register_forward_hook(make_hook("layer1"))) |
| hooks.append(model.layer2.register_forward_hook(make_hook("layer2"))) |
| hooks.append(model.layer3.register_forward_hook(make_hook("layer3"))) |
| hooks.append(model.layer4.register_forward_hook(make_hook("layer4"))) |
| hooks.append(model.global_pool.register_forward_hook(make_hook("avgpool"))) |
|
|
| return hooks, features |
|
|
|
|
| def extract_features(model, loader, device, max_samples=1000): |
| """ |
| Run forward pass and collect features at all 6 layers. |
| """ |
| model.eval() |
| hooks, feat_dict = register_hooks(model) |
|
|
| all_hospital = [] |
| all_tumor = [] |
| count = 0 |
|
|
| with torch.no_grad(): |
| for batch in loader: |
| imgs = batch[0].to(device) |
| labels = batch[1].squeeze().long() |
| metadata = batch[2] |
| model(imgs) |
| all_hospital.append(metadata[:, 0].long()) |
| all_tumor.append(labels) |
| count += imgs.size(0) |
| if count >= max_samples: |
| break |
|
|
| for h in hooks: |
| h.remove() |
|
|
| features = {k: torch.cat(v).numpy() for k, v in feat_dict.items()} |
| hospital_ids = torch.cat(all_hospital).numpy() |
| tumor_labels = torch.cat(all_tumor).numpy() |
|
|
| n = min(max_samples, len(hospital_ids)) |
| features = {k: v[:n] for k, v in features.items()} |
| hospital_ids = hospital_ids[:n] |
| tumor_labels = tumor_labels[:n] |
|
|
| return features, hospital_ids, tumor_labels |
|
|
|
|
| def train_probe(X_train, y_train, X_val, y_val): |
| """ |
| Train logistic regression probe on frozen features. |
| """ |
| if len(np.unique(y_train)) < 2: |
| return 0.5 |
|
|
| scaler = StandardScaler() |
| X_train = scaler.fit_transform(X_train) |
| X_val = scaler.transform(X_val) |
|
|
| clf = LogisticRegression( |
| max_iter=500, |
| C=1.0, |
| solver="lbfgs", |
| multi_class="auto", |
| n_jobs=-1, |
| ) |
| try: |
| clf.fit(X_train, y_train) |
| return clf.score(X_val, y_val) |
| except Exception: |
| return float("nan") |
|
|
|
|
| |
| |
| |
|
|
| def find_checkpoints(run_dir: str) -> List[tuple]: |
| """ |
| Find all checkpoints in a run directory. |
| Returns list of (epoch, checkpoint_path) sorted by epoch. |
| """ |
| ckpt_dir = os.path.join(run_dir, "checkpoints") |
| if not os.path.isdir(ckpt_dir): |
| return [] |
|
|
| checkpoints = [] |
|
|
| |
| for f in sorted(glob.glob(os.path.join(ckpt_dir, "ep*.pt"))): |
| epoch_str = os.path.basename(f).replace("ep", "").replace(".pt", "") |
| try: |
| epoch = int(epoch_str) |
| checkpoints.append((epoch, f)) |
| except ValueError: |
| continue |
|
|
| |
| final = os.path.join(ckpt_dir, "final.pt") |
| if os.path.isfile(final): |
| hist_path = os.path.join(run_dir, "results", "history.json") |
| if os.path.isfile(hist_path): |
| try: |
| hist = json.load(open(hist_path)) |
| epoch = hist[-1]["epoch"] if hist else 9999 |
| except Exception: |
| epoch = 9999 |
| else: |
| epoch = 9999 |
| checkpoints.append((epoch, final)) |
|
|
| return sorted(checkpoints, key=lambda x: x[0]) |
|
|
|
|
| def load_model_from_checkpoint(ckpt_path: str, n_classes: int = 2, |
| device: str = "cuda") -> nn.Module: |
| model = timm.create_model("resnet18", pretrained=False, |
| num_classes=n_classes) |
| state = torch.load(ckpt_path, map_location=device) |
| model.load_state_dict(state, strict=True) |
| model.eval() |
| return model.to(device) |
|
|
|
|
| |
| |
| |
|
|
| def run_probe_analysis(run_dir: str, data_root: str, |
| device: str = "cuda", |
| max_samples: int = 800, |
| latest_only: bool = False) -> Optional[Dict]: |
| """ |
| For each checkpoint in a run, extract features at all 6 layers |
| and train hospital + tumor probes. |
| """ |
| from utils.camelyon_data import get_camelyon_subsets |
|
|
| cfg_path = os.path.join(run_dir, "config.json") |
| if not os.path.isfile(cfg_path): |
| print(f" No config.json in {run_dir}, skipping") |
| return None |
|
|
| cfg = json.load(open(cfg_path)) |
| condition = cfg.get("condition", "unknown") |
| n_train = cfg.get("n_train", 300) |
| seed = cfg.get("seed", 42) |
|
|
| print(f"\n{'='*55}") |
| print(f" M1 Probe Analysis: {os.path.basename(run_dir)}") |
| print(f" condition={condition}, n_train={n_train}, seed={seed}") |
| print(f"{'='*55}") |
|
|
| checkpoints = find_checkpoints(run_dir) |
| if not checkpoints: |
| print(f" No checkpoints found β skipping") |
| return None |
|
|
| if latest_only: |
| checkpoints = checkpoints[-1:] |
|
|
| print(f" Found {len(checkpoints)} checkpoints: " |
| f"epochs {[e for e,_ in checkpoints]}") |
| print(f" Hospital probe: fits on training data (H0-H2), " |
| f"evaluates on H3 and H4 separately") |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((96, 96)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]), |
| ]) |
| train_ds, id_val_ds, ood_test_ds, full_ds = get_camelyon_subsets( |
| root_dir=data_root, download=False) |
|
|
| |
| class _TransformWrapper: |
| def __init__(self, dataset, transform): |
| self.dataset = dataset |
| self.transform = transform |
| def __len__(self): |
| return len(self.dataset) |
| def __getitem__(self, idx): |
| img, label, metadata = self.dataset[idx] |
| return self.transform(img), label, metadata |
|
|
| id_val_t = _TransformWrapper(id_val_ds, transform) |
| ood_test_t = _TransformWrapper(ood_test_ds, transform) |
| train_t = _TransformWrapper(train_ds, transform) |
|
|
| torch.manual_seed(seed) |
| probe_idx = torch.randperm(len(id_val_t))[:max_samples // 2] |
| ood_idx = torch.randperm(len(ood_test_t))[:max_samples // 2] |
| train_idx = torch.randperm(len(train_t))[:max_samples] |
|
|
| probe_loader = DataLoader( |
| Subset(id_val_t, probe_idx), |
| batch_size=128, shuffle=False, num_workers=0) |
| ood_loader = DataLoader( |
| Subset(ood_test_t, ood_idx), |
| batch_size=128, shuffle=False, num_workers=0) |
| train_loader = DataLoader( |
| Subset(train_t, train_idx), |
| batch_size=128, shuffle=False, num_workers=0) |
|
|
| |
| results = { |
| "run_id": os.path.basename(run_dir), |
| "condition": condition, |
| "n_train": n_train, |
| "seed": seed, |
| "epochs": [], |
| "layers": LAYER_NAMES, |
| "hospital_probe_id": [], |
| "hospital_probe_ood": [], |
| "tumor_probe_id": [], |
| "tumor_probe_ood": [], |
| } |
|
|
| |
| for epoch, ckpt_path in checkpoints: |
| print(f"\n Epoch {epoch} | {os.path.basename(ckpt_path)}") |
|
|
| try: |
| model = load_model_from_checkpoint( |
| ckpt_path, n_classes=2, device=device) |
| except Exception as e: |
| print(f" Failed to load checkpoint: {e}") |
| continue |
|
|
| |
| feats_train, hosp_train, tumor_train = extract_features( |
| model, train_loader, device, max_samples=max_samples) |
|
|
| feats_id, hosp_id, tumor_id = extract_features( |
| model, probe_loader, device, max_samples=max_samples // 2) |
|
|
| feats_ood, hosp_ood, tumor_ood = extract_features( |
| model, ood_loader, device, max_samples=max_samples // 2) |
|
|
| epoch_hosp_id = [] |
| epoch_hosp_ood = [] |
| epoch_tumor_id = [] |
| epoch_tumor_ood = [] |
|
|
| for layer_name in LAYER_NAMES: |
| |
| X_train_layer = feats_train[layer_name] |
| X_id_layer = feats_id[layer_name] |
| X_ood_layer = feats_ood[layer_name] |
|
|
| |
| |
| h_acc_id = train_probe(X_train_layer, hosp_train, |
| X_id_layer, hosp_id) |
| h_acc_ood = train_probe(X_train_layer, hosp_train, |
| X_ood_layer, hosp_ood) |
|
|
| |
| t_acc_id = train_probe(X_train_layer, tumor_train, |
| X_id_layer, tumor_id) |
| t_acc_ood = train_probe(X_train_layer, tumor_train, |
| X_ood_layer, tumor_ood) |
|
|
| epoch_hosp_id.append(h_acc_id) |
| epoch_hosp_ood.append(h_acc_ood) |
| epoch_tumor_id.append(t_acc_id) |
| epoch_tumor_ood.append(t_acc_ood) |
|
|
| print(f" {layer_name:8s}: " |
| f"hosp_H3={h_acc_id:.3f} hosp_H4={h_acc_ood:.3f} " |
| f"tumor_H3={t_acc_id:.3f} tumor_H4={t_acc_ood:.3f}") |
|
|
| results["epochs"].append(epoch) |
| results["hospital_probe_id"].append(epoch_hosp_id) |
| results["hospital_probe_ood"].append(epoch_hosp_ood) |
| results["tumor_probe_id"].append(epoch_tumor_id) |
| results["tumor_probe_ood"].append(epoch_tumor_ood) |
|
|
| del model |
|
|
| |
| out_dir = os.path.join(run_dir, "mechinterp") |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| data_path = os.path.join(out_dir, "m1_probe_data.json") |
| with open(data_path, "w") as f: |
| json.dump(results, f, indent=2) |
| print(f"\n Probe data β {data_path}") |
|
|
| |
| _plot_probe_heatmaps(results, out_dir) |
| _plot_probe_curves(results, out_dir) |
|
|
| print(f" Figures β {out_dir}/") |
| return results |
|
|
|
|
| def _plot_probe_heatmaps(results: Dict, out_dir: str): |
| """ |
| Epoch (x) Γ layer (y), color = probe accuracy. |
| |
| Hospital probe shown on H3 (held-in held-out hospital, classes overlap |
| with training). The H4 version is degenerate by construction since the |
| probe is fit on the training-hospital class set and H4 is not in it |
| (hospital_probe_ood β‘ 0 across all epochs / layers). |
| |
| Tumor probe shown on H4 (truly OOD hospital) since tumor labels are |
| binary and shared across hospitals β H4 captures the causal-feature |
| transferability we care about. |
| """ |
| epochs = results["epochs"] |
| layers = results["layers"] |
|
|
| if not epochs: |
| return |
|
|
| hosp_matrix = np.array(results["hospital_probe_id"]) |
| tumor_matrix = np.array(results["tumor_probe_ood"]) |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
|
|
| for ax, matrix, title, cmap in [ |
| (axes[0], hosp_matrix, "Hospital probe on H3 (shortcut recoverability)\nHigh = stain still encoded = BAD", "Reds"), |
| (axes[1], tumor_matrix, "Tumor probe on H4 (causal, OOD)\nHigh = causal feature transfers = GOOD", "Greens"), |
| ]: |
| im = ax.imshow( |
| matrix.T, |
| aspect="auto", |
| cmap=cmap, |
| vmin=0.0, vmax=1.0, |
| interpolation="nearest", |
| origin="lower", |
| ) |
| ax.set_xticks(range(len(epochs))) |
| ax.set_xticklabels(epochs, rotation=45, ha="right", fontsize=8) |
| ax.set_yticks(range(len(layers))) |
| ax.set_yticklabels(layers, fontsize=9) |
| ax.set_xlabel("Training epoch") |
| ax.set_ylabel("ResNet layer") |
| ax.set_title(title, fontsize=10, fontweight="bold") |
| plt.colorbar(im, ax=ax, label="Probe accuracy") |
|
|
| fig.suptitle( |
| f"M1 β Layer-wise Linear Probing: {results['run_id']}\n" |
| "Circuit signature: deep-layer hospital-probe drop (Reds) + sustained tumor recoverability (Greens)", |
| fontsize=10, y=1.02 |
| ) |
| plt.tight_layout() |
| out = os.path.join(out_dir, "m1_probe_heatmap.png") |
| plt.savefig(out, bbox_inches="tight") |
| plt.close() |
|
|
|
|
| def _plot_probe_curves(results: Dict, out_dir: str): |
| """ |
| Line plot per-layer: hospital probe (H3 β recoverability) + tumor probe |
| (H4 β causal-feature transfer to truly unseen hospital), with OOD accuracy |
| from history.json overlaid. |
| """ |
| epochs = results["epochs"] |
| run_dir = os.path.join(out_dir, "..") |
| layers = results["layers"] |
| avgpool_idx = layers.index("avgpool") |
| layer2_idx = layers.index("layer2") |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
|
|
| for ax, layer_idx, layer_label in [ |
| (axes[0], avgpool_idx, "avgpool (penultimate)"), |
| (axes[1], layer2_idx, "layer2 (early)"), |
| ]: |
| hosp = [results["hospital_probe_id"][i][layer_idx] |
| for i in range(len(epochs))] |
| tumor = [results["tumor_probe_ood"][i][layer_idx] |
| for i in range(len(epochs))] |
|
|
| ax.plot(epochs, hosp, "r-o", markersize=4, lw=2, |
| label="Hospital probe on H3 (shortcut recoverability β want)") |
| ax.plot(epochs, tumor, "g-s", markersize=4, lw=2, |
| label="Tumor probe on H4 (causal transfer β want)") |
|
|
| hist_path = os.path.join(run_dir, "results", "history.json") |
| if os.path.isfile(hist_path): |
| try: |
| hist = json.load(open(hist_path)) |
| hist_eps = [r["epoch"] for r in hist] |
| ood_accs = [r.get("ood_acc", float("nan")) for r in hist] |
| ax.plot(hist_eps, ood_accs, "b--", lw=1.5, alpha=0.7, |
| label="OOD accuracy (H4)") |
| except Exception: |
| pass |
|
|
| ax.axhline(0.5, color="gray", ls=":", lw=1, alpha=0.5, |
| label="Chance (0.5)") |
| ax.set_xlabel("Training epoch") |
| ax.set_ylabel("Probe / OOD accuracy") |
| ax.set_title(f"Layer: {layer_label}", fontweight="bold") |
| ax.legend(fontsize=9) |
| ax.set_ylim([0, 1.05]) |
| ax.grid(alpha=0.3) |
|
|
| fig.suptitle( |
| f"M1 β Probe Curves: {results['run_id']}\n" |
| "Hospital recoverability (H3) drops in deep layers + tumor transfers (H4) β circuit signature", |
| fontsize=10, y=1.02 |
| ) |
| plt.tight_layout() |
| out = os.path.join(out_dir, "m1_probe_curves.png") |
| plt.savefig(out, bbox_inches="tight") |
| plt.close() |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser( |
| description="M1: Layer-wise linear probing for CausalGrok") |
| p.add_argument("--run_dir", default=None, |
| help="Single run directory to analyze") |
| p.add_argument("--all_runs", action="store_true", |
| help="Analyze all camelyon_v2 grokking runs") |
| p.add_argument("--data_root", default="data/wilds") |
| p.add_argument("--device", default="cuda") |
| p.add_argument("--max_samples", type=int, default=800) |
| p.add_argument("--latest_only", action="store_true", |
| help="Analyze only latest checkpoint (quick check)") |
| args = p.parse_args() |
|
|
| if args.all_runs: |
| run_dirs = sorted(glob.glob( |
| "experiments/runs/*camelyon_v2*grokking*")) |
| print(f"Found {len(run_dirs)} grokking runs") |
| all_results = [] |
| for rd in run_dirs: |
| r = run_probe_analysis(rd, args.data_root, |
| device=args.device, |
| max_samples=args.max_samples, |
| latest_only=args.latest_only) |
| if r: |
| all_results.append(r) |
|
|
| if all_results: |
| os.makedirs("paper_figures", exist_ok=True) |
| with open("paper_figures/m1_all_probes.json", "w") as f: |
| json.dump(all_results, f, indent=2) |
| print(f"\nCombined β paper_figures/m1_all_probes.json") |
|
|
| elif args.run_dir: |
| run_probe_analysis(args.run_dir, args.data_root, |
| device=args.device, |
| max_samples=args.max_samples, |
| latest_only=args.latest_only) |
| else: |
| print("Specify --run_dir <path> or --all_runs") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|