"""Verify Stage 1 matches Stage 0 F1 exactly on COCO val 2017.""" import os, sys, json, time import torch import numpy as np from PIL import Image from pycocotools.coco import COCO sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from model import Stage1PersonClassifier from transformers import AutoModel COCO_ROOT = '/home/zootest/datasets/coco' CLASSIFIER_JSON = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'stage_0', 'classifier.json') RES = 768 def main(): print('[init] loading Argus + classifier', flush=True) argus = AutoModel.from_pretrained('/mnt/d/Argus', trust_remote_code=True) c = json.load(open(CLASSIFIER_JSON)) m = Stage1PersonClassifier(argus, c['pos_dims'], c['neg_dims'], c['threshold']).cuda().eval() coco = COCO(f'{COCO_ROOT}/annotations/instances_val2017.json') img_ids = sorted(coco.getImgIds()) labels = [any(a['category_id'] == 1 for a in coco.loadAnns(coco.getAnnIds(imgIds=i, iscrowd=False))) for i in img_ids] MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda() STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda() scores = torch.zeros(len(img_ids), device='cuda') t0 = time.time() for i, img_id in enumerate(img_ids): info = coco.loadImgs(img_id)[0] path = f"{COCO_ROOT}/val2017/{info['file_name']}" img = Image.open(path).convert('RGB').resize((RES, RES), Image.BILINEAR) arr = np.asarray(img, dtype=np.uint8).copy() x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).cuda().float() / 255.0 x = (x - MEAN) / STD s, _ = m(x) scores[i] = s[0] if (i + 1) % 500 == 0: print(f' {i+1}/{len(img_ids)} {(i+1)/(time.time()-t0):.1f} img/s', flush=True) y = torch.tensor(labels, dtype=torch.bool, device='cuda') pred = scores > m.threshold tp = (pred & y).sum().float() fp = (pred & ~y).sum().float() fn = (~pred & y).sum().float() prec = tp / (tp + fp).clamp(min=1) rec = tp / (tp + fn).clamp(min=1) f1 = 2 * prec * rec / (prec + rec).clamp(min=1e-9) print(f'\n[verify] F1={f1:.4f} P={prec:.4f} R={rec:.4f}', flush=True) print(f'[parity] Stage 0 baseline = 0.8886; Stage 1 should be identical', flush=True) out = { 'stage': 1, 'F1': float(f1), 'precision': float(prec), 'recall': float(rec), 'threshold': float(m.threshold.item()), 'stage_0_baseline_F1': 0.8886, 'match': bool(abs(float(f1) - 0.8886) < 1e-3), } with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'eval.json'), 'w') as f: json.dump(out, f, indent=2) print(f'[done] wrote eval.json match={out["match"]}', flush=True) if __name__ == '__main__': main()