phanerozoic's picture
Stage 1: output-channel pruning (Path A, exact parity with Stage 0)
d9418d2 verified
"""Verify Stage 1 matches Stage 0 F1 exactly on COCO val 2017."""
import os, sys, json, time
import torch
import numpy as np
from PIL import Image
from pycocotools.coco import COCO
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from model import Stage1PersonClassifier
from transformers import AutoModel
COCO_ROOT = '/home/zootest/datasets/coco'
CLASSIFIER_JSON = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'..', 'stage_0', 'classifier.json')
RES = 768
def main():
print('[init] loading Argus + classifier', flush=True)
argus = AutoModel.from_pretrained('/mnt/d/Argus', trust_remote_code=True)
c = json.load(open(CLASSIFIER_JSON))
m = Stage1PersonClassifier(argus, c['pos_dims'], c['neg_dims'], c['threshold']).cuda().eval()
coco = COCO(f'{COCO_ROOT}/annotations/instances_val2017.json')
img_ids = sorted(coco.getImgIds())
labels = [any(a['category_id'] == 1 for a in coco.loadAnns(coco.getAnnIds(imgIds=i, iscrowd=False)))
for i in img_ids]
MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda()
STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda()
scores = torch.zeros(len(img_ids), device='cuda')
t0 = time.time()
for i, img_id in enumerate(img_ids):
info = coco.loadImgs(img_id)[0]
path = f"{COCO_ROOT}/val2017/{info['file_name']}"
img = Image.open(path).convert('RGB').resize((RES, RES), Image.BILINEAR)
arr = np.asarray(img, dtype=np.uint8).copy()
x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).cuda().float() / 255.0
x = (x - MEAN) / STD
s, _ = m(x)
scores[i] = s[0]
if (i + 1) % 500 == 0:
print(f' {i+1}/{len(img_ids)} {(i+1)/(time.time()-t0):.1f} img/s', flush=True)
y = torch.tensor(labels, dtype=torch.bool, device='cuda')
pred = scores > m.threshold
tp = (pred & y).sum().float()
fp = (pred & ~y).sum().float()
fn = (~pred & y).sum().float()
prec = tp / (tp + fp).clamp(min=1)
rec = tp / (tp + fn).clamp(min=1)
f1 = 2 * prec * rec / (prec + rec).clamp(min=1e-9)
print(f'\n[verify] F1={f1:.4f} P={prec:.4f} R={rec:.4f}', flush=True)
print(f'[parity] Stage 0 baseline = 0.8886; Stage 1 should be identical', flush=True)
out = {
'stage': 1,
'F1': float(f1),
'precision': float(prec),
'recall': float(rec),
'threshold': float(m.threshold.item()),
'stage_0_baseline_F1': 0.8886,
'match': bool(abs(float(f1) - 0.8886) < 1e-3),
}
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'eval.json'), 'w') as f:
json.dump(out, f, indent=2)
print(f'[done] wrote eval.json match={out["match"]}', flush=True)
if __name__ == '__main__':
main()