Upload scripts/train_pubguard.py with huggingface_hub
Browse files- scripts/train_pubguard.py +108 -0
scripts/train_pubguard.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Full training pipeline: download data → train heads → evaluate.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
cd /home/joneill/pubverse_brett/pub_check
|
| 7 |
+
source ~/myenv/bin/activate
|
| 8 |
+
pip install -e ".[train]"
|
| 9 |
+
python scripts/train_pubguard.py [--data-dir ./pubguard_data] [--n-per-class 15000]
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import logging
|
| 14 |
+
import sys
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
logging.basicConfig(
|
| 18 |
+
level=logging.INFO,
|
| 19 |
+
format="%(asctime)s | %(levelname)s | %(message)s",
|
| 20 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from pubguard.config import PubGuardConfig
|
| 25 |
+
from pubguard.data import prepare_all
|
| 26 |
+
from pubguard.train import train_all
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def main():
|
| 30 |
+
parser = argparse.ArgumentParser(description="Train PubGuard")
|
| 31 |
+
parser.add_argument("--data-dir", default="./pubguard_data",
|
| 32 |
+
help="Directory for training data")
|
| 33 |
+
parser.add_argument("--models-dir", default=None,
|
| 34 |
+
help="Override models output directory")
|
| 35 |
+
parser.add_argument("--n-per-class", type=int, default=15000,
|
| 36 |
+
help="Samples per class per head")
|
| 37 |
+
parser.add_argument("--test-size", type=float, default=0.15,
|
| 38 |
+
help="Held-out test fraction")
|
| 39 |
+
parser.add_argument("--skip-download", action="store_true",
|
| 40 |
+
help="Skip dataset download (use existing data)")
|
| 41 |
+
args = parser.parse_args()
|
| 42 |
+
|
| 43 |
+
data_dir = Path(args.data_dir)
|
| 44 |
+
config = PubGuardConfig()
|
| 45 |
+
if args.models_dir:
|
| 46 |
+
config.models_dir = Path(args.models_dir)
|
| 47 |
+
|
| 48 |
+
# Step 1: Download and prepare datasets
|
| 49 |
+
if not args.skip_download:
|
| 50 |
+
prepare_all(data_dir, n_per_class=args.n_per_class)
|
| 51 |
+
|
| 52 |
+
# Step 2: Train all heads
|
| 53 |
+
train_all(data_dir, config=config, test_size=args.test_size)
|
| 54 |
+
|
| 55 |
+
# Step 3: Quick smoke test
|
| 56 |
+
print("\n" + "=" * 60)
|
| 57 |
+
print("SMOKE TEST")
|
| 58 |
+
print("=" * 60)
|
| 59 |
+
|
| 60 |
+
from pubguard import PubGuard
|
| 61 |
+
|
| 62 |
+
guard = PubGuard(config=config)
|
| 63 |
+
guard.initialize()
|
| 64 |
+
|
| 65 |
+
test_cases = [
|
| 66 |
+
(
|
| 67 |
+
"Introduction: We present a novel deep learning approach for protein "
|
| 68 |
+
"structure prediction. Methods: We trained a transformer model on 50,000 "
|
| 69 |
+
"protein sequences from the PDB database. Results: Our model achieves "
|
| 70 |
+
"state-of-the-art accuracy with an RMSD of 1.2 Å on the CASP14 benchmark. "
|
| 71 |
+
"Discussion: These results demonstrate the potential of attention mechanisms "
|
| 72 |
+
"for structural biology. References: [1] AlphaFold (2021) [2] ESMFold (2022)",
|
| 73 |
+
"scientific_paper",
|
| 74 |
+
),
|
| 75 |
+
(
|
| 76 |
+
"🎉 POOL PARTY THIS SATURDAY! 🏊 Come join us at the community center "
|
| 77 |
+
"pool. Bring snacks and sunscreen. RSVP to poolparty@gmail.com by Thursday!",
|
| 78 |
+
"junk",
|
| 79 |
+
),
|
| 80 |
+
(
|
| 81 |
+
"TITLE: Deep Learning for Medical Imaging\nAUTHORS: J. Smith, A. Lee\n"
|
| 82 |
+
"AFFILIATION: MIT\n\nKey Findings:\n• 95% accuracy on chest X-rays\n"
|
| 83 |
+
"• Novel attention mechanism\n\nContact: jsmith@mit.edu",
|
| 84 |
+
"poster",
|
| 85 |
+
),
|
| 86 |
+
(
|
| 87 |
+
"We investigate the role of microRNAs in hepatocellular carcinoma "
|
| 88 |
+
"progression. Using RNA-seq data from 200 patient samples collected at "
|
| 89 |
+
"three clinical sites, we identified 15 differentially expressed miRNAs "
|
| 90 |
+
"associated with tumor stage (FDR < 0.01).",
|
| 91 |
+
"abstract_only",
|
| 92 |
+
),
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
for text, expected_type in test_cases:
|
| 96 |
+
verdict = guard.screen(text)
|
| 97 |
+
status = "✅" if verdict["doc_type"]["label"] == expected_type else "⚠️"
|
| 98 |
+
print(f" {status} Expected: {expected_type:20s} Got: {verdict['doc_type']['label']:20s} "
|
| 99 |
+
f"(score={verdict['doc_type']['score']:.3f})")
|
| 100 |
+
print(f" AI: {verdict['ai_generated']['label']} ({verdict['ai_generated']['score']:.3f}) "
|
| 101 |
+
f"Toxic: {verdict['toxicity']['label']} ({verdict['toxicity']['score']:.3f}) "
|
| 102 |
+
f"Pass: {verdict['pass']}")
|
| 103 |
+
|
| 104 |
+
print(f"\n✅ Training complete! Heads saved to: {config.models_dir}")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
if __name__ == "__main__":
|
| 108 |
+
main()
|