detection-heads / utils /cache.py
phanerozoic's picture
Restructure: one folder per head, shared losses/utils, registry runner
ca63835 verified
"""Feature caching and dataset loading utilities."""
import json
import os
import random
from typing import List, Tuple
import numpy as np
import torch
from PIL import Image
from torchvision.transforms import v2
RESOLUTION = 640
HOOK_BLOCKS = [2, 5, 8, 11]
N_PREFIX = 5
COCO_CONTIG_TO_CAT = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50,
51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75,
76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90,
]
COCO_CAT_TO_CONTIG = {cat: i for i, cat in enumerate(COCO_CONTIG_TO_CAT)}
def letterbox(image: Image.Image, res: int):
W0, H0 = image.size
scale = res / max(H0, W0)
new_w, new_h = int(round(W0 * scale)), int(round(H0 * scale))
resized = image.resize((new_w, new_h), Image.BILINEAR)
canvas = Image.new("RGB", (res, res), (0, 0, 0))
canvas.paste(resized, (0, 0))
return canvas, scale
def load_coco_subset(split, n, coco_root, min_anns=0):
img_dir = os.path.join(coco_root, f"{split}2017")
ann_file = os.path.join(coco_root, "annotations", f"instances_{split}2017.json")
with open(ann_file) as f:
coco = json.load(f)
id_to_anns = {}
for a in coco["annotations"]:
if a["iscrowd"]:
continue
cat = a["category_id"]
if cat not in COCO_CAT_TO_CONTIG:
continue
id_to_anns.setdefault(a["image_id"], []).append(a)
id_to_info = {img["id"]: img for img in coco["images"]}
candidates = [(iid, anns) for iid, anns in id_to_anns.items() if len(anns) >= min_anns]
if split == "val":
candidates.sort(key=lambda x: -len(x[1]))
else:
random.seed(42)
random.shuffle(candidates)
items = []
for iid, anns in candidates[:n]:
info = id_to_info[iid]
path = os.path.join(img_dir, info["file_name"])
boxes, labels = [], []
for a in anns:
x, y, w, h = a["bbox"]
if w < 1 or h < 1:
continue
boxes.append([x, y, x + w, y + h])
labels.append(COCO_CAT_TO_CONTIG[a["category_id"]])
if boxes:
items.append({"path": path, "boxes": boxes, "labels": labels,
"width": info["width"], "height": info["height"]})
return items
def cache_features(backbone, items, cache_path, resolution=RESOLUTION):
"""Cache backbone features for a list of image items."""
if os.path.isfile(cache_path):
return
normalize = v2.Compose([
v2.ToImage(), v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
hooks, intermediates = [], {}
for idx in HOOK_BLOCKS:
def _hook(block_idx):
def fn(mod, inp, out):
intermediates[block_idx] = (out[0] if isinstance(out, list) else out).detach()
return fn
hooks.append(backbone.blocks[idx].register_forward_hook(_hook(idx)))
cached = []
for item in items:
img = Image.open(item["path"]).convert("RGB")
canvas, scale = letterbox(img, resolution)
x = normalize(canvas).unsqueeze(0).cuda()
intermediates.clear()
with torch.no_grad():
with torch.autocast("cuda", dtype=torch.bfloat16):
out = backbone.forward_features(x)
patches = out["x_norm_patchtokens"].float()
B, N, D = patches.shape
h = w = int(N ** 0.5)
spatial = patches[0].permute(1, 0).reshape(D, h, w).half().cpu()
inter = [intermediates[idx][0].half().cpu() for idx in HOOK_BLOCKS]
boxes = torch.tensor(item["boxes"], dtype=torch.float32) * scale
labels = torch.tensor(item["labels"], dtype=torch.long)
cached.append({
"spatial": spatial, "intermediates": inter,
"boxes": boxes, "labels": labels,
})
for h in hooks:
h.remove()
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
torch.save(cached, cache_path)