File size: 3,529 Bytes
dbbceb8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | """Eval split tower head."""
import os, sys, time, torch, json
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from train_split_tower_5scale import FCOSLiteHead, cofiber_decompose, make_locations
DEVICE = "cuda"
COCO_ROOT = os.environ["ARENA_COCO_ROOT"]
VAL_CACHE = os.environ["ARENA_VAL_CACHE"]
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--hidden", type=int, default=224)
parser.add_argument("--std-layers", type=int, default=3)
parser.add_argument("--dw-layers", type=int, default=6)
parser.add_argument("--checkpoint", required=True)
args = parser.parse_args()
head = FCOSLiteHead(hidden=args.hidden, n_std_layers=args.std_layers, n_dw_layers=args.dw_layers, n_scales=4).to(DEVICE)
ckpt = torch.load(args.checkpoint, map_location=DEVICE, weights_only=False)
if isinstance(ckpt, dict) and "head" in ckpt:
head.load_state_dict(ckpt["head"])
step = ckpt["step"]
else:
head.load_state_dict(ckpt)
step = "final"
head.eval()
print(f"Loaded step {step}, {sum(p.numel() for p in head.parameters()):,} params")
val = torch.load(VAL_CACHE, map_location="cpu", weights_only=False)
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
coco_gt = COCO(os.path.join(COCO_ROOT, "annotations", "instances_val2017.json"))
cat_ids = sorted(coco_gt.getCatIds())
idx_to_cat = {i: c for i, c in enumerate(cat_ids)}
H = 640 // 16
strides = [8, 16, 32, 64, 128]
all_locs = torch.cat(make_locations([(H*2,H*2),(H,H),(H//2,H//2),(H//4,H//4),(H//8,H//8)], strides, torch.device(DEVICE)))
all_results = []
t0 = time.time()
with torch.no_grad():
for idx in range(len(val)):
item = val[idx]
spatial = item["spatial"].unsqueeze(0).float().to(DEVICE)
img_id = int(item["img_id"]); scale = item["scale"]
cls_l, reg_l, ctr_l = head(spatial)
cls_s = torch.cat([c.permute(0,2,3,1).reshape(-1, 80) for c in cls_l]).sigmoid()
reg_s = torch.cat([r.permute(0,2,3,1).reshape(-1, 4) for r in reg_l])
ctr_s = torch.cat([c.permute(0,2,3,1).reshape(-1) for c in ctr_l]).sigmoid()
scores = cls_s * ctr_s.unsqueeze(1)
max_s, max_c = scores.max(1)
topk = min(100, max_s.shape[0])
top_s, top_i = max_s.topk(topk)
tc=max_c[top_i]; tr=reg_s[top_i]; tl=all_locs[top_i]
x1=(tl[:,0]-tr[:,0])/scale; y1=(tl[:,1]-tr[:,1])/scale
x2=(tl[:,0]+tr[:,2])/scale; y2=(tl[:,1]+tr[:,3])/scale
w=(x2-x1).clamp(min=0); h=(y2-y1).clamp(min=0)
for i in range(topk):
s = top_s[i].item()
if s < 0.01: continue
all_results.append({"image_id": img_id, "category_id": idx_to_cat[tc[i].item()],
"bbox": [x1[i].item(), y1[i].item(), w[i].item(), h[i].item()],
"score": s})
if (idx+1) % 1000 == 0:
print(f" {idx+1}/{len(val)} ({time.time()-t0:.0f}s)", flush=True)
print(f"\n{len(all_results)} detections")
results_path = f"split_tower_step{step}_results.json"
with open(results_path, "w") as f:
json.dump(all_results, f)
print(f"Saved: {results_path}")
try:
coco_dt = coco_gt.loadRes(all_results)
ev = COCOeval(coco_gt, coco_dt, "bbox")
ev.params.imgIds = sorted(coco_gt.getImgIds())[:len(val)]
ev.evaluate(); ev.accumulate(); ev.summarize()
print(f"\nSplit Tower {args.hidden}h step{step}: "
f"mAP={ev.stats[0]:.4f} mAP50={ev.stats[1]:.4f} mAP75={ev.stats[2]:.4f}")
except Exception as e:
print(f"pycocotools failed: {e}")
|