jimnoneill commited on
Commit
e462642
Β·
verified Β·
1 Parent(s): ba415d6

Upload scripts/train_poster_sentry.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/train_poster_sentry.py +402 -0
scripts/train_poster_sentry.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Train PosterSentry on the real posters.science corpus.
4
+
5
+ Data sources (all real, zero synthetic):
6
+ Positive (poster):
7
+ 28K+ verified scientific posters from Zenodo & Figshare
8
+ /home/joneill/Nextcloud/vaults/jmind/calmi2/poster_science/poster-pdf-meta/downloads/
9
+
10
+ Negative (non_poster):
11
+ 2,036 verified non-posters (multi-page docs, proceedings, abstracts)
12
+ Listed in: poster_classifier/non_posters_20251208_152217.txt
13
+
14
+ Plus: single pages extracted from armanc/scientific_papers (real papers)
15
+ Plus: ag_news articles (real junk text, rendered to match)
16
+
17
+ Usage:
18
+ cd /home/joneill/pubverse_brett/poster_sentry
19
+ pip install -e ".[train]"
20
+ python scripts/train_poster_sentry.py --n-per-class 5000
21
+ """
22
+
23
+ import argparse
24
+ import json
25
+ import logging
26
+ import os
27
+ import random
28
+ import sys
29
+ import time
30
+ from pathlib import Path
31
+ from typing import Dict, List, Optional, Tuple
32
+
33
+ import numpy as np
34
+
35
+ logging.basicConfig(
36
+ level=logging.INFO,
37
+ format="%(asctime)s | %(levelname)s | %(message)s",
38
+ datefmt="%Y-%m-%d %H:%M:%S",
39
+ )
40
+ logger = logging.getLogger(__name__)
41
+
42
+ SEED = 42
43
+ random.seed(SEED)
44
+ np.random.seed(SEED)
45
+
46
+ # ── Paths ────────────────────────────────────────────────────────
47
+
48
+ POSTER_SCIENCE_BASE = Path(
49
+ "/home/joneill/Nextcloud/vaults/jmind/calmi2/poster_science"
50
+ )
51
+ DOWNLOADS_DIR = POSTER_SCIENCE_BASE / "poster-pdf-meta" / "downloads"
52
+ NON_POSTERS_LIST = (
53
+ POSTER_SCIENCE_BASE
54
+ / "poster_classifier"
55
+ / "non_posters_20251208_152217.txt"
56
+ )
57
+ CLASSIFICATION_JSON = (
58
+ POSTER_SCIENCE_BASE
59
+ / "poster_classifier"
60
+ / "classification_results_20251208_152217.json"
61
+ )
62
+
63
+
64
+ def _fix_path(p: str) -> str:
65
+ """Fix paths from classification JSON β€” they use /home/joneill/vaults/
66
+ but the actual Nextcloud mount is /home/joneill/Nextcloud/vaults/."""
67
+ if "/joneill/vaults/" in p and "/Nextcloud/" not in p:
68
+ return p.replace("/joneill/vaults/", "/joneill/Nextcloud/vaults/")
69
+ return p
70
+
71
+
72
+ def collect_poster_paths(max_n: int = 10000) -> List[str]:
73
+ """Collect verified poster PDF paths from the corpus."""
74
+ # Load the classification results to get confirmed poster paths
75
+ if CLASSIFICATION_JSON.exists():
76
+ logger.info(f"Loading classification results from {CLASSIFICATION_JSON}")
77
+ with open(CLASSIFICATION_JSON) as f:
78
+ data = json.load(f)
79
+ poster_entries = data.get("posters", [])
80
+ paths = [_fix_path(e["pdf_path"]) for e in poster_entries if Path(_fix_path(e["pdf_path"])).exists()]
81
+ logger.info(f" Found {len(paths)} verified poster paths")
82
+ else:
83
+ # Fallback: glob the downloads directory
84
+ logger.info(f"Globbing {DOWNLOADS_DIR} for PDFs...")
85
+ paths = [str(p) for p in DOWNLOADS_DIR.rglob("*.pdf")]
86
+ paths += [str(p) for p in DOWNLOADS_DIR.rglob("*.PDF")]
87
+ logger.info(f" Found {len(paths)} PDFs")
88
+
89
+ random.shuffle(paths)
90
+ return paths[:max_n]
91
+
92
+
93
+ def collect_non_poster_paths(max_n: int = 2000) -> List[str]:
94
+ """Collect verified non-poster PDF paths.
95
+
96
+ The non-posters were separated into:
97
+ poster-pdf-meta/separated_non_posters/downloads/{zenodo,figshare}/
98
+ """
99
+ paths = []
100
+
101
+ # Primary: glob the separated_non_posters directory
102
+ sep_dir = POSTER_SCIENCE_BASE / "poster-pdf-meta" / "separated_non_posters" / "downloads"
103
+ if sep_dir.exists():
104
+ for pdf in sep_dir.rglob("*.pdf"):
105
+ paths.append(str(pdf))
106
+ for pdf in sep_dir.rglob("*.PDF"):
107
+ paths.append(str(pdf))
108
+ logger.info(f" Found {len(paths)} non-poster PDFs in {sep_dir}")
109
+ else:
110
+ # Fallback: try the original list with path fixing
111
+ logger.info(" Separated dir not found, trying original list...")
112
+ if NON_POSTERS_LIST.exists():
113
+ with open(NON_POSTERS_LIST) as f:
114
+ for line in f:
115
+ p = _fix_path(line.strip())
116
+ if p and Path(p).exists():
117
+ paths.append(p)
118
+ logger.info(f" Found {len(paths)} verified non-poster paths from list")
119
+
120
+ random.shuffle(paths)
121
+ return paths[:max_n]
122
+
123
+
124
+ def extract_features_from_pdfs(
125
+ pdf_paths: List[str],
126
+ label: int,
127
+ text_model,
128
+ visual_ext,
129
+ structural_ext,
130
+ max_text_chars: int = 4000,
131
+ ) -> Tuple[np.ndarray, np.ndarray, List[str]]:
132
+ """
133
+ Extract multimodal features from a list of PDFs.
134
+
135
+ Returns (X, y, extracted_texts) where:
136
+ X: (N, 542) feature matrix
137
+ y: (N,) labels
138
+ extracted_texts: list of extracted text strings (for PubGuard reuse)
139
+ """
140
+ from tqdm import tqdm
141
+ import fitz
142
+ import re
143
+
144
+ embeddings = []
145
+ visual_vecs = []
146
+ struct_vecs = []
147
+ texts_out = []
148
+ labels = []
149
+
150
+ for pdf_path in tqdm(pdf_paths, desc=f"{'poster' if label == 1 else 'non_poster'}"):
151
+ try:
152
+ # Extract text
153
+ doc = fitz.open(pdf_path)
154
+ if len(doc) == 0:
155
+ doc.close()
156
+ continue
157
+ text = doc[0].get_text()
158
+ doc.close()
159
+ text = re.sub(r"\s+", " ", text).strip()[:max_text_chars]
160
+
161
+ if len(text) < 20:
162
+ continue
163
+
164
+ # Visual features
165
+ img = visual_ext.pdf_to_image(pdf_path)
166
+ if img is not None:
167
+ vf = visual_ext.extract(img)
168
+ else:
169
+ vf = {n: 0.0 for n in visual_ext.FEATURE_NAMES}
170
+
171
+ # Structural features
172
+ sf = structural_ext.extract(pdf_path)
173
+
174
+ texts_out.append(text)
175
+ visual_vecs.append(visual_ext.to_vector(vf))
176
+ struct_vecs.append(structural_ext.to_vector(sf))
177
+ labels.append(label)
178
+
179
+ except Exception as e:
180
+ logger.debug(f"Skipping {pdf_path}: {e}")
181
+ continue
182
+
183
+ if not texts_out:
184
+ return np.array([]), np.array([]), []
185
+
186
+ # Embed all texts at once
187
+ logger.info(f"Embedding {len(texts_out)} texts...")
188
+ emb = text_model.encode(texts_out, show_progress_bar=True)
189
+ norms = np.linalg.norm(emb, axis=1, keepdims=True)
190
+ norms = np.where(norms == 0, 1, norms)
191
+ emb = (emb / norms).astype("float32")
192
+
193
+ visual_arr = np.array(visual_vecs, dtype="float32")
194
+ struct_arr = np.array(struct_vecs, dtype="float32")
195
+
196
+ X = np.concatenate([emb, visual_arr, struct_arr], axis=1)
197
+ y = np.array(labels)
198
+
199
+ return X, y, texts_out
200
+
201
+
202
+ def main():
203
+ parser = argparse.ArgumentParser(description="Train PosterSentry")
204
+ parser.add_argument("--n-per-class", type=int, default=5000,
205
+ help="Max samples per class (poster/non_poster)")
206
+ parser.add_argument("--test-size", type=float, default=0.15)
207
+ parser.add_argument("--models-dir", default=None)
208
+ parser.add_argument("--export-texts", default=None,
209
+ help="Export extracted texts as NDJSON for PubGuard retraining")
210
+ args = parser.parse_args()
211
+
212
+ from model2vec import StaticModel
213
+ from sklearn.linear_model import LogisticRegression
214
+ from sklearn.metrics import classification_report
215
+ from sklearn.model_selection import train_test_split
216
+ from poster_sentry.features import VisualFeatureExtractor, PDFStructuralExtractor
217
+
218
+ # Models dir
219
+ if args.models_dir:
220
+ models_dir = Path(args.models_dir)
221
+ else:
222
+ models_dir = Path.home() / ".poster_sentry" / "models"
223
+ models_dir.mkdir(parents=True, exist_ok=True)
224
+
225
+ # Load embedding model
226
+ logger.info("Loading model2vec...")
227
+ emb_cache = models_dir / "poster-sentry-embedding"
228
+ if emb_cache.exists():
229
+ text_model = StaticModel.from_pretrained(str(emb_cache))
230
+ else:
231
+ text_model = StaticModel.from_pretrained("minishlab/potion-base-32M")
232
+ emb_cache.parent.mkdir(parents=True, exist_ok=True)
233
+ text_model.save_pretrained(str(emb_cache))
234
+
235
+ visual_ext = VisualFeatureExtractor()
236
+ structural_ext = PDFStructuralExtractor()
237
+
238
+ # ── Collect data ─────────────────────────────────────────────
239
+ logger.info("=" * 60)
240
+ logger.info("Collecting training data...")
241
+ logger.info("=" * 60)
242
+
243
+ poster_paths = collect_poster_paths(max_n=args.n_per_class)
244
+ non_poster_paths = collect_non_poster_paths(max_n=args.n_per_class)
245
+
246
+ logger.info(f"Poster PDFs to process: {len(poster_paths)}")
247
+ logger.info(f"Non-poster PDFs to process: {len(non_poster_paths)}")
248
+
249
+ # ── Extract features ─────────────────────────────────────────
250
+ logger.info("=" * 60)
251
+ logger.info("Extracting features from poster PDFs...")
252
+ logger.info("=" * 60)
253
+
254
+ X_pos, y_pos, texts_pos = extract_features_from_pdfs(
255
+ poster_paths, label=1, text_model=text_model,
256
+ visual_ext=visual_ext, structural_ext=structural_ext,
257
+ )
258
+
259
+ logger.info(f"Poster features: {X_pos.shape}")
260
+
261
+ logger.info("=" * 60)
262
+ logger.info("Extracting features from non-poster PDFs...")
263
+ logger.info("=" * 60)
264
+
265
+ X_neg, y_neg, texts_neg = extract_features_from_pdfs(
266
+ non_poster_paths, label=0, text_model=text_model,
267
+ visual_ext=visual_ext, structural_ext=structural_ext,
268
+ )
269
+
270
+ logger.info(f"Non-poster features: {X_neg.shape}")
271
+
272
+ # ── Balance classes ──────────────────────────────────────────
273
+ min_count = min(len(y_pos), len(y_neg))
274
+ logger.info(f"Balancing: {min_count} samples per class")
275
+
276
+ if len(y_pos) > min_count:
277
+ idx = np.random.choice(len(y_pos), min_count, replace=False)
278
+ X_pos = X_pos[idx]
279
+ y_pos = y_pos[idx]
280
+ texts_pos = [texts_pos[i] for i in idx]
281
+
282
+ if len(y_neg) > min_count:
283
+ idx = np.random.choice(len(y_neg), min_count, replace=False)
284
+ X_neg = X_neg[idx]
285
+ y_neg = y_neg[idx]
286
+ texts_neg = [texts_neg[i] for i in idx]
287
+
288
+ X = np.vstack([X_pos, X_neg])
289
+ y = np.concatenate([y_pos, y_neg])
290
+
291
+ logger.info(f"Total training data: {X.shape} (poster={sum(y)}, non_poster={len(y)-sum(y)})")
292
+
293
+ # ── Export texts for PubGuard ────────────────────────────────
294
+ if args.export_texts:
295
+ export_path = Path(args.export_texts)
296
+ export_path.parent.mkdir(parents=True, exist_ok=True)
297
+ with open(export_path, "w") as f:
298
+ for text in texts_pos:
299
+ f.write(json.dumps({"text": text, "label": "poster"}) + "\n")
300
+ for text in texts_neg:
301
+ f.write(json.dumps({"text": text, "label": "non_poster"}) + "\n")
302
+ logger.info(f"Exported {len(texts_pos) + len(texts_neg)} texts to {export_path}")
303
+
304
+ # ── Feature scaling ──────────────────────────────────────────
305
+ # Critical: the 512-d text embedding drowns out the 30 structural/visual
306
+ # features if we don't scale. StandardScaler normalizes each column to
307
+ # zero mean and unit variance, giving structural signals fair weight.
308
+ from sklearn.preprocessing import StandardScaler
309
+
310
+ logger.info("=" * 60)
311
+ logger.info("Scaling features (StandardScaler)")
312
+ logger.info("=" * 60)
313
+
314
+ scaler = StandardScaler()
315
+ X_scaled = scaler.fit_transform(X)
316
+
317
+ # Log feature variance to confirm structural features are alive
318
+ emb_var = np.mean(np.var(X_scaled[:, :512], axis=0))
319
+ vis_var = np.mean(np.var(X_scaled[:, 512:527], axis=0))
320
+ str_var = np.mean(np.var(X_scaled[:, 527:], axis=0))
321
+ logger.info(f" Mean variance β€” text: {emb_var:.3f} visual: {vis_var:.3f} structural: {str_var:.3f}")
322
+
323
+ # ── Train ────────────────────────────────────────────────────
324
+ logger.info("=" * 60)
325
+ logger.info("Training PosterSentry classifier")
326
+ logger.info("=" * 60)
327
+
328
+ X_tr, X_te, y_tr, y_te = train_test_split(
329
+ X_scaled, y, test_size=args.test_size, stratify=y, random_state=SEED,
330
+ )
331
+
332
+ logger.info(f"Train: {X_tr.shape[0]:,} | Test: {X_te.shape[0]:,}")
333
+ logger.info(f"Features: {X_tr.shape[1]} (512 text + 15 visual + 15 structural)")
334
+
335
+ clf = LogisticRegression(
336
+ C=1.0, max_iter=1000, class_weight="balanced",
337
+ solver="lbfgs", n_jobs=1, random_state=SEED,
338
+ )
339
+
340
+ t0 = time.time()
341
+ clf.fit(X_tr, y_tr)
342
+ elapsed = time.time() - t0
343
+ logger.info(f"Trained in {elapsed:.1f}s")
344
+
345
+ y_pred = clf.predict(X_te)
346
+ labels = ["non_poster", "poster"]
347
+ report = classification_report(y_te, y_pred, target_names=labels, digits=4)
348
+ logger.info(f"\n{report}")
349
+
350
+ # Show top feature importances
351
+ coef = clf.coef_[0]
352
+ all_names = (
353
+ [f"emb_{i}" for i in range(512)]
354
+ + list(VisualFeatureExtractor.FEATURE_NAMES)
355
+ + list(PDFStructuralExtractor.FEATURE_NAMES)
356
+ )
357
+ top_idx = np.argsort(np.abs(coef))[-15:][::-1]
358
+ logger.info("Top 15 features by |coefficient|:")
359
+ for idx in top_idx:
360
+ logger.info(f" {all_names[idx]:30s} coef={coef[idx]:+.4f}")
361
+
362
+ # ── Save head as .npz ────────────────────────────────────────
363
+ if clf.coef_.shape[0] == 1:
364
+ W = np.vstack([-clf.coef_[0], clf.coef_[0]]).T.astype("float32")
365
+ b = np.array([-clf.intercept_[0], clf.intercept_[0]], dtype="float32")
366
+ else:
367
+ W = clf.coef_.T.astype("float32")
368
+ b = clf.intercept_.astype("float32")
369
+
370
+ head_path = models_dir / "poster_sentry_head.npz"
371
+ np.savez(
372
+ head_path, W=W, b=b, labels=np.array(labels),
373
+ scaler_mean=scaler.mean_.astype("float32"),
374
+ scaler_scale=scaler.scale_.astype("float32"),
375
+ )
376
+ logger.info(f"Saved classifier head + scaler β†’ {head_path}")
377
+
378
+ # ── Smoke test ───────────────────────────────────────────────
379
+ logger.info("\n" + "=" * 60)
380
+ logger.info("SMOKE TEST")
381
+ logger.info("=" * 60)
382
+
383
+ from poster_sentry import PosterSentry
384
+
385
+ sentry = PosterSentry(models_dir=models_dir)
386
+ sentry.initialize()
387
+
388
+ # Test with some real PDFs
389
+ test_pdfs = poster_paths[:2] + non_poster_paths[:2]
390
+ for p in test_pdfs:
391
+ try:
392
+ result = sentry.classify(p)
393
+ icon = "πŸ“‹" if result["is_poster"] else "πŸ“„"
394
+ print(f" {icon} {Path(p).name[:60]:60s} poster={result['is_poster']} conf={result['confidence']:.3f}")
395
+ except Exception as e:
396
+ print(f" ⚠️ {Path(p).name[:60]}: {e}")
397
+
398
+ logger.info(f"\nDone! Model saved to: {models_dir}")
399
+
400
+
401
+ if __name__ == "__main__":
402
+ main()