"""Build the per-image 100-D teacher target tensor from the existing ViT-B COCO train feature cache. One-time, ~5 min. For each image in /home/zootest/datasets/coco/feature_cache_768 (the 419 GB train cache of EUPE-ViT-B at 768 px), apply LayerNorm across the 768 channel axis, max-pool across 2304 patches, and select the 100 classifier-relevant dims (pos + neg). Save as a flat (N, 100) float16 tensor + index list. Output: /home/zootest/datasets/coco/stage4_teacher_targets/targets.pt containing {'targets': (N, 100) float16, 'img_ids': list, 'dims': (100,) long} """ import os, glob, json, time import torch import torch.nn.functional as F COCO_ROOT = '/home/zootest/datasets/coco' CACHE = f'{COCO_ROOT}/feature_cache_768' CLASSIFIER = '/mnt/d/_tmp/1pc_repo/stage_0/classifier.json' OUT_DIR = f'{COCO_ROOT}/stage4_teacher_targets' D = 768 def main(): os.makedirs(OUT_DIR, exist_ok=True) with open(CLASSIFIER) as f: c = json.load(f) dims = torch.tensor(c['pos_dims'] + c['neg_dims'], dtype=torch.long) print(f'[init] using {len(dims)} target dims', flush=True) shards = sorted(glob.glob(f'{CACHE}/shard_*.pt')) print(f'[init] {len(shards)} teacher shards to process', flush=True) all_targets = [] all_img_ids = [] t0 = time.time() for si, spath in enumerate(shards): shard = torch.load(spath, map_location='cpu', weights_only=False) for entry in shard: sp = entry['spatial'].float() # (768, 48, 48) ln = F.layer_norm(sp.permute(1, 2, 0).reshape(-1, D), [D]) pooled = ln.max(dim=0).values # (768,) target = pooled[dims].half() # (100,) all_targets.append(target) # Some caches may not have img_id; use an ordinal if missing img_id = entry.get('img_id', len(all_img_ids)) all_img_ids.append(int(img_id) if isinstance(img_id, (int, float)) else len(all_img_ids)) elapsed = time.time() - t0 rate = (si + 1) / max(elapsed, 1e-6) eta = (len(shards) - si - 1) / max(rate, 1e-6) print(f' shard {si+1}/{len(shards)} cumulative n={len(all_targets)} ' f'{elapsed:.0f}s ETA {eta:.1f}s', flush=True) targets = torch.stack(all_targets) torch.save({ 'targets': targets, 'img_ids': all_img_ids, 'dims': dims, }, f'{OUT_DIR}/targets.pt') print(f'[done] {targets.shape[0]} targets shape {tuple(targets.shape)} ' f'-> {OUT_DIR}/targets.pt', flush=True) if __name__ == '__main__': main()