grid-geometric-multishape / legacy /cell4_probe_v1.py
AbstractPhil's picture
Rename cell4_probe_v1.py to legacy/cell4_probe_v1.py
15e8f85 verified
"""
Full Spectrum Probe
====================
Colab Cell 4 of 4 - depends on Cells 1-3 namespace.
Downloads trained checkpoint from HF and runs comprehensive evaluation:
- Per-class precision/recall/F1 across all 27 shapes
- Confusion matrix (predicted vs actual)
- Accuracy by scene complexity (2, 3, 4 shapes)
- Per-patch breakdown: dims, curvature, topology, gates
- Global vs patch agreement
- Occupancy-binned accuracy (sparse vs dense)
- Per-gate accuracy (rigid/curved/combined/open/closed)
"""
import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from collections import defaultdict
import json
# Cell 1: constants, generate_dataset, analyze_patches_torch, ShapeDataset, collate_fn
# Cell 2: SuperpositionPatchClassifier, SuperpositionLoss, NUM_DIMS, NUM_CURVS, NUM_TOPOS
# Cell 3: compute_metrics, TrainerConfig, SuperpositionTrainer, upload_checkpoint
# === Config ===================================================================
PROBE_SAMPLES = 20000 # fresh test set (different seed from training)
PROBE_SEED = 9999 # distinct from training seed=42
PROBE_BATCH = 512
CHECKPOINT_URL = "https://huggingface.co/AbstractPhil/grid-geometric-multishape/resolve/main/checkpoints/best_model_epoch200.pt"
CHECKPOINT_LOCAL = "/tmp/probe_checkpoint.pt"
# === Download Checkpoint ======================================================
def download_checkpoint():
"""Download checkpoint from HF."""
import urllib.request
from pathlib import Path
if not Path(CHECKPOINT_LOCAL).exists():
print(f"Downloading checkpoint...")
urllib.request.urlretrieve(CHECKPOINT_URL, CHECKPOINT_LOCAL)
print(f"✓ Checkpoint ready: {CHECKPOINT_LOCAL}")
return CHECKPOINT_LOCAL
# === Load Model ===============================================================
def load_model(device):
"""Load trained model from checkpoint."""
ckpt_path = download_checkpoint()
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
cfg = ckpt.get("config", {})
model = SuperpositionPatchClassifier(
embed_dim=cfg.get("embed_dim", 128),
patch_dim=cfg.get("patch_dim", 64),
n_layers=cfg.get("n_layers", 4),
n_heads=cfg.get("n_heads", 4),
dropout=0.0, # no dropout at eval
).to(device)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
n_params = sum(p.numel() for p in model.parameters())
print(f"✓ Model loaded: {n_params:,} params, epoch {ckpt.get('epoch', '?')}, train recall {ckpt.get('recall', '?'):.4f}")
return model
# === Generate Test Data =======================================================
def generate_test_data(device):
"""Generate fresh test set with distinct seed."""
print(f"\nGenerating {PROBE_SAMPLES} test samples (seed={PROBE_SEED})...")
import time
t0 = time.time()
data = generate_dataset(PROBE_SAMPLES, seed=PROBE_SEED, num_workers=MAX_WORKERS)
print(f"Generated in {time.time()-t0:.1f}s")
grids = torch.from_numpy(data["grids"]).float().to(device)
memberships = torch.from_numpy(data["memberships"]).float().to(device)
n_shapes = data["n_shapes"]
with torch.no_grad():
patch_data = analyze_patches_torch(grids)
grids, memberships = grids.cpu(), memberships.cpu()
patch_data = {k: v.cpu() for k, v in patch_data.items()}
dataset = ShapeDataset(grids, memberships, patch_data)
loader = DataLoader(dataset, batch_size=PROBE_BATCH, shuffle=False, collate_fn=collate_fn, num_workers=MAX_WORKERS // 2, pin_memory=True)
print(f"✓ Test set ready: {len(dataset)} samples")
return loader, n_shapes, memberships
# === Probes ===================================================================
def run_probe(model, loader, n_shapes_arr, memberships_all, device):
"""Run all probes and collect results."""
# Accumulators
all_pred_global = []
all_true_global = []
all_pred_patch = []
all_true_patch = []
all_patch_occ = []
all_patch_dims_pred = []
all_patch_dims_true = []
all_patch_curv_pred = []
all_patch_curv_true = []
all_patch_topo_pred = []
all_patch_topo_true = []
all_patch_gate_pred = []
all_patch_gate_true = []
all_global_gate_pred = []
all_global_gate_true = []
all_n_shapes = []
sample_idx = 0
print("\nRunning inference...")
with torch.no_grad():
for batch in loader:
batch_dev = {k: v.to(device) for k, v in batch.items()}
outputs = model(batch_dev["grid"])
B = batch["grid"].shape[0]
# Global shapes
pred_global = (torch.sigmoid(outputs["global_shapes"]) > 0.5).cpu()
true_global = batch["global_shapes"]
all_pred_global.append(pred_global)
all_true_global.append(true_global)
# Patch shapes
pred_patch = (torch.sigmoid(outputs["patch_shape_logits"]) > 0.5).cpu()
true_patch = batch["patch_shape_membership"]
occ = batch["patch_occupancy"]
all_pred_patch.append(pred_patch)
all_true_patch.append(true_patch)
all_patch_occ.append(occ)
# Patch dims
all_patch_dims_pred.append(outputs["patch_dim_logits"].argmax(dim=-1).cpu())
all_patch_dims_true.append(batch["patch_dims"])
# Patch curvature
all_patch_curv_pred.append(outputs["patch_curv_logits"].argmax(dim=-1).cpu())
all_patch_curv_true.append(batch["patch_curvature"])
# Patch topology
all_patch_topo_pred.append(outputs["patch_topo_logits"].argmax(dim=-1).cpu())
all_patch_topo_true.append(batch["patch_topology"])
# Patch gates
all_patch_gate_pred.append((outputs["patch_gate_soft"] > 0.5).cpu())
all_patch_gate_true.append((batch["patch_labels"] > 0.5))
# Global gates
all_global_gate_pred.append((torch.sigmoid(outputs["global_gates"]) > 0.5).cpu())
all_global_gate_true.append((batch["global_gates"] > 0.5))
# N shapes for this batch
bs = B
all_n_shapes.append(n_shapes_arr[sample_idx:sample_idx + bs])
sample_idx += bs
# Concatenate
pred_global = torch.cat(all_pred_global).float()
true_global = torch.cat(all_true_global).float()
pred_patch = torch.cat(all_pred_patch).float()
true_patch = torch.cat(all_true_patch).float()
occ_all = torch.cat(all_patch_occ)
occ_mask = occ_all > 0.01
dims_pred = torch.cat(all_patch_dims_pred)
dims_true = torch.cat(all_patch_dims_true)
curv_pred = torch.cat(all_patch_curv_pred)
curv_true = torch.cat(all_patch_curv_true)
topo_pred = torch.cat(all_patch_topo_pred)
topo_true = torch.cat(all_patch_topo_true)
gate_pred = torch.cat(all_patch_gate_pred).float()
gate_true = torch.cat(all_patch_gate_true).float()
ggate_pred = torch.cat(all_global_gate_pred).float()
ggate_true = torch.cat(all_global_gate_true).float()
n_shapes_all = np.concatenate(all_n_shapes)
results = {}
# === 1. Per-class precision/recall/F1 (global) ============================
print("\n" + "="*70)
print("1. PER-CLASS GLOBAL SHAPE DETECTION")
print("="*70)
print(f"{'Class':<15} {'Prec':>7} {'Recall':>7} {'F1':>7} {'Support':>8} {'Pred+':>7}")
class_results = {}
for i, name in enumerate(CLASS_NAMES):
tp = (pred_global[:, i] * true_global[:, i]).sum().item()
fp = (pred_global[:, i] * (1 - true_global[:, i])).sum().item()
fn = ((1 - pred_global[:, i]) * true_global[:, i]).sum().item()
prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
support = int(true_global[:, i].sum().item())
pred_pos = int(pred_global[:, i].sum().item())
print(f" {name:<13} {prec:>7.3f} {rec:>7.3f} {f1:>7.3f} {support:>8} {pred_pos:>7}")
class_results[name] = {"precision": prec, "recall": rec, "f1": f1, "support": support}
macro_prec = np.mean([v["precision"] for v in class_results.values()])
macro_rec = np.mean([v["recall"] for v in class_results.values()])
macro_f1 = np.mean([v["f1"] for v in class_results.values()])
print(f" {'MACRO AVG':<13} {macro_prec:>7.3f} {macro_rec:>7.3f} {macro_f1:>7.3f}")
results["per_class"] = class_results
results["macro"] = {"precision": macro_prec, "recall": macro_rec, "f1": macro_f1}
# === 2. Confusion pairs (top false positives) =============================
print("\n" + "="*70)
print("2. TOP CONFUSION PAIRS (false positive rate)")
print("="*70)
confusion = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=np.float64)
for i in range(NUM_CLASSES):
for j in range(NUM_CLASSES):
if i == j:
continue
# When class j is present but not i, how often do we predict i?
mask = (true_global[:, j] == 1) & (true_global[:, i] == 0)
n = mask.sum().item()
if n > 0:
confusion[i, j] = (pred_global[mask, i]).sum().item() / n
# Top 20 confusion pairs
pairs = []
for i in range(NUM_CLASSES):
for j in range(NUM_CLASSES):
if i != j and confusion[i, j] > 0.01:
pairs.append((confusion[i, j], CLASS_NAMES[i], CLASS_NAMES[j]))
pairs.sort(reverse=True)
print(f" {'Predicted':<15} {'When present':<15} {'FP Rate':>8}")
for rate, pred_name, true_name in pairs[:20]:
print(f" {pred_name:<15} {true_name:<15} {rate:>8.3f}")
results["top_confusions"] = [(p, t, r) for r, p, t in pairs[:20]]
# === 3. Accuracy by scene complexity ======================================
print("\n" + "="*70)
print("3. ACCURACY BY SCENE COMPLEXITY")
print("="*70)
complexity_results = {}
for ns in sorted(set(n_shapes_all)):
mask = n_shapes_all == ns
n = mask.sum()
if n == 0:
continue
p, t = pred_global[mask], true_global[mask]
exact_match = (p == t).all(dim=-1).float().mean().item()
tp = (p * t).sum().item()
total_true = t.sum().item()
rec = tp / total_true if total_true > 0 else 0.0
acc = (p == t).float().mean().item()
print(f" {ns} shapes: n={n:>5}, exact_match={exact_match:.3f}, recall={rec:.3f}, elem_acc={acc:.3f}")
complexity_results[int(ns)] = {"n": int(n), "exact_match": exact_match, "recall": rec, "elem_acc": acc}
results["by_complexity"] = complexity_results
# === 4. Per-patch property accuracy =======================================
print("\n" + "="*70)
print("4. PER-PATCH PROPERTY ACCURACY (occupied patches only)")
print("="*70)
n_occ = occ_mask.sum().item()
if n_occ > 0:
dim_acc = ((dims_pred == dims_true.clamp(0, NUM_DIMS-1)) & occ_mask).sum().item() / n_occ
curv_acc = ((curv_pred == curv_true.clamp(0, NUM_CURVS-1)) & occ_mask).sum().item() / n_occ
topo_acc = ((topo_pred == topo_true.clamp(0, NUM_TOPOS-1)) & occ_mask).sum().item() / n_occ
print(f" Dimensionality: {dim_acc:.4f}")
print(f" Curvature: {curv_acc:.4f}")
print(f" Topology: {topo_acc:.4f}")
# Per-dim class breakdown
print(f"\n Dimensionality breakdown:")
for d in range(NUM_DIMS):
d_mask = (dims_true == d) & occ_mask
n_d = d_mask.sum().item()
if n_d > 0:
d_acc = (dims_pred[d_mask] == d).float().mean().item()
print(f" dim={d}: acc={d_acc:.3f} (n={int(n_d)})")
# Per-curvature breakdown
print(f"\n Curvature breakdown:")
curv_names = ["rigid (fill>0.6)", "curved (fill<0.3)", "combined"]
for c in range(NUM_CURVS):
c_mask = (curv_true == c) & occ_mask
n_c = c_mask.sum().item()
if n_c > 0:
c_acc = (curv_pred[c_mask] == c).float().mean().item()
print(f" {curv_names[c]}: acc={c_acc:.3f} (n={int(n_c)})")
results["patch_props"] = {"dim_acc": dim_acc, "curv_acc": curv_acc, "topo_acc": topo_acc}
else:
print(" No occupied patches found!")
results["patch_props"] = {}
# === 5. Per-gate accuracy =================================================
print("\n" + "="*70)
print("5. PER-GATE ACCURACY")
print("="*70)
gate_results = {}
# Patch-level gates
print(" Patch-level:")
for g, gname in enumerate(GATES):
g_mask = occ_mask
n_g = g_mask.sum().item()
if n_g > 0:
g_acc = ((gate_pred[:, :, g] == gate_true[:, :, g]).float() * g_mask.float()).sum().item() / n_g
g_true_rate = (gate_true[:, :, g] * g_mask.float()).sum().item() / n_g
print(f" {gname:<12}: acc={g_acc:.3f}, base_rate={g_true_rate:.3f}")
gate_results[f"patch_{gname}"] = {"acc": g_acc, "base_rate": g_true_rate}
# Global-level gates
print(" Global-level:")
for g, gname in enumerate(GATES):
g_acc = (ggate_pred[:, g] == ggate_true[:, g]).float().mean().item()
g_true_rate = ggate_true[:, g].float().mean().item()
print(f" {gname:<12}: acc={g_acc:.3f}, base_rate={g_true_rate:.3f}")
gate_results[f"global_{gname}"] = {"acc": g_acc, "base_rate": g_true_rate}
results["gates"] = gate_results
# === 6. Global vs patch agreement =========================================
print("\n" + "="*70)
print("6. GLOBAL vs PATCH AGREEMENT")
print("="*70)
print(" (Global says shape present, but no patch claims it)")
agreement = {}
for i, name in enumerate(CLASS_NAMES):
global_pos = pred_global[:, i] == 1 # global predicts shape present
if global_pos.sum() == 0:
continue
# For each sample where global says present, check if any occupied patch also says present
patch_any = (pred_patch[:, :, i] * occ_mask.float()).sum(dim=1) > 0
agree = (patch_any[global_pos]).float().mean().item()
disagree_n = int((~patch_any[global_pos]).sum().item())
if agree < 0.99:
print(f" {name:<15}: global-patch agree={agree:.3f}, orphan_globals={disagree_n}")
agreement[name] = {"agreement": agree, "orphan_globals": disagree_n}
if not agreement:
print(" All classes: full agreement (>99%)")
results["global_patch_agreement"] = agreement
# === 7. Occupancy-binned accuracy =========================================
print("\n" + "="*70)
print("7. ACCURACY BY PATCH OCCUPANCY")
print("="*70)
occ_flat = occ_all[occ_mask].numpy()
bins = [(0.01, 0.1, "sparse (1-10%)"), (0.1, 0.3, "low (10-30%)"),
(0.3, 0.6, "medium (30-60%)"), (0.6, 1.01, "dense (60-100%)")]
occ_results = {}
for lo, hi, label in bins:
bin_mask = occ_mask & (occ_all >= lo) & (occ_all < hi)
n_bin = bin_mask.sum().item()
if n_bin > 0:
# Shape classification accuracy for these patches
shape_match = (pred_patch == true_patch).float().mean(dim=-1)
bin_acc = (shape_match * bin_mask.float()).sum().item() / n_bin
# Dim accuracy
dim_bin_acc = ((dims_pred == dims_true.clamp(0, NUM_DIMS-1)) & bin_mask).sum().item() / n_bin
print(f" {label:<20}: n={int(n_bin):>7}, shape_acc={bin_acc:.3f}, dim_acc={dim_bin_acc:.3f}")
occ_results[label] = {"n": int(n_bin), "shape_acc": bin_acc, "dim_acc": dim_bin_acc}
results["by_occupancy"] = occ_results
# === 8. Per-class patch-level recall ======================================
print("\n" + "="*70)
print("8. PER-CLASS PATCH-LEVEL DETECTION (occupied patches)")
print("="*70)
print(f" {'Class':<15} {'Prec':>7} {'Recall':>7} {'F1':>7} {'Support':>8}")
patch_class_results = {}
for i, name in enumerate(CLASS_NAMES):
tp = (pred_patch[:, :, i] * true_patch[:, :, i] * occ_mask.float()).sum().item()
fp = (pred_patch[:, :, i] * (1 - true_patch[:, :, i]) * occ_mask.float()).sum().item()
fn = ((1 - pred_patch[:, :, i]) * true_patch[:, :, i] * occ_mask.float()).sum().item()
prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0
support = int((true_patch[:, :, i] * occ_mask.float()).sum().item())
print(f" {name:<13} {prec:>7.3f} {rec:>7.3f} {f1:>7.3f} {support:>8}")
patch_class_results[name] = {"precision": prec, "recall": rec, "f1": f1, "support": support}
results["patch_per_class"] = patch_class_results
# === Summary ==============================================================
print("\n" + "="*70)
print("SUMMARY")
print("="*70)
overall_recall = (pred_global * true_global).sum().item() / true_global.sum().clamp(min=1).item()
overall_prec = (pred_global * true_global).sum().item() / pred_global.sum().clamp(min=1).item()
exact_match = (pred_global == true_global).all(dim=-1).float().mean().item()
print(f" Global shape recall: {overall_recall:.4f}")
print(f" Global shape precision: {overall_prec:.4f}")
print(f" Global exact match: {exact_match:.4f}")
print(f" Macro F1: {macro_f1:.4f}")
if n_occ > 0:
print(f" Patch dim accuracy: {dim_acc:.4f}")
print(f" Patch curv accuracy: {curv_acc:.4f}")
print(f" Patch topo accuracy: {topo_acc:.4f}")
results["summary"] = {
"global_recall": overall_recall, "global_precision": overall_prec,
"exact_match": exact_match, "macro_f1": macro_f1,
}
return results
# === Run ======================================================================
def run():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
model = load_model(device)
loader, n_shapes, memberships = generate_test_data(device)
results = run_probe(model, loader, n_shapes, memberships, device)
# Save results
results_path = "/tmp/probe_results.json"
def make_serializable(obj):
if isinstance(obj, dict):
return {k: make_serializable(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return [make_serializable(x) for x in obj]
elif isinstance(obj, (np.integer,)):
return int(obj)
elif isinstance(obj, (np.floating, np.float64)):
return float(obj)
return obj
with open(results_path, "w") as f:
json.dump(make_serializable(results), f, indent=2)
print(f"\n✓ Results saved to {results_path}")
# Upload to HF
try:
upload_file(
path_or_fileobj=results_path,
path_in_repo="probe/probe_results.json",
repo_id=REPO_ID,
token=HF_TOKEN,
commit_message=f"Full spectrum probe: {PROBE_SAMPLES} samples"
)
print(f"✓ Results uploaded to HF: {REPO_ID}/probe/probe_results.json")
except Exception as e:
print(f"✗ Upload failed: {e}")
return results
results = run()
print("\n✓ Probe complete")