"""Repo-level inference dispatcher. Loads the weights of any stage in this repo and returns a callable person detector with the shape: score: float (+ = person-scene, − = no person) present: bool (score > threshold) Examples: # Baseline, via the Argus HF repo for the backbone det = PersonDetector.from_stage('stage_0') # Tight-FPR variant of the same det = PersonDetector.from_stage('stage_0_tight_fpr') # Head-pruned backbone det = PersonDetector.from_stage('stage_2b') # Specialist student (no Argus backbone needed) det = PersonDetector.from_stage('stage_4b') # Direct-scalar supervision student (same 3.27M as Stage 4) det = PersonDetector.from_stage('stage_4c') score, present = det.predict('path/to/image.jpg') Stage 3 (depth reduction) and Stage 5/5b (circuit-level synthesis) are not loadable at Python level. Stage 3 is an ablation study; Stages 5/5b are Verilog. """ import json, os, sys, io from pathlib import Path from typing import Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image HERE = Path(__file__).parent DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' RES = 768 D = 768 def _norm_input(image: Union[str, Path, Image.Image, np.ndarray, torch.Tensor], resolution: int = RES) -> torch.Tensor: if isinstance(image, (str, Path)): img = Image.open(image).convert('RGB') elif isinstance(image, Image.Image): img = image.convert('RGB') elif isinstance(image, np.ndarray): img = Image.fromarray(image).convert('RGB') elif isinstance(image, torch.Tensor): arr = image.cpu().numpy() if image.ndim == 3 else image[0].cpu().numpy() if arr.shape[0] == 3: arr = arr.transpose(1, 2, 0) img = Image.fromarray((arr * 255).astype('uint8')).convert('RGB') else: raise TypeError(f'unsupported image type: {type(image)}') img = img.resize((resolution, resolution), Image.BILINEAR) arr = np.asarray(img, dtype=np.uint8).copy() x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).to(DEVICE).float() / 255.0 mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(DEVICE) std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(DEVICE) return (x - mean) / std def _load_classifier(path: Path) -> dict: with open(path) as f: return json.load(f) class PersonDetector: def __init__(self, forward_fn, pos_dims, neg_dims, threshold): self._forward = forward_fn self._pos = torch.tensor(pos_dims, dtype=torch.long, device=DEVICE) self._neg = torch.tensor(neg_dims, dtype=torch.long, device=DEVICE) self._thr = float(threshold) @torch.inference_mode() def predict(self, image) -> Tuple[float, bool]: x = _norm_input(image) pooled = self._forward(x) # (D,) float score = (pooled[self._pos].sum() - pooled[self._neg].sum()).item() return float(score), bool(score > self._thr) @classmethod def from_stage(cls, stage: str, argus_repo: str = 'phanerozoic/argus', repo_local: Union[str, Path, None] = None): """Load one of the stages by name. stage ∈ { 'stage_0', 'stage_0_tight_fpr', 'stage_1', 'stage_2a', # heads masked 'stage_2b', # backbone structurally pruned 'stage_4', # 3.27M student, per-dim MSE 'stage_4b', # 15.67M student, cosine on 768-D 'stage_4c', # 3.27M student, scalar-MSE } argus_repo: HF repo for the EUPE-ViT-B backbone. Used by stage_0, stage_0_tight_fpr, stage_1, stage_2a. Stage 2b bundles its own pruned backbone. Stages 4, 4b, 4c don't use Argus. repo_local: local path to this repo (contains stage_*/ directories). Defaults to the directory containing this file. """ root = Path(repo_local) if repo_local else HERE if stage in ('stage_0', 'stage_1'): return cls._build_argus_variant(root / 'stage_0' / 'classifier.json', argus_repo) if stage == 'stage_0_tight_fpr': return cls._build_argus_variant(root / 'stage_0_tight_fpr' / 'classifier.json', argus_repo) if stage in ('stage_2a', 'stage_2'): return cls._build_stage2a(root, argus_repo) if stage == 'stage_2b': return cls._build_stage2b(root) if stage == 'stage_4': return cls._build_stage4(root, root / 'stage_4' / 'student_final.safetensors', student_out_dim=40, student_dim=192, student_depth=6, heads=3) if stage == 'stage_4b': return cls._build_stage4(root, root / 'stage_4b' / 'student_final.safetensors', student_out_dim=768, student_dim=384, student_depth=8, heads=6) if stage == 'stage_4c': return cls._build_stage4(root, root / 'stage_4c' / 'student_final.safetensors', student_out_dim=40, student_dim=192, student_depth=6, heads=3) raise ValueError(f'unknown stage: {stage}') # --------------- stage-specific builders --------------- @classmethod def _build_argus_variant(cls, classifier_json, argus_repo): from transformers import AutoModel model = AutoModel.from_pretrained(argus_repo, trust_remote_code=True).to(DEVICE).eval() c = _load_classifier(classifier_json) def fwd(x): with torch.autocast('cuda' if DEVICE == 'cuda' else 'cpu', dtype=torch.bfloat16): out = model.backbone.forward_features(x) patches = out['x_norm_patchtokens'].float().squeeze(0) ln = F.layer_norm(patches, [D]) return ln.max(dim=0).values return cls(fwd, c['pos_dims'], c['neg_dims'], c['threshold']) @classmethod def _build_stage2a(cls, root, argus_repo): from transformers import AutoModel model = AutoModel.from_pretrained(argus_repo, trust_remote_code=True).to(DEVICE).eval() c = _load_classifier(root / 'stage_0' / 'classifier.json') # Apply head mask from stage_2 head_importance.json (top 10 most prunable) with open(root / 'stage_2' / 'head_importance.json') as f: imp = json.load(f) HEAD_DIM = 64 with torch.no_grad(): for (b, h, _drop) in imp['ranked_most_prunable_first'][:10]: model.backbone.blocks[b].attn.proj.weight.data[:, h*HEAD_DIM:(h+1)*HEAD_DIM] = 0.0 def fwd(x): with torch.autocast('cuda' if DEVICE == 'cuda' else 'cpu', dtype=torch.bfloat16): out = model.backbone.forward_features(x) patches = out['x_norm_patchtokens'].float().squeeze(0) ln = F.layer_norm(patches, [D]) return ln.max(dim=0).values return cls(fwd, c['pos_dims'], c['neg_dims'], c['threshold']) @classmethod def _build_stage2b(cls, root): sys.path.insert(0, str(root / 'stage_2b')) from load_pruned_backbone import load_stage2b_backbone backbone = load_stage2b_backbone( str(root / 'stage_2b' / 'pruned_state_dict.safetensors'), str(root / 'stage_2b' / 'head_config.json'), ).to(DEVICE).eval() c = _load_classifier(root / 'stage_0' / 'classifier.json') def fwd(x): with torch.autocast('cuda' if DEVICE == 'cuda' else 'cpu', dtype=torch.bfloat16): out = backbone.forward_features(x) patches = out['x_norm_patchtokens'].float().squeeze(0) ln = F.layer_norm(patches, [D]) return ln.max(dim=0).values return cls(fwd, c['pos_dims'], c['neg_dims'], c['threshold']) @classmethod def _build_stage4(cls, root, weights_path, student_out_dim, student_dim, student_depth, heads): from safetensors.torch import load_file class _Block(nn.Module): def __init__(self, dim, h, ratio=4.0): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = nn.MultiheadAttention(dim, h, batch_first=True) self.norm2 = nn.LayerNorm(dim) hidden = int(dim * ratio) self.mlp = nn.Sequential(nn.Linear(dim, hidden), nn.GELU(), nn.Linear(hidden, dim)) def forward(self, x): h_, _ = self.attn(self.norm1(x), self.norm1(x), self.norm1(x), need_weights=False) x = x + h_ return x + self.mlp(self.norm2(x)) class _Student(nn.Module): def __init__(self, out_dim, dim, depth, h, patch=16, img=RES): super().__init__() self.patch = nn.Conv2d(3, dim, patch, stride=patch) self.pos = nn.Parameter(torch.zeros(1, (img // patch) ** 2, dim)) self.blocks = nn.ModuleList([_Block(dim, h) for _ in range(depth)]) self.norm = nn.LayerNorm(dim) self.head = nn.Linear(dim, out_dim) def forward(self, x): t = self.patch(x).flatten(2).transpose(1, 2) t = t + self.pos[:, :t.shape[1]] for blk in self.blocks: t = blk(t) t = self.norm(t) return self.head(t.max(dim=1).values) student = _Student(student_out_dim, student_dim, student_depth, heads).to(DEVICE).eval() student.load_state_dict(load_file(str(weights_path))) # Classifier indexing depends on student output layout: # - stage_4 / stage_4c: student emits the 40 classifier-relevant dims # directly (pos at [0:20], neg at [20:40]). # - stage_4b: student emits a 768-D vector matching teacher layout; # use Stage 0's pos/neg dims directly. if student_out_dim == 40: pos, neg = list(range(20)), list(range(20, 40)) else: c = _load_classifier(root / 'stage_0' / 'classifier.json') pos, neg = c['pos_dims'], c['neg_dims'] # student_final.safetensors is the peak-F1 epoch (ep3 for Stage 4, # ep10 for Stage 4b/4c). Pull that epoch's threshold, not the last. with open(Path(weights_path).parent / 'training_log.json') as f: log = json.load(f) best = max(log['epochs'], key=lambda e: e.get('F1', 0.0)) thr = best.get('threshold', 0.0) def fwd(x): with torch.autocast('cuda' if DEVICE == 'cuda' else 'cpu', dtype=torch.bfloat16): out = student(x) return out.float().squeeze(0) return cls(fwd, pos, neg, thr) if __name__ == '__main__': if len(sys.argv) < 3: print('usage: python infer.py [image ...]') print('stages: stage_0, stage_0_tight_fpr, stage_1, stage_2a, stage_2b, ' 'stage_4, stage_4b, stage_4c') sys.exit(1) stage = sys.argv[1] det = PersonDetector.from_stage(stage) for path in sys.argv[2:]: score, present = det.predict(path) print(f'{path} score={score:+.3f} present={present}')