jimnoneill commited on
Commit
d37e06a
·
verified ·
1 Parent(s): 28142d9

Upload train_abstract_archon.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_abstract_archon.py +431 -0
train_abstract_archon.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Abstract Archon — binary classifier: "Is this text a real research abstract?"
4
+
5
+ Uses Potion-base-32M (512-dim) + LogisticRegression, distilled from SVM-RBF.
6
+ Applied as a quality gate to every publication in the database.
7
+
8
+ Usage:
9
+ python train_abstract_archon.py --export # Export training data from PG
10
+ python train_abstract_archon.py --train # Train and save model
11
+ python train_abstract_archon.py --validate # Validate on held-out data
12
+ """
13
+
14
+ import argparse
15
+ import json
16
+ import os
17
+ import sys
18
+ import time
19
+ from pathlib import Path
20
+
21
+ import numpy as np
22
+ import psycopg2
23
+ import psycopg2.extras
24
+
25
+ DB_PARAMS = dict(host='localhost', port=5434, dbname='pubverse',
26
+ user='pubverse', password='pubverse123')
27
+
28
+ DATA_DIR = Path(__file__).parent / 'abstract_archon_data'
29
+ EXPORT_PATH = DATA_DIR / 'training_export.ndjson'
30
+ MODEL_PATH = DATA_DIR / 'abstract_archon_head.npz'
31
+
32
+ N_POSITIVES = 2000
33
+ N_NEGATIVES_PER_REASON = {
34
+ 'html_heavy': 250,
35
+ 'html_heavy_text': 250,
36
+ 'supplementary_content': 250,
37
+ 'author_byline': 200,
38
+ 'figure_table_caption': 250,
39
+ 'journal_article_scrape': 250,
40
+ 'moesm_title': 200,
41
+ 'taxonomy_stub': 200,
42
+ }
43
+ N_BORDERLINE_SHORT = 150
44
+
45
+
46
+ def get_conn():
47
+ return psycopg2.connect(**DB_PARAMS)
48
+
49
+
50
+ def export_data():
51
+ """Export training data from PostgreSQL."""
52
+ DATA_DIR.mkdir(exist_ok=True)
53
+ conn = get_conn()
54
+ records = []
55
+
56
+ # --- Load cleanup IDs into a set for fast lookup ---
57
+ print("Loading cleanup source_ids for exclusion filter...")
58
+ cleanup_ids = set()
59
+ with conn.cursor(name='cleanup_scan') as cur:
60
+ cur.itersize = 500000
61
+ cur.execute("SELECT source_id FROM _quality_cleanup_ids")
62
+ for row in cur:
63
+ cleanup_ids.add(row[0])
64
+ print(f" Loaded {len(cleanup_ids):,} cleanup IDs")
65
+
66
+ # --- Positive examples: TABLESAMPLE then filter in Python ---
67
+ print(f"Exporting {N_POSITIVES} positive examples (real abstracts)...")
68
+ with conn.cursor(name='pos_scan') as cur:
69
+ cur.itersize = 10000
70
+ cur.execute("""
71
+ SELECT p.source_id, LEFT(p.abstract, 500) as text
72
+ FROM publications p TABLESAMPLE BERNOULLI(0.005)
73
+ WHERE LENGTH(p.abstract) >= 200
74
+ """)
75
+ pos_count = 0
76
+ for source_id, text in cur:
77
+ if pos_count >= N_POSITIVES:
78
+ break
79
+ if source_id in cleanup_ids:
80
+ continue
81
+ if text and len(text.strip()) >= 50:
82
+ records.append({
83
+ 'text': text.strip()[:500],
84
+ 'label': 1,
85
+ 'source': 'positive_real_abstract',
86
+ 'source_id': source_id
87
+ })
88
+ pos_count += 1
89
+ if pos_count % 500 == 0:
90
+ print(f" {pos_count} positives collected...")
91
+ print(f" Got {pos_count} positives")
92
+
93
+ # --- Negative examples: known garbage by reason ---
94
+ # Pre-fetch source_ids per reason from the smaller cleanup table, then look up text
95
+ total_neg = 0
96
+ for reason, n in N_NEGATIVES_PER_REASON.items():
97
+ print(f"Exporting {n} negatives for reason={reason}...")
98
+ with conn.cursor() as cur:
99
+ # Fast: random sample from cleanup table (much smaller), then fetch text
100
+ cur.execute("""
101
+ SELECT q.source_id
102
+ FROM _quality_cleanup_ids q
103
+ WHERE q.reason = %s
104
+ ORDER BY RANDOM()
105
+ LIMIT %s
106
+ """, (reason, n * 3))
107
+ candidate_ids = [row[0] for row in cur.fetchall()]
108
+
109
+ # Fetch actual text for candidates
110
+ collected = 0
111
+ batch_size = 200
112
+ for i in range(0, len(candidate_ids), batch_size):
113
+ if collected >= n:
114
+ break
115
+ batch = candidate_ids[i:i+batch_size]
116
+ with conn.cursor() as cur:
117
+ cur.execute("""
118
+ SELECT source_id, LEFT(abstract, 500) as text
119
+ FROM publications
120
+ WHERE source_id = ANY(%s)
121
+ AND LENGTH(abstract) > 10
122
+ """, (batch,))
123
+ for source_id, text in cur.fetchall():
124
+ if collected >= n:
125
+ break
126
+ if text and len(text.strip()) > 5:
127
+ records.append({
128
+ 'text': text.strip()[:500],
129
+ 'label': 0,
130
+ 'source': f'negative_{reason}',
131
+ 'source_id': source_id
132
+ })
133
+ collected += 1
134
+ total_neg += 1
135
+ print(f" Got {collected} for {reason}, running total: {total_neg}")
136
+
137
+ # --- Borderline negatives: very short garbage texts ---
138
+ print(f"Exporting {N_BORDERLINE_SHORT} borderline short negatives...")
139
+ with conn.cursor() as cur:
140
+ cur.execute("""
141
+ SELECT q.source_id
142
+ FROM _quality_cleanup_ids q
143
+ WHERE q.reason NOT IN ('short_abstract', 'empty_abstract', 'non_english')
144
+ ORDER BY RANDOM()
145
+ LIMIT %s
146
+ """, (N_BORDERLINE_SHORT * 5,))
147
+ candidate_ids = [row[0] for row in cur.fetchall()]
148
+
149
+ collected = 0
150
+ for i in range(0, len(candidate_ids), 200):
151
+ if collected >= N_BORDERLINE_SHORT:
152
+ break
153
+ batch = candidate_ids[i:i+200]
154
+ with conn.cursor() as cur:
155
+ cur.execute("""
156
+ SELECT source_id, LEFT(abstract, 500) as text
157
+ FROM publications
158
+ WHERE source_id = ANY(%s)
159
+ AND LENGTH(abstract) BETWEEN 20 AND 100
160
+ """, (batch,))
161
+ for source_id, text in cur.fetchall():
162
+ if collected >= N_BORDERLINE_SHORT:
163
+ break
164
+ if text and len(text.strip()) > 5:
165
+ records.append({
166
+ 'text': text.strip()[:500],
167
+ 'label': 0,
168
+ 'source': 'negative_borderline_short',
169
+ 'source_id': source_id
170
+ })
171
+ collected += 1
172
+ total_neg += 1
173
+ print(f" Got {collected} borderline, total negatives: {total_neg}")
174
+
175
+ conn.close()
176
+
177
+ print(f"\nTotal: {len([r for r in records if r['label']==1])} positives, "
178
+ f"{len([r for r in records if r['label']==0])} negatives")
179
+
180
+ with open(EXPORT_PATH, 'w') as f:
181
+ for r in records:
182
+ f.write(json.dumps(r) + '\n')
183
+ print(f"Saved to {EXPORT_PATH}")
184
+
185
+
186
+ def train_model():
187
+ """Train SVM-RBF, distill to LogisticRegression, save .npz head."""
188
+ from model2vec import StaticModel
189
+ from sklearn.linear_model import LogisticRegression
190
+ from sklearn.metrics import (classification_report, confusion_matrix,
191
+ roc_auc_score)
192
+ from sklearn.model_selection import StratifiedKFold, train_test_split
193
+ from sklearn.preprocessing import StandardScaler
194
+ from sklearn.svm import SVC
195
+
196
+ print("Loading training data...")
197
+ records = []
198
+ with open(EXPORT_PATH) as f:
199
+ for line in f:
200
+ records.append(json.loads(line))
201
+
202
+ texts = [r['text'] for r in records]
203
+ labels = np.array([r['label'] for r in records])
204
+ sources = [r['source'] for r in records]
205
+ print(f" {len(texts)} samples: {labels.sum()} positive, {(1-labels).sum()} negative")
206
+
207
+ print("Embedding with Potion-base-32M...")
208
+ model = StaticModel.from_pretrained('minishlab/potion-base-32M')
209
+ embeddings = model.encode(texts, show_progress_bar=True)
210
+ print(f" Embeddings shape: {embeddings.shape}")
211
+
212
+ X_train, X_test, y_train, y_test, src_train, src_test, txt_train, txt_test = \
213
+ train_test_split(embeddings, labels, sources, texts,
214
+ test_size=0.2, random_state=42, stratify=labels)
215
+
216
+ scaler = StandardScaler()
217
+ X_train_s = scaler.fit_transform(X_train)
218
+ X_test_s = scaler.transform(X_test)
219
+
220
+ # --- Train SVM-RBF teacher ---
221
+ print("\nTraining SVM-RBF teacher...")
222
+ svm = SVC(kernel='rbf', probability=True, C=10.0, gamma='scale',
223
+ class_weight='balanced', random_state=42)
224
+ svm.fit(X_train_s, y_train)
225
+ svm_pred = svm.predict(X_test_s)
226
+ svm_proba = svm.predict_proba(X_test_s)[:, 1]
227
+ print("\n=== SVM-RBF Results ===")
228
+ print(classification_report(y_test, svm_pred, target_names=['garbage', 'abstract']))
229
+ print("Confusion matrix:\n", confusion_matrix(y_test, svm_pred))
230
+ print(f"ROC-AUC: {roc_auc_score(y_test, svm_proba):.4f}")
231
+
232
+ fn_rate = ((svm_pred == 0) & (y_test == 1)).sum() / (y_test == 1).sum()
233
+ print(f"False negative rate on real abstracts: {fn_rate:.4f}")
234
+
235
+ # Show misclassified examples
236
+ fn_mask = (svm_pred == 0) & (y_test == 1)
237
+ fp_mask = (svm_pred == 1) & (y_test == 0)
238
+ print(f"\n--- False Negatives (real abstracts called garbage): {fn_mask.sum()} ---")
239
+ fn_indices = np.where(fn_mask)[0]
240
+ for idx in fn_indices[:10]:
241
+ print(f" [{svm_proba[idx]:.3f}] {txt_test[idx][:120]}")
242
+ print(f"\n--- False Positives (garbage called abstract): {fp_mask.sum()} ---")
243
+ fp_indices = np.where(fp_mask)[0]
244
+ for idx in fp_indices[:10]:
245
+ print(f" [{svm_proba[idx]:.3f}] [{src_test[idx]}] {txt_test[idx][:120]}")
246
+
247
+ # --- Train LR directly (often better than distillation for small datasets) ---
248
+ print("\n\nTraining LogisticRegression directly...")
249
+ best_lr = None
250
+ best_auc = 0
251
+ for C in [0.01, 0.1, 1.0, 10.0, 100.0]:
252
+ lr = LogisticRegression(max_iter=5000, C=C, solver='lbfgs',
253
+ class_weight='balanced', random_state=42)
254
+ lr.fit(X_train_s, y_train)
255
+ lr_proba = lr.predict_proba(X_test_s)[:, 1]
256
+ auc = roc_auc_score(y_test, lr_proba)
257
+ lr_pred = lr.predict(X_test_s)
258
+ fn = ((lr_pred == 0) & (y_test == 1)).sum() / (y_test == 1).sum()
259
+ print(f" C={C:6.2f} → AUC={auc:.4f}, FNR={fn:.4f}")
260
+ if auc > best_auc:
261
+ best_auc = auc
262
+ best_lr = lr
263
+
264
+ lr = best_lr
265
+ lr_pred = lr.predict(X_test_s)
266
+ lr_proba = lr.predict_proba(X_test_s)[:, 1]
267
+ print(f"\n=== Best Direct LR Results (C={lr.C}) ===")
268
+ print(classification_report(y_test, lr_pred, target_names=['garbage', 'abstract']))
269
+ print("Confusion matrix:\n", confusion_matrix(y_test, lr_pred))
270
+ print(f"ROC-AUC: {roc_auc_score(y_test, lr_proba):.4f}")
271
+
272
+ fn_rate_lr = ((lr_pred == 0) & (y_test == 1)).sum() / (y_test == 1).sum()
273
+ print(f"LR False negative rate: {fn_rate_lr:.4f}")
274
+
275
+ # --- Also try SVM distillation for comparison ---
276
+ print("\nDistilling SVM → LR...")
277
+ svm_soft = svm.predict_proba(X_train_s)[:, 1]
278
+ lr_distilled = LogisticRegression(max_iter=5000, C=1.0, random_state=42)
279
+ lr_distilled.fit(X_train_s, (svm_soft > 0.5).astype(int))
280
+ dist_proba = lr_distilled.predict_proba(X_test_s)[:, 1]
281
+ dist_auc = roc_auc_score(y_test, dist_proba)
282
+ print(f" Distilled LR AUC: {dist_auc:.4f}")
283
+
284
+ # Pick the best LR variant
285
+ if dist_auc > best_auc:
286
+ print(" → Distilled LR wins, using that")
287
+ lr = lr_distilled
288
+ lr_proba = dist_proba
289
+ else:
290
+ print(f" → Direct LR wins (AUC {best_auc:.4f} vs {dist_auc:.4f})")
291
+
292
+ # Find threshold for ~99.5% recall on real abstracts
293
+ thresholds = np.arange(0.01, 0.99, 0.001)
294
+ best_t = 0.5
295
+ for t in thresholds:
296
+ pred_t = (lr_proba >= t).astype(int)
297
+ recall = ((pred_t == 1) & (y_test == 1)).sum() / (y_test == 1).sum()
298
+ precision_garbage = ((pred_t == 0) & (y_test == 0)).sum() / max((pred_t == 0).sum(), 1)
299
+ if recall >= 0.995:
300
+ best_t = t
301
+ print(f"\nAt threshold {t:.3f}: recall={recall:.4f}, garbage_precision={precision_garbage:.4f}")
302
+ break
303
+ else:
304
+ # Find the lowest threshold that gives max recall
305
+ for t in thresholds:
306
+ pred_t = (lr_proba >= t).astype(int)
307
+ recall = ((pred_t == 1) & (y_test == 1)).sum() / (y_test == 1).sum()
308
+ if recall >= 0.99:
309
+ best_t = t
310
+ print(f"\nRelaxed: threshold {t:.3f} gives recall={recall:.4f}")
311
+ break
312
+ else:
313
+ best_t = 0.1
314
+ pred_t = (lr_proba >= best_t).astype(int)
315
+ recall = ((pred_t == 1) & (y_test == 1)).sum() / (y_test == 1).sum()
316
+ print(f"\nFallback: threshold {best_t:.3f} gives recall={recall:.4f}")
317
+
318
+ # Save model
319
+ np.savez(MODEL_PATH,
320
+ coef=lr.coef_,
321
+ intercept=lr.intercept_,
322
+ classes=lr.classes_,
323
+ labels=np.array(['garbage', 'abstract']),
324
+ scaler_mean=scaler.mean_,
325
+ scaler_scale=scaler.scale_,
326
+ embed_model='minishlab/potion-base-32M',
327
+ version='v1',
328
+ threshold=np.array([best_t]))
329
+ print(f"\nSaved model to {MODEL_PATH}")
330
+ print(f"Model size: {MODEL_PATH.stat().st_size / 1024:.1f} KB")
331
+
332
+
333
+ def validate():
334
+ """Validate on held-out random publications."""
335
+ from model2vec import StaticModel
336
+
337
+ print("Loading model...")
338
+ data = np.load(MODEL_PATH, allow_pickle=True)
339
+ coef = data['coef']
340
+ intercept = data['intercept']
341
+ scaler_mean = data['scaler_mean']
342
+ scaler_scale = data['scaler_scale']
343
+ threshold = float(data['threshold'][0])
344
+ print(f" Threshold: {threshold:.3f}")
345
+
346
+ model = StaticModel.from_pretrained('minishlab/potion-base-32M')
347
+
348
+ # Load training source_ids to exclude
349
+ training_ids = set()
350
+ with open(EXPORT_PATH) as f:
351
+ for line in f:
352
+ r = json.loads(line)
353
+ training_ids.add(r['source_id'])
354
+
355
+ conn = get_conn()
356
+ print("Sampling 500 random publications for validation...")
357
+ with conn.cursor() as cur:
358
+ cur.execute("""
359
+ SELECT source_id, LEFT(abstract, 500) as text
360
+ FROM publications TABLESAMPLE BERNOULLI(0.001)
361
+ WHERE LENGTH(abstract) > 10
362
+ LIMIT 1000
363
+ """)
364
+ rows = cur.fetchall()
365
+
366
+ # Filter out training data
367
+ val_data = [(sid, t) for sid, t in rows if sid not in training_ids][:500]
368
+ conn.close()
369
+
370
+ texts = [t for _, t in val_data]
371
+ embeddings = model.encode(texts)
372
+ X_s = (embeddings - scaler_mean) / scaler_scale
373
+
374
+ # LR prediction
375
+ logits = X_s @ coef.T + intercept
376
+ from scipy.special import expit
377
+ probas = expit(logits)[:, 0] if coef.shape[0] == 1 else expit(logits)[:, 1]
378
+ preds = (probas >= threshold).astype(int)
379
+
380
+ print(f"\nResults on {len(texts)} validation samples:")
381
+ print(f" Predicted abstract: {preds.sum()}")
382
+ print(f" Predicted garbage: {(1-preds).sum()}")
383
+
384
+ # Show borderline cases
385
+ borderline = [(i, probas[i], texts[i][:120]) for i in range(len(texts))
386
+ if 0.3 <= probas[i] <= 0.7]
387
+ if borderline:
388
+ print(f"\n Borderline cases ({len(borderline)}):")
389
+ for i, p, t in borderline[:10]:
390
+ print(f" [{p:.3f}] {t}")
391
+
392
+ # Show confident garbage
393
+ garbage_idx = np.where(preds == 0)[0]
394
+ if len(garbage_idx) > 0:
395
+ print(f"\n Sample 'garbage' predictions:")
396
+ for idx in garbage_idx[:10]:
397
+ print(f" [{probas[idx]:.3f}] {texts[idx][:150]}")
398
+
399
+ # Sanity check PMID 39869795
400
+ print("\n Sanity check: PMID 39869795...")
401
+ conn = get_conn()
402
+ with conn.cursor() as cur:
403
+ cur.execute("SELECT LEFT(abstract, 500) FROM publications WHERE source_id LIKE '%39869795%' LIMIT 1")
404
+ row = cur.fetchone()
405
+ conn.close()
406
+ if row:
407
+ emb = model.encode([row[0]])
408
+ x_s = (emb - scaler_mean) / scaler_scale
409
+ logit = x_s @ coef.T + intercept
410
+ prob = expit(logit).flatten()
411
+ p = prob[0] if coef.shape[0] == 1 else prob[1]
412
+ print(f" Probability(abstract): {p:.4f} → {'PASS' if p >= threshold else 'FAIL'}")
413
+ else:
414
+ print(" PMID not found in database")
415
+
416
+
417
+ if __name__ == '__main__':
418
+ parser = argparse.ArgumentParser(description='Abstract Archon trainer')
419
+ parser.add_argument('--export', action='store_true', help='Export training data from PG')
420
+ parser.add_argument('--train', action='store_true', help='Train model')
421
+ parser.add_argument('--validate', action='store_true', help='Validate on held-out data')
422
+ args = parser.parse_args()
423
+
424
+ if args.export:
425
+ export_data()
426
+ elif args.train:
427
+ train_model()
428
+ elif args.validate:
429
+ validate()
430
+ else:
431
+ parser.print_help()