cofiber-detection / analytical /scripts /analytical_fractal_gpu.py
phanerozoic's picture
update repository
dbbceb8
"""
Fractal cofiber decomposition — wavelet packet style.
Instead of recursing only on the low-frequency residual (3 bands),
recurse on BOTH the cofiber and residual at each level.
Depth 1: 2 bands (standard single split)
Depth 2: 4 bands
Depth 3: 8 bands
Each band is 768 dims. Classification and regression are solved independently
per band, then results are merged. The solver picks which bands matter.
Or: concatenate all bands and solve one large system.
"""
import json, os, sys, time
import torch, torch.nn.functional as F
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, SCRIPT_DIR)
COCO_ROOT = os.environ.get("ARENA_COCO_ROOT")
VAL_CACHE = os.environ.get("ARENA_VAL_CACHE")
CACHE_DIR = os.environ.get("ARENA_CACHE_DIR")
DEVICE = "cuda"
RESOLUTION = 640
NUM_CLASSES = 80
def fractal_decompose(f, depth):
"""Fractal cofiber decomposition. Returns list of 2^depth feature maps."""
if depth == 0:
return [f]
omega = F.avg_pool2d(f, 2)
sigma_omega = F.interpolate(omega, size=f.shape[2:], mode="bilinear", align_corners=False)
cofiber = f - sigma_omega # high frequency at this scale
# Recurse on BOTH branches
high_bands = fractal_decompose(cofiber, depth - 1)
low_bands = fractal_decompose(omega, depth - 1)
return high_bands + low_bands
def standard_decompose(f, n_scales):
"""Standard cofiber: recurse only on residual."""
cofibers = []
residual = f
for _ in range(n_scales - 1):
omega = F.avg_pool2d(residual, 2)
sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
cofibers.append(residual - sigma_omega)
residual = omega
cofibers.append(residual)
return cofibers
def make_locations(sizes, strides):
locs = []
for (h, w), s in zip(sizes, strides):
ys = (torch.arange(h, dtype=torch.float32) + 0.5) * s
xs = (torch.arange(w, dtype=torch.float32) + 0.5) * s
gy, gx = torch.meshgrid(ys, xs, indexing="ij")
locs.append(torch.stack([gx.flatten(), gy.flatten()], -1))
return locs
def assign_targets(loc, boxes, labels, stride, sr):
n = loc.shape[0]
ct = torch.full((n,), -1, dtype=torch.long)
rt = torch.zeros(n, 4)
ctrt = torch.zeros(n)
if boxes.numel() == 0:
return ct, rt, ctrt
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
l = loc[:, None, 0] - boxes[None, :, 0]
t = loc[:, None, 1] - boxes[None, :, 1]
r = boxes[None, :, 2] - loc[:, None, 0]
b = boxes[None, :, 3] - loc[:, None, 1]
ltrb = torch.stack([l, t, r, b], -1)
in_box = ltrb.min(-1).values > 0
cx = (boxes[:, 0] + boxes[:, 2]) / 2
cy = (boxes[:, 1] + boxes[:, 3]) / 2
rad = stride * 1.5
in_center = ((loc[:, None, 0] >= cx - rad) & (loc[:, None, 0] <= cx + rad) &
(loc[:, None, 1] >= cy - rad) & (loc[:, None, 1] <= cy + rad))
max_d = ltrb.max(-1).values
in_level = (max_d >= sr[0]) & (max_d <= sr[1])
pos = in_box & in_center & in_level
a = areas[None, :].expand_as(pos).clone()
a[~pos] = float("inf")
matched = a.argmin(1)
is_pos = a.gather(1, matched[:, None]).squeeze(1) < float("inf")
ct[is_pos] = labels[matched[is_pos]]
if is_pos.any():
rt[is_pos] = ltrb[torch.arange(n)[is_pos], matched[is_pos]]
lp, tp, rp, bp = rt[is_pos].unbind(-1)
ctrt[is_pos] = torch.sqrt(
(torch.minimum(lp, rp) / torch.maximum(lp, rp).clamp(min=1e-6)) *
(torch.minimum(tp, bp) / torch.maximum(tp, bp).clamp(min=1e-6)))
return ct, rt, ctrt
def eval_decomposition(val, coco_gt, cat_ids, decompose_fn, name, lam=0.1, n_train=10000):
"""Accumulate, solve, and eval a decomposition variant.
All bands are upsampled to stride-16 resolution (40x40) and use the same
target assignment. The decomposition separates frequencies, not resolutions.
"""
idx_to_cat = {i: c for i, c in enumerate(cat_ids)}
H = RESOLUTION // 16
target_size = (H, H)
stride = 16
sr = (-1, float("inf")) # single scale, all object sizes
locs_flat = make_locations([target_size], [stride])
n_locs = H * H
manifest = json.load(open(os.path.join(CACHE_DIR, "manifest.json")))
feat_dim = 768
cls_XtX = torch.zeros(feat_dim + 1, feat_dim + 1, device=DEVICE)
cls_XtY = torch.zeros(feat_dim + 1, NUM_CLASSES, device=DEVICE)
reg_XtX = torch.zeros(feat_dim + 1, feat_dim + 1, device=DEVICE)
reg_XtY = torch.zeros(feat_dim + 1, 4, device=DEVICE)
ctr_XtX = torch.zeros(feat_dim + 1, feat_dim + 1, device=DEVICE)
ctr_XtY = torch.zeros(feat_dim + 1, 1, device=DEVICE)
n_pos = 0; seen = 0
t0 = time.time()
for si in range(manifest["n_shards"]):
if seen >= n_train: break
shard = torch.load(os.path.join(CACHE_DIR, f"shard_{si:04d}.pt"),
map_location="cpu", weights_only=False)
for item in shard:
if seen >= n_train: break
sp = item["spatial"].unsqueeze(0).float().to(DEVICE)
boxes = item["boxes"]; labels = item["labels"]
bands = decompose_fn(sp)
# Upsample all bands to 40x40, average them
upsampled = []
for band in bands:
if band.shape[2:] != target_size:
band = F.interpolate(band, size=target_size, mode="bilinear", align_corners=False)
upsampled.append(band)
# Average across all bands — the solver sees the mean multi-frequency representation
merged = torch.stack(upsampled).mean(0) # (1, 768, 40, 40)
B, C, Hc, Wc = merged.shape
f = F.layer_norm(merged.permute(0, 2, 3, 1).reshape(-1, C), [C])
ct, rt, ctrt = assign_targets(locs_flat[0], boxes, labels, stride, sr)
pos_mask = ct >= 0
if not pos_mask.any():
seen += 1; continue
fp = f[pos_mask]
fa = torch.cat([fp, torch.ones(fp.shape[0], 1, device=DEVICE)], 1)
yc = torch.zeros(fp.shape[0], NUM_CLASSES, device=DEVICE)
yc[torch.arange(fp.shape[0], device=DEVICE), ct[pos_mask].to(DEVICE)] = 1.0
cls_XtX += fa.T @ fa; cls_XtY += fa.T @ yc
ltrb = rt[pos_mask]; valid = (ltrb > 0).all(1)
if valid.any():
fv = fa[valid]; yt = torch.log(ltrb[valid]).to(DEVICE)
reg_XtX += fv.T @ fv; reg_XtY += fv.T @ yt
ctr_XtX += fa.T @ fa
ctr_XtY += fa.T @ ctrt[pos_mask].unsqueeze(1).to(DEVICE)
n_pos += pos_mask.sum().item()
seen += 1
del shard
I = torch.eye(feat_dim + 1, device=DEVICE)
cls_W = torch.linalg.solve(cls_XtX + lam * I * n_pos, cls_XtY)
reg_W = torch.linalg.solve(reg_XtX + lam * I * n_pos, reg_XtY)
ctr_W = torch.linalg.solve(ctr_XtX + lam * I * n_pos, ctr_XtY)
accum_time = time.time() - t0
all_locs = locs_flat[0].to(DEVICE)
all_results = []
for idx in range(len(val)):
spatial = val[idx]["spatial"].unsqueeze(0).float().to(DEVICE)
img_id = int(val[idx]["img_id"]); scale = val[idx]["scale"]
bands = decompose_fn(spatial)
upsampled = []
for band in bands:
if band.shape[2:] != target_size:
band = F.interpolate(band, size=target_size, mode="bilinear", align_corners=False)
upsampled.append(band)
merged = torch.stack(upsampled).mean(0)
B, C, Hc, Wc = merged.shape
f = F.layer_norm(merged.permute(0, 2, 3, 1).reshape(-1, C), [C])
cls_s = (f @ cls_W[:feat_dim] + cls_W[feat_dim]).sigmoid()
reg_s = (f @ reg_W[:feat_dim] + reg_W[feat_dim]).exp()
ctr_s = (f @ ctr_W[:feat_dim] + ctr_W[feat_dim]).sigmoid().squeeze(1)
scores = cls_s * ctr_s.unsqueeze(1)
max_s, max_c = scores.max(1)
topk = min(100, max_s.shape[0])
top_s, top_i = max_s.topk(topk)
tc = max_c[top_i]; tr = reg_s[top_i]; tl = all_locs[top_i]
x1 = (tl[:,0]-tr[:,0])/scale; y1 = (tl[:,1]-tr[:,1])/scale
x2 = (tl[:,0]+tr[:,2])/scale; y2 = (tl[:,1]+tr[:,3])/scale
w = (x2-x1).clamp(min=0); h = (y2-y1).clamp(min=0)
for i in range(topk):
s = top_s[i].item()
if s < 0.01: continue
all_results.append({"image_id": img_id, "category_id": idx_to_cat[tc[i].item()],
"bbox": [x1[i].item(), y1[i].item(), w[i].item(), h[i].item()],
"score": s})
# pycocotools eval
from pycocotools.cocoeval import COCOeval
if not all_results:
print(f" {name}: no detections"); return 0.0
coco_dt = coco_gt.loadRes(all_results)
coco_eval = COCOeval(coco_gt, coco_dt, "bbox")
coco_eval.params.imgIds = sorted(coco_gt.getImgIds())[:len(val)]
coco_eval.evaluate(); coco_eval.accumulate(); coco_eval.summarize()
mAP = coco_eval.stats[0]
mAP50 = coco_eval.stats[1]
mAP75 = coco_eval.stats[2]
print(f" {name}: mAP={mAP:.4f} mAP50={mAP50:.4f} mAP75={mAP75:.4f} "
f"({accum_time:.0f}s accum, {n_pos} pos)")
return mAP
def main():
from pycocotools.coco import COCO
print("=" * 60)
print("Fractal vs Standard Cofiber Decomposition")
print("=" * 60, flush=True)
val = torch.load(VAL_CACHE, map_location="cpu", weights_only=False)
ann_file = os.path.join(COCO_ROOT, "annotations", "instances_val2017.json")
coco_gt = COCO(ann_file)
cat_ids = sorted(coco_gt.getCatIds())
results = []
# Standard 3-band cofiber (baseline)
print("\n1. Standard 3-band cofiber:", flush=True)
mAP = eval_decomposition(val, coco_gt, cat_ids,
lambda sp: standard_decompose(sp, 3), "standard_3band")
results.append({"name": "standard_3band", "mAP": mAP, "bands": 3})
# Fractal depth 2 (4 bands)
print("\n2. Fractal depth 2 (4 bands):", flush=True)
mAP = eval_decomposition(val, coco_gt, cat_ids,
lambda sp: fractal_decompose(sp, 2), "fractal_depth2")
results.append({"name": "fractal_depth2", "mAP": mAP, "bands": 4})
# Fractal depth 3 (8 bands)
print("\n3. Fractal depth 3 (8 bands):", flush=True)
mAP = eval_decomposition(val, coco_gt, cat_ids,
lambda sp: fractal_decompose(sp, 3), "fractal_depth3")
results.append({"name": "fractal_depth3", "mAP": mAP, "bands": 8})
print(f"\n{'='*60}")
print("Summary:")
for r in results:
print(f" {r['name']:20s}: mAP={r['mAP']:.4f} ({r['bands']} bands)")
out = os.path.join(SCRIPT_DIR, "analytical_variants", "fractal_results.json")
with open(out, "w") as f:
json.dump(results, f, indent=2)
print(f"Saved: {out}")
if __name__ == "__main__":
main()