jimnoneill commited on
Commit
4c8eee0
·
verified ·
1 Parent(s): f3101a4

Upload scripts/train_pubguard.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/train_pubguard.py +108 -0
scripts/train_pubguard.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Full training pipeline: download data → train heads → evaluate.
4
+
5
+ Usage:
6
+ cd /home/joneill/pubverse_brett/pub_check
7
+ source ~/myenv/bin/activate
8
+ pip install -e ".[train]"
9
+ python scripts/train_pubguard.py [--data-dir ./pubguard_data] [--n-per-class 15000]
10
+ """
11
+
12
+ import argparse
13
+ import logging
14
+ import sys
15
+ import os
16
+
17
+ logging.basicConfig(
18
+ level=logging.INFO,
19
+ format="%(asctime)s | %(levelname)s | %(message)s",
20
+ datefmt="%Y-%m-%d %H:%M:%S",
21
+ )
22
+
23
+ from pathlib import Path
24
+ from pubguard.config import PubGuardConfig
25
+ from pubguard.data import prepare_all
26
+ from pubguard.train import train_all
27
+
28
+
29
+ def main():
30
+ parser = argparse.ArgumentParser(description="Train PubGuard")
31
+ parser.add_argument("--data-dir", default="./pubguard_data",
32
+ help="Directory for training data")
33
+ parser.add_argument("--models-dir", default=None,
34
+ help="Override models output directory")
35
+ parser.add_argument("--n-per-class", type=int, default=15000,
36
+ help="Samples per class per head")
37
+ parser.add_argument("--test-size", type=float, default=0.15,
38
+ help="Held-out test fraction")
39
+ parser.add_argument("--skip-download", action="store_true",
40
+ help="Skip dataset download (use existing data)")
41
+ args = parser.parse_args()
42
+
43
+ data_dir = Path(args.data_dir)
44
+ config = PubGuardConfig()
45
+ if args.models_dir:
46
+ config.models_dir = Path(args.models_dir)
47
+
48
+ # Step 1: Download and prepare datasets
49
+ if not args.skip_download:
50
+ prepare_all(data_dir, n_per_class=args.n_per_class)
51
+
52
+ # Step 2: Train all heads
53
+ train_all(data_dir, config=config, test_size=args.test_size)
54
+
55
+ # Step 3: Quick smoke test
56
+ print("\n" + "=" * 60)
57
+ print("SMOKE TEST")
58
+ print("=" * 60)
59
+
60
+ from pubguard import PubGuard
61
+
62
+ guard = PubGuard(config=config)
63
+ guard.initialize()
64
+
65
+ test_cases = [
66
+ (
67
+ "Introduction: We present a novel deep learning approach for protein "
68
+ "structure prediction. Methods: We trained a transformer model on 50,000 "
69
+ "protein sequences from the PDB database. Results: Our model achieves "
70
+ "state-of-the-art accuracy with an RMSD of 1.2 Å on the CASP14 benchmark. "
71
+ "Discussion: These results demonstrate the potential of attention mechanisms "
72
+ "for structural biology. References: [1] AlphaFold (2021) [2] ESMFold (2022)",
73
+ "scientific_paper",
74
+ ),
75
+ (
76
+ "🎉 POOL PARTY THIS SATURDAY! 🏊 Come join us at the community center "
77
+ "pool. Bring snacks and sunscreen. RSVP to poolparty@gmail.com by Thursday!",
78
+ "junk",
79
+ ),
80
+ (
81
+ "TITLE: Deep Learning for Medical Imaging\nAUTHORS: J. Smith, A. Lee\n"
82
+ "AFFILIATION: MIT\n\nKey Findings:\n• 95% accuracy on chest X-rays\n"
83
+ "• Novel attention mechanism\n\nContact: jsmith@mit.edu",
84
+ "poster",
85
+ ),
86
+ (
87
+ "We investigate the role of microRNAs in hepatocellular carcinoma "
88
+ "progression. Using RNA-seq data from 200 patient samples collected at "
89
+ "three clinical sites, we identified 15 differentially expressed miRNAs "
90
+ "associated with tumor stage (FDR < 0.01).",
91
+ "abstract_only",
92
+ ),
93
+ ]
94
+
95
+ for text, expected_type in test_cases:
96
+ verdict = guard.screen(text)
97
+ status = "✅" if verdict["doc_type"]["label"] == expected_type else "⚠️"
98
+ print(f" {status} Expected: {expected_type:20s} Got: {verdict['doc_type']['label']:20s} "
99
+ f"(score={verdict['doc_type']['score']:.3f})")
100
+ print(f" AI: {verdict['ai_generated']['label']} ({verdict['ai_generated']['score']:.3f}) "
101
+ f"Toxic: {verdict['toxicity']['label']} ({verdict['toxicity']['score']:.3f}) "
102
+ f"Pass: {verdict['pass']}")
103
+
104
+ print(f"\n✅ Training complete! Heads saved to: {config.models_dir}")
105
+
106
+
107
+ if __name__ == "__main__":
108
+ main()