cofiber-detection / analytical /scripts /analytical_best_gpu.py
phanerozoic's picture
update repository
dbbceb8
"""
Build the best analytical head from our findings and run full mAP eval.
Classification: 768 raw LayerNorm'd features (69.6% accuracy)
Regression: 768 raw + H^1 vertical + H^1 horizontal boundary features (68.7% quality)
Centerness: 768 raw features
Accumulate on training data, solve, save checkpoint, eval via eval_coco_map.py.
"""
import json
import os
import sys
import time
import torch
import torch.nn.functional as F
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, SCRIPT_DIR)
CACHE_DIR = os.environ.get("ARENA_CACHE_DIR", "feature_cache")
COCO_ROOT = os.environ.get("ARENA_COCO_ROOT", "coco")
VAL_CACHE = os.environ.get("ARENA_VAL_CACHE", "val_cache/val.pt")
RESOLUTION = 640
NUM_CLASSES = 80
DEVICE = "cuda"
def cofiber_decompose(f, n_scales):
cofibers = []
residual = f
for _ in range(n_scales - 1):
omega = F.avg_pool2d(residual, 2)
sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
cofibers.append(residual - sigma_omega)
residual = omega
cofibers.append(residual)
return cofibers
def make_locations(sizes, strides):
locs = []
for (h, w), s in zip(sizes, strides):
ys = (torch.arange(h, dtype=torch.float32) + 0.5) * s
xs = (torch.arange(w, dtype=torch.float32) + 0.5) * s
gy, gx = torch.meshgrid(ys, xs, indexing="ij")
locs.append(torch.stack([gx.flatten(), gy.flatten()], -1))
return locs
def assign_targets(loc, boxes, labels, stride, sr):
n = loc.shape[0]
ct = torch.full((n,), -1, dtype=torch.long)
rt = torch.zeros(n, 4)
ctrt = torch.zeros(n)
if boxes.numel() == 0:
return ct, rt, ctrt
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
l = loc[:, None, 0] - boxes[None, :, 0]
t = loc[:, None, 1] - boxes[None, :, 1]
r = boxes[None, :, 2] - loc[:, None, 0]
b = boxes[None, :, 3] - loc[:, None, 1]
ltrb = torch.stack([l, t, r, b], -1)
in_box = ltrb.min(-1).values > 0
cx = (boxes[:, 0] + boxes[:, 2]) / 2
cy = (boxes[:, 1] + boxes[:, 3]) / 2
rad = stride * 1.5
in_center = ((loc[:, None, 0] >= cx - rad) & (loc[:, None, 0] <= cx + rad) &
(loc[:, None, 1] >= cy - rad) & (loc[:, None, 1] <= cy + rad))
max_d = ltrb.max(-1).values
in_level = (max_d >= sr[0]) & (max_d <= sr[1])
pos = in_box & in_center & in_level
a = areas[None, :].expand_as(pos).clone()
a[~pos] = float("inf")
matched = a.argmin(1)
is_pos = a.gather(1, matched[:, None]).squeeze(1) < float("inf")
ct[is_pos] = labels[matched[is_pos]]
if is_pos.any():
rt[is_pos] = ltrb[torch.arange(n)[is_pos], matched[is_pos]]
lp, tp, rp, bp = rt[is_pos].unbind(-1)
ctrt[is_pos] = torch.sqrt(
(torch.minimum(lp, rp) / torch.maximum(lp, rp).clamp(min=1e-6)) *
(torch.minimum(tp, bp) / torch.maximum(tp, bp).clamp(min=1e-6)))
return ct, rt, ctrt
def compute_h1(f, B, H, W, C):
"""Sheaf H^1 compact: vertical + horizontal boundary magnitudes."""
f_4d = f.reshape(B, H, W, C).permute(0, 3, 1, 2)
d_up = f_4d - F.pad(f_4d[:, :, 1:, :], (0, 0, 0, 1))
d_down = f_4d - F.pad(f_4d[:, :, :-1, :], (0, 0, 1, 0))
d_left = f_4d - F.pad(f_4d[:, :, :, 1:], (0, 1, 0, 0))
d_right = f_4d - F.pad(f_4d[:, :, :, :-1], (1, 0, 0, 0))
v_bound = (d_up.abs() + d_down.abs()).permute(0, 2, 3, 1).reshape(-1, C)
h_bound = (d_left.abs() + d_right.abs()).permute(0, 2, 3, 1).reshape(-1, C)
return v_bound, h_bound
def main():
print("=" * 60)
print("Best Analytical Head: 768 cls + H^1 regression")
print("=" * 60, flush=True)
manifest = json.load(open(os.path.join(CACHE_DIR, "manifest.json")))
n_shards = manifest["n_shards"]
strides = [16, 32, 64]
H = RESOLUTION // 16
sizes = [(H, H), (H // 2, H // 2), (H // 4, H // 4)]
sr = [(-1, 128), (128, 256), (256, float("inf"))]
locs = make_locations(sizes, strides)
feat_dim = 768
reg_dim = 768 * 3 # raw + h1v + h1h
# Accumulators
cls_XtX = torch.zeros(feat_dim + 1, feat_dim + 1, device=DEVICE)
cls_XtY = torch.zeros(feat_dim + 1, NUM_CLASSES, device=DEVICE)
reg_XtX = torch.zeros(reg_dim + 1, reg_dim + 1, device=DEVICE)
reg_XtY = torch.zeros(reg_dim + 1, 4, device=DEVICE)
ctr_XtX = torch.zeros(feat_dim + 1, feat_dim + 1, device=DEVICE)
ctr_XtY = torch.zeros(feat_dim + 1, 1, device=DEVICE)
n_pos = 0
n_images = 20000
seen = 0
t0 = time.time()
for si in range(n_shards):
if seen >= n_images:
break
shard = torch.load(os.path.join(CACHE_DIR, f"shard_{si:04d}.pt"),
map_location="cpu", weights_only=False)
for item in shard:
if seen >= n_images:
break
sp = item["spatial"].unsqueeze(0).float()
boxes = item["boxes"]
labels = item["labels"]
cofibers = cofiber_decompose(sp, 3)
for sci, cof in enumerate(cofibers):
B, C, Hc, Wc = cof.shape
f = F.layer_norm(cof.permute(0, 2, 3, 1).reshape(-1, C), [C])
h1v, h1h = compute_h1(f, B, Hc, Wc, C)
ct, rt, ctrt = assign_targets(locs[sci], boxes, labels, strides[sci], sr[sci])
pos_mask = ct >= 0
if not pos_mask.any():
continue
# Classification: raw features only
fp = f[pos_mask].to(DEVICE)
fa = torch.cat([fp, torch.ones(fp.shape[0], 1, device=DEVICE)], 1)
yc = torch.zeros(fp.shape[0], NUM_CLASSES, device=DEVICE)
yc[torch.arange(fp.shape[0], device=DEVICE), ct[pos_mask].to(DEVICE)] = 1.0
cls_XtX += fa.T @ fa
cls_XtY += fa.T @ yc
# Regression: raw + H^1
f_reg = torch.cat([f[pos_mask], h1v[pos_mask], h1h[pos_mask]], 1).to(DEVICE)
ltrb = rt[pos_mask]
valid = (ltrb > 0).all(1)
if valid.any():
fv = f_reg[valid]
fva = torch.cat([fv, torch.ones(fv.shape[0], 1, device=DEVICE)], 1)
yt = torch.log(ltrb[valid]).to(DEVICE)
reg_XtX += fva.T @ fva
reg_XtY += fva.T @ yt
# Centerness: raw features
ctr_XtX += fa.T @ fa
ctr_XtY += fa.T @ ctrt[pos_mask].unsqueeze(1).to(DEVICE)
n_pos += pos_mask.sum().item()
seen += 1
del shard
if (si + 1) % 5 == 0:
print(f" shard {si+1}: {seen} imgs, {n_pos} pos, {time.time()-t0:.0f}s", flush=True)
print(f"\nAccumulated {seen} images, {n_pos} positives", flush=True)
# Solve
lam = 0.1
I_cls = torch.eye(feat_dim + 1, device=DEVICE)
I_reg = torch.eye(reg_dim + 1, device=DEVICE)
I_ctr = torch.eye(feat_dim + 1, device=DEVICE)
cls_W = torch.linalg.solve(cls_XtX + lam * I_cls * n_pos, cls_XtY)
reg_W = torch.linalg.solve(reg_XtX + lam * I_reg * n_pos, reg_XtY)
ctr_W = torch.linalg.solve(ctr_XtX + lam * I_ctr * n_pos, ctr_XtY)
print(f"Solved. cls: {feat_dim}->80, reg: {reg_dim}->4, ctr: {feat_dim}->1", flush=True)
# Save as state dict
state = {
"cls_weight": cls_W[:feat_dim].T.cpu(),
"cls_bias": cls_W[feat_dim].cpu(),
"reg_weight": reg_W[:reg_dim].T.cpu(),
"reg_bias": reg_W[reg_dim].cpu(),
"ctr_weight": ctr_W[:feat_dim].T.cpu(),
"ctr_bias": ctr_W[feat_dim].cpu(),
"scale_norms.0.weight": torch.ones(768),
"scale_norms.0.bias": torch.zeros(768),
"scale_norms.1.weight": torch.ones(768),
"scale_norms.1.bias": torch.zeros(768),
"scale_norms.2.weight": torch.ones(768),
"scale_norms.2.bias": torch.zeros(768),
"scale_params": torch.ones(3),
"meta": {"cls_features": "768_layernorm",
"reg_features": "768_layernorm_h1v_h1h",
"ctr_features": "768_layernorm",
"lambda": lam, "n_images": seen, "n_pos": n_pos},
}
out_dir = os.path.join(SCRIPT_DIR, "heads", "cofiber_threshold", "analytical_h1")
os.makedirs(out_dir, exist_ok=True)
out_path = os.path.join(out_dir, "analytical_h1_best.pth")
torch.save(state, out_path)
n_params = sum(v.numel() for k, v in state.items() if isinstance(v, torch.Tensor))
elapsed = time.time() - t0
print(f"\nSaved: {out_path}")
print(f"Total params: {n_params:,}")
print(f"Construction time: {elapsed:.0f}s")
print(f"\nClassification: 768 dims, {feat_dim * NUM_CLASSES + NUM_CLASSES:,} params")
print(f"Regression: {reg_dim} dims (768+768+768), {reg_dim * 4 + 4:,} params")
print(f"Centerness: 768 dims, {feat_dim + 1:,} params")
if __name__ == "__main__":
main()