phanerozoic's picture
Initialize repo with Stage 0 baseline
81d1bef verified
"""Reference inference for the Stage 0 baseline.
Loads Argus (EUPE-ViT-B backbone), reads the classifier config, and scores one
or more images. Prints the raw score and the binary decision.
Usage: python infer.py image1.jpg [image2.jpg ...]
"""
import json, sys, os
import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
from transformers import AutoModel
def load_classifier(path='classifier.json'):
with open(path) as f:
return json.load(f)
def load_argus(repo_or_path='phanerozoic/argus'):
return AutoModel.from_pretrained(repo_or_path, trust_remote_code=True)
def preprocess(image_path, resolution=768, device='cuda'):
img = Image.open(image_path).convert('RGB').resize((resolution, resolution), Image.BILINEAR)
arr = np.asarray(img, dtype=np.uint8).copy()
x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).to(device).float() / 255.0
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
return (x - mean) / std
@torch.inference_mode()
def score(model, x, classifier):
with torch.autocast('cuda', dtype=torch.bfloat16):
out = model.backbone.forward_features(x)
patches = out['x_norm_patchtokens'].float().squeeze(0)
D = classifier['feature_dim']
ln = F.layer_norm(patches, [D])
pooled = ln.max(dim=0).values
pos = pooled[classifier['pos_dims']].sum()
neg = pooled[classifier['neg_dims']].sum()
return float((pos - neg).item())
def main():
if len(sys.argv) < 2:
print('usage: python infer.py <image1> [image2 ...]')
sys.exit(1)
here = os.path.dirname(os.path.abspath(__file__))
classifier = load_classifier(os.path.join(here, 'classifier.json'))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = load_argus().to(device).eval()
thr = classifier['threshold']
for image_path in sys.argv[1:]:
x = preprocess(image_path, classifier['input_resolution'], device)
s = score(model, x, classifier)
print(f'{image_path} score={s:+.3f} threshold={thr:+.3f} person={s > thr}')
if __name__ == '__main__':
main()