| """Feature caching and dataset loading utilities.""" |
|
|
| import json |
| import os |
| import random |
| from typing import List, Tuple |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
| from torchvision.transforms import v2 |
|
|
| RESOLUTION = 640 |
| HOOK_BLOCKS = [2, 5, 8, 11] |
| N_PREFIX = 5 |
|
|
| COCO_CONTIG_TO_CAT = [ |
| 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, |
| 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, |
| 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, |
| 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, |
| ] |
| COCO_CAT_TO_CONTIG = {cat: i for i, cat in enumerate(COCO_CONTIG_TO_CAT)} |
|
|
|
|
| def letterbox(image: Image.Image, res: int): |
| W0, H0 = image.size |
| scale = res / max(H0, W0) |
| new_w, new_h = int(round(W0 * scale)), int(round(H0 * scale)) |
| resized = image.resize((new_w, new_h), Image.BILINEAR) |
| canvas = Image.new("RGB", (res, res), (0, 0, 0)) |
| canvas.paste(resized, (0, 0)) |
| return canvas, scale |
|
|
|
|
| def load_coco_subset(split, n, coco_root, min_anns=0): |
| img_dir = os.path.join(coco_root, f"{split}2017") |
| ann_file = os.path.join(coco_root, "annotations", f"instances_{split}2017.json") |
| with open(ann_file) as f: |
| coco = json.load(f) |
| id_to_anns = {} |
| for a in coco["annotations"]: |
| if a["iscrowd"]: |
| continue |
| cat = a["category_id"] |
| if cat not in COCO_CAT_TO_CONTIG: |
| continue |
| id_to_anns.setdefault(a["image_id"], []).append(a) |
| id_to_info = {img["id"]: img for img in coco["images"]} |
| candidates = [(iid, anns) for iid, anns in id_to_anns.items() if len(anns) >= min_anns] |
| if split == "val": |
| candidates.sort(key=lambda x: -len(x[1])) |
| else: |
| random.seed(42) |
| random.shuffle(candidates) |
| items = [] |
| for iid, anns in candidates[:n]: |
| info = id_to_info[iid] |
| path = os.path.join(img_dir, info["file_name"]) |
| boxes, labels = [], [] |
| for a in anns: |
| x, y, w, h = a["bbox"] |
| if w < 1 or h < 1: |
| continue |
| boxes.append([x, y, x + w, y + h]) |
| labels.append(COCO_CAT_TO_CONTIG[a["category_id"]]) |
| if boxes: |
| items.append({"path": path, "boxes": boxes, "labels": labels, |
| "width": info["width"], "height": info["height"]}) |
| return items |
|
|
|
|
| def cache_features(backbone, items, cache_path, resolution=RESOLUTION): |
| """Cache backbone features for a list of image items.""" |
| if os.path.isfile(cache_path): |
| return |
|
|
| normalize = v2.Compose([ |
| v2.ToImage(), v2.ToDtype(torch.float32, scale=True), |
| v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
| ]) |
|
|
| hooks, intermediates = [], {} |
| for idx in HOOK_BLOCKS: |
| def _hook(block_idx): |
| def fn(mod, inp, out): |
| intermediates[block_idx] = (out[0] if isinstance(out, list) else out).detach() |
| return fn |
| hooks.append(backbone.blocks[idx].register_forward_hook(_hook(idx))) |
|
|
| cached = [] |
| for item in items: |
| img = Image.open(item["path"]).convert("RGB") |
| canvas, scale = letterbox(img, resolution) |
| x = normalize(canvas).unsqueeze(0).cuda() |
|
|
| intermediates.clear() |
| with torch.no_grad(): |
| with torch.autocast("cuda", dtype=torch.bfloat16): |
| out = backbone.forward_features(x) |
| patches = out["x_norm_patchtokens"].float() |
| B, N, D = patches.shape |
| h = w = int(N ** 0.5) |
| spatial = patches[0].permute(1, 0).reshape(D, h, w).half().cpu() |
| inter = [intermediates[idx][0].half().cpu() for idx in HOOK_BLOCKS] |
|
|
| boxes = torch.tensor(item["boxes"], dtype=torch.float32) * scale |
| labels = torch.tensor(item["labels"], dtype=torch.long) |
| cached.append({ |
| "spatial": spatial, "intermediates": inter, |
| "boxes": boxes, "labels": labels, |
| }) |
|
|
| for h in hooks: |
| h.remove() |
|
|
| os.makedirs(os.path.dirname(cache_path), exist_ok=True) |
| torch.save(cached, cache_path) |
|
|