File size: 10,164 Bytes
dbbceb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
"""
Morse Field Detection (MFD).

Project 768-dim patch tokens to K scalar potential fields via a linear projection.
Objects are peaks in the potential landscape. Bounding boxes from Hessian eigenvalues
via the Morse lemma. No box regression — boxes emerge from curvature.

Learned: W_psi (768 x K) projection + class prototypes
Fixed: Hessian (finite-difference kernels), peak detection, box extraction
"""

import json, os, sys, time, math
import torch
import torch.nn as nn
import torch.nn.functional as F

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, SCRIPT_DIR)

COCO_ROOT = os.environ.get("ARENA_COCO_ROOT")
VAL_CACHE = os.environ.get("ARENA_VAL_CACHE")
CACHE_DIR = os.environ.get("ARENA_CACHE_DIR")
DEVICE = "cuda"
RESOLUTION = 640
NUM_CLASSES = 80
STRIDE = 16


def cofiber_decompose(f, n_scales):
    cofibers = []; residual = f
    for _ in range(n_scales - 1):
        omega = F.avg_pool2d(residual, 2)
        sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
        cofibers.append(residual - sigma_omega); residual = omega
    cofibers.append(residual); return cofibers


class MorseFieldDetector(nn.Module):
    """Morse theory detection head. Boxes from curvature, not regression."""

    def __init__(self, feat_dim=768, n_fields=3, num_classes=80):
        super().__init__()
        # The only learned spatial component: project features to scalar fields
        self.field_proj = nn.Linear(feat_dim, n_fields, bias=False)
        # Class prototypes for classification at detected peaks
        self.cls_prototypes = nn.Parameter(torch.randn(num_classes, feat_dim) * 0.01)
        self.cls_bias = nn.Parameter(torch.zeros(num_classes))

        # Fixed finite-difference Hessian kernels (Sobel-style)
        # d²f/dx²
        hxx = torch.tensor([[0, 0, 0], [1, -2, 1], [0, 0, 0]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        # d²f/dy²
        hyy = torch.tensor([[0, 1, 0], [0, -2, 0], [0, 1, 0]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        # d²f/dxdy
        hxy = torch.tensor([[1, 0, -1], [0, 0, 0], [-1, 0, 1]], dtype=torch.float32).unsqueeze(0).unsqueeze(0) * 0.25
        self.register_buffer("hxx_kernel", hxx)
        self.register_buffer("hyy_kernel", hyy)
        self.register_buffer("hxy_kernel", hxy)

    def compute_hessian(self, field):
        """Compute Hessian components of a scalar field. field: (B, 1, H, W)."""
        fxx = F.conv2d(field, self.hxx_kernel, padding=1)
        fyy = F.conv2d(field, self.hyy_kernel, padding=1)
        fxy = F.conv2d(field, self.hxy_kernel, padding=1)
        return fxx, fyy, fxy

    def forward_detect(self, spatial, scale=1.0):
        """Run detection on one image. Returns boxes, scores, classes."""
        B, C, H, W = spatial.shape
        assert B == 1

        f = F.layer_norm(spatial.permute(0, 2, 3, 1).reshape(-1, C), [C])  # (H*W, 768)

        # Project to scalar potential fields
        fields = self.field_proj(f).reshape(1, H, W, -1).permute(0, 3, 1, 2)  # (1, K, H, W)

        # Classification scores at every location
        cls_scores = (f @ self.cls_prototypes.T + self.cls_bias).reshape(1, H, W, -1).permute(0, 3, 1, 2)  # (1, 80, H, W)

        all_boxes = []
        all_scores = []
        all_classes = []

        n_fields = fields.shape[1]
        for k in range(n_fields):
            field = fields[:, k:k+1]  # (1, 1, H, W)

            # Compute Hessian
            fxx, fyy, fxy = self.compute_hessian(field)

            # Determinant and trace of Hessian
            det_H = fxx * fyy - fxy * fxy  # (1, 1, H, W)
            tr_H = fxx + fyy

            # Objectness: peak = det(H) > 0 AND tr(H) < 0 (local maximum)
            objectness = torch.sigmoid(det_H * 10) * torch.sigmoid(-tr_H * 10)
            objectness = objectness.squeeze(0).squeeze(0)  # (H, W)

            # Field values at each location
            psi = field.squeeze(0).squeeze(0)  # (H, W)

            # Find peaks: local maxima of objectness
            # Use max_pool to find locations that are local maxima
            obj_padded = objectness.unsqueeze(0).unsqueeze(0)
            local_max = F.max_pool2d(obj_padded, 3, stride=1, padding=1).squeeze()
            is_peak = (objectness == local_max) & (objectness > 0.3)

            peak_locs = is_peak.nonzero(as_tuple=False)  # (M, 2) — row, col

            for pi in range(min(len(peak_locs), 50)):
                r, c = peak_locs[pi]
                ri, ci = r.item(), c.item()

                # Hessian eigenvalues at this peak
                h11 = fxx[0, 0, ri, ci].item()
                h22 = fyy[0, 0, ri, ci].item()
                h12 = fxy[0, 0, ri, ci].item()
                psi_val = max(psi[ri, ci].item(), 0.01)

                # Eigenvalues of -H (should be positive at a maximum)
                neg_tr = -(h11 + h22)
                discriminant = (h11 - h22) ** 2 + 4 * h12 ** 2
                sqrt_disc = math.sqrt(max(discriminant, 0))
                lam1 = (neg_tr + sqrt_disc) / 2
                lam2 = (neg_tr - sqrt_disc) / 2

                if lam1 <= 0 or lam2 <= 0:
                    continue

                # Morse lemma: box dimensions from curvature
                # w = 2 * sqrt(psi / lam2), h = 2 * sqrt(psi / lam1)
                box_w = 2 * math.sqrt(psi_val / lam2) * STRIDE
                box_h = 2 * math.sqrt(psi_val / lam1) * STRIDE

                # Box center in pixel coords
                cx = (ci + 0.5) * STRIDE
                cy = (ri + 0.5) * STRIDE

                x1 = (cx - box_w / 2) / scale
                y1 = (cy - box_h / 2) / scale
                w = box_w / scale
                h = box_h / scale

                if w < 1 or h < 1:
                    continue

                # Classification at this location
                cls = cls_scores[0, :, ri, ci]
                cls_score, cls_idx = cls.sigmoid().max(0)

                score = objectness[ri, ci].item() * cls_score.item()
                if score < 0.01:
                    continue

                all_boxes.append([x1, y1, w, h])
                all_scores.append(score)
                all_classes.append(cls_idx.item())

        return all_boxes, all_scores, all_classes


def main():
    print("=" * 60)
    print("Morse Field Detection (MFD)")
    print("=" * 60, flush=True)

    head = MorseFieldDetector().to(DEVICE)
    n_params = sum(p.numel() for p in head.parameters())
    print(f"  {n_params:,} params ({head.field_proj.weight.numel()} projection + "
          f"{head.cls_prototypes.numel() + head.cls_bias.numel()} classification)")

    # Initialize class prototypes from analytical solution
    analytical_path = os.path.join(SCRIPT_DIR, "heads", "cofiber_threshold",
                                    "analytical_70k", "analytical_head_70k.pth")
    if os.path.isfile(analytical_path):
        ckpt = torch.load(analytical_path, map_location=DEVICE, weights_only=False)
        head.cls_prototypes.data = ckpt["cls_weight"].to(DEVICE)
        head.cls_bias.data = ckpt["cls_bias"].to(DEVICE)
        print("  Loaded analytical class prototypes")

    # Initialize field projection from PCA of features
    # Use first 3 PCA components — the "smooth scalar fields" the paper shows
    print("  Computing PCA for field projection init...", flush=True)
    manifest = json.load(open(os.path.join(CACHE_DIR, "manifest.json")))
    shard = torch.load(os.path.join(CACHE_DIR, "shard_0000.pt"), map_location="cpu", weights_only=False)
    sample_features = []
    for item in shard[:100]:
        sp = item["spatial"].unsqueeze(0).float()
        f = F.layer_norm(sp.permute(0, 2, 3, 1).reshape(-1, 768), [768])
        sample_features.append(f)
    sample_f = torch.cat(sample_features)
    _, _, Vh = torch.linalg.svd(sample_f[:10000] - sample_f[:10000].mean(0), full_matrices=False)
    head.field_proj.weight.data = Vh[:3].to(DEVICE)
    print(f"  PCA-initialized field projection")
    del shard, sample_features, sample_f

    # Eval (no training — pure derived detection)
    val = torch.load(VAL_CACHE, map_location="cpu", weights_only=False)
    from pycocotools.coco import COCO
    from pycocotools.cocoeval import COCOeval
    coco_gt = COCO(os.path.join(COCO_ROOT, "annotations", "instances_val2017.json"))
    cat_ids = sorted(coco_gt.getCatIds())
    idx_to_cat = {i: c for i, c in enumerate(cat_ids)}

    max_images = 1000
    all_results = []
    t0 = time.time()

    head.eval()
    with torch.no_grad():
        for idx in range(min(max_images, len(val))):
            item = val[idx]
            spatial = item["spatial"].unsqueeze(0).float().to(DEVICE)
            img_id = int(item["img_id"])
            scale = item["scale"]

            boxes, scores, classes = head.forward_detect(spatial, scale)

            for b, s, c in zip(boxes, scores, classes):
                all_results.append({
                    "image_id": img_id,
                    "category_id": idx_to_cat[c],
                    "bbox": b,
                    "score": s,
                })

            if (idx + 1) % 200 == 0:
                elapsed = time.time() - t0
                print(f"  {idx+1}/{max_images} ({elapsed:.0f}s, {len(all_results)} dets)", flush=True)

    print(f"\n{len(all_results)} total detections ({time.time()-t0:.0f}s)")

    if all_results:
        coco_dt = coco_gt.loadRes(all_results)
        coco_eval = COCOeval(coco_gt, coco_dt, "bbox")
        coco_eval.params.imgIds = sorted(coco_gt.getImgIds())[:min(max_images, len(val))]
        coco_eval.evaluate(); coco_eval.accumulate(); coco_eval.summarize()
        print(f"\nMFD (zero training): mAP@[.5:.95]={coco_eval.stats[0]:.4f} "
              f"mAP@.50={coco_eval.stats[1]:.4f} mAP@.75={coco_eval.stats[2]:.4f}")
    else:
        print("No detections")

    print(f"\n  Field projection: {head.field_proj.weight.numel()} params (PCA-initialized)")
    print(f"  Classification: {head.cls_prototypes.numel()} params (analytical)")
    print(f"  Hessian + peak detection + Morse box extraction: 0 params")


if __name__ == "__main__":
    main()