""" Morse Field Detection (MFD). Project 768-dim patch tokens to K scalar potential fields via a linear projection. Objects are peaks in the potential landscape. Bounding boxes from Hessian eigenvalues via the Morse lemma. No box regression — boxes emerge from curvature. Learned: W_psi (768 x K) projection + class prototypes Fixed: Hessian (finite-difference kernels), peak detection, box extraction """ import json, os, sys, time, math import torch import torch.nn as nn import torch.nn.functional as F SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, SCRIPT_DIR) COCO_ROOT = os.environ.get("ARENA_COCO_ROOT") VAL_CACHE = os.environ.get("ARENA_VAL_CACHE") CACHE_DIR = os.environ.get("ARENA_CACHE_DIR") DEVICE = "cuda" RESOLUTION = 640 NUM_CLASSES = 80 STRIDE = 16 def cofiber_decompose(f, n_scales): cofibers = []; residual = f for _ in range(n_scales - 1): omega = F.avg_pool2d(residual, 2) sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False) cofibers.append(residual - sigma_omega); residual = omega cofibers.append(residual); return cofibers class MorseFieldDetector(nn.Module): """Morse theory detection head. Boxes from curvature, not regression.""" def __init__(self, feat_dim=768, n_fields=3, num_classes=80): super().__init__() # The only learned spatial component: project features to scalar fields self.field_proj = nn.Linear(feat_dim, n_fields, bias=False) # Class prototypes for classification at detected peaks self.cls_prototypes = nn.Parameter(torch.randn(num_classes, feat_dim) * 0.01) self.cls_bias = nn.Parameter(torch.zeros(num_classes)) # Fixed finite-difference Hessian kernels (Sobel-style) # d²f/dx² hxx = torch.tensor([[0, 0, 0], [1, -2, 1], [0, 0, 0]], dtype=torch.float32).unsqueeze(0).unsqueeze(0) # d²f/dy² hyy = torch.tensor([[0, 1, 0], [0, -2, 0], [0, 1, 0]], dtype=torch.float32).unsqueeze(0).unsqueeze(0) # d²f/dxdy hxy = torch.tensor([[1, 0, -1], [0, 0, 0], [-1, 0, 1]], dtype=torch.float32).unsqueeze(0).unsqueeze(0) * 0.25 self.register_buffer("hxx_kernel", hxx) self.register_buffer("hyy_kernel", hyy) self.register_buffer("hxy_kernel", hxy) def compute_hessian(self, field): """Compute Hessian components of a scalar field. field: (B, 1, H, W).""" fxx = F.conv2d(field, self.hxx_kernel, padding=1) fyy = F.conv2d(field, self.hyy_kernel, padding=1) fxy = F.conv2d(field, self.hxy_kernel, padding=1) return fxx, fyy, fxy def forward_detect(self, spatial, scale=1.0): """Run detection on one image. Returns boxes, scores, classes.""" B, C, H, W = spatial.shape assert B == 1 f = F.layer_norm(spatial.permute(0, 2, 3, 1).reshape(-1, C), [C]) # (H*W, 768) # Project to scalar potential fields fields = self.field_proj(f).reshape(1, H, W, -1).permute(0, 3, 1, 2) # (1, K, H, W) # Classification scores at every location cls_scores = (f @ self.cls_prototypes.T + self.cls_bias).reshape(1, H, W, -1).permute(0, 3, 1, 2) # (1, 80, H, W) all_boxes = [] all_scores = [] all_classes = [] n_fields = fields.shape[1] for k in range(n_fields): field = fields[:, k:k+1] # (1, 1, H, W) # Compute Hessian fxx, fyy, fxy = self.compute_hessian(field) # Determinant and trace of Hessian det_H = fxx * fyy - fxy * fxy # (1, 1, H, W) tr_H = fxx + fyy # Objectness: peak = det(H) > 0 AND tr(H) < 0 (local maximum) objectness = torch.sigmoid(det_H * 10) * torch.sigmoid(-tr_H * 10) objectness = objectness.squeeze(0).squeeze(0) # (H, W) # Field values at each location psi = field.squeeze(0).squeeze(0) # (H, W) # Find peaks: local maxima of objectness # Use max_pool to find locations that are local maxima obj_padded = objectness.unsqueeze(0).unsqueeze(0) local_max = F.max_pool2d(obj_padded, 3, stride=1, padding=1).squeeze() is_peak = (objectness == local_max) & (objectness > 0.3) peak_locs = is_peak.nonzero(as_tuple=False) # (M, 2) — row, col for pi in range(min(len(peak_locs), 50)): r, c = peak_locs[pi] ri, ci = r.item(), c.item() # Hessian eigenvalues at this peak h11 = fxx[0, 0, ri, ci].item() h22 = fyy[0, 0, ri, ci].item() h12 = fxy[0, 0, ri, ci].item() psi_val = max(psi[ri, ci].item(), 0.01) # Eigenvalues of -H (should be positive at a maximum) neg_tr = -(h11 + h22) discriminant = (h11 - h22) ** 2 + 4 * h12 ** 2 sqrt_disc = math.sqrt(max(discriminant, 0)) lam1 = (neg_tr + sqrt_disc) / 2 lam2 = (neg_tr - sqrt_disc) / 2 if lam1 <= 0 or lam2 <= 0: continue # Morse lemma: box dimensions from curvature # w = 2 * sqrt(psi / lam2), h = 2 * sqrt(psi / lam1) box_w = 2 * math.sqrt(psi_val / lam2) * STRIDE box_h = 2 * math.sqrt(psi_val / lam1) * STRIDE # Box center in pixel coords cx = (ci + 0.5) * STRIDE cy = (ri + 0.5) * STRIDE x1 = (cx - box_w / 2) / scale y1 = (cy - box_h / 2) / scale w = box_w / scale h = box_h / scale if w < 1 or h < 1: continue # Classification at this location cls = cls_scores[0, :, ri, ci] cls_score, cls_idx = cls.sigmoid().max(0) score = objectness[ri, ci].item() * cls_score.item() if score < 0.01: continue all_boxes.append([x1, y1, w, h]) all_scores.append(score) all_classes.append(cls_idx.item()) return all_boxes, all_scores, all_classes def main(): print("=" * 60) print("Morse Field Detection (MFD)") print("=" * 60, flush=True) head = MorseFieldDetector().to(DEVICE) n_params = sum(p.numel() for p in head.parameters()) print(f" {n_params:,} params ({head.field_proj.weight.numel()} projection + " f"{head.cls_prototypes.numel() + head.cls_bias.numel()} classification)") # Initialize class prototypes from analytical solution analytical_path = os.path.join(SCRIPT_DIR, "heads", "cofiber_threshold", "analytical_70k", "analytical_head_70k.pth") if os.path.isfile(analytical_path): ckpt = torch.load(analytical_path, map_location=DEVICE, weights_only=False) head.cls_prototypes.data = ckpt["cls_weight"].to(DEVICE) head.cls_bias.data = ckpt["cls_bias"].to(DEVICE) print(" Loaded analytical class prototypes") # Initialize field projection from PCA of features # Use first 3 PCA components — the "smooth scalar fields" the paper shows print(" Computing PCA for field projection init...", flush=True) manifest = json.load(open(os.path.join(CACHE_DIR, "manifest.json"))) shard = torch.load(os.path.join(CACHE_DIR, "shard_0000.pt"), map_location="cpu", weights_only=False) sample_features = [] for item in shard[:100]: sp = item["spatial"].unsqueeze(0).float() f = F.layer_norm(sp.permute(0, 2, 3, 1).reshape(-1, 768), [768]) sample_features.append(f) sample_f = torch.cat(sample_features) _, _, Vh = torch.linalg.svd(sample_f[:10000] - sample_f[:10000].mean(0), full_matrices=False) head.field_proj.weight.data = Vh[:3].to(DEVICE) print(f" PCA-initialized field projection") del shard, sample_features, sample_f # Eval (no training — pure derived detection) val = torch.load(VAL_CACHE, map_location="cpu", weights_only=False) from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval coco_gt = COCO(os.path.join(COCO_ROOT, "annotations", "instances_val2017.json")) cat_ids = sorted(coco_gt.getCatIds()) idx_to_cat = {i: c for i, c in enumerate(cat_ids)} max_images = 1000 all_results = [] t0 = time.time() head.eval() with torch.no_grad(): for idx in range(min(max_images, len(val))): item = val[idx] spatial = item["spatial"].unsqueeze(0).float().to(DEVICE) img_id = int(item["img_id"]) scale = item["scale"] boxes, scores, classes = head.forward_detect(spatial, scale) for b, s, c in zip(boxes, scores, classes): all_results.append({ "image_id": img_id, "category_id": idx_to_cat[c], "bbox": b, "score": s, }) if (idx + 1) % 200 == 0: elapsed = time.time() - t0 print(f" {idx+1}/{max_images} ({elapsed:.0f}s, {len(all_results)} dets)", flush=True) print(f"\n{len(all_results)} total detections ({time.time()-t0:.0f}s)") if all_results: coco_dt = coco_gt.loadRes(all_results) coco_eval = COCOeval(coco_gt, coco_dt, "bbox") coco_eval.params.imgIds = sorted(coco_gt.getImgIds())[:min(max_images, len(val))] coco_eval.evaluate(); coco_eval.accumulate(); coco_eval.summarize() print(f"\nMFD (zero training): mAP@[.5:.95]={coco_eval.stats[0]:.4f} " f"mAP@.50={coco_eval.stats[1]:.4f} mAP@.75={coco_eval.stats[2]:.4f}") else: print("No detections") print(f"\n Field projection: {head.field_proj.weight.numel()} params (PCA-initialized)") print(f" Classification: {head.cls_prototypes.numel()} params (analytical)") print(f" Hessian + peak detection + Morse box extraction: 0 params") if __name__ == "__main__": main()