jimnoneill commited on
Commit
f3101a4
Β·
verified Β·
1 Parent(s): 85de43a

Upload src/pubguard/classifier.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/pubguard/classifier.py +264 -0
src/pubguard/classifier.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PubGuard β€” Multi-head Publication Gatekeeper
3
+ =============================================
4
+
5
+ Architecture
6
+ ~~~~~~~~~~~~
7
+
8
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
9
+ β”‚ PDF text β”‚
10
+ β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜
11
+ β”‚
12
+ β”Œβ”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
13
+ β”‚ clean_text │────►│ model2vec encode │──► emb ∈ R^512
14
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
15
+ β”‚
16
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
17
+ β–Ό β–Ό β–Ό
18
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
19
+ β”‚ doc_type head β”‚ β”‚ ai_detect β”‚ β”‚ toxicity β”‚
20
+ β”‚ (concat struct) β”‚ β”‚ head β”‚ β”‚ head β”‚
21
+ β”‚ WΒ·[emb;feat]+b β”‚ β”‚ WΒ·emb + b β”‚ β”‚ WΒ·emb + b β”‚
22
+ β”‚ β†’ softmax(4) β”‚ β”‚ β†’ softmax(2) β”‚ β”‚ β†’ softmax(2) β”‚
23
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
24
+
25
+ Each head is a single linear layer stored as a numpy .npz file
26
+ (weights W and bias b). Inference is pure numpy β€” no torch needed
27
+ at prediction time, matching the openalex classifier's deployment
28
+ philosophy.
29
+
30
+ The doc_type head additionally receives 14 structural features
31
+ (section headings present, citation density, etc.) concatenated
32
+ with the embedding β€” these are powerful priors that cost ~0 compute.
33
+
34
+ Performance target: β‰₯2,000 records/sec on CPU (same ballpark as
35
+ openalex classifier at ~3,000/sec).
36
+ """
37
+
38
+ import logging
39
+ import time
40
+ from pathlib import Path
41
+ from typing import Any, Dict, List, Optional, Union
42
+
43
+ import numpy as np
44
+
45
+ from .config import PubGuardConfig, DOC_TYPE_LABELS, AI_DETECT_LABELS, TOXICITY_LABELS
46
+ from .text import clean_text, extract_structural_features, STRUCTURAL_FEATURE_NAMES, N_STRUCTURAL_FEATURES
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ class LinearHead:
52
+ """
53
+ Single linear classifier head: logits = X @ W + b β†’ softmax.
54
+
55
+ Stored as .npz with keys 'W', 'b', 'labels'.
56
+ """
57
+
58
+ def __init__(self, labels: List[str]):
59
+ self.labels = labels
60
+ self.n_classes = len(labels)
61
+ self.W: Optional[np.ndarray] = None # (input_dim, n_classes)
62
+ self.b: Optional[np.ndarray] = None # (n_classes,)
63
+
64
+ def load(self, path: Path) -> bool:
65
+ if not path.exists():
66
+ return False
67
+ data = np.load(path, allow_pickle=True)
68
+ self.W = data["W"]
69
+ self.b = data["b"]
70
+ stored_labels = data.get("labels", None)
71
+ if stored_labels is not None:
72
+ self.labels = list(stored_labels)
73
+ self.n_classes = len(self.labels)
74
+ return True
75
+
76
+ def save(self, path: Path):
77
+ path.parent.mkdir(parents=True, exist_ok=True)
78
+ np.savez(path, W=self.W, b=self.b, labels=np.array(self.labels))
79
+
80
+ def predict(self, X: np.ndarray) -> tuple:
81
+ """
82
+ Returns (pred_labels, pred_scores) for batch.
83
+
84
+ X : (batch, input_dim)
85
+ """
86
+ logits = X @ self.W + self.b # (batch, n_classes)
87
+ probs = _softmax(logits) # (batch, n_classes)
88
+ pred_idx = np.argmax(probs, axis=1) # (batch,)
89
+ pred_scores = probs[np.arange(len(X)), pred_idx]
90
+ pred_labels = [self.labels[i] for i in pred_idx]
91
+ return pred_labels, pred_scores, probs
92
+
93
+
94
+ def _softmax(x: np.ndarray) -> np.ndarray:
95
+ """Numerically stable softmax."""
96
+ e = np.exp(x - x.max(axis=-1, keepdims=True))
97
+ return e / e.sum(axis=-1, keepdims=True)
98
+
99
+
100
+ class PubGuard:
101
+ """
102
+ Multi-head publication screening classifier.
103
+
104
+ Usage:
105
+ guard = PubGuard()
106
+ guard.initialize()
107
+
108
+ # Single document
109
+ verdict = guard.screen("Introduction: We present a novel ...")
110
+
111
+ # Batch
112
+ verdicts = guard.screen_batch(["text1", "text2", ...])
113
+ """
114
+
115
+ def __init__(self, config: Optional[PubGuardConfig] = None):
116
+ self.config = config or PubGuardConfig()
117
+ self.model = None
118
+ self.head_doc_type = LinearHead(DOC_TYPE_LABELS)
119
+ self.head_ai_detect = LinearHead(AI_DETECT_LABELS)
120
+ self.head_toxicity = LinearHead(TOXICITY_LABELS)
121
+ self._initialized = False
122
+
123
+ # ── Initialisation ──────────────────────────────────────────
124
+
125
+ def initialize(self) -> bool:
126
+ """Load embedding model + all classification heads."""
127
+ if self._initialized:
128
+ return True
129
+
130
+ logger.info("Initializing PubGuard...")
131
+ start = time.time()
132
+
133
+ self._load_model()
134
+ self._load_heads()
135
+
136
+ self._initialized = True
137
+ logger.info(f"PubGuard initialized in {time.time()-start:.1f}s")
138
+ return True
139
+
140
+ def _load_model(self):
141
+ """Load model2vec StaticModel (same as openalex classifier)."""
142
+ from model2vec import StaticModel
143
+
144
+ cache = self.config.distilled_model_path
145
+ if cache.exists():
146
+ logger.info(f"Loading embedding model from {cache}")
147
+ self.model = StaticModel.from_pretrained(str(cache))
148
+ else:
149
+ logger.info(f"Downloading model: {self.config.model_name}")
150
+ self.model = StaticModel.from_pretrained(self.config.model_name)
151
+ cache.parent.mkdir(parents=True, exist_ok=True)
152
+ self.model.save_pretrained(str(cache))
153
+ logger.info(f"Cached to {cache}")
154
+
155
+ def _load_heads(self):
156
+ """Load each classification head from .npz files."""
157
+ heads = [
158
+ ("doc_type", self.head_doc_type, self.config.doc_type_head_path),
159
+ ("ai_detect", self.head_ai_detect, self.config.ai_detect_head_path),
160
+ ("toxicity", self.head_toxicity, self.config.toxicity_head_path),
161
+ ]
162
+ for name, head, path in heads:
163
+ if head.load(path):
164
+ logger.info(f" Loaded {name} head: {path}")
165
+ else:
166
+ logger.warning(
167
+ f" {name} head not found at {path} β€” "
168
+ f"run `python -m pubguard.train` first"
169
+ )
170
+
171
+ # ── Inference ───────────────────────────────────────────────
172
+
173
+ def screen(self, text: str) -> Dict[str, Any]:
174
+ """Screen a single document. Returns verdict dict."""
175
+ return self.screen_batch([text])[0]
176
+
177
+ def screen_batch(self, texts: List[str]) -> List[Dict[str, Any]]:
178
+ """
179
+ Screen a batch of documents.
180
+
181
+ Returns list of verdict dicts, each containing:
182
+ doc_type: {label, score}
183
+ ai_generated: {label, score}
184
+ toxicity: {label, score}
185
+ pass: bool (overall gate decision)
186
+ """
187
+ if not self._initialized:
188
+ self.initialize()
189
+
190
+ if not texts:
191
+ return []
192
+
193
+ cfg = self.config
194
+
195
+ # ── Preprocess ──────────────────────────────────────────
196
+ cleaned = [clean_text(t, cfg.max_text_chars) for t in texts]
197
+
198
+ # ── Embed ───────────────────────────────────────────────
199
+ embeddings = self.model.encode(cleaned)
200
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
201
+ norms = np.where(norms == 0, 1, norms) # avoid div-by-zero
202
+ embeddings = (embeddings / norms).astype("float32")
203
+
204
+ # ── Structural features (for doc_type head) ─────────────
205
+ struct_feats = np.array(
206
+ [list(extract_structural_features(t).values()) for t in cleaned],
207
+ dtype="float32",
208
+ )
209
+ doc_type_input = np.concatenate([embeddings, struct_feats], axis=1)
210
+
211
+ # ── Per-head predictions ────────────────────────────────
212
+ results = []
213
+
214
+ has_doc = self.head_doc_type.W is not None
215
+ has_ai = self.head_ai_detect.W is not None
216
+ has_tox = self.head_toxicity.W is not None
217
+
218
+ dt_labels, dt_scores, _ = (
219
+ self.head_doc_type.predict(doc_type_input) if has_doc
220
+ else (["unknown"] * len(texts), [0.0] * len(texts), None)
221
+ )
222
+ ai_labels, ai_scores, _ = (
223
+ self.head_ai_detect.predict(embeddings) if has_ai
224
+ else (["unknown"] * len(texts), [0.0] * len(texts), None)
225
+ )
226
+ tx_labels, tx_scores, _ = (
227
+ self.head_toxicity.predict(embeddings) if has_tox
228
+ else (["unknown"] * len(texts), [0.0] * len(texts), None)
229
+ )
230
+
231
+ for i in range(len(texts)):
232
+ # Gate logic
233
+ passes = True
234
+ if cfg.require_scientific and dt_labels[i] != "scientific_paper":
235
+ passes = False
236
+ if cfg.block_ai_generated and ai_labels[i] == "ai_generated":
237
+ passes = False
238
+ if cfg.block_toxic and tx_labels[i] == "toxic":
239
+ passes = False
240
+
241
+ results.append({
242
+ "doc_type": {
243
+ "label": dt_labels[i],
244
+ "score": round(float(dt_scores[i]), 4),
245
+ },
246
+ "ai_generated": {
247
+ "label": ai_labels[i],
248
+ "score": round(float(ai_scores[i]), 4),
249
+ },
250
+ "toxicity": {
251
+ "label": tx_labels[i],
252
+ "score": round(float(tx_scores[i]), 4),
253
+ },
254
+ "pass": passes,
255
+ })
256
+
257
+ return results
258
+
259
+ # ── File-level convenience ──────────────────────────────────
260
+
261
+ def screen_file(self, path: Path) -> Dict[str, Any]:
262
+ """Read a text file and screen it."""
263
+ text = Path(path).read_text(errors="replace")
264
+ return self.screen(text)