jimnoneill commited on
Commit
ba415d6
Β·
verified Β·
1 Parent(s): 9b4dd95

Upload src/poster_sentry/classifier.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/poster_sentry/classifier.py +252 -0
src/poster_sentry/classifier.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PosterSentry β€” Multimodal Scientific Poster Classifier
3
+ =======================================================
4
+
5
+ Architecture:
6
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
7
+ β”‚ PDF text β”‚ β”‚ PDF β†’ image β”‚ β”‚ PDF structure β”‚
8
+ β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜
9
+ β”‚ β”‚ β”‚
10
+ model2vec 15 visual 15 structural
11
+ β†’ 512-d emb features features
12
+ β”‚ β”‚ β”‚
13
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
14
+ β”‚
15
+ concat β†’ 542-d input
16
+ β”‚
17
+ LogisticRegression
18
+ β”‚
19
+ poster / non_poster
20
+
21
+ Single linear classifier on the concatenated feature vector.
22
+ Same paradigm as PubGuard β€” lightweight, CPU-only, fast.
23
+ """
24
+
25
+ import logging
26
+ import time
27
+ from pathlib import Path
28
+ from typing import Any, Dict, List, Optional
29
+
30
+ import numpy as np
31
+
32
+ from .features import (
33
+ VisualFeatureExtractor,
34
+ PDFStructuralExtractor,
35
+ N_VISUAL_FEATURES,
36
+ N_STRUCTURAL_FEATURES,
37
+ )
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ class PosterSentry:
43
+ """
44
+ Multimodal poster classifier.
45
+
46
+ Combines:
47
+ - model2vec text embedding (512-d)
48
+ - 15 visual features (color, edge, FFT, whitespace)
49
+ - 15 structural features (page geometry, fonts, text blocks)
50
+
51
+ into a single 542-d feature vector for logistic regression.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ model_name: str = "minishlab/potion-base-32M",
57
+ models_dir: Optional[Path] = None,
58
+ ):
59
+ self.model_name = model_name
60
+ self.models_dir = models_dir or self._default_models_dir()
61
+ self.models_dir = Path(self.models_dir)
62
+
63
+ self.text_model = None
64
+ self.W: Optional[np.ndarray] = None
65
+ self.b: Optional[np.ndarray] = None
66
+ self.scaler_mean: Optional[np.ndarray] = None
67
+ self.scaler_scale: Optional[np.ndarray] = None
68
+ self.labels = ["non_poster", "poster"]
69
+
70
+ self.visual_extractor = VisualFeatureExtractor()
71
+ self.structural_extractor = PDFStructuralExtractor()
72
+ self._initialized = False
73
+
74
+ @staticmethod
75
+ def _default_models_dir() -> Path:
76
+ import os
77
+ if env := os.environ.get("POSTER_SENTRY_MODELS_DIR"):
78
+ return Path(env)
79
+ home = Path.home() / ".poster_sentry" / "models"
80
+ home.mkdir(parents=True, exist_ok=True)
81
+ return home
82
+
83
+ # ── Initialization ──────────────────────────────────────────
84
+
85
+ def initialize(self) -> bool:
86
+ if self._initialized:
87
+ return True
88
+ logger.info("Initializing PosterSentry...")
89
+ t0 = time.time()
90
+ self._load_text_model()
91
+ self._load_head()
92
+ self._initialized = True
93
+ logger.info(f"PosterSentry initialized in {time.time()-t0:.1f}s")
94
+ return True
95
+
96
+ def _load_text_model(self):
97
+ from model2vec import StaticModel
98
+ cache = self.models_dir / "poster-sentry-embedding"
99
+ if cache.exists():
100
+ self.text_model = StaticModel.from_pretrained(str(cache))
101
+ else:
102
+ self.text_model = StaticModel.from_pretrained(self.model_name)
103
+ cache.parent.mkdir(parents=True, exist_ok=True)
104
+ self.text_model.save_pretrained(str(cache))
105
+
106
+ def _load_head(self):
107
+ path = self.models_dir / "poster_sentry_head.npz"
108
+ if path.exists():
109
+ data = np.load(path, allow_pickle=True)
110
+ self.W = data["W"]
111
+ self.b = data["b"]
112
+ if "labels" in data:
113
+ self.labels = list(data["labels"])
114
+ if "scaler_mean" in data and "scaler_scale" in data:
115
+ self.scaler_mean = data["scaler_mean"]
116
+ self.scaler_scale = data["scaler_scale"]
117
+ logger.info(f" Loaded classifier head: {path}")
118
+ else:
119
+ logger.warning(f" Head not found: {path} β€” run training first")
120
+
121
+ def save_head(self, path: Optional[Path] = None):
122
+ path = path or (self.models_dir / "poster_sentry_head.npz")
123
+ path.parent.mkdir(parents=True, exist_ok=True)
124
+ np.savez(path, W=self.W, b=self.b, labels=np.array(self.labels))
125
+
126
+ # ── Feature extraction ──────────────────────────────────────
127
+
128
+ def extract_text(self, pdf_path: str, max_chars: int = 4000) -> str:
129
+ """Extract and clean text from first page of PDF."""
130
+ try:
131
+ import fitz
132
+ doc = fitz.open(pdf_path)
133
+ if len(doc) == 0:
134
+ doc.close()
135
+ return ""
136
+ text = doc[0].get_text()
137
+ doc.close()
138
+ # Basic cleanup
139
+ import re
140
+ text = re.sub(r"\s+", " ", text).strip()
141
+ return text[:max_chars]
142
+ except Exception:
143
+ return ""
144
+
145
+ def embed_texts(self, texts: List[str]) -> np.ndarray:
146
+ """Encode texts with model2vec, L2-normalize."""
147
+ embeddings = self.text_model.encode(texts)
148
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
149
+ norms = np.where(norms == 0, 1, norms)
150
+ return (embeddings / norms).astype("float32")
151
+
152
+ def build_feature_vector(
153
+ self,
154
+ text_emb: np.ndarray,
155
+ visual_feats: np.ndarray,
156
+ structural_feats: np.ndarray,
157
+ ) -> np.ndarray:
158
+ """Concatenate all features: [512 text + 15 visual + 15 structural] = 542."""
159
+ return np.concatenate([text_emb, visual_feats, structural_feats])
160
+
161
+ # ── Inference ───────────────────────────────────────────────
162
+
163
+ def classify(self, pdf_path: str) -> Dict[str, Any]:
164
+ """Classify a single PDF as poster or non-poster."""
165
+ if not self._initialized:
166
+ self.initialize()
167
+ return self.classify_batch([pdf_path])[0]
168
+
169
+ def classify_batch(self, pdf_paths: List[str]) -> List[Dict[str, Any]]:
170
+ """Classify a batch of PDFs."""
171
+ if not self._initialized:
172
+ self.initialize()
173
+
174
+ texts = []
175
+ visual_vecs = []
176
+ structural_vecs = []
177
+
178
+ for p in pdf_paths:
179
+ texts.append(self.extract_text(p))
180
+
181
+ img = self.visual_extractor.pdf_to_image(p)
182
+ if img is not None:
183
+ vf = self.visual_extractor.extract(img)
184
+ else:
185
+ vf = {n: 0.0 for n in self.visual_extractor.FEATURE_NAMES}
186
+ visual_vecs.append(self.visual_extractor.to_vector(vf))
187
+
188
+ sf = self.structural_extractor.extract(p)
189
+ structural_vecs.append(self.structural_extractor.to_vector(sf))
190
+
191
+ # Embed text
192
+ text_embs = self.embed_texts(texts)
193
+ visual_arr = np.array(visual_vecs, dtype="float32")
194
+ struct_arr = np.array(structural_vecs, dtype="float32")
195
+
196
+ # Concatenate
197
+ X = np.concatenate([text_embs, visual_arr, struct_arr], axis=1)
198
+
199
+ # Scale features (critical for balanced text vs structural signal)
200
+ if self.scaler_mean is not None and self.scaler_scale is not None:
201
+ X = (X - self.scaler_mean) / np.where(self.scaler_scale == 0, 1, self.scaler_scale)
202
+
203
+ # Predict
204
+ if self.W is None:
205
+ return [{"path": p, "is_poster": False, "confidence": 0.0,
206
+ "error": "Model not trained"} for p in pdf_paths]
207
+
208
+ logits = X @ self.W + self.b
209
+ e = np.exp(logits - logits.max(axis=-1, keepdims=True))
210
+ probs = e / e.sum(axis=-1, keepdims=True)
211
+
212
+ results = []
213
+ for i, p in enumerate(pdf_paths):
214
+ poster_prob = float(probs[i, 1])
215
+ results.append({
216
+ "path": str(p),
217
+ "is_poster": poster_prob > 0.5,
218
+ "confidence": round(poster_prob, 4),
219
+ "text_score": round(float(probs[i, 1]), 4),
220
+ })
221
+ return results
222
+
223
+ # ── Text-only classification (for PubGuard integration) ─────
224
+
225
+ def classify_text(self, text: str) -> Dict[str, Any]:
226
+ """Classify from text alone (no PDF needed). Used by PubGuard."""
227
+ return self.classify_texts([text])[0]
228
+
229
+ def classify_texts(self, texts: List[str]) -> List[Dict[str, Any]]:
230
+ """Classify from text alone (batch)."""
231
+ if not self._initialized:
232
+ self.initialize()
233
+ if self.W is None:
234
+ return [{"is_poster": False, "confidence": 0.0}] * len(texts)
235
+
236
+ text_embs = self.embed_texts(texts)
237
+ # Zero-fill visual and structural features
238
+ zeros_visual = np.zeros((len(texts), N_VISUAL_FEATURES), dtype="float32")
239
+ zeros_struct = np.zeros((len(texts), N_STRUCTURAL_FEATURES), dtype="float32")
240
+ X = np.concatenate([text_embs, zeros_visual, zeros_struct], axis=1)
241
+
242
+ # Scale
243
+ if self.scaler_mean is not None and self.scaler_scale is not None:
244
+ X = (X - self.scaler_mean) / np.where(self.scaler_scale == 0, 1, self.scaler_scale)
245
+
246
+ logits = X @ self.W + self.b
247
+ e = np.exp(logits - logits.max(axis=-1, keepdims=True))
248
+ probs = e / e.sum(axis=-1, keepdims=True)
249
+
250
+ return [{"is_poster": float(probs[i, 1]) > 0.5,
251
+ "confidence": round(float(probs[i, 1]), 4)}
252
+ for i in range(len(texts))]