cofiber-detection / analytical /scripts /analytical_greedy_gpu.py
phanerozoic's picture
update repository
dbbceb8
"""
GPU-accelerated greedy forward construction of a minimal detection head.
Batches all candidate evaluations into parallel matmuls on GPU.
50 greedy steps in seconds instead of minutes.
"""
import argparse
import json
import os
import sys
import time
import torch
import torch.nn.functional as F
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, SCRIPT_DIR)
COCO_ROOT = os.environ.get("ARENA_COCO_ROOT", "coco")
VAL_CACHE = os.environ.get("ARENA_VAL_CACHE", "val_cache/val.pt")
NUM_CLASSES = 80
def cofiber_decompose(f, n_scales):
cofibers = []
residual = f
for _ in range(n_scales - 1):
omega = F.avg_pool2d(residual, 2)
sigma_omega = F.interpolate(omega, size=residual.shape[2:], mode="bilinear", align_corners=False)
cofibers.append(residual - sigma_omega)
residual = omega
cofibers.append(residual)
return cofibers
def make_locations(sizes, strides):
locs = []
for (h, w), s in zip(sizes, strides):
ys = (torch.arange(h, dtype=torch.float32) + 0.5) * s
xs = (torch.arange(w, dtype=torch.float32) + 0.5) * s
gy, gx = torch.meshgrid(ys, xs, indexing="ij")
locs.append(torch.stack([gx.flatten(), gy.flatten()], -1))
return locs
def assign_targets(loc, boxes, labels, stride, sr):
n = loc.shape[0]
if boxes.numel() == 0:
return torch.full((n,), -1, dtype=torch.long), torch.zeros(n, 4)
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
l = loc[:, None, 0] - boxes[None, :, 0]
t = loc[:, None, 1] - boxes[None, :, 1]
r = boxes[None, :, 2] - loc[:, None, 0]
b = boxes[None, :, 3] - loc[:, None, 1]
ltrb = torch.stack([l, t, r, b], -1)
in_box = ltrb.min(-1).values > 0
cx = (boxes[:, 0] + boxes[:, 2]) / 2
cy = (boxes[:, 1] + boxes[:, 3]) / 2
rad = stride * 1.5
in_center = ((loc[:, None, 0] >= cx - rad) & (loc[:, None, 0] <= cx + rad) &
(loc[:, None, 1] >= cy - rad) & (loc[:, None, 1] <= cy + rad))
max_d = ltrb.max(-1).values
in_level = (max_d >= sr[0]) & (max_d <= sr[1])
pos = in_box & in_center & in_level
a = areas[None, :].expand_as(pos).clone()
a[~pos] = float("inf")
matched = a.argmin(1)
is_pos = a.gather(1, matched[:, None]).squeeze(1) < float("inf")
ct = torch.full((n,), -1, dtype=torch.long)
ct[is_pos] = labels[matched[is_pos]]
rt = torch.zeros(n, 4)
if is_pos.any():
rt[is_pos] = ltrb[torch.arange(n)[is_pos], matched[is_pos]]
return ct, rt
def build_val_data(val_path, n_images=500, device="cuda"):
"""Build feature matrix + targets on GPU."""
val = torch.load(val_path, map_location="cpu", weights_only=False)
from pycocotools.coco import COCO
ann_file = os.path.join(COCO_ROOT, "annotations", "instances_val2017.json")
coco = COCO(ann_file)
cat_ids = sorted(coco.getCatIds())
cat_to_idx = {c: i for i, c in enumerate(cat_ids)}
strides = [16, 32, 64]
H = 640 // 16
sizes = [(H, H), (H // 2, H // 2), (H // 4, H // 4)]
sr = [(-1, 128), (128, 256), (256, float("inf"))]
locs = make_locations(sizes, strides)
all_f, all_cls = [], []
for idx in range(min(n_images, len(val))):
item = val[idx]
spatial = item["spatial"].unsqueeze(0).float()
img_id = item["img_id"]
scale = item["scale"]
ann_ids = coco.getAnnIds(imgIds=int(img_id), iscrowd=False)
anns = coco.loadAnns(ann_ids)
boxes, labels = [], []
for ann in anns:
x, y, w, h = ann["bbox"]
if w < 1 or h < 1:
continue
boxes.append([x * scale, y * scale, (x + w) * scale, (y + h) * scale])
labels.append(cat_to_idx[ann["category_id"]])
boxes_t = torch.tensor(boxes, dtype=torch.float32) if boxes else torch.zeros(0, 4)
labels_t = torch.tensor(labels, dtype=torch.long) if labels else torch.zeros(0, dtype=torch.long)
cofibers = cofiber_decompose(spatial, 3)
for sci, cof in enumerate(cofibers):
B, C, Hc, Wc = cof.shape
f = F.layer_norm(cof.permute(0, 2, 3, 1).reshape(-1, C), [C])
ct, _ = assign_targets(locs[sci], boxes_t, labels_t, strides[sci], sr[sci])
all_f.append(f)
all_cls.append(ct)
features = torch.cat(all_f).to(device)
cls_targets = torch.cat(all_cls).to(device)
return features, cls_targets
def greedy_step_gpu(features, cls_targets, selected, remaining, lam=0.1):
"""Test all remaining candidates in parallel on GPU. Return best dim and accuracy."""
pos = cls_targets >= 0
n_pos = pos.sum().item()
if n_pos == 0:
return -1, 0.0
# Build one-hot targets
f_pos = features[pos]
y_cls = torch.zeros(n_pos, NUM_CLASSES, device=features.device)
y_cls[torch.arange(n_pos, device=features.device), cls_targets[pos]] = 1.0
gt = cls_targets[pos]
best_dim = -1
best_acc = -1.0
# For each candidate, solve and score
# Batch in chunks to avoid OOM on very large candidate sets
chunk_size = 64
for chunk_start in range(0, len(remaining), chunk_size):
chunk = remaining[chunk_start:chunk_start + chunk_size]
accs = []
for d in chunk:
dims = selected + [d]
fd = len(dims)
fp = f_pos[:, dims]
fa = torch.cat([fp, torch.ones(n_pos, 1, device=fp.device)], 1)
I = torch.eye(fd + 1, device=fp.device)
XtX = fa.T @ fa
XtY = fa.T @ y_cls
try:
W = torch.linalg.solve(XtX + lam * I * n_pos, XtY)
except Exception:
accs.append(0.0)
continue
# Score on all positive locations
scores = fp @ W[:fd] + W[fd] # (n_pos, 80)
pred = scores.argmax(1)
acc = (pred == gt).float().mean().item()
accs.append(acc)
for i, d in enumerate(chunk):
if accs[i] > best_acc:
best_acc = accs[i]
best_dim = d
return best_dim, best_acc
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--max-dims", type=int, default=100)
parser.add_argument("--n-eval", type=int, default=500)
parser.add_argument("--lam", type=float, default=0.1)
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
print("=" * 60)
print(f"GPU Greedy Forward Construction (max {args.max_dims} dims)")
print("=" * 60, flush=True)
print("Building val data...", flush=True)
features, cls_targets = build_val_data(VAL_CACHE, args.n_eval, device)
pos = cls_targets >= 0
print(f" {features.shape[0]} locations, {pos.sum().item()} positives, "
f"{features.shape[1]} dims", flush=True)
selected = []
remaining = list(range(768))
history = []
t0 = time.time()
for step in range(args.max_dims):
t_step = time.time()
best_dim, best_acc = greedy_step_gpu(features, cls_targets, selected, remaining, args.lam)
if best_dim < 0:
break
selected.append(best_dim)
remaining.remove(best_dim)
step_time = time.time() - t_step
n_params = len(selected) * NUM_CLASSES + NUM_CLASSES # cls only for now
entry = {"step": step + 1, "dim": best_dim, "cls_acc": round(best_acc, 4),
"n_params": n_params, "step_ms": round(step_time * 1000)}
history.append(entry)
print(f" step {step+1:3d}: +dim{best_dim:3d} -> cls_acc={best_acc:.4f} "
f"({len(selected)} dims, {n_params} params, {step_time*1000:.0f}ms)", flush=True)
# Early stopping
if len(history) >= 10:
recent_gain = history[-1]["cls_acc"] - history[-10]["cls_acc"]
if recent_gain < 0.005:
print(f" Converged: <0.5% gain in 10 steps", flush=True)
break
elapsed = time.time() - t0
print(f"\n{'='*60}")
print(f"Selected {len(selected)} dimensions in {elapsed:.1f}s")
print(f"Final cls_acc: {history[-1]['cls_acc']:.4f}")
print(f"Final params: {history[-1]['n_params']}")
print(f"\nTop 20 dimensions (most to least important):")
for h in history[:20]:
print(f" step {h['step']:2d}: dim{h['dim']:3d} cumul_acc={h['cls_acc']:.4f} ({h['step_ms']}ms)")
# Save
result = {"selected_dims": selected, "history": history,
"final_cls_acc": history[-1]["cls_acc"], "final_params": history[-1]["n_params"],
"total_time_s": round(elapsed, 1)}
out = os.path.join(SCRIPT_DIR, "analytical_variants", "greedy_forward_gpu.json")
os.makedirs(os.path.dirname(out), exist_ok=True)
with open(out, "w") as f:
json.dump(result, f, indent=2)
print(f"\nSaved: {out}")
if __name__ == "__main__":
main()