cofiber-detection / scripts /morse_detector.py
phanerozoic's picture
update repository
dbbceb8
"""
Morse Field Detection (MFD).
Project 768-dim patch tokens to K scalar potential fields via a linear projection.
Objects are peaks in the potential landscape. Bounding boxes from Hessian eigenvalues
via the Morse lemma. No box regression — boxes emerge from curvature.
Learned: W_psi (768 x K) projection + class prototypes
Fixed: Hessian (finite-difference kernels), peak detection, box extraction
"""
import json, os, sys, time, math
import torch
import torch.nn as nn
import torch.nn.functional as F
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, SCRIPT_DIR)
COCO_ROOT = os.environ.get("ARENA_COCO_ROOT")
VAL_CACHE = os.environ.get("ARENA_VAL_CACHE")
CACHE_DIR = os.environ.get("ARENA_CACHE_DIR")
DEVICE = "cuda"
RESOLUTION = 640
NUM_CLASSES = 80
STRIDE = 16
def cofiber_decompose(f, n_scales):
cofibers = []; residual = f
for _ in range(n_scales - 1):
omega = F.avg_pool2d(residual, 2)
sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
cofibers.append(residual - sigma_omega); residual = omega
cofibers.append(residual); return cofibers
class MorseFieldDetector(nn.Module):
"""Morse theory detection head. Boxes from curvature, not regression."""
def __init__(self, feat_dim=768, n_fields=3, num_classes=80):
super().__init__()
# The only learned spatial component: project features to scalar fields
self.field_proj = nn.Linear(feat_dim, n_fields, bias=False)
# Class prototypes for classification at detected peaks
self.cls_prototypes = nn.Parameter(torch.randn(num_classes, feat_dim) * 0.01)
self.cls_bias = nn.Parameter(torch.zeros(num_classes))
# Fixed finite-difference Hessian kernels (Sobel-style)
# d²f/dx²
hxx = torch.tensor([[0, 0, 0], [1, -2, 1], [0, 0, 0]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
# d²f/dy²
hyy = torch.tensor([[0, 1, 0], [0, -2, 0], [0, 1, 0]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
# d²f/dxdy
hxy = torch.tensor([[1, 0, -1], [0, 0, 0], [-1, 0, 1]], dtype=torch.float32).unsqueeze(0).unsqueeze(0) * 0.25
self.register_buffer("hxx_kernel", hxx)
self.register_buffer("hyy_kernel", hyy)
self.register_buffer("hxy_kernel", hxy)
def compute_hessian(self, field):
"""Compute Hessian components of a scalar field. field: (B, 1, H, W)."""
fxx = F.conv2d(field, self.hxx_kernel, padding=1)
fyy = F.conv2d(field, self.hyy_kernel, padding=1)
fxy = F.conv2d(field, self.hxy_kernel, padding=1)
return fxx, fyy, fxy
def forward_detect(self, spatial, scale=1.0):
"""Run detection on one image. Returns boxes, scores, classes."""
B, C, H, W = spatial.shape
assert B == 1
f = F.layer_norm(spatial.permute(0, 2, 3, 1).reshape(-1, C), [C]) # (H*W, 768)
# Project to scalar potential fields
fields = self.field_proj(f).reshape(1, H, W, -1).permute(0, 3, 1, 2) # (1, K, H, W)
# Classification scores at every location
cls_scores = (f @ self.cls_prototypes.T + self.cls_bias).reshape(1, H, W, -1).permute(0, 3, 1, 2) # (1, 80, H, W)
all_boxes = []
all_scores = []
all_classes = []
n_fields = fields.shape[1]
for k in range(n_fields):
field = fields[:, k:k+1] # (1, 1, H, W)
# Compute Hessian
fxx, fyy, fxy = self.compute_hessian(field)
# Determinant and trace of Hessian
det_H = fxx * fyy - fxy * fxy # (1, 1, H, W)
tr_H = fxx + fyy
# Objectness: peak = det(H) > 0 AND tr(H) < 0 (local maximum)
objectness = torch.sigmoid(det_H * 10) * torch.sigmoid(-tr_H * 10)
objectness = objectness.squeeze(0).squeeze(0) # (H, W)
# Field values at each location
psi = field.squeeze(0).squeeze(0) # (H, W)
# Find peaks: local maxima of objectness
# Use max_pool to find locations that are local maxima
obj_padded = objectness.unsqueeze(0).unsqueeze(0)
local_max = F.max_pool2d(obj_padded, 3, stride=1, padding=1).squeeze()
is_peak = (objectness == local_max) & (objectness > 0.3)
peak_locs = is_peak.nonzero(as_tuple=False) # (M, 2) — row, col
for pi in range(min(len(peak_locs), 50)):
r, c = peak_locs[pi]
ri, ci = r.item(), c.item()
# Hessian eigenvalues at this peak
h11 = fxx[0, 0, ri, ci].item()
h22 = fyy[0, 0, ri, ci].item()
h12 = fxy[0, 0, ri, ci].item()
psi_val = max(psi[ri, ci].item(), 0.01)
# Eigenvalues of -H (should be positive at a maximum)
neg_tr = -(h11 + h22)
discriminant = (h11 - h22) ** 2 + 4 * h12 ** 2
sqrt_disc = math.sqrt(max(discriminant, 0))
lam1 = (neg_tr + sqrt_disc) / 2
lam2 = (neg_tr - sqrt_disc) / 2
if lam1 <= 0 or lam2 <= 0:
continue
# Morse lemma: box dimensions from curvature
# w = 2 * sqrt(psi / lam2), h = 2 * sqrt(psi / lam1)
box_w = 2 * math.sqrt(psi_val / lam2) * STRIDE
box_h = 2 * math.sqrt(psi_val / lam1) * STRIDE
# Box center in pixel coords
cx = (ci + 0.5) * STRIDE
cy = (ri + 0.5) * STRIDE
x1 = (cx - box_w / 2) / scale
y1 = (cy - box_h / 2) / scale
w = box_w / scale
h = box_h / scale
if w < 1 or h < 1:
continue
# Classification at this location
cls = cls_scores[0, :, ri, ci]
cls_score, cls_idx = cls.sigmoid().max(0)
score = objectness[ri, ci].item() * cls_score.item()
if score < 0.01:
continue
all_boxes.append([x1, y1, w, h])
all_scores.append(score)
all_classes.append(cls_idx.item())
return all_boxes, all_scores, all_classes
def main():
print("=" * 60)
print("Morse Field Detection (MFD)")
print("=" * 60, flush=True)
head = MorseFieldDetector().to(DEVICE)
n_params = sum(p.numel() for p in head.parameters())
print(f" {n_params:,} params ({head.field_proj.weight.numel()} projection + "
f"{head.cls_prototypes.numel() + head.cls_bias.numel()} classification)")
# Initialize class prototypes from analytical solution
analytical_path = os.path.join(SCRIPT_DIR, "heads", "cofiber_threshold",
"analytical_70k", "analytical_head_70k.pth")
if os.path.isfile(analytical_path):
ckpt = torch.load(analytical_path, map_location=DEVICE, weights_only=False)
head.cls_prototypes.data = ckpt["cls_weight"].to(DEVICE)
head.cls_bias.data = ckpt["cls_bias"].to(DEVICE)
print(" Loaded analytical class prototypes")
# Initialize field projection from PCA of features
# Use first 3 PCA components — the "smooth scalar fields" the paper shows
print(" Computing PCA for field projection init...", flush=True)
manifest = json.load(open(os.path.join(CACHE_DIR, "manifest.json")))
shard = torch.load(os.path.join(CACHE_DIR, "shard_0000.pt"), map_location="cpu", weights_only=False)
sample_features = []
for item in shard[:100]:
sp = item["spatial"].unsqueeze(0).float()
f = F.layer_norm(sp.permute(0, 2, 3, 1).reshape(-1, 768), [768])
sample_features.append(f)
sample_f = torch.cat(sample_features)
_, _, Vh = torch.linalg.svd(sample_f[:10000] - sample_f[:10000].mean(0), full_matrices=False)
head.field_proj.weight.data = Vh[:3].to(DEVICE)
print(f" PCA-initialized field projection")
del shard, sample_features, sample_f
# Eval (no training — pure derived detection)
val = torch.load(VAL_CACHE, map_location="cpu", weights_only=False)
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
coco_gt = COCO(os.path.join(COCO_ROOT, "annotations", "instances_val2017.json"))
cat_ids = sorted(coco_gt.getCatIds())
idx_to_cat = {i: c for i, c in enumerate(cat_ids)}
max_images = 1000
all_results = []
t0 = time.time()
head.eval()
with torch.no_grad():
for idx in range(min(max_images, len(val))):
item = val[idx]
spatial = item["spatial"].unsqueeze(0).float().to(DEVICE)
img_id = int(item["img_id"])
scale = item["scale"]
boxes, scores, classes = head.forward_detect(spatial, scale)
for b, s, c in zip(boxes, scores, classes):
all_results.append({
"image_id": img_id,
"category_id": idx_to_cat[c],
"bbox": b,
"score": s,
})
if (idx + 1) % 200 == 0:
elapsed = time.time() - t0
print(f" {idx+1}/{max_images} ({elapsed:.0f}s, {len(all_results)} dets)", flush=True)
print(f"\n{len(all_results)} total detections ({time.time()-t0:.0f}s)")
if all_results:
coco_dt = coco_gt.loadRes(all_results)
coco_eval = COCOeval(coco_gt, coco_dt, "bbox")
coco_eval.params.imgIds = sorted(coco_gt.getImgIds())[:min(max_images, len(val))]
coco_eval.evaluate(); coco_eval.accumulate(); coco_eval.summarize()
print(f"\nMFD (zero training): mAP@[.5:.95]={coco_eval.stats[0]:.4f} "
f"mAP@.50={coco_eval.stats[1]:.4f} mAP@.75={coco_eval.stats[2]:.4f}")
else:
print("No detections")
print(f"\n Field projection: {head.field_proj.weight.numel()} params (PCA-initialized)")
print(f" Classification: {head.cls_prototypes.numel()} params (analytical)")
print(f" Hessian + peak detection + Morse box extraction: 0 params")
if __name__ == "__main__":
main()