| """Verify Stage 1 matches Stage 0 F1 exactly on COCO val 2017.""" |
| import os, sys, json, time |
| import torch |
| import numpy as np |
| from PIL import Image |
| from pycocotools.coco import COCO |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from model import Stage1PersonClassifier |
| from transformers import AutoModel |
|
|
| COCO_ROOT = '/home/zootest/datasets/coco' |
| CLASSIFIER_JSON = os.path.join(os.path.dirname(os.path.abspath(__file__)), |
| '..', 'stage_0', 'classifier.json') |
| RES = 768 |
|
|
|
|
| def main(): |
| print('[init] loading Argus + classifier', flush=True) |
| argus = AutoModel.from_pretrained('/mnt/d/Argus', trust_remote_code=True) |
| c = json.load(open(CLASSIFIER_JSON)) |
| m = Stage1PersonClassifier(argus, c['pos_dims'], c['neg_dims'], c['threshold']).cuda().eval() |
|
|
| coco = COCO(f'{COCO_ROOT}/annotations/instances_val2017.json') |
| img_ids = sorted(coco.getImgIds()) |
| labels = [any(a['category_id'] == 1 for a in coco.loadAnns(coco.getAnnIds(imgIds=i, iscrowd=False))) |
| for i in img_ids] |
|
|
| MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda() |
| STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda() |
|
|
| scores = torch.zeros(len(img_ids), device='cuda') |
| t0 = time.time() |
| for i, img_id in enumerate(img_ids): |
| info = coco.loadImgs(img_id)[0] |
| path = f"{COCO_ROOT}/val2017/{info['file_name']}" |
| img = Image.open(path).convert('RGB').resize((RES, RES), Image.BILINEAR) |
| arr = np.asarray(img, dtype=np.uint8).copy() |
| x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).cuda().float() / 255.0 |
| x = (x - MEAN) / STD |
| s, _ = m(x) |
| scores[i] = s[0] |
| if (i + 1) % 500 == 0: |
| print(f' {i+1}/{len(img_ids)} {(i+1)/(time.time()-t0):.1f} img/s', flush=True) |
|
|
| y = torch.tensor(labels, dtype=torch.bool, device='cuda') |
| pred = scores > m.threshold |
| tp = (pred & y).sum().float() |
| fp = (pred & ~y).sum().float() |
| fn = (~pred & y).sum().float() |
| prec = tp / (tp + fp).clamp(min=1) |
| rec = tp / (tp + fn).clamp(min=1) |
| f1 = 2 * prec * rec / (prec + rec).clamp(min=1e-9) |
| print(f'\n[verify] F1={f1:.4f} P={prec:.4f} R={rec:.4f}', flush=True) |
| print(f'[parity] Stage 0 baseline = 0.8886; Stage 1 should be identical', flush=True) |
|
|
| out = { |
| 'stage': 1, |
| 'F1': float(f1), |
| 'precision': float(prec), |
| 'recall': float(rec), |
| 'threshold': float(m.threshold.item()), |
| 'stage_0_baseline_F1': 0.8886, |
| 'match': bool(abs(float(f1) - 0.8886) < 1e-3), |
| } |
| with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'eval.json'), 'w') as f: |
| json.dump(out, f, indent=2) |
| print(f'[done] wrote eval.json match={out["match"]}', flush=True) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|