phanerozoic's picture
infer.py: add stage_4c dispatch, pick peak-F1 epoch for student threshold
958c1e0 verified
"""Repo-level inference dispatcher.
Loads the weights of any stage in this repo and returns a callable person
detector with the shape:
score: float (+ = person-scene, − = no person)
present: bool (score > threshold)
Examples:
# Baseline, via the Argus HF repo for the backbone
det = PersonDetector.from_stage('stage_0')
# Tight-FPR variant of the same
det = PersonDetector.from_stage('stage_0_tight_fpr')
# Head-pruned backbone
det = PersonDetector.from_stage('stage_2b')
# Specialist student (no Argus backbone needed)
det = PersonDetector.from_stage('stage_4b')
# Direct-scalar supervision student (same 3.27M as Stage 4)
det = PersonDetector.from_stage('stage_4c')
score, present = det.predict('path/to/image.jpg')
Stage 3 (depth reduction) and Stage 5/5b (circuit-level synthesis) are not
loadable at Python level. Stage 3 is an ablation study; Stages 5/5b are
Verilog.
"""
import json, os, sys, io
from pathlib import Path
from typing import Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
HERE = Path(__file__).parent
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
RES = 768
D = 768
def _norm_input(image: Union[str, Path, Image.Image, np.ndarray, torch.Tensor],
resolution: int = RES) -> torch.Tensor:
if isinstance(image, (str, Path)):
img = Image.open(image).convert('RGB')
elif isinstance(image, Image.Image):
img = image.convert('RGB')
elif isinstance(image, np.ndarray):
img = Image.fromarray(image).convert('RGB')
elif isinstance(image, torch.Tensor):
arr = image.cpu().numpy() if image.ndim == 3 else image[0].cpu().numpy()
if arr.shape[0] == 3:
arr = arr.transpose(1, 2, 0)
img = Image.fromarray((arr * 255).astype('uint8')).convert('RGB')
else:
raise TypeError(f'unsupported image type: {type(image)}')
img = img.resize((resolution, resolution), Image.BILINEAR)
arr = np.asarray(img, dtype=np.uint8).copy()
x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).to(DEVICE).float() / 255.0
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(DEVICE)
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(DEVICE)
return (x - mean) / std
def _load_classifier(path: Path) -> dict:
with open(path) as f:
return json.load(f)
class PersonDetector:
def __init__(self, forward_fn, pos_dims, neg_dims, threshold):
self._forward = forward_fn
self._pos = torch.tensor(pos_dims, dtype=torch.long, device=DEVICE)
self._neg = torch.tensor(neg_dims, dtype=torch.long, device=DEVICE)
self._thr = float(threshold)
@torch.inference_mode()
def predict(self, image) -> Tuple[float, bool]:
x = _norm_input(image)
pooled = self._forward(x) # (D,) float
score = (pooled[self._pos].sum() - pooled[self._neg].sum()).item()
return float(score), bool(score > self._thr)
@classmethod
def from_stage(cls, stage: str, argus_repo: str = 'phanerozoic/argus',
repo_local: Union[str, Path, None] = None):
"""Load one of the stages by name.
stage ∈ {
'stage_0', 'stage_0_tight_fpr', 'stage_1',
'stage_2a', # heads masked
'stage_2b', # backbone structurally pruned
'stage_4', # 3.27M student, per-dim MSE
'stage_4b', # 15.67M student, cosine on 768-D
'stage_4c', # 3.27M student, scalar-MSE
}
argus_repo: HF repo for the EUPE-ViT-B backbone. Used by stage_0,
stage_0_tight_fpr, stage_1, stage_2a. Stage 2b bundles its own
pruned backbone. Stages 4, 4b, 4c don't use Argus.
repo_local: local path to this repo (contains stage_*/ directories).
Defaults to the directory containing this file.
"""
root = Path(repo_local) if repo_local else HERE
if stage in ('stage_0', 'stage_1'):
return cls._build_argus_variant(root / 'stage_0' / 'classifier.json', argus_repo)
if stage == 'stage_0_tight_fpr':
return cls._build_argus_variant(root / 'stage_0_tight_fpr' / 'classifier.json', argus_repo)
if stage in ('stage_2a', 'stage_2'):
return cls._build_stage2a(root, argus_repo)
if stage == 'stage_2b':
return cls._build_stage2b(root)
if stage == 'stage_4':
return cls._build_stage4(root, root / 'stage_4' / 'student_final.safetensors',
student_out_dim=40, student_dim=192, student_depth=6, heads=3)
if stage == 'stage_4b':
return cls._build_stage4(root, root / 'stage_4b' / 'student_final.safetensors',
student_out_dim=768, student_dim=384, student_depth=8, heads=6)
if stage == 'stage_4c':
return cls._build_stage4(root, root / 'stage_4c' / 'student_final.safetensors',
student_out_dim=40, student_dim=192, student_depth=6, heads=3)
raise ValueError(f'unknown stage: {stage}')
# --------------- stage-specific builders ---------------
@classmethod
def _build_argus_variant(cls, classifier_json, argus_repo):
from transformers import AutoModel
model = AutoModel.from_pretrained(argus_repo, trust_remote_code=True).to(DEVICE).eval()
c = _load_classifier(classifier_json)
def fwd(x):
with torch.autocast('cuda' if DEVICE == 'cuda' else 'cpu', dtype=torch.bfloat16):
out = model.backbone.forward_features(x)
patches = out['x_norm_patchtokens'].float().squeeze(0)
ln = F.layer_norm(patches, [D])
return ln.max(dim=0).values
return cls(fwd, c['pos_dims'], c['neg_dims'], c['threshold'])
@classmethod
def _build_stage2a(cls, root, argus_repo):
from transformers import AutoModel
model = AutoModel.from_pretrained(argus_repo, trust_remote_code=True).to(DEVICE).eval()
c = _load_classifier(root / 'stage_0' / 'classifier.json')
# Apply head mask from stage_2 head_importance.json (top 10 most prunable)
with open(root / 'stage_2' / 'head_importance.json') as f:
imp = json.load(f)
HEAD_DIM = 64
with torch.no_grad():
for (b, h, _drop) in imp['ranked_most_prunable_first'][:10]:
model.backbone.blocks[b].attn.proj.weight.data[:, h*HEAD_DIM:(h+1)*HEAD_DIM] = 0.0
def fwd(x):
with torch.autocast('cuda' if DEVICE == 'cuda' else 'cpu', dtype=torch.bfloat16):
out = model.backbone.forward_features(x)
patches = out['x_norm_patchtokens'].float().squeeze(0)
ln = F.layer_norm(patches, [D])
return ln.max(dim=0).values
return cls(fwd, c['pos_dims'], c['neg_dims'], c['threshold'])
@classmethod
def _build_stage2b(cls, root):
sys.path.insert(0, str(root / 'stage_2b'))
from load_pruned_backbone import load_stage2b_backbone
backbone = load_stage2b_backbone(
str(root / 'stage_2b' / 'pruned_state_dict.safetensors'),
str(root / 'stage_2b' / 'head_config.json'),
).to(DEVICE).eval()
c = _load_classifier(root / 'stage_0' / 'classifier.json')
def fwd(x):
with torch.autocast('cuda' if DEVICE == 'cuda' else 'cpu', dtype=torch.bfloat16):
out = backbone.forward_features(x)
patches = out['x_norm_patchtokens'].float().squeeze(0)
ln = F.layer_norm(patches, [D])
return ln.max(dim=0).values
return cls(fwd, c['pos_dims'], c['neg_dims'], c['threshold'])
@classmethod
def _build_stage4(cls, root, weights_path, student_out_dim, student_dim, student_depth, heads):
from safetensors.torch import load_file
class _Block(nn.Module):
def __init__(self, dim, h, ratio=4.0):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, h, batch_first=True)
self.norm2 = nn.LayerNorm(dim)
hidden = int(dim * ratio)
self.mlp = nn.Sequential(nn.Linear(dim, hidden), nn.GELU(), nn.Linear(hidden, dim))
def forward(self, x):
h_, _ = self.attn(self.norm1(x), self.norm1(x), self.norm1(x), need_weights=False)
x = x + h_
return x + self.mlp(self.norm2(x))
class _Student(nn.Module):
def __init__(self, out_dim, dim, depth, h, patch=16, img=RES):
super().__init__()
self.patch = nn.Conv2d(3, dim, patch, stride=patch)
self.pos = nn.Parameter(torch.zeros(1, (img // patch) ** 2, dim))
self.blocks = nn.ModuleList([_Block(dim, h) for _ in range(depth)])
self.norm = nn.LayerNorm(dim)
self.head = nn.Linear(dim, out_dim)
def forward(self, x):
t = self.patch(x).flatten(2).transpose(1, 2)
t = t + self.pos[:, :t.shape[1]]
for blk in self.blocks:
t = blk(t)
t = self.norm(t)
return self.head(t.max(dim=1).values)
student = _Student(student_out_dim, student_dim, student_depth, heads).to(DEVICE).eval()
student.load_state_dict(load_file(str(weights_path)))
# Classifier indexing depends on student output layout:
# - stage_4 / stage_4c: student emits the 40 classifier-relevant dims
# directly (pos at [0:20], neg at [20:40]).
# - stage_4b: student emits a 768-D vector matching teacher layout;
# use Stage 0's pos/neg dims directly.
if student_out_dim == 40:
pos, neg = list(range(20)), list(range(20, 40))
else:
c = _load_classifier(root / 'stage_0' / 'classifier.json')
pos, neg = c['pos_dims'], c['neg_dims']
# student_final.safetensors is the peak-F1 epoch (ep3 for Stage 4,
# ep10 for Stage 4b/4c). Pull that epoch's threshold, not the last.
with open(Path(weights_path).parent / 'training_log.json') as f:
log = json.load(f)
best = max(log['epochs'], key=lambda e: e.get('F1', 0.0))
thr = best.get('threshold', 0.0)
def fwd(x):
with torch.autocast('cuda' if DEVICE == 'cuda' else 'cpu', dtype=torch.bfloat16):
out = student(x)
return out.float().squeeze(0)
return cls(fwd, pos, neg, thr)
if __name__ == '__main__':
if len(sys.argv) < 3:
print('usage: python infer.py <stage> <image> [image ...]')
print('stages: stage_0, stage_0_tight_fpr, stage_1, stage_2a, stage_2b, '
'stage_4, stage_4b, stage_4c')
sys.exit(1)
stage = sys.argv[1]
det = PersonDetector.from_stage(stage)
for path in sys.argv[2:]:
score, present = det.predict(path)
print(f'{path} score={score:+.3f} present={present}')