phanerozoic commited on
Commit
55f3b2e
·
verified ·
1 Parent(s): 1564941

Add root infer.py dispatcher for all stages (Stage 0/1/2a/2b/4/4b)

Browse files
Files changed (1) hide show
  1. infer.py +255 -0
infer.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Repo-level inference dispatcher.
2
+
3
+ Loads the weights of any stage in this repo and returns a callable person
4
+ detector with the shape:
5
+
6
+ score: float (+ = person-scene, − = no person)
7
+ present: bool (score > threshold)
8
+
9
+ Examples:
10
+
11
+ # Baseline, via the Argus HF repo for the backbone
12
+ det = PersonDetector.from_stage('stage_0')
13
+
14
+ # Tight-FPR variant of the same
15
+ det = PersonDetector.from_stage('stage_0_tight_fpr')
16
+
17
+ # Head-pruned backbone
18
+ det = PersonDetector.from_stage('stage_2b')
19
+
20
+ # Specialist student (no Argus backbone needed)
21
+ det = PersonDetector.from_stage('stage_4b')
22
+
23
+ score, present = det.predict('path/to/image.jpg')
24
+
25
+ Stage 3 (depth reduction) and Stage 5/5b (circuit-level synthesis) are not
26
+ loadable at Python level — Stage 3 is an ablation study, Stages 5/5b are
27
+ Verilog.
28
+ """
29
+ import json, os, sys, io
30
+ from pathlib import Path
31
+ from typing import Tuple, Union
32
+
33
+ import numpy as np
34
+ import torch
35
+ import torch.nn as nn
36
+ import torch.nn.functional as F
37
+ from PIL import Image
38
+
39
+ HERE = Path(__file__).parent
40
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
41
+ RES = 768
42
+ D = 768
43
+
44
+
45
+ def _norm_input(image: Union[str, Path, Image.Image, np.ndarray, torch.Tensor],
46
+ resolution: int = RES) -> torch.Tensor:
47
+ if isinstance(image, (str, Path)):
48
+ img = Image.open(image).convert('RGB')
49
+ elif isinstance(image, Image.Image):
50
+ img = image.convert('RGB')
51
+ elif isinstance(image, np.ndarray):
52
+ img = Image.fromarray(image).convert('RGB')
53
+ elif isinstance(image, torch.Tensor):
54
+ arr = image.cpu().numpy() if image.ndim == 3 else image[0].cpu().numpy()
55
+ if arr.shape[0] == 3:
56
+ arr = arr.transpose(1, 2, 0)
57
+ img = Image.fromarray((arr * 255).astype('uint8')).convert('RGB')
58
+ else:
59
+ raise TypeError(f'unsupported image type: {type(image)}')
60
+ img = img.resize((resolution, resolution), Image.BILINEAR)
61
+ arr = np.asarray(img, dtype=np.uint8).copy()
62
+ x = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0).to(DEVICE).float() / 255.0
63
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(DEVICE)
64
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(DEVICE)
65
+ return (x - mean) / std
66
+
67
+
68
+ def _load_classifier(path: Path) -> dict:
69
+ with open(path) as f:
70
+ return json.load(f)
71
+
72
+
73
+ class PersonDetector:
74
+ def __init__(self, forward_fn, pos_dims, neg_dims, threshold):
75
+ self._forward = forward_fn
76
+ self._pos = torch.tensor(pos_dims, dtype=torch.long, device=DEVICE)
77
+ self._neg = torch.tensor(neg_dims, dtype=torch.long, device=DEVICE)
78
+ self._thr = float(threshold)
79
+
80
+ @torch.inference_mode()
81
+ def predict(self, image) -> Tuple[float, bool]:
82
+ x = _norm_input(image)
83
+ pooled = self._forward(x) # (D,) float
84
+ score = (pooled[self._pos].sum() - pooled[self._neg].sum()).item()
85
+ return float(score), bool(score > self._thr)
86
+
87
+ @classmethod
88
+ def from_stage(cls, stage: str, argus_repo: str = 'phanerozoic/argus',
89
+ repo_local: Union[str, Path, None] = None):
90
+ """Load one of the stages by name.
91
+
92
+ stage ∈ {
93
+ 'stage_0', 'stage_0_tight_fpr', 'stage_1',
94
+ 'stage_2a', # heads masked
95
+ 'stage_2b', # backbone structurally pruned
96
+ 'stage_4', # 3.27M student
97
+ 'stage_4b', # 15.67M student
98
+ }
99
+
100
+ argus_repo: HF repo for the EUPE-ViT-B backbone. Used by stage_0,
101
+ stage_0_tight_fpr, stage_1, stage_2a. Stage 2b bundles its own
102
+ pruned backbone. Stages 4 and 4b don't use Argus.
103
+
104
+ repo_local: local path to this repo (contains stage_*/ directories).
105
+ Defaults to the directory containing this file.
106
+ """
107
+ root = Path(repo_local) if repo_local else HERE
108
+
109
+ if stage in ('stage_0', 'stage_1'):
110
+ return cls._build_argus_variant(root / 'stage_0' / 'classifier.json', argus_repo)
111
+ if stage == 'stage_0_tight_fpr':
112
+ return cls._build_argus_variant(root / 'stage_0_tight_fpr' / 'classifier.json', argus_repo)
113
+ if stage in ('stage_2a', 'stage_2'):
114
+ return cls._build_stage2a(root, argus_repo)
115
+ if stage == 'stage_2b':
116
+ return cls._build_stage2b(root)
117
+ if stage == 'stage_4':
118
+ return cls._build_stage4(root, root / 'stage_4' / 'student_final.safetensors',
119
+ student_out_dim=40, student_dim=192, student_depth=6, heads=3)
120
+ if stage == 'stage_4b':
121
+ return cls._build_stage4(root, root / 'stage_4b' / 'student_final.safetensors',
122
+ student_out_dim=768, student_dim=384, student_depth=8, heads=6)
123
+ raise ValueError(f'unknown stage: {stage}')
124
+
125
+ # --------------- stage-specific builders ---------------
126
+
127
+ @classmethod
128
+ def _build_argus_variant(cls, classifier_json, argus_repo):
129
+ from transformers import AutoModel
130
+ model = AutoModel.from_pretrained(argus_repo, trust_remote_code=True).to(DEVICE).eval()
131
+ c = _load_classifier(classifier_json)
132
+
133
+ def fwd(x):
134
+ with torch.autocast('cuda' if DEVICE == 'cuda' else 'cpu', dtype=torch.bfloat16):
135
+ out = model.backbone.forward_features(x)
136
+ patches = out['x_norm_patchtokens'].float().squeeze(0)
137
+ ln = F.layer_norm(patches, [D])
138
+ return ln.max(dim=0).values
139
+
140
+ return cls(fwd, c['pos_dims'], c['neg_dims'], c['threshold'])
141
+
142
+ @classmethod
143
+ def _build_stage2a(cls, root, argus_repo):
144
+ from transformers import AutoModel
145
+ model = AutoModel.from_pretrained(argus_repo, trust_remote_code=True).to(DEVICE).eval()
146
+ c = _load_classifier(root / 'stage_0' / 'classifier.json')
147
+ # Apply head mask from stage_2 head_importance.json (top 10 most prunable)
148
+ with open(root / 'stage_2' / 'head_importance.json') as f:
149
+ imp = json.load(f)
150
+ HEAD_DIM = 64
151
+ with torch.no_grad():
152
+ for (b, h, _drop) in imp['ranked_most_prunable_first'][:10]:
153
+ model.backbone.blocks[b].attn.proj.weight.data[:, h*HEAD_DIM:(h+1)*HEAD_DIM] = 0.0
154
+
155
+ def fwd(x):
156
+ with torch.autocast('cuda' if DEVICE == 'cuda' else 'cpu', dtype=torch.bfloat16):
157
+ out = model.backbone.forward_features(x)
158
+ patches = out['x_norm_patchtokens'].float().squeeze(0)
159
+ ln = F.layer_norm(patches, [D])
160
+ return ln.max(dim=0).values
161
+
162
+ return cls(fwd, c['pos_dims'], c['neg_dims'], c['threshold'])
163
+
164
+ @classmethod
165
+ def _build_stage2b(cls, root):
166
+ sys.path.insert(0, str(root / 'stage_2b'))
167
+ from load_pruned_backbone import load_stage2b_backbone
168
+ backbone = load_stage2b_backbone(
169
+ str(root / 'stage_2b' / 'pruned_state_dict.safetensors'),
170
+ str(root / 'stage_2b' / 'head_config.json'),
171
+ ).to(DEVICE).eval()
172
+ c = _load_classifier(root / 'stage_0' / 'classifier.json')
173
+
174
+ def fwd(x):
175
+ with torch.autocast('cuda' if DEVICE == 'cuda' else 'cpu', dtype=torch.bfloat16):
176
+ out = backbone.forward_features(x)
177
+ patches = out['x_norm_patchtokens'].float().squeeze(0)
178
+ ln = F.layer_norm(patches, [D])
179
+ return ln.max(dim=0).values
180
+
181
+ return cls(fwd, c['pos_dims'], c['neg_dims'], c['threshold'])
182
+
183
+ @classmethod
184
+ def _build_stage4(cls, root, weights_path, student_out_dim, student_dim, student_depth, heads):
185
+ from safetensors.torch import load_file
186
+
187
+ class _Block(nn.Module):
188
+ def __init__(self, dim, h, ratio=4.0):
189
+ super().__init__()
190
+ self.norm1 = nn.LayerNorm(dim)
191
+ self.attn = nn.MultiheadAttention(dim, h, batch_first=True)
192
+ self.norm2 = nn.LayerNorm(dim)
193
+ hidden = int(dim * ratio)
194
+ self.mlp = nn.Sequential(nn.Linear(dim, hidden), nn.GELU(), nn.Linear(hidden, dim))
195
+
196
+ def forward(self, x):
197
+ h_, _ = self.attn(self.norm1(x), self.norm1(x), self.norm1(x), need_weights=False)
198
+ x = x + h_
199
+ return x + self.mlp(self.norm2(x))
200
+
201
+ class _Student(nn.Module):
202
+ def __init__(self, out_dim, dim, depth, h, patch=16, img=RES):
203
+ super().__init__()
204
+ self.patch = nn.Conv2d(3, dim, patch, stride=patch)
205
+ self.pos = nn.Parameter(torch.zeros(1, (img // patch) ** 2, dim))
206
+ self.blocks = nn.ModuleList([_Block(dim, h) for _ in range(depth)])
207
+ self.norm = nn.LayerNorm(dim)
208
+ self.head = nn.Linear(dim, out_dim)
209
+
210
+ def forward(self, x):
211
+ t = self.patch(x).flatten(2).transpose(1, 2)
212
+ t = t + self.pos[:, :t.shape[1]]
213
+ for blk in self.blocks:
214
+ t = blk(t)
215
+ t = self.norm(t)
216
+ return self.head(t.max(dim=1).values)
217
+
218
+ student = _Student(student_out_dim, student_dim, student_depth, heads).to(DEVICE).eval()
219
+ student.load_state_dict(load_file(str(weights_path)))
220
+
221
+ # Classifier indexing depends on student output layout:
222
+ # - stage_4: student emits the 40 classifier-relevant dims directly
223
+ # (pos at [0:20], neg at [20:40])
224
+ # - stage_4b: student emits a 768-D vector matching teacher layout;
225
+ # use Stage 0's pos/neg dims directly.
226
+ if student_out_dim == 40:
227
+ pos, neg = list(range(20)), list(range(20, 40))
228
+ with open(root / 'stage_4' / 'training_log.json') as f:
229
+ log = json.load(f)
230
+ thr = log['epochs'][-1].get('threshold', 0.0)
231
+ else:
232
+ c = _load_classifier(root / 'stage_0' / 'classifier.json')
233
+ pos, neg = c['pos_dims'], c['neg_dims']
234
+ with open(root / 'stage_4b' / 'training_log.json') as f:
235
+ log = json.load(f)
236
+ thr = log['epochs'][-1].get('threshold', 0.0)
237
+
238
+ def fwd(x):
239
+ with torch.autocast('cuda' if DEVICE == 'cuda' else 'cpu', dtype=torch.bfloat16):
240
+ out = student(x)
241
+ return out.float().squeeze(0)
242
+
243
+ return cls(fwd, pos, neg, thr)
244
+
245
+
246
+ if __name__ == '__main__':
247
+ if len(sys.argv) < 3:
248
+ print('usage: python infer.py <stage> <image> [image ...]')
249
+ print('stages: stage_0, stage_0_tight_fpr, stage_1, stage_2a, stage_2b, stage_4, stage_4b')
250
+ sys.exit(1)
251
+ stage = sys.argv[1]
252
+ det = PersonDetector.from_stage(stage)
253
+ for path in sys.argv[2:]:
254
+ score, present = det.predict(path)
255
+ print(f'{path} score={score:+.3f} present={present}')