File size: 4,146 Bytes
ca63835
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""Feature caching and dataset loading utilities."""

import json
import os
import random
from typing import List, Tuple

import numpy as np
import torch
from PIL import Image
from torchvision.transforms import v2

RESOLUTION = 640
HOOK_BLOCKS = [2, 5, 8, 11]
N_PREFIX = 5

COCO_CONTIG_TO_CAT = [
    1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
    27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50,
    51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75,
    76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90,
]
COCO_CAT_TO_CONTIG = {cat: i for i, cat in enumerate(COCO_CONTIG_TO_CAT)}


def letterbox(image: Image.Image, res: int):
    W0, H0 = image.size
    scale = res / max(H0, W0)
    new_w, new_h = int(round(W0 * scale)), int(round(H0 * scale))
    resized = image.resize((new_w, new_h), Image.BILINEAR)
    canvas = Image.new("RGB", (res, res), (0, 0, 0))
    canvas.paste(resized, (0, 0))
    return canvas, scale


def load_coco_subset(split, n, coco_root, min_anns=0):
    img_dir = os.path.join(coco_root, f"{split}2017")
    ann_file = os.path.join(coco_root, "annotations", f"instances_{split}2017.json")
    with open(ann_file) as f:
        coco = json.load(f)
    id_to_anns = {}
    for a in coco["annotations"]:
        if a["iscrowd"]:
            continue
        cat = a["category_id"]
        if cat not in COCO_CAT_TO_CONTIG:
            continue
        id_to_anns.setdefault(a["image_id"], []).append(a)
    id_to_info = {img["id"]: img for img in coco["images"]}
    candidates = [(iid, anns) for iid, anns in id_to_anns.items() if len(anns) >= min_anns]
    if split == "val":
        candidates.sort(key=lambda x: -len(x[1]))
    else:
        random.seed(42)
        random.shuffle(candidates)
    items = []
    for iid, anns in candidates[:n]:
        info = id_to_info[iid]
        path = os.path.join(img_dir, info["file_name"])
        boxes, labels = [], []
        for a in anns:
            x, y, w, h = a["bbox"]
            if w < 1 or h < 1:
                continue
            boxes.append([x, y, x + w, y + h])
            labels.append(COCO_CAT_TO_CONTIG[a["category_id"]])
        if boxes:
            items.append({"path": path, "boxes": boxes, "labels": labels,
                          "width": info["width"], "height": info["height"]})
    return items


def cache_features(backbone, items, cache_path, resolution=RESOLUTION):
    """Cache backbone features for a list of image items."""
    if os.path.isfile(cache_path):
        return

    normalize = v2.Compose([
        v2.ToImage(), v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ])

    hooks, intermediates = [], {}
    for idx in HOOK_BLOCKS:
        def _hook(block_idx):
            def fn(mod, inp, out):
                intermediates[block_idx] = (out[0] if isinstance(out, list) else out).detach()
            return fn
        hooks.append(backbone.blocks[idx].register_forward_hook(_hook(idx)))

    cached = []
    for item in items:
        img = Image.open(item["path"]).convert("RGB")
        canvas, scale = letterbox(img, resolution)
        x = normalize(canvas).unsqueeze(0).cuda()

        intermediates.clear()
        with torch.no_grad():
            with torch.autocast("cuda", dtype=torch.bfloat16):
                out = backbone.forward_features(x)
        patches = out["x_norm_patchtokens"].float()
        B, N, D = patches.shape
        h = w = int(N ** 0.5)
        spatial = patches[0].permute(1, 0).reshape(D, h, w).half().cpu()
        inter = [intermediates[idx][0].half().cpu() for idx in HOOK_BLOCKS]

        boxes = torch.tensor(item["boxes"], dtype=torch.float32) * scale
        labels = torch.tensor(item["labels"], dtype=torch.long)
        cached.append({
            "spatial": spatial, "intermediates": inter,
            "boxes": boxes, "labels": labels,
        })

    for h in hooks:
        h.remove()

    os.makedirs(os.path.dirname(cache_path), exist_ok=True)
    torch.save(cached, cache_path)