| """Eval mask regression head on COCO val2017.""" |
| import os, sys, time, torch, json |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from train_mask_regression import ( |
| MaskRegressionHead, make_locations, decode_mask_to_box, K |
| ) |
|
|
| DEVICE = "cuda" |
| COCO_ROOT = os.environ["ARENA_COCO_ROOT"] |
| VAL_CACHE = os.environ["ARENA_VAL_CACHE"] |
|
|
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--hidden", type=int, default=192) |
| parser.add_argument("--std-layers", type=int, default=5) |
| parser.add_argument("--dw-layers", type=int, default=4) |
| parser.add_argument("--checkpoint", required=True) |
| args = parser.parse_args() |
|
|
| head = MaskRegressionHead(hidden=args.hidden, n_std_layers=args.std_layers, |
| n_dw_layers=args.dw_layers).to(DEVICE) |
| ckpt = torch.load(args.checkpoint, map_location=DEVICE, weights_only=False) |
| if isinstance(ckpt, dict) and "head" in ckpt: |
| head.load_state_dict(ckpt["head"]) |
| step = ckpt["step"] |
| else: |
| head.load_state_dict(ckpt) |
| step = "final" |
| head.eval() |
| print(f"Loaded step {step}, {sum(p.numel() for p in head.parameters()):,} params") |
|
|
| val = torch.load(VAL_CACHE, map_location="cpu", weights_only=False) |
| from pycocotools.coco import COCO |
| coco_gt = COCO(os.path.join(COCO_ROOT, "annotations", "instances_val2017.json")) |
| cat_ids = sorted(coco_gt.getCatIds()) |
| idx_to_cat = {i: c for i, c in enumerate(cat_ids)} |
|
|
| H = 640 // 16 |
| strides = [8, 16, 32, 64] |
| grid_sizes = [(H*2, H*2), (H, H), (H//2, H//2), (H//4, H//4)] |
| locs_per_level = make_locations(grid_sizes, strides, torch.device(DEVICE)) |
| strides_per_level = [torch.full((loc.shape[0],), s, device=DEVICE, dtype=torch.float32) |
| for loc, s in zip(locs_per_level, strides)] |
|
|
| all_locs = torch.cat(locs_per_level) |
| all_strides = torch.cat(strides_per_level) |
|
|
| all_results = [] |
| t0 = time.time() |
| with torch.no_grad(): |
| for idx in range(len(val)): |
| item = val[idx] |
| spatial = item["spatial"].unsqueeze(0).float().to(DEVICE) |
| img_id = int(item["img_id"]); img_scale = item["scale"] |
| cls_l, mask_l, ctr_l = head(spatial) |
|
|
| cls_s = torch.cat([c.permute(0,2,3,1).reshape(-1, 80) for c in cls_l]).sigmoid() |
| mask_s = torch.cat([m.permute(0,2,3,1).reshape(-1, K, K) for m in mask_l]).clamp(0, 1) |
| ctr_s = torch.cat([c.permute(0,2,3,1).reshape(-1) for c in ctr_l]).sigmoid() |
|
|
| scores = cls_s * ctr_s.unsqueeze(1) |
| max_s, max_c = scores.max(1) |
| topk = min(100, max_s.shape[0]) |
| top_s, top_i = max_s.topk(topk) |
| top_c = max_c[top_i] |
| top_masks = mask_s[top_i] |
| top_locs = all_locs[top_i] |
| top_strides = all_strides[top_i] |
|
|
| boxes = torch.zeros(topk, 4, device=DEVICE) |
| for s_val in [8, 16, 32, 64]: |
| sel = top_strides == s_val |
| if not sel.any(): continue |
| these_boxes = decode_mask_to_box(top_masks[sel], s_val, top_locs[sel, 1], top_locs[sel, 0]) |
| boxes[sel] = these_boxes |
|
|
| y0 = boxes[:, 0] / img_scale |
| x0 = boxes[:, 1] / img_scale |
| y1 = boxes[:, 2] / img_scale |
| x1 = boxes[:, 3] / img_scale |
| w_box = (x1 - x0).clamp(min=0) |
| h_box = (y1 - y0).clamp(min=0) |
|
|
| for i in range(topk): |
| s = top_s[i].item() |
| if s < 0.01: continue |
| all_results.append({ |
| "image_id": img_id, |
| "category_id": idx_to_cat[top_c[i].item()], |
| "bbox": [x0[i].item(), y0[i].item(), w_box[i].item(), h_box[i].item()], |
| "score": s, |
| }) |
| if (idx + 1) % 1000 == 0: |
| print(f" {idx+1}/{len(val)} ({time.time()-t0:.0f}s)", flush=True) |
|
|
| print(f"\n{len(all_results)} detections") |
| results_path = f"mask_reg_step{step}_results.json" |
| with open(results_path, "w") as f: |
| json.dump(all_results, f) |
| print(f"Saved: {results_path}") |
|
|
| try: |
| from pycocotools.cocoeval import COCOeval |
| coco_dt = coco_gt.loadRes(all_results) |
| ev = COCOeval(coco_gt, coco_dt, "bbox") |
| ev.params.imgIds = sorted(coco_gt.getImgIds())[:len(val)] |
| ev.evaluate(); ev.accumulate(); ev.summarize() |
| print(f"\nMask Reg {args.hidden}h step{step}: " |
| f"mAP={ev.stats[0]:.4f} mAP50={ev.stats[1]:.4f} mAP75={ev.stats[2]:.4f}") |
| except Exception as e: |
| print(f"pycocotools failed: {e}. Use eval_from_results.py") |
|
|