File size: 4,093 Bytes
4c8eee0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
#!/usr/bin/env python3
"""
Full training pipeline: download data → train heads → evaluate.
Usage:
cd /home/joneill/pubverse_brett/pub_check
source ~/myenv/bin/activate
pip install -e ".[train]"
python scripts/train_pubguard.py [--data-dir ./pubguard_data] [--n-per-class 15000]
"""
import argparse
import logging
import sys
import os
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
from pathlib import Path
from pubguard.config import PubGuardConfig
from pubguard.data import prepare_all
from pubguard.train import train_all
def main():
parser = argparse.ArgumentParser(description="Train PubGuard")
parser.add_argument("--data-dir", default="./pubguard_data",
help="Directory for training data")
parser.add_argument("--models-dir", default=None,
help="Override models output directory")
parser.add_argument("--n-per-class", type=int, default=15000,
help="Samples per class per head")
parser.add_argument("--test-size", type=float, default=0.15,
help="Held-out test fraction")
parser.add_argument("--skip-download", action="store_true",
help="Skip dataset download (use existing data)")
args = parser.parse_args()
data_dir = Path(args.data_dir)
config = PubGuardConfig()
if args.models_dir:
config.models_dir = Path(args.models_dir)
# Step 1: Download and prepare datasets
if not args.skip_download:
prepare_all(data_dir, n_per_class=args.n_per_class)
# Step 2: Train all heads
train_all(data_dir, config=config, test_size=args.test_size)
# Step 3: Quick smoke test
print("\n" + "=" * 60)
print("SMOKE TEST")
print("=" * 60)
from pubguard import PubGuard
guard = PubGuard(config=config)
guard.initialize()
test_cases = [
(
"Introduction: We present a novel deep learning approach for protein "
"structure prediction. Methods: We trained a transformer model on 50,000 "
"protein sequences from the PDB database. Results: Our model achieves "
"state-of-the-art accuracy with an RMSD of 1.2 Å on the CASP14 benchmark. "
"Discussion: These results demonstrate the potential of attention mechanisms "
"for structural biology. References: [1] AlphaFold (2021) [2] ESMFold (2022)",
"scientific_paper",
),
(
"🎉 POOL PARTY THIS SATURDAY! 🏊 Come join us at the community center "
"pool. Bring snacks and sunscreen. RSVP to poolparty@gmail.com by Thursday!",
"junk",
),
(
"TITLE: Deep Learning for Medical Imaging\nAUTHORS: J. Smith, A. Lee\n"
"AFFILIATION: MIT\n\nKey Findings:\n• 95% accuracy on chest X-rays\n"
"• Novel attention mechanism\n\nContact: jsmith@mit.edu",
"poster",
),
(
"We investigate the role of microRNAs in hepatocellular carcinoma "
"progression. Using RNA-seq data from 200 patient samples collected at "
"three clinical sites, we identified 15 differentially expressed miRNAs "
"associated with tumor stage (FDR < 0.01).",
"abstract_only",
),
]
for text, expected_type in test_cases:
verdict = guard.screen(text)
status = "✅" if verdict["doc_type"]["label"] == expected_type else "⚠️"
print(f" {status} Expected: {expected_type:20s} Got: {verdict['doc_type']['label']:20s} "
f"(score={verdict['doc_type']['score']:.3f})")
print(f" AI: {verdict['ai_generated']['label']} ({verdict['ai_generated']['score']:.3f}) "
f"Toxic: {verdict['toxicity']['label']} ({verdict['toxicity']['score']:.3f}) "
f"Pass: {verdict['pass']}")
print(f"\n✅ Training complete! Heads saved to: {config.models_dir}")
if __name__ == "__main__":
main()
|