"""Feature caching and dataset loading utilities.""" import json import os import random from typing import List, Tuple import numpy as np import torch from PIL import Image from torchvision.transforms import v2 RESOLUTION = 640 HOOK_BLOCKS = [2, 5, 8, 11] N_PREFIX = 5 COCO_CONTIG_TO_CAT = [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, ] COCO_CAT_TO_CONTIG = {cat: i for i, cat in enumerate(COCO_CONTIG_TO_CAT)} def letterbox(image: Image.Image, res: int): W0, H0 = image.size scale = res / max(H0, W0) new_w, new_h = int(round(W0 * scale)), int(round(H0 * scale)) resized = image.resize((new_w, new_h), Image.BILINEAR) canvas = Image.new("RGB", (res, res), (0, 0, 0)) canvas.paste(resized, (0, 0)) return canvas, scale def load_coco_subset(split, n, coco_root, min_anns=0): img_dir = os.path.join(coco_root, f"{split}2017") ann_file = os.path.join(coco_root, "annotations", f"instances_{split}2017.json") with open(ann_file) as f: coco = json.load(f) id_to_anns = {} for a in coco["annotations"]: if a["iscrowd"]: continue cat = a["category_id"] if cat not in COCO_CAT_TO_CONTIG: continue id_to_anns.setdefault(a["image_id"], []).append(a) id_to_info = {img["id"]: img for img in coco["images"]} candidates = [(iid, anns) for iid, anns in id_to_anns.items() if len(anns) >= min_anns] if split == "val": candidates.sort(key=lambda x: -len(x[1])) else: random.seed(42) random.shuffle(candidates) items = [] for iid, anns in candidates[:n]: info = id_to_info[iid] path = os.path.join(img_dir, info["file_name"]) boxes, labels = [], [] for a in anns: x, y, w, h = a["bbox"] if w < 1 or h < 1: continue boxes.append([x, y, x + w, y + h]) labels.append(COCO_CAT_TO_CONTIG[a["category_id"]]) if boxes: items.append({"path": path, "boxes": boxes, "labels": labels, "width": info["width"], "height": info["height"]}) return items def cache_features(backbone, items, cache_path, resolution=RESOLUTION): """Cache backbone features for a list of image items.""" if os.path.isfile(cache_path): return normalize = v2.Compose([ v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ]) hooks, intermediates = [], {} for idx in HOOK_BLOCKS: def _hook(block_idx): def fn(mod, inp, out): intermediates[block_idx] = (out[0] if isinstance(out, list) else out).detach() return fn hooks.append(backbone.blocks[idx].register_forward_hook(_hook(idx))) cached = [] for item in items: img = Image.open(item["path"]).convert("RGB") canvas, scale = letterbox(img, resolution) x = normalize(canvas).unsqueeze(0).cuda() intermediates.clear() with torch.no_grad(): with torch.autocast("cuda", dtype=torch.bfloat16): out = backbone.forward_features(x) patches = out["x_norm_patchtokens"].float() B, N, D = patches.shape h = w = int(N ** 0.5) spatial = patches[0].permute(1, 0).reshape(D, h, w).half().cpu() inter = [intermediates[idx][0].half().cpu() for idx in HOOK_BLOCKS] boxes = torch.tensor(item["boxes"], dtype=torch.float32) * scale labels = torch.tensor(item["labels"], dtype=torch.long) cached.append({ "spatial": spatial, "intermediates": inter, "boxes": boxes, "labels": labels, }) for h in hooks: h.remove() os.makedirs(os.path.dirname(cache_path), exist_ok=True) torch.save(cached, cache_path)