| """ |
| Morse Field Detection (MFD). |
| |
| Project 768-dim patch tokens to K scalar potential fields via a linear projection. |
| Objects are peaks in the potential landscape. Bounding boxes from Hessian eigenvalues |
| via the Morse lemma. No box regression — boxes emerge from curvature. |
| |
| Learned: W_psi (768 x K) projection + class prototypes |
| Fixed: Hessian (finite-difference kernels), peak detection, box extraction |
| """ |
|
|
| import json, os, sys, time, math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.insert(0, SCRIPT_DIR) |
|
|
| COCO_ROOT = os.environ.get("ARENA_COCO_ROOT") |
| VAL_CACHE = os.environ.get("ARENA_VAL_CACHE") |
| CACHE_DIR = os.environ.get("ARENA_CACHE_DIR") |
| DEVICE = "cuda" |
| RESOLUTION = 640 |
| NUM_CLASSES = 80 |
| STRIDE = 16 |
|
|
|
|
| def cofiber_decompose(f, n_scales): |
| cofibers = []; residual = f |
| for _ in range(n_scales - 1): |
| omega = F.avg_pool2d(residual, 2) |
| sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False) |
| cofibers.append(residual - sigma_omega); residual = omega |
| cofibers.append(residual); return cofibers |
|
|
|
|
| class MorseFieldDetector(nn.Module): |
| """Morse theory detection head. Boxes from curvature, not regression.""" |
|
|
| def __init__(self, feat_dim=768, n_fields=3, num_classes=80): |
| super().__init__() |
| |
| self.field_proj = nn.Linear(feat_dim, n_fields, bias=False) |
| |
| self.cls_prototypes = nn.Parameter(torch.randn(num_classes, feat_dim) * 0.01) |
| self.cls_bias = nn.Parameter(torch.zeros(num_classes)) |
|
|
| |
| |
| hxx = torch.tensor([[0, 0, 0], [1, -2, 1], [0, 0, 0]], dtype=torch.float32).unsqueeze(0).unsqueeze(0) |
| |
| hyy = torch.tensor([[0, 1, 0], [0, -2, 0], [0, 1, 0]], dtype=torch.float32).unsqueeze(0).unsqueeze(0) |
| |
| hxy = torch.tensor([[1, 0, -1], [0, 0, 0], [-1, 0, 1]], dtype=torch.float32).unsqueeze(0).unsqueeze(0) * 0.25 |
| self.register_buffer("hxx_kernel", hxx) |
| self.register_buffer("hyy_kernel", hyy) |
| self.register_buffer("hxy_kernel", hxy) |
|
|
| def compute_hessian(self, field): |
| """Compute Hessian components of a scalar field. field: (B, 1, H, W).""" |
| fxx = F.conv2d(field, self.hxx_kernel, padding=1) |
| fyy = F.conv2d(field, self.hyy_kernel, padding=1) |
| fxy = F.conv2d(field, self.hxy_kernel, padding=1) |
| return fxx, fyy, fxy |
|
|
| def forward_detect(self, spatial, scale=1.0): |
| """Run detection on one image. Returns boxes, scores, classes.""" |
| B, C, H, W = spatial.shape |
| assert B == 1 |
|
|
| f = F.layer_norm(spatial.permute(0, 2, 3, 1).reshape(-1, C), [C]) |
|
|
| |
| fields = self.field_proj(f).reshape(1, H, W, -1).permute(0, 3, 1, 2) |
|
|
| |
| cls_scores = (f @ self.cls_prototypes.T + self.cls_bias).reshape(1, H, W, -1).permute(0, 3, 1, 2) |
|
|
| all_boxes = [] |
| all_scores = [] |
| all_classes = [] |
|
|
| n_fields = fields.shape[1] |
| for k in range(n_fields): |
| field = fields[:, k:k+1] |
|
|
| |
| fxx, fyy, fxy = self.compute_hessian(field) |
|
|
| |
| det_H = fxx * fyy - fxy * fxy |
| tr_H = fxx + fyy |
|
|
| |
| objectness = torch.sigmoid(det_H * 10) * torch.sigmoid(-tr_H * 10) |
| objectness = objectness.squeeze(0).squeeze(0) |
|
|
| |
| psi = field.squeeze(0).squeeze(0) |
|
|
| |
| |
| obj_padded = objectness.unsqueeze(0).unsqueeze(0) |
| local_max = F.max_pool2d(obj_padded, 3, stride=1, padding=1).squeeze() |
| is_peak = (objectness == local_max) & (objectness > 0.3) |
|
|
| peak_locs = is_peak.nonzero(as_tuple=False) |
|
|
| for pi in range(min(len(peak_locs), 50)): |
| r, c = peak_locs[pi] |
| ri, ci = r.item(), c.item() |
|
|
| |
| h11 = fxx[0, 0, ri, ci].item() |
| h22 = fyy[0, 0, ri, ci].item() |
| h12 = fxy[0, 0, ri, ci].item() |
| psi_val = max(psi[ri, ci].item(), 0.01) |
|
|
| |
| neg_tr = -(h11 + h22) |
| discriminant = (h11 - h22) ** 2 + 4 * h12 ** 2 |
| sqrt_disc = math.sqrt(max(discriminant, 0)) |
| lam1 = (neg_tr + sqrt_disc) / 2 |
| lam2 = (neg_tr - sqrt_disc) / 2 |
|
|
| if lam1 <= 0 or lam2 <= 0: |
| continue |
|
|
| |
| |
| box_w = 2 * math.sqrt(psi_val / lam2) * STRIDE |
| box_h = 2 * math.sqrt(psi_val / lam1) * STRIDE |
|
|
| |
| cx = (ci + 0.5) * STRIDE |
| cy = (ri + 0.5) * STRIDE |
|
|
| x1 = (cx - box_w / 2) / scale |
| y1 = (cy - box_h / 2) / scale |
| w = box_w / scale |
| h = box_h / scale |
|
|
| if w < 1 or h < 1: |
| continue |
|
|
| |
| cls = cls_scores[0, :, ri, ci] |
| cls_score, cls_idx = cls.sigmoid().max(0) |
|
|
| score = objectness[ri, ci].item() * cls_score.item() |
| if score < 0.01: |
| continue |
|
|
| all_boxes.append([x1, y1, w, h]) |
| all_scores.append(score) |
| all_classes.append(cls_idx.item()) |
|
|
| return all_boxes, all_scores, all_classes |
|
|
|
|
| def main(): |
| print("=" * 60) |
| print("Morse Field Detection (MFD)") |
| print("=" * 60, flush=True) |
|
|
| head = MorseFieldDetector().to(DEVICE) |
| n_params = sum(p.numel() for p in head.parameters()) |
| print(f" {n_params:,} params ({head.field_proj.weight.numel()} projection + " |
| f"{head.cls_prototypes.numel() + head.cls_bias.numel()} classification)") |
|
|
| |
| analytical_path = os.path.join(SCRIPT_DIR, "heads", "cofiber_threshold", |
| "analytical_70k", "analytical_head_70k.pth") |
| if os.path.isfile(analytical_path): |
| ckpt = torch.load(analytical_path, map_location=DEVICE, weights_only=False) |
| head.cls_prototypes.data = ckpt["cls_weight"].to(DEVICE) |
| head.cls_bias.data = ckpt["cls_bias"].to(DEVICE) |
| print(" Loaded analytical class prototypes") |
|
|
| |
| |
| print(" Computing PCA for field projection init...", flush=True) |
| manifest = json.load(open(os.path.join(CACHE_DIR, "manifest.json"))) |
| shard = torch.load(os.path.join(CACHE_DIR, "shard_0000.pt"), map_location="cpu", weights_only=False) |
| sample_features = [] |
| for item in shard[:100]: |
| sp = item["spatial"].unsqueeze(0).float() |
| f = F.layer_norm(sp.permute(0, 2, 3, 1).reshape(-1, 768), [768]) |
| sample_features.append(f) |
| sample_f = torch.cat(sample_features) |
| _, _, Vh = torch.linalg.svd(sample_f[:10000] - sample_f[:10000].mean(0), full_matrices=False) |
| head.field_proj.weight.data = Vh[:3].to(DEVICE) |
| print(f" PCA-initialized field projection") |
| del shard, sample_features, sample_f |
|
|
| |
| val = torch.load(VAL_CACHE, map_location="cpu", weights_only=False) |
| from pycocotools.coco import COCO |
| from pycocotools.cocoeval import COCOeval |
| coco_gt = COCO(os.path.join(COCO_ROOT, "annotations", "instances_val2017.json")) |
| cat_ids = sorted(coco_gt.getCatIds()) |
| idx_to_cat = {i: c for i, c in enumerate(cat_ids)} |
|
|
| max_images = 1000 |
| all_results = [] |
| t0 = time.time() |
|
|
| head.eval() |
| with torch.no_grad(): |
| for idx in range(min(max_images, len(val))): |
| item = val[idx] |
| spatial = item["spatial"].unsqueeze(0).float().to(DEVICE) |
| img_id = int(item["img_id"]) |
| scale = item["scale"] |
|
|
| boxes, scores, classes = head.forward_detect(spatial, scale) |
|
|
| for b, s, c in zip(boxes, scores, classes): |
| all_results.append({ |
| "image_id": img_id, |
| "category_id": idx_to_cat[c], |
| "bbox": b, |
| "score": s, |
| }) |
|
|
| if (idx + 1) % 200 == 0: |
| elapsed = time.time() - t0 |
| print(f" {idx+1}/{max_images} ({elapsed:.0f}s, {len(all_results)} dets)", flush=True) |
|
|
| print(f"\n{len(all_results)} total detections ({time.time()-t0:.0f}s)") |
|
|
| if all_results: |
| coco_dt = coco_gt.loadRes(all_results) |
| coco_eval = COCOeval(coco_gt, coco_dt, "bbox") |
| coco_eval.params.imgIds = sorted(coco_gt.getImgIds())[:min(max_images, len(val))] |
| coco_eval.evaluate(); coco_eval.accumulate(); coco_eval.summarize() |
| print(f"\nMFD (zero training): mAP@[.5:.95]={coco_eval.stats[0]:.4f} " |
| f"mAP@.50={coco_eval.stats[1]:.4f} mAP@.75={coco_eval.stats[2]:.4f}") |
| else: |
| print("No detections") |
|
|
| print(f"\n Field projection: {head.field_proj.weight.numel()} params (PCA-initialized)") |
| print(f" Classification: {head.cls_prototypes.numel()} params (analytical)") |
| print(f" Hessian + peak detection + Morse box extraction: 0 params") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|