| """Build the per-image 100-D teacher target tensor from the existing ViT-B |
| COCO train feature cache. One-time, ~5 min. |
| |
| For each image in /home/zootest/datasets/coco/feature_cache_768 (the 419 GB |
| train cache of EUPE-ViT-B at 768 px), apply LayerNorm across the 768 channel |
| axis, max-pool across 2304 patches, and select the 100 classifier-relevant |
| dims (pos + neg). Save as a flat (N, 100) float16 tensor + index list. |
| |
| Output: /home/zootest/datasets/coco/stage4_teacher_targets/targets.pt |
| containing {'targets': (N, 100) float16, 'img_ids': list, 'dims': (100,) long} |
| """ |
| import os, glob, json, time |
| import torch |
| import torch.nn.functional as F |
|
|
| COCO_ROOT = '/home/zootest/datasets/coco' |
| CACHE = f'{COCO_ROOT}/feature_cache_768' |
| CLASSIFIER = '/mnt/d/_tmp/1pc_repo/stage_0/classifier.json' |
| OUT_DIR = f'{COCO_ROOT}/stage4_teacher_targets' |
| D = 768 |
|
|
|
|
| def main(): |
| os.makedirs(OUT_DIR, exist_ok=True) |
| with open(CLASSIFIER) as f: |
| c = json.load(f) |
| dims = torch.tensor(c['pos_dims'] + c['neg_dims'], dtype=torch.long) |
| print(f'[init] using {len(dims)} target dims', flush=True) |
|
|
| shards = sorted(glob.glob(f'{CACHE}/shard_*.pt')) |
| print(f'[init] {len(shards)} teacher shards to process', flush=True) |
|
|
| all_targets = [] |
| all_img_ids = [] |
| t0 = time.time() |
| for si, spath in enumerate(shards): |
| shard = torch.load(spath, map_location='cpu', weights_only=False) |
| for entry in shard: |
| sp = entry['spatial'].float() |
| ln = F.layer_norm(sp.permute(1, 2, 0).reshape(-1, D), [D]) |
| pooled = ln.max(dim=0).values |
| target = pooled[dims].half() |
| all_targets.append(target) |
| |
| img_id = entry.get('img_id', len(all_img_ids)) |
| all_img_ids.append(int(img_id) if isinstance(img_id, (int, float)) else len(all_img_ids)) |
| elapsed = time.time() - t0 |
| rate = (si + 1) / max(elapsed, 1e-6) |
| eta = (len(shards) - si - 1) / max(rate, 1e-6) |
| print(f' shard {si+1}/{len(shards)} cumulative n={len(all_targets)} ' |
| f'{elapsed:.0f}s ETA {eta:.1f}s', flush=True) |
|
|
| targets = torch.stack(all_targets) |
| torch.save({ |
| 'targets': targets, |
| 'img_ids': all_img_ids, |
| 'dims': dims, |
| }, f'{OUT_DIR}/targets.pt') |
| print(f'[done] {targets.shape[0]} targets shape {tuple(targets.shape)} ' |
| f'-> {OUT_DIR}/targets.pt', flush=True) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|