File size: 4,340 Bytes
dbbceb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""Eval mask regression head on COCO val2017."""
import os, sys, time, torch, json
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from train_mask_regression import (
    MaskRegressionHead, make_locations, decode_mask_to_box, K
)

DEVICE = "cuda"
COCO_ROOT = os.environ["ARENA_COCO_ROOT"]
VAL_CACHE = os.environ["ARENA_VAL_CACHE"]

import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--hidden", type=int, default=192)
parser.add_argument("--std-layers", type=int, default=5)
parser.add_argument("--dw-layers", type=int, default=4)
parser.add_argument("--checkpoint", required=True)
args = parser.parse_args()

head = MaskRegressionHead(hidden=args.hidden, n_std_layers=args.std_layers,
                          n_dw_layers=args.dw_layers).to(DEVICE)
ckpt = torch.load(args.checkpoint, map_location=DEVICE, weights_only=False)
if isinstance(ckpt, dict) and "head" in ckpt:
    head.load_state_dict(ckpt["head"])
    step = ckpt["step"]
else:
    head.load_state_dict(ckpt)
    step = "final"
head.eval()
print(f"Loaded step {step}, {sum(p.numel() for p in head.parameters()):,} params")

val = torch.load(VAL_CACHE, map_location="cpu", weights_only=False)
from pycocotools.coco import COCO
coco_gt = COCO(os.path.join(COCO_ROOT, "annotations", "instances_val2017.json"))
cat_ids = sorted(coco_gt.getCatIds())
idx_to_cat = {i: c for i, c in enumerate(cat_ids)}

H = 640 // 16
strides = [8, 16, 32, 64]
grid_sizes = [(H*2, H*2), (H, H), (H//2, H//2), (H//4, H//4)]
locs_per_level = make_locations(grid_sizes, strides, torch.device(DEVICE))
strides_per_level = [torch.full((loc.shape[0],), s, device=DEVICE, dtype=torch.float32)
                     for loc, s in zip(locs_per_level, strides)]

all_locs = torch.cat(locs_per_level)
all_strides = torch.cat(strides_per_level)

all_results = []
t0 = time.time()
with torch.no_grad():
    for idx in range(len(val)):
        item = val[idx]
        spatial = item["spatial"].unsqueeze(0).float().to(DEVICE)
        img_id = int(item["img_id"]); img_scale = item["scale"]
        cls_l, mask_l, ctr_l = head(spatial)

        cls_s = torch.cat([c.permute(0,2,3,1).reshape(-1, 80) for c in cls_l]).sigmoid()
        mask_s = torch.cat([m.permute(0,2,3,1).reshape(-1, K, K) for m in mask_l]).clamp(0, 1)
        ctr_s = torch.cat([c.permute(0,2,3,1).reshape(-1) for c in ctr_l]).sigmoid()

        scores = cls_s * ctr_s.unsqueeze(1)
        max_s, max_c = scores.max(1)
        topk = min(100, max_s.shape[0])
        top_s, top_i = max_s.topk(topk)
        top_c = max_c[top_i]
        top_masks = mask_s[top_i]
        top_locs = all_locs[top_i]
        top_strides = all_strides[top_i]

        boxes = torch.zeros(topk, 4, device=DEVICE)
        for s_val in [8, 16, 32, 64]:
            sel = top_strides == s_val
            if not sel.any(): continue
            these_boxes = decode_mask_to_box(top_masks[sel], s_val, top_locs[sel, 1], top_locs[sel, 0])
            boxes[sel] = these_boxes

        y0 = boxes[:, 0] / img_scale
        x0 = boxes[:, 1] / img_scale
        y1 = boxes[:, 2] / img_scale
        x1 = boxes[:, 3] / img_scale
        w_box = (x1 - x0).clamp(min=0)
        h_box = (y1 - y0).clamp(min=0)

        for i in range(topk):
            s = top_s[i].item()
            if s < 0.01: continue
            all_results.append({
                "image_id": img_id,
                "category_id": idx_to_cat[top_c[i].item()],
                "bbox": [x0[i].item(), y0[i].item(), w_box[i].item(), h_box[i].item()],
                "score": s,
            })
        if (idx + 1) % 1000 == 0:
            print(f"  {idx+1}/{len(val)} ({time.time()-t0:.0f}s)", flush=True)

print(f"\n{len(all_results)} detections")
results_path = f"mask_reg_step{step}_results.json"
with open(results_path, "w") as f:
    json.dump(all_results, f)
print(f"Saved: {results_path}")

try:
    from pycocotools.cocoeval import COCOeval
    coco_dt = coco_gt.loadRes(all_results)
    ev = COCOeval(coco_gt, coco_dt, "bbox")
    ev.params.imgIds = sorted(coco_gt.getImgIds())[:len(val)]
    ev.evaluate(); ev.accumulate(); ev.summarize()
    print(f"\nMask Reg {args.hidden}h step{step}: "
          f"mAP={ev.stats[0]:.4f} mAP50={ev.stats[1]:.4f} mAP75={ev.stats[2]:.4f}")
except Exception as e:
    print(f"pycocotools failed: {e}. Use eval_from_results.py")