jimnoneill commited on
Commit
cd3adb9
·
verified ·
1 Parent(s): 0b39aef

Upload src/pubguard/cli.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/pubguard/cli.py +197 -0
src/pubguard/cli.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Command-line interface for PubGuard.
3
+
4
+ Usage:
5
+ # Download datasets and train
6
+ pubguard train --data-dir ./data
7
+
8
+ # Download datasets only
9
+ pubguard prepare --data-dir ./data
10
+
11
+ # Screen a text file
12
+ pubguard screen input.txt
13
+
14
+ # Screen extracted PDF text from stdin
15
+ cat extracted_text.txt | pubguard screen -
16
+
17
+ # Batch screen NDJSON
18
+ pubguard batch input.ndjson output.ndjson
19
+ """
20
+
21
+ import argparse
22
+ import json
23
+ import logging
24
+ import sys
25
+ import time
26
+ from pathlib import Path
27
+
28
+ from .classifier import PubGuard
29
+ from .config import PubGuardConfig
30
+
31
+
32
+ def cmd_prepare(args):
33
+ """Download and prepare training datasets."""
34
+ from .data import prepare_all
35
+
36
+ prepare_all(Path(args.data_dir), n_per_class=args.n_per_class)
37
+
38
+
39
+ def cmd_train(args):
40
+ """Prepare data (if needed) and train all heads."""
41
+ from .data import prepare_all
42
+ from .train import train_all
43
+
44
+ data_dir = Path(args.data_dir)
45
+
46
+ if args.download:
47
+ prepare_all(data_dir, n_per_class=args.n_per_class)
48
+
49
+ config = PubGuardConfig()
50
+ if args.models_dir:
51
+ config.models_dir = Path(args.models_dir)
52
+
53
+ train_all(data_dir, config=config, test_size=args.test_size)
54
+
55
+
56
+ def cmd_screen(args):
57
+ """Screen a single document."""
58
+ config = PubGuardConfig()
59
+ if args.models_dir:
60
+ config.models_dir = Path(args.models_dir)
61
+
62
+ guard = PubGuard(config=config)
63
+ guard.initialize()
64
+
65
+ if args.input == "-":
66
+ text = sys.stdin.read()
67
+ else:
68
+ text = Path(args.input).read_text(errors="replace")
69
+
70
+ verdict = guard.screen(text)
71
+
72
+ if args.json:
73
+ print(json.dumps(verdict, indent=2))
74
+ else:
75
+ _print_verdict(verdict)
76
+
77
+
78
+ def cmd_batch(args):
79
+ """Batch-screen an NDJSON file."""
80
+ config = PubGuardConfig()
81
+ if args.models_dir:
82
+ config.models_dir = Path(args.models_dir)
83
+
84
+ guard = PubGuard(config=config)
85
+ guard.initialize()
86
+
87
+ start = time.time()
88
+ processed = 0
89
+
90
+ with open(args.input) as fin, open(args.output, "w") as fout:
91
+ batch_texts = []
92
+ batch_records = []
93
+
94
+ for line in fin:
95
+ if not line.strip():
96
+ continue
97
+ record = json.loads(line)
98
+ text = record.get("text", "") or record.get("abstract", "") or ""
99
+ batch_texts.append(text)
100
+ batch_records.append(record)
101
+
102
+ if len(batch_texts) >= config.batch_size:
103
+ verdicts = guard.screen_batch(batch_texts)
104
+ for rec, verd in zip(batch_records, verdicts):
105
+ rec["pubguard"] = verd
106
+ fout.write(json.dumps(rec) + "\n")
107
+ processed += len(batch_texts)
108
+ batch_texts, batch_records = [], []
109
+
110
+ # Final batch
111
+ if batch_texts:
112
+ verdicts = guard.screen_batch(batch_texts)
113
+ for rec, verd in zip(batch_records, verdicts):
114
+ rec["pubguard"] = verd
115
+ fout.write(json.dumps(rec) + "\n")
116
+ processed += len(batch_texts)
117
+
118
+ elapsed = time.time() - start
119
+ rate = processed / elapsed if elapsed > 0 else 0
120
+ print(f"Screened {processed:,} records in {elapsed:.1f}s ({rate:,.0f} rec/s)")
121
+ print(f"Output: {args.output}")
122
+
123
+
124
+ def _print_verdict(v: dict):
125
+ """Pretty-print a verdict."""
126
+ pass_icon = "✅" if v["pass"] else "❌"
127
+ print(f"\n{pass_icon} PubGuard Verdict: {'PASS' if v['pass'] else 'FAIL'}")
128
+ print(f" Document type: {v['doc_type']['label']:20s} (score: {v['doc_type']['score']:.3f})")
129
+ print(f" AI detection: {v['ai_generated']['label']:20s} (score: {v['ai_generated']['score']:.3f})")
130
+ print(f" Toxicity: {v['toxicity']['label']:20s} (score: {v['toxicity']['score']:.3f})")
131
+ print()
132
+
133
+
134
+ def main():
135
+ parser = argparse.ArgumentParser(
136
+ description="PubGuard — Scientific Publication Gatekeeper",
137
+ formatter_class=argparse.RawDescriptionHelpFormatter,
138
+ )
139
+ parser.add_argument(
140
+ "--verbose", "-v", action="store_true",
141
+ help="Enable verbose logging",
142
+ )
143
+ parser.add_argument(
144
+ "--models-dir", type=str, default=None,
145
+ help="Override models directory",
146
+ )
147
+
148
+ subparsers = parser.add_subparsers(dest="command")
149
+
150
+ # prepare
151
+ p_prepare = subparsers.add_parser("prepare", help="Download and prepare datasets")
152
+ p_prepare.add_argument("--data-dir", default="./pubguard_data")
153
+ p_prepare.add_argument("--n-per-class", type=int, default=15000)
154
+
155
+ # train
156
+ p_train = subparsers.add_parser("train", help="Train classification heads")
157
+ p_train.add_argument("--data-dir", default="./pubguard_data")
158
+ p_train.add_argument("--models-dir", default=None)
159
+ p_train.add_argument("--download", action="store_true", default=True,
160
+ help="Download datasets before training")
161
+ p_train.add_argument("--no-download", action="store_false", dest="download")
162
+ p_train.add_argument("--n-per-class", type=int, default=15000)
163
+ p_train.add_argument("--test-size", type=float, default=0.15)
164
+
165
+ # screen
166
+ p_screen = subparsers.add_parser("screen", help="Screen a single document")
167
+ p_screen.add_argument("input", help="Text file to screen (or - for stdin)")
168
+ p_screen.add_argument("--json", action="store_true", help="JSON output")
169
+
170
+ # batch
171
+ p_batch = subparsers.add_parser("batch", help="Batch screen NDJSON")
172
+ p_batch.add_argument("input", help="Input NDJSON file")
173
+ p_batch.add_argument("output", help="Output NDJSON file")
174
+
175
+ args = parser.parse_args()
176
+
177
+ level = logging.DEBUG if args.verbose else logging.INFO
178
+ logging.basicConfig(
179
+ level=level,
180
+ format="%(asctime)s | %(levelname)s | %(message)s",
181
+ datefmt="%Y-%m-%d %H:%M:%S",
182
+ )
183
+
184
+ if args.command == "prepare":
185
+ cmd_prepare(args)
186
+ elif args.command == "train":
187
+ cmd_train(args)
188
+ elif args.command == "screen":
189
+ cmd_screen(args)
190
+ elif args.command == "batch":
191
+ cmd_batch(args)
192
+ else:
193
+ parser.print_help()
194
+
195
+
196
+ if __name__ == "__main__":
197
+ main()