1-parameter-classifier / stage_4 /prepare_targets.py
phanerozoic's picture
Stage 4: specialist student (3.27M params, F1 0.710 vs 0.894 baseline)
864ba61 verified
"""Build the per-image 100-D teacher target tensor from the existing ViT-B
COCO train feature cache. One-time, ~5 min.
For each image in /home/zootest/datasets/coco/feature_cache_768 (the 419 GB
train cache of EUPE-ViT-B at 768 px), apply LayerNorm across the 768 channel
axis, max-pool across 2304 patches, and select the 100 classifier-relevant
dims (pos + neg). Save as a flat (N, 100) float16 tensor + index list.
Output: /home/zootest/datasets/coco/stage4_teacher_targets/targets.pt
containing {'targets': (N, 100) float16, 'img_ids': list, 'dims': (100,) long}
"""
import os, glob, json, time
import torch
import torch.nn.functional as F
COCO_ROOT = '/home/zootest/datasets/coco'
CACHE = f'{COCO_ROOT}/feature_cache_768'
CLASSIFIER = '/mnt/d/_tmp/1pc_repo/stage_0/classifier.json'
OUT_DIR = f'{COCO_ROOT}/stage4_teacher_targets'
D = 768
def main():
os.makedirs(OUT_DIR, exist_ok=True)
with open(CLASSIFIER) as f:
c = json.load(f)
dims = torch.tensor(c['pos_dims'] + c['neg_dims'], dtype=torch.long)
print(f'[init] using {len(dims)} target dims', flush=True)
shards = sorted(glob.glob(f'{CACHE}/shard_*.pt'))
print(f'[init] {len(shards)} teacher shards to process', flush=True)
all_targets = []
all_img_ids = []
t0 = time.time()
for si, spath in enumerate(shards):
shard = torch.load(spath, map_location='cpu', weights_only=False)
for entry in shard:
sp = entry['spatial'].float() # (768, 48, 48)
ln = F.layer_norm(sp.permute(1, 2, 0).reshape(-1, D), [D])
pooled = ln.max(dim=0).values # (768,)
target = pooled[dims].half() # (100,)
all_targets.append(target)
# Some caches may not have img_id; use an ordinal if missing
img_id = entry.get('img_id', len(all_img_ids))
all_img_ids.append(int(img_id) if isinstance(img_id, (int, float)) else len(all_img_ids))
elapsed = time.time() - t0
rate = (si + 1) / max(elapsed, 1e-6)
eta = (len(shards) - si - 1) / max(rate, 1e-6)
print(f' shard {si+1}/{len(shards)} cumulative n={len(all_targets)} '
f'{elapsed:.0f}s ETA {eta:.1f}s', flush=True)
targets = torch.stack(all_targets)
torch.save({
'targets': targets,
'img_ids': all_img_ids,
'dims': dims,
}, f'{OUT_DIR}/targets.pt')
print(f'[done] {targets.shape[0]} targets shape {tuple(targets.shape)} '
f'-> {OUT_DIR}/targets.pt', flush=True)
if __name__ == '__main__':
main()