| """ |
| Full Spectrum Probe |
| ==================== |
| Colab Cell 4 of 4 - depends on Cells 1-3 namespace. |
| |
| Downloads trained checkpoint from HF and runs comprehensive evaluation: |
| - Per-class precision/recall/F1 across all 27 shapes |
| - Confusion matrix (predicted vs actual) |
| - Accuracy by scene complexity (2, 3, 4 shapes) |
| - Per-patch breakdown: dims, curvature, topology, gates |
| - Global vs patch agreement |
| - Occupancy-binned accuracy (sparse vs dense) |
| - Per-gate accuracy (rigid/curved/combined/open/closed) |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from torch.utils.data import DataLoader |
| from collections import defaultdict |
| import json |
|
|
| |
| |
| |
|
|
| |
| PROBE_SAMPLES = 20000 |
| PROBE_SEED = 9999 |
| PROBE_BATCH = 512 |
| CHECKPOINT_URL = "https://huggingface.co/AbstractPhil/grid-geometric-multishape/resolve/main/checkpoints/best_model_epoch200.pt" |
| CHECKPOINT_LOCAL = "/tmp/probe_checkpoint.pt" |
|
|
|
|
| |
| def download_checkpoint(): |
| """Download checkpoint from HF.""" |
| import urllib.request |
| from pathlib import Path |
| if not Path(CHECKPOINT_LOCAL).exists(): |
| print(f"Downloading checkpoint...") |
| urllib.request.urlretrieve(CHECKPOINT_URL, CHECKPOINT_LOCAL) |
| print(f"✓ Checkpoint ready: {CHECKPOINT_LOCAL}") |
| return CHECKPOINT_LOCAL |
|
|
|
|
| |
| def load_model(device): |
| """Load trained model from checkpoint.""" |
| ckpt_path = download_checkpoint() |
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) |
|
|
| cfg = ckpt.get("config", {}) |
| model = SuperpositionPatchClassifier( |
| embed_dim=cfg.get("embed_dim", 128), |
| patch_dim=cfg.get("patch_dim", 64), |
| n_layers=cfg.get("n_layers", 4), |
| n_heads=cfg.get("n_heads", 4), |
| dropout=0.0, |
| ).to(device) |
|
|
| model.load_state_dict(ckpt["model_state_dict"]) |
| model.eval() |
|
|
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"✓ Model loaded: {n_params:,} params, epoch {ckpt.get('epoch', '?')}, train recall {ckpt.get('recall', '?'):.4f}") |
| return model |
|
|
|
|
| |
| def generate_test_data(device): |
| """Generate fresh test set with distinct seed.""" |
| print(f"\nGenerating {PROBE_SAMPLES} test samples (seed={PROBE_SEED})...") |
| import time |
| t0 = time.time() |
| data = generate_dataset(PROBE_SAMPLES, seed=PROBE_SEED, num_workers=MAX_WORKERS) |
| print(f"Generated in {time.time()-t0:.1f}s") |
|
|
| grids = torch.from_numpy(data["grids"]).float().to(device) |
| memberships = torch.from_numpy(data["memberships"]).float().to(device) |
| n_shapes = data["n_shapes"] |
|
|
| with torch.no_grad(): |
| patch_data = analyze_patches_torch(grids) |
|
|
| grids, memberships = grids.cpu(), memberships.cpu() |
| patch_data = {k: v.cpu() for k, v in patch_data.items()} |
|
|
| dataset = ShapeDataset(grids, memberships, patch_data) |
| loader = DataLoader(dataset, batch_size=PROBE_BATCH, shuffle=False, collate_fn=collate_fn, num_workers=MAX_WORKERS // 2, pin_memory=True) |
|
|
| print(f"✓ Test set ready: {len(dataset)} samples") |
| return loader, n_shapes, memberships |
|
|
|
|
| |
|
|
| def run_probe(model, loader, n_shapes_arr, memberships_all, device): |
| """Run all probes and collect results.""" |
|
|
| |
| all_pred_global = [] |
| all_true_global = [] |
| all_pred_patch = [] |
| all_true_patch = [] |
| all_patch_occ = [] |
| all_patch_dims_pred = [] |
| all_patch_dims_true = [] |
| all_patch_curv_pred = [] |
| all_patch_curv_true = [] |
| all_patch_topo_pred = [] |
| all_patch_topo_true = [] |
| all_patch_gate_pred = [] |
| all_patch_gate_true = [] |
| all_global_gate_pred = [] |
| all_global_gate_true = [] |
| all_n_shapes = [] |
|
|
| sample_idx = 0 |
| print("\nRunning inference...") |
| with torch.no_grad(): |
| for batch in loader: |
| batch_dev = {k: v.to(device) for k, v in batch.items()} |
| outputs = model(batch_dev["grid"]) |
| B = batch["grid"].shape[0] |
|
|
| |
| pred_global = (torch.sigmoid(outputs["global_shapes"]) > 0.5).cpu() |
| true_global = batch["global_shapes"] |
| all_pred_global.append(pred_global) |
| all_true_global.append(true_global) |
|
|
| |
| pred_patch = (torch.sigmoid(outputs["patch_shape_logits"]) > 0.5).cpu() |
| true_patch = batch["patch_shape_membership"] |
| occ = batch["patch_occupancy"] |
| all_pred_patch.append(pred_patch) |
| all_true_patch.append(true_patch) |
| all_patch_occ.append(occ) |
|
|
| |
| all_patch_dims_pred.append(outputs["patch_dim_logits"].argmax(dim=-1).cpu()) |
| all_patch_dims_true.append(batch["patch_dims"]) |
|
|
| |
| all_patch_curv_pred.append(outputs["patch_curv_logits"].argmax(dim=-1).cpu()) |
| all_patch_curv_true.append(batch["patch_curvature"]) |
|
|
| |
| all_patch_topo_pred.append(outputs["patch_topo_logits"].argmax(dim=-1).cpu()) |
| all_patch_topo_true.append(batch["patch_topology"]) |
|
|
| |
| all_patch_gate_pred.append((outputs["patch_gate_soft"] > 0.5).cpu()) |
| all_patch_gate_true.append((batch["patch_labels"] > 0.5)) |
|
|
| |
| all_global_gate_pred.append((torch.sigmoid(outputs["global_gates"]) > 0.5).cpu()) |
| all_global_gate_true.append((batch["global_gates"] > 0.5)) |
|
|
| |
| bs = B |
| all_n_shapes.append(n_shapes_arr[sample_idx:sample_idx + bs]) |
| sample_idx += bs |
|
|
| |
| pred_global = torch.cat(all_pred_global).float() |
| true_global = torch.cat(all_true_global).float() |
| pred_patch = torch.cat(all_pred_patch).float() |
| true_patch = torch.cat(all_true_patch).float() |
| occ_all = torch.cat(all_patch_occ) |
| occ_mask = occ_all > 0.01 |
| dims_pred = torch.cat(all_patch_dims_pred) |
| dims_true = torch.cat(all_patch_dims_true) |
| curv_pred = torch.cat(all_patch_curv_pred) |
| curv_true = torch.cat(all_patch_curv_true) |
| topo_pred = torch.cat(all_patch_topo_pred) |
| topo_true = torch.cat(all_patch_topo_true) |
| gate_pred = torch.cat(all_patch_gate_pred).float() |
| gate_true = torch.cat(all_patch_gate_true).float() |
| ggate_pred = torch.cat(all_global_gate_pred).float() |
| ggate_true = torch.cat(all_global_gate_true).float() |
| n_shapes_all = np.concatenate(all_n_shapes) |
|
|
| results = {} |
|
|
| |
| print("\n" + "="*70) |
| print("1. PER-CLASS GLOBAL SHAPE DETECTION") |
| print("="*70) |
| print(f"{'Class':<15} {'Prec':>7} {'Recall':>7} {'F1':>7} {'Support':>8} {'Pred+':>7}") |
|
|
| class_results = {} |
| for i, name in enumerate(CLASS_NAMES): |
| tp = (pred_global[:, i] * true_global[:, i]).sum().item() |
| fp = (pred_global[:, i] * (1 - true_global[:, i])).sum().item() |
| fn = ((1 - pred_global[:, i]) * true_global[:, i]).sum().item() |
| prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0 |
| rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0 |
| f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0 |
| support = int(true_global[:, i].sum().item()) |
| pred_pos = int(pred_global[:, i].sum().item()) |
| print(f" {name:<13} {prec:>7.3f} {rec:>7.3f} {f1:>7.3f} {support:>8} {pred_pos:>7}") |
| class_results[name] = {"precision": prec, "recall": rec, "f1": f1, "support": support} |
|
|
| macro_prec = np.mean([v["precision"] for v in class_results.values()]) |
| macro_rec = np.mean([v["recall"] for v in class_results.values()]) |
| macro_f1 = np.mean([v["f1"] for v in class_results.values()]) |
| print(f" {'MACRO AVG':<13} {macro_prec:>7.3f} {macro_rec:>7.3f} {macro_f1:>7.3f}") |
| results["per_class"] = class_results |
| results["macro"] = {"precision": macro_prec, "recall": macro_rec, "f1": macro_f1} |
|
|
| |
| print("\n" + "="*70) |
| print("2. TOP CONFUSION PAIRS (false positive rate)") |
| print("="*70) |
|
|
| confusion = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=np.float64) |
| for i in range(NUM_CLASSES): |
| for j in range(NUM_CLASSES): |
| if i == j: |
| continue |
| |
| mask = (true_global[:, j] == 1) & (true_global[:, i] == 0) |
| n = mask.sum().item() |
| if n > 0: |
| confusion[i, j] = (pred_global[mask, i]).sum().item() / n |
|
|
| |
| pairs = [] |
| for i in range(NUM_CLASSES): |
| for j in range(NUM_CLASSES): |
| if i != j and confusion[i, j] > 0.01: |
| pairs.append((confusion[i, j], CLASS_NAMES[i], CLASS_NAMES[j])) |
| pairs.sort(reverse=True) |
|
|
| print(f" {'Predicted':<15} {'When present':<15} {'FP Rate':>8}") |
| for rate, pred_name, true_name in pairs[:20]: |
| print(f" {pred_name:<15} {true_name:<15} {rate:>8.3f}") |
| results["top_confusions"] = [(p, t, r) for r, p, t in pairs[:20]] |
|
|
| |
| print("\n" + "="*70) |
| print("3. ACCURACY BY SCENE COMPLEXITY") |
| print("="*70) |
|
|
| complexity_results = {} |
| for ns in sorted(set(n_shapes_all)): |
| mask = n_shapes_all == ns |
| n = mask.sum() |
| if n == 0: |
| continue |
| p, t = pred_global[mask], true_global[mask] |
| exact_match = (p == t).all(dim=-1).float().mean().item() |
| tp = (p * t).sum().item() |
| total_true = t.sum().item() |
| rec = tp / total_true if total_true > 0 else 0.0 |
| acc = (p == t).float().mean().item() |
| print(f" {ns} shapes: n={n:>5}, exact_match={exact_match:.3f}, recall={rec:.3f}, elem_acc={acc:.3f}") |
| complexity_results[int(ns)] = {"n": int(n), "exact_match": exact_match, "recall": rec, "elem_acc": acc} |
| results["by_complexity"] = complexity_results |
|
|
| |
| print("\n" + "="*70) |
| print("4. PER-PATCH PROPERTY ACCURACY (occupied patches only)") |
| print("="*70) |
|
|
| n_occ = occ_mask.sum().item() |
| if n_occ > 0: |
| dim_acc = ((dims_pred == dims_true.clamp(0, NUM_DIMS-1)) & occ_mask).sum().item() / n_occ |
| curv_acc = ((curv_pred == curv_true.clamp(0, NUM_CURVS-1)) & occ_mask).sum().item() / n_occ |
| topo_acc = ((topo_pred == topo_true.clamp(0, NUM_TOPOS-1)) & occ_mask).sum().item() / n_occ |
|
|
| print(f" Dimensionality: {dim_acc:.4f}") |
| print(f" Curvature: {curv_acc:.4f}") |
| print(f" Topology: {topo_acc:.4f}") |
|
|
| |
| print(f"\n Dimensionality breakdown:") |
| for d in range(NUM_DIMS): |
| d_mask = (dims_true == d) & occ_mask |
| n_d = d_mask.sum().item() |
| if n_d > 0: |
| d_acc = (dims_pred[d_mask] == d).float().mean().item() |
| print(f" dim={d}: acc={d_acc:.3f} (n={int(n_d)})") |
|
|
| |
| print(f"\n Curvature breakdown:") |
| curv_names = ["rigid (fill>0.6)", "curved (fill<0.3)", "combined"] |
| for c in range(NUM_CURVS): |
| c_mask = (curv_true == c) & occ_mask |
| n_c = c_mask.sum().item() |
| if n_c > 0: |
| c_acc = (curv_pred[c_mask] == c).float().mean().item() |
| print(f" {curv_names[c]}: acc={c_acc:.3f} (n={int(n_c)})") |
|
|
| results["patch_props"] = {"dim_acc": dim_acc, "curv_acc": curv_acc, "topo_acc": topo_acc} |
| else: |
| print(" No occupied patches found!") |
| results["patch_props"] = {} |
|
|
| |
| print("\n" + "="*70) |
| print("5. PER-GATE ACCURACY") |
| print("="*70) |
|
|
| gate_results = {} |
| |
| print(" Patch-level:") |
| for g, gname in enumerate(GATES): |
| g_mask = occ_mask |
| n_g = g_mask.sum().item() |
| if n_g > 0: |
| g_acc = ((gate_pred[:, :, g] == gate_true[:, :, g]).float() * g_mask.float()).sum().item() / n_g |
| g_true_rate = (gate_true[:, :, g] * g_mask.float()).sum().item() / n_g |
| print(f" {gname:<12}: acc={g_acc:.3f}, base_rate={g_true_rate:.3f}") |
| gate_results[f"patch_{gname}"] = {"acc": g_acc, "base_rate": g_true_rate} |
|
|
| |
| print(" Global-level:") |
| for g, gname in enumerate(GATES): |
| g_acc = (ggate_pred[:, g] == ggate_true[:, g]).float().mean().item() |
| g_true_rate = ggate_true[:, g].float().mean().item() |
| print(f" {gname:<12}: acc={g_acc:.3f}, base_rate={g_true_rate:.3f}") |
| gate_results[f"global_{gname}"] = {"acc": g_acc, "base_rate": g_true_rate} |
| results["gates"] = gate_results |
|
|
| |
| print("\n" + "="*70) |
| print("6. GLOBAL vs PATCH AGREEMENT") |
| print("="*70) |
| print(" (Global says shape present, but no patch claims it)") |
|
|
| agreement = {} |
| for i, name in enumerate(CLASS_NAMES): |
| global_pos = pred_global[:, i] == 1 |
| if global_pos.sum() == 0: |
| continue |
| |
| patch_any = (pred_patch[:, :, i] * occ_mask.float()).sum(dim=1) > 0 |
| agree = (patch_any[global_pos]).float().mean().item() |
| disagree_n = int((~patch_any[global_pos]).sum().item()) |
| if agree < 0.99: |
| print(f" {name:<15}: global-patch agree={agree:.3f}, orphan_globals={disagree_n}") |
| agreement[name] = {"agreement": agree, "orphan_globals": disagree_n} |
| if not agreement: |
| print(" All classes: full agreement (>99%)") |
| results["global_patch_agreement"] = agreement |
|
|
| |
| print("\n" + "="*70) |
| print("7. ACCURACY BY PATCH OCCUPANCY") |
| print("="*70) |
|
|
| occ_flat = occ_all[occ_mask].numpy() |
| bins = [(0.01, 0.1, "sparse (1-10%)"), (0.1, 0.3, "low (10-30%)"), |
| (0.3, 0.6, "medium (30-60%)"), (0.6, 1.01, "dense (60-100%)")] |
|
|
| occ_results = {} |
| for lo, hi, label in bins: |
| bin_mask = occ_mask & (occ_all >= lo) & (occ_all < hi) |
| n_bin = bin_mask.sum().item() |
| if n_bin > 0: |
| |
| shape_match = (pred_patch == true_patch).float().mean(dim=-1) |
| bin_acc = (shape_match * bin_mask.float()).sum().item() / n_bin |
|
|
| |
| dim_bin_acc = ((dims_pred == dims_true.clamp(0, NUM_DIMS-1)) & bin_mask).sum().item() / n_bin |
|
|
| print(f" {label:<20}: n={int(n_bin):>7}, shape_acc={bin_acc:.3f}, dim_acc={dim_bin_acc:.3f}") |
| occ_results[label] = {"n": int(n_bin), "shape_acc": bin_acc, "dim_acc": dim_bin_acc} |
| results["by_occupancy"] = occ_results |
|
|
| |
| print("\n" + "="*70) |
| print("8. PER-CLASS PATCH-LEVEL DETECTION (occupied patches)") |
| print("="*70) |
| print(f" {'Class':<15} {'Prec':>7} {'Recall':>7} {'F1':>7} {'Support':>8}") |
|
|
| patch_class_results = {} |
| for i, name in enumerate(CLASS_NAMES): |
| tp = (pred_patch[:, :, i] * true_patch[:, :, i] * occ_mask.float()).sum().item() |
| fp = (pred_patch[:, :, i] * (1 - true_patch[:, :, i]) * occ_mask.float()).sum().item() |
| fn = ((1 - pred_patch[:, :, i]) * true_patch[:, :, i] * occ_mask.float()).sum().item() |
| prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0 |
| rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0 |
| f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0 |
| support = int((true_patch[:, :, i] * occ_mask.float()).sum().item()) |
| print(f" {name:<13} {prec:>7.3f} {rec:>7.3f} {f1:>7.3f} {support:>8}") |
| patch_class_results[name] = {"precision": prec, "recall": rec, "f1": f1, "support": support} |
| results["patch_per_class"] = patch_class_results |
|
|
| |
| print("\n" + "="*70) |
| print("SUMMARY") |
| print("="*70) |
| overall_recall = (pred_global * true_global).sum().item() / true_global.sum().clamp(min=1).item() |
| overall_prec = (pred_global * true_global).sum().item() / pred_global.sum().clamp(min=1).item() |
| exact_match = (pred_global == true_global).all(dim=-1).float().mean().item() |
| print(f" Global shape recall: {overall_recall:.4f}") |
| print(f" Global shape precision: {overall_prec:.4f}") |
| print(f" Global exact match: {exact_match:.4f}") |
| print(f" Macro F1: {macro_f1:.4f}") |
| if n_occ > 0: |
| print(f" Patch dim accuracy: {dim_acc:.4f}") |
| print(f" Patch curv accuracy: {curv_acc:.4f}") |
| print(f" Patch topo accuracy: {topo_acc:.4f}") |
| results["summary"] = { |
| "global_recall": overall_recall, "global_precision": overall_prec, |
| "exact_match": exact_match, "macro_f1": macro_f1, |
| } |
|
|
| return results |
|
|
|
|
| |
| def run(): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Device: {device}") |
|
|
| model = load_model(device) |
| loader, n_shapes, memberships = generate_test_data(device) |
| results = run_probe(model, loader, n_shapes, memberships, device) |
|
|
| |
| results_path = "/tmp/probe_results.json" |
|
|
| def make_serializable(obj): |
| if isinstance(obj, dict): |
| return {k: make_serializable(v) for k, v in obj.items()} |
| elif isinstance(obj, (list, tuple)): |
| return [make_serializable(x) for x in obj] |
| elif isinstance(obj, (np.integer,)): |
| return int(obj) |
| elif isinstance(obj, (np.floating, np.float64)): |
| return float(obj) |
| return obj |
|
|
| with open(results_path, "w") as f: |
| json.dump(make_serializable(results), f, indent=2) |
| print(f"\n✓ Results saved to {results_path}") |
|
|
| |
| try: |
| upload_file( |
| path_or_fileobj=results_path, |
| path_in_repo="probe/probe_results.json", |
| repo_id=REPO_ID, |
| token=HF_TOKEN, |
| commit_message=f"Full spectrum probe: {PROBE_SAMPLES} samples" |
| ) |
| print(f"✓ Results uploaded to HF: {REPO_ID}/probe/probe_results.json") |
| except Exception as e: |
| print(f"✗ Upload failed: {e}") |
|
|
| return results |
|
|
|
|
| results = run() |
| print("\n✓ Probe complete") |