jimnoneill commited on
Commit
0bf5290
Β·
verified Β·
1 Parent(s): 3e68bea

Upload src/pubguard/data.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/pubguard/data.py +550 -0
src/pubguard/data.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset preparation for PubGuard training.
3
+
4
+ Downloads publicly available datasets from HuggingFace and assembles
5
+ them into the three labelled corpora needed by the training pipeline.
6
+
7
+ Datasets used (verified available 2026-02)
8
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
9
+
10
+ **Head 1 β€” Document Type** (scientific_paper | poster | abstract_only | junk)
11
+
12
+ Positive (scientific_paper):
13
+ - armanc/scientific_papers (arxiv) ~300 K full-text articles
14
+ cols: article, abstract, section_names
15
+
16
+ Negative (abstract_only):
17
+ - gfissore/arxiv-abstracts-2021 ~2 M abstracts
18
+ cols: abstract (filter length < 600 chars)
19
+
20
+ Negative (junk):
21
+ - ag_news (news articles) + synthetic templates (flyers, invoices, etc.)
22
+
23
+ Negative (poster):
24
+ - Synthetic poster-style structured text
25
+
26
+ **Head 2 β€” AI-Generated Text Detection**
27
+
28
+ - liamdugan/raid – multi-model generations, domain="abstracts"
29
+ cols: model, domain, generation (model="human" for human text)
30
+ - NicolaiSivesind/ChatGPT-Research-Abstracts – real + GPT-3.5 abstracts
31
+ cols: real_abstract, generated_abstract
32
+
33
+ **Head 3 β€” Toxicity**
34
+
35
+ - google/civil_comments – 1.8 M comments with toxicity scores (0–1)
36
+ cols: text, toxicity
37
+ - skg/toxigen-data – 274 K annotated toxic/benign statements
38
+ cols: text, toxicity_human (1–5 scale)
39
+ """
40
+
41
+ import json
42
+ import logging
43
+ import random
44
+ from pathlib import Path
45
+ from typing import Dict, List, Tuple
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+ # ── Constants ────────────────────────────────────────────────────
50
+
51
+ SEED = 42
52
+ random.seed(SEED)
53
+
54
+ # ── Synthetic templates ──────────────────────────────────────────
55
+
56
+ JUNK_TEMPLATES = [
57
+ "πŸŽ‰ Annual {event} at {place}! Join us on {date}. Free food and drinks. RSVP to {email}.",
58
+ "FOR SALE: {item}. Great condition. ${price}. Contact {name} at {phone}.",
59
+ "{company} is hiring! We're looking for a {role}. Apply now at {url}.",
60
+ "NOTICE: The {dept} office will be closed on {date} for {reason}. Questions? Call {phone}.",
61
+ "Don't miss our {event}! {date} from {time}. {place}. Tickets: ${price}.",
62
+ "Weekly newsletter from {company}. This week: {topic1}, {topic2}, and more!",
63
+ "Invoice #{num} from {company}. Amount due: ${price}. Payment due by {date}.",
64
+ "Meeting agenda for {date}. 1) {topic1} 2) {topic2} 3) {topic3}. Location: {place}.",
65
+ "URGENT: Your {account} password expires on {date}. Click here to reset: {url}.",
66
+ "Congratulations {name}! You've been selected for our exclusive {event}. Limited spots!",
67
+ "Thank you for your purchase! Order #{num}. Estimated delivery: {date}.",
68
+ "{company} presents the {event}. Keynote by {name}. Register at {url}.",
69
+ "Garage sale this weekend! {place}. {date} {time}. Everything must go!",
70
+ "Happy Birthday to {name} from all of us at {company}! πŸŽ‚",
71
+ "POOL PARTY! 🏊 Come join us at {place} on {date}. Bring your swimsuit and sunscreen!",
72
+ "Menu for this week: Monday: {food1}. Tuesday: {food2}. Wednesday: {food3}.",
73
+ "Building maintenance notice: {reason} on {date}. Please plan accordingly.",
74
+ "Lost & Found: {item} found near {place}. Contact front desk to claim.",
75
+ "Fantasy Football League draft is on {date}! Don't forget to submit your picks.",
76
+ "Book club meeting: We're reading '{book}' by {name}. Discussion on {date}.",
77
+ "Hey everyone! Movie night at {place} on {date}. We're watching '{movie}'. Bring popcorn!",
78
+ "Reminder: Staff meeting {date} at {time}. Attendance mandatory. {dept}.",
79
+ "Lost cat! Orange tabby, answers to '{pet_name}'. Last seen near {place}. Call {phone}.",
80
+ "HOT DEAL! {item} only ${price}! Limited time offer. Visit {url}.",
81
+ "Club registration open! Join the {club} club. Meetings every {day} at {time}. {place}.",
82
+ "Fundraiser bake sale! {date} at {place}. All proceeds go to {charity}.",
83
+ "Apartment for rent: 2BR/1BA near {place}. ${price}/month. Pet friendly. Call {phone}.",
84
+ "Yoga class every {day} at {time}. {place}. All levels welcome. Bring your own mat!",
85
+ "IT Alert: System maintenance scheduled for {date}. Expected downtime: {time}. {dept}.",
86
+ "Carpool needed! Driving from {place} to {place2} daily. Contact {name} at {email}.",
87
+ ]
88
+
89
+ POSTER_TEMPLATES = [
90
+ "TITLE: {title}\n\nAUTHORS: {authors}\nAFFILIATION: {affil}\n\nINTRODUCTION\n{intro}\n\nMETHODS\n{methods}\n\nRESULTS\n{results}\n\nCONCLUSIONS\n{conclusions}\n\nACKNOWLEDGMENTS\n{ack}",
91
+ "{title}\n{authors} | {affil}\n\nBackground: {intro}\n\nApproach: {methods}\n\nKey Findings:\nβ€’ {finding1}\nβ€’ {finding2}\nβ€’ {finding3}\n\nFuture Work: {future}\n\nContact: {email}",
92
+ "POSTER PRESENTATION\n\n{title}\n\n{authors}\n{affil}\n\nObjective: {intro}\n\nDesign: {methods}\n\nOutcome: {results}\n\nConclusion: {conclusions}",
93
+ "{title}\n\n{authors} ({affil})\n\nAim: {intro}\nMethod: {methods}\nResult: {results}\nSummary: {conclusions}\n\nCorrespondence: {email}",
94
+ "RESEARCH POSTER\n─────────────────────\n{title}\n{authors}\n{affil}\n\nβ–Έ Background\n{intro}\n\nβ–Έ Methods\n{methods}\n\nβ–Έ Results\nβ€’ {finding1}\nβ€’ {finding2}\n\nβ–Έ Conclusion\n{conclusions}\n\nFunding: {ack}",
95
+ ]
96
+
97
+
98
+ def _fill_template(template: str) -> str:
99
+ """Fill a template with random plausible values."""
100
+ fillers = {
101
+ "{event}": random.choice(["Pool Party", "BBQ Bash", "Career Fair", "Fundraiser Gala", "Open House", "Trivia Night"]),
102
+ "{place}": random.choice(["Room 201", "Hilton Downtown", "the Community Center", "Central Park", "Building B Courtyard", "Main Auditorium"]),
103
+ "{place2}": random.choice(["Campus North", "Downtown", "Tech Park", "Medical Center"]),
104
+ "{date}": random.choice(["March 15", "June 22", "Sept 5", "November 10", "January 30", "Friday the 13th"]),
105
+ "{email}": "info@example.com",
106
+ "{item}": random.choice(["2019 Honda Civic", "MacBook Pro 16-inch", "Standing Desk", "Mountain Bike", "Vintage Guitar"]),
107
+ "{price}": str(random.randint(10, 5000)),
108
+ "{name}": random.choice(["Dr. Smith", "Jane Doe", "Prof. Chen", "Maria Garcia", "Bob Wilson"]),
109
+ "{phone}": "555-0123",
110
+ "{company}": random.choice(["TechCorp", "BioGen Inc.", "Global Solutions", "Acme Labs", "DataFlow Systems"]),
111
+ "{role}": random.choice(["Data Scientist", "Lab Technician", "Project Manager", "Software Engineer"]),
112
+ "{url}": "https://example.com/apply",
113
+ "{dept}": random.choice(["HR", "Finance", "Engineering", "Admissions", "IT Support"]),
114
+ "{reason}": random.choice(["maintenance", "holiday", "training day", "renovation", "fire drill"]),
115
+ "{time}": random.choice(["2-5 PM", "10 AM - 3 PM", "6-9 PM", "All Day", "Noon"]),
116
+ "{topic1}": random.choice(["Q3 Review", "Budget Update", "New Hires", "Project Status"]),
117
+ "{topic2}": random.choice(["Safety Training", "Holiday Schedule", "IT Migration", "Team Building"]),
118
+ "{topic3}": random.choice(["Parking Changes", "Wellness Program", "Open Q&A"]),
119
+ "{account}": random.choice(["university", "corporate", "cloud storage"]),
120
+ "{num}": str(random.randint(10000, 99999)),
121
+ "{food1}": "Pasta Primavera", "{food2}": "Chicken Tikka", "{food3}": "Fish Tacos",
122
+ "{book}": random.choice(["1984", "Sapiens", "The Gene", "Thinking, Fast and Slow"]),
123
+ "{movie}": random.choice(["Inception", "The Matrix", "Interstellar"]),
124
+ "{pet_name}": random.choice(["Whiskers", "Max", "Luna"]),
125
+ "{club}": random.choice(["Chess", "Photography", "Hiking", "Debate"]),
126
+ "{day}": random.choice(["Monday", "Wednesday", "Friday"]),
127
+ "{charity}": random.choice(["Children's Hospital", "Local Food Bank", "Animal Shelter"]),
128
+ "{title}": random.choice([
129
+ "Effects of Temperature on Enzyme Kinetics in Thermophilic Bacteria",
130
+ "Deep Learning for Medical Image Segmentation: A Systematic Review",
131
+ "Novel Biomarkers in Cardiovascular Disease Progression",
132
+ "Metagenomic Analysis of Coral Reef Microbiomes Under Thermal Stress",
133
+ "CRISPR-Cas9 Editing Efficiency in Human iPSC-Derived Neurons",
134
+ ]),
135
+ "{authors}": random.choice(["A. Smith, B. Jones, C. Lee", "R. Patel, S. Kim, T. Brown", "M. Wang, L. Davis"]),
136
+ "{affil}": random.choice(["University of Example, Dept. of Science", "MIT, CSAIL", "Stanford School of Medicine"]),
137
+ "{intro}": random.choice([
138
+ "Background text about the research problem being investigated.",
139
+ "This study addresses the gap in understanding of X in the context of Y.",
140
+ "Recent advances in Z have highlighted the need for improved W.",
141
+ ]),
142
+ "{methods}": random.choice([
143
+ "We employed a cross-sectional study design with N=200 participants.",
144
+ "Samples were collected from 5 sites and processed using standard protocols.",
145
+ "We developed a convolutional neural network trained on 50K labeled images.",
146
+ ]),
147
+ "{results}": random.choice([
148
+ "Treatment group showed 45% improvement (p<0.01) compared to control.",
149
+ "Our model achieved 94.2% accuracy on the held-out test set.",
150
+ "We identified 23 significantly enriched pathways (FDR < 0.05).",
151
+ ]),
152
+ "{conclusions}": random.choice([
153
+ "Our findings support the hypothesis that X leads to improved Y.",
154
+ "These results demonstrate the feasibility of the proposed approach.",
155
+ "Further validation with larger cohorts is warranted.",
156
+ ]),
157
+ "{finding1}": "Significant reduction in error rate (p<0.001)",
158
+ "{finding2}": "Model outperformed baseline by 15%",
159
+ "{finding3}": "Robust to distribution shift across domains",
160
+ "{future}": "Extend to longitudinal datasets and multi-site validation.",
161
+ "{ack}": random.choice(["Funded by NIH Grant R01-ABC123.", "Supported by NSF Award #1234567."]),
162
+ }
163
+ result = template
164
+ for key, val in fillers.items():
165
+ result = result.replace(key, val)
166
+ return result
167
+
168
+
169
+ def generate_synthetic_junk(n: int = 5000) -> List[Dict[str, str]]:
170
+ """Generate synthetic junk documents."""
171
+ samples = []
172
+ for _ in range(n):
173
+ template = random.choice(JUNK_TEMPLATES)
174
+ text = _fill_template(template)
175
+ samples.append({"text": text, "label": "junk"})
176
+ return samples
177
+
178
+
179
+ def generate_synthetic_posters(n: int = 3000) -> List[Dict[str, str]]:
180
+ """Generate synthetic poster-style documents."""
181
+ samples = []
182
+ for _ in range(n):
183
+ template = random.choice(POSTER_TEMPLATES)
184
+ text = _fill_template(template)
185
+ samples.append({"text": text, "label": "poster"})
186
+ return samples
187
+
188
+
189
+ # ── Head 1: doc_type ────────────────────────────────────────────
190
+
191
+ def prepare_doc_type_dataset(
192
+ output_dir: Path,
193
+ n_per_class: int = 15000,
194
+ ) -> Path:
195
+ """
196
+ Assemble and save document-type training data.
197
+
198
+ Downloads from HuggingFace and combines with synthetic data.
199
+ Saves as NDJSON: {text, label}
200
+ """
201
+ from datasets import load_dataset
202
+
203
+ output_dir.mkdir(parents=True, exist_ok=True)
204
+ output_path = output_dir / "doc_type_train.ndjson"
205
+ all_samples = []
206
+
207
+ logger.info("=== Preparing doc_type dataset ===")
208
+
209
+ # ── scientific_paper ─────────────────────────────────────────
210
+ logger.info("Loading armanc/scientific_papers (arxiv split)...")
211
+ try:
212
+ ds = load_dataset(
213
+ "armanc/scientific_papers", "arxiv",
214
+ split="train", streaming=True, trust_remote_code=True,
215
+ )
216
+ count = 0
217
+ for row in ds:
218
+ if count >= n_per_class:
219
+ break
220
+ # Combine abstract + article body for full-text signal
221
+ abstract = row.get("abstract", "") or ""
222
+ article = row.get("article", "") or ""
223
+ text = (abstract + " " + article)[:4000]
224
+ if len(text.strip()) > 100:
225
+ all_samples.append({"text": text.strip(), "label": "scientific_paper"})
226
+ count += 1
227
+ logger.info(f" scientific_paper: {count}")
228
+ except Exception as e:
229
+ logger.warning(f"Could not load scientific_papers: {e}")
230
+ # Fallback
231
+ logger.info("Falling back to ccdv/arxiv-summarization...")
232
+ try:
233
+ ds = load_dataset(
234
+ "ccdv/arxiv-summarization",
235
+ split="train", streaming=True, trust_remote_code=True,
236
+ )
237
+ count = 0
238
+ for row in ds:
239
+ if count >= n_per_class:
240
+ break
241
+ text = ((row.get("abstract", "") or "") + " " + (row.get("article", "") or ""))[:4000]
242
+ if len(text.strip()) > 100:
243
+ all_samples.append({"text": text.strip(), "label": "scientific_paper"})
244
+ count += 1
245
+ logger.info(f" scientific_paper (fallback): {count}")
246
+ except Exception as e2:
247
+ logger.error(f"Fallback also failed: {e2}")
248
+
249
+ # ── abstract_only ────────────────────────────────────────────
250
+ logger.info("Loading gfissore/arxiv-abstracts-2021...")
251
+ try:
252
+ ds = load_dataset(
253
+ "gfissore/arxiv-abstracts-2021",
254
+ split="train", streaming=True, trust_remote_code=True,
255
+ )
256
+ count = 0
257
+ for row in ds:
258
+ if count >= n_per_class:
259
+ break
260
+ abstract = row.get("abstract", "")
261
+ if abstract and 50 < len(abstract) < 600:
262
+ all_samples.append({"text": abstract.strip(), "label": "abstract_only"})
263
+ count += 1
264
+ logger.info(f" abstract_only: {count}")
265
+ except Exception as e:
266
+ logger.warning(f"Could not load arxiv-abstracts: {e}")
267
+ # Fallback: extract abstracts from scientific_papers
268
+ logger.info("Generating abstract_only from scientific_papers abstracts...")
269
+ try:
270
+ ds = load_dataset(
271
+ "armanc/scientific_papers", "arxiv",
272
+ split="train", streaming=True, trust_remote_code=True,
273
+ )
274
+ count = 0
275
+ for row in ds:
276
+ if count >= n_per_class:
277
+ break
278
+ abstract = row.get("abstract", "")
279
+ if abstract and 50 < len(abstract) < 600:
280
+ all_samples.append({"text": abstract.strip(), "label": "abstract_only"})
281
+ count += 1
282
+ logger.info(f" abstract_only (fallback): {count}")
283
+ except Exception:
284
+ pass
285
+
286
+ # ── junk ─────────��───────────────────────────────────────────
287
+ logger.info("Loading ag_news for junk class...")
288
+ try:
289
+ ds = load_dataset(
290
+ "ag_news",
291
+ split="train", streaming=True, trust_remote_code=True,
292
+ )
293
+ count = 0
294
+ for row in ds:
295
+ if count >= n_per_class // 2:
296
+ break
297
+ text = row.get("text", "")
298
+ if len(text) > 30:
299
+ all_samples.append({"text": text.strip(), "label": "junk"})
300
+ count += 1
301
+ logger.info(f" junk (ag_news): {count}")
302
+ except Exception as e:
303
+ logger.warning(f"Could not load ag_news: {e}")
304
+
305
+ logger.info("Generating synthetic junk...")
306
+ synth_junk = generate_synthetic_junk(n_per_class // 2)
307
+ all_samples.extend(synth_junk)
308
+ logger.info(f" junk (synthetic): {len(synth_junk)}")
309
+
310
+ # ── poster ───────────────────────────────────────────────────
311
+ logger.info("Generating synthetic poster data...")
312
+ synth_posters = generate_synthetic_posters(n_per_class)
313
+ all_samples.extend(synth_posters)
314
+ logger.info(f" poster (synthetic): {len(synth_posters)}")
315
+
316
+ # ── Shuffle and save ─────────────────────────────────────────
317
+ random.shuffle(all_samples)
318
+
319
+ with open(output_path, "w") as f:
320
+ for sample in all_samples:
321
+ f.write(json.dumps(sample) + "\n")
322
+
323
+ # Report distribution
324
+ dist = {}
325
+ for s in all_samples:
326
+ dist[s["label"]] = dist.get(s["label"], 0) + 1
327
+ logger.info(f"Saved {len(all_samples)} samples to {output_path}")
328
+ for label, count in sorted(dist.items()):
329
+ logger.info(f" {label}: {count}")
330
+
331
+ return output_path
332
+
333
+
334
+ # ── Head 2: ai_detect ───────────────────────────────────────────
335
+
336
+ def prepare_ai_detect_dataset(
337
+ output_dir: Path,
338
+ n_per_class: int = 20000,
339
+ ) -> Path:
340
+ """
341
+ Assemble AI-generated text detection training data.
342
+
343
+ Sources (all verified available):
344
+ - liamdugan/raid: multi-model generations, domain="abstracts"
345
+ model="human" β†’ human, otherwise β†’ ai_generated
346
+ - NicolaiSivesind/ChatGPT-Research-Abstracts: real + GPT-3.5 abstracts
347
+ """
348
+ from datasets import load_dataset
349
+
350
+ output_dir.mkdir(parents=True, exist_ok=True)
351
+ output_path = output_dir / "ai_detect_train.ndjson"
352
+ human_samples = []
353
+ ai_samples = []
354
+
355
+ logger.info("=== Preparing ai_detect dataset ===")
356
+
357
+ # ── RAID (scientific abstracts domain) ───────────────────────
358
+ logger.info("Loading liamdugan/raid (abstracts domain)...")
359
+ try:
360
+ ds = load_dataset(
361
+ "liamdugan/raid",
362
+ split="train", streaming=True, trust_remote_code=True,
363
+ )
364
+ human_count = 0
365
+ ai_count = 0
366
+ for row in ds:
367
+ domain = row.get("domain", "")
368
+ if domain != "abstracts":
369
+ continue
370
+ text = row.get("generation", "") or ""
371
+ if not text or len(text) < 50:
372
+ continue
373
+ model = row.get("model", "")
374
+ if model == "human":
375
+ if human_count < n_per_class:
376
+ human_samples.append({"text": text[:4000], "label": "human"})
377
+ human_count += 1
378
+ else:
379
+ if ai_count < n_per_class:
380
+ ai_samples.append({"text": text[:4000], "label": "ai_generated"})
381
+ ai_count += 1
382
+ if human_count >= n_per_class and ai_count >= n_per_class:
383
+ break
384
+ logger.info(f" RAID: human={human_count}, ai={ai_count}")
385
+ except Exception as e:
386
+ logger.warning(f"Could not load RAID: {e}")
387
+
388
+ # ── ChatGPT-Research-Abstracts ───────────────────────────────
389
+ logger.info("Loading NicolaiSivesind/ChatGPT-Research-Abstracts...")
390
+ try:
391
+ ds = load_dataset(
392
+ "NicolaiSivesind/ChatGPT-Research-Abstracts",
393
+ split="train", streaming=True, trust_remote_code=True,
394
+ )
395
+ h_count = 0
396
+ a_count = 0
397
+ for row in ds:
398
+ real = row.get("real_abstract", "")
399
+ generated = row.get("generated_abstract", "")
400
+ if real and len(real) > 50:
401
+ human_samples.append({"text": real[:4000], "label": "human"})
402
+ h_count += 1
403
+ if generated and len(generated) > 50:
404
+ ai_samples.append({"text": generated[:4000], "label": "ai_generated"})
405
+ a_count += 1
406
+ logger.info(f" ChatGPT-Abstracts: human={h_count}, ai={a_count}")
407
+ except Exception as e:
408
+ logger.warning(f"Could not load ChatGPT-Research-Abstracts: {e}")
409
+
410
+ # ── Balance and save ─────────────────────────────────────────
411
+ min_count = min(len(human_samples), len(ai_samples), n_per_class)
412
+ if min_count == 0:
413
+ logger.error("No AI detection training data available!")
414
+ # Save empty file
415
+ with open(output_path, "w") as f:
416
+ pass
417
+ return output_path
418
+
419
+ balanced = (
420
+ random.sample(human_samples, min(min_count, len(human_samples)))
421
+ + random.sample(ai_samples, min(min_count, len(ai_samples)))
422
+ )
423
+ random.shuffle(balanced)
424
+
425
+ with open(output_path, "w") as f:
426
+ for sample in balanced:
427
+ f.write(json.dumps(sample) + "\n")
428
+
429
+ n_h = sum(1 for s in balanced if s["label"] == "human")
430
+ n_a = sum(1 for s in balanced if s["label"] == "ai_generated")
431
+ logger.info(f"Saved {len(balanced)} samples (human={n_h}, ai={n_a}) to {output_path}")
432
+ return output_path
433
+
434
+
435
+ # ── Head 3: toxicity ────────────────────────────────────────────
436
+
437
+ def prepare_toxicity_dataset(
438
+ output_dir: Path,
439
+ n_per_class: int = 20000,
440
+ ) -> Path:
441
+ """
442
+ Assemble toxicity detection training data.
443
+
444
+ Sources (all verified available without manual download):
445
+ - google/civil_comments – ~1.8 M comments with toxicity float (0–1)
446
+ We threshold: toxic >= 0.5, clean < 0.1
447
+ - skg/toxigen-data – 274 K annotated statements
448
+ toxicity_human is a float 1–5; we use >= 4.0 as toxic, <= 2.0 as clean
449
+ """
450
+ from datasets import load_dataset
451
+
452
+ output_dir.mkdir(parents=True, exist_ok=True)
453
+ output_path = output_dir / "toxicity_train.ndjson"
454
+ toxic_samples = []
455
+ clean_samples = []
456
+
457
+ logger.info("=== Preparing toxicity dataset ===")
458
+
459
+ # ── Civil Comments ───────────────────────────────────────────
460
+ logger.info("Loading google/civil_comments...")
461
+ try:
462
+ ds = load_dataset(
463
+ "google/civil_comments",
464
+ split="train", streaming=True, trust_remote_code=True,
465
+ )
466
+ toxic_count = 0
467
+ clean_count = 0
468
+ for row in ds:
469
+ text = row.get("text", "")
470
+ if not text or len(text) < 20:
471
+ continue
472
+ toxicity = row.get("toxicity", 0.0)
473
+ if toxicity >= 0.5 and toxic_count < n_per_class:
474
+ toxic_samples.append({"text": text[:4000], "label": "toxic"})
475
+ toxic_count += 1
476
+ elif toxicity < 0.1 and clean_count < n_per_class:
477
+ clean_samples.append({"text": text[:4000], "label": "clean"})
478
+ clean_count += 1
479
+ if toxic_count >= n_per_class and clean_count >= n_per_class:
480
+ break
481
+ logger.info(f" Civil Comments: toxic={toxic_count}, clean={clean_count}")
482
+ except Exception as e:
483
+ logger.warning(f"Could not load civil_comments: {e}")
484
+
485
+ # ── ToxiGen ──────────────────────────────────────────────────
486
+ logger.info("Loading skg/toxigen-data...")
487
+ try:
488
+ ds = load_dataset(
489
+ "skg/toxigen-data",
490
+ split="train", streaming=True, trust_remote_code=True,
491
+ )
492
+ t_count = 0
493
+ c_count = 0
494
+ for row in ds:
495
+ text = row.get("text", "")
496
+ if not text or len(text) < 20:
497
+ continue
498
+ # toxicity_human is 1-5 scale
499
+ tox_score = row.get("toxicity_human", None)
500
+ if tox_score is None:
501
+ continue
502
+ tox_score = float(tox_score)
503
+ if tox_score >= 4.0:
504
+ toxic_samples.append({"text": text[:4000], "label": "toxic"})
505
+ t_count += 1
506
+ elif tox_score <= 2.0:
507
+ clean_samples.append({"text": text[:4000], "label": "clean"})
508
+ c_count += 1
509
+ logger.info(f" ToxiGen: toxic={t_count}, clean={c_count}")
510
+ except Exception as e:
511
+ logger.warning(f"Could not load ToxiGen: {e}")
512
+
513
+ # ── Balance and save ─────────────────────────────────────────
514
+ min_count = min(len(toxic_samples), len(clean_samples), n_per_class)
515
+ if min_count == 0:
516
+ logger.error("No toxicity training data available!")
517
+ with open(output_path, "w") as f:
518
+ pass
519
+ return output_path
520
+
521
+ balanced = (
522
+ random.sample(toxic_samples, min(min_count, len(toxic_samples)))
523
+ + random.sample(clean_samples, min(min_count, len(clean_samples)))
524
+ )
525
+ random.shuffle(balanced)
526
+
527
+ with open(output_path, "w") as f:
528
+ for sample in balanced:
529
+ f.write(json.dumps(sample) + "\n")
530
+
531
+ n_t = sum(1 for s in balanced if s["label"] == "toxic")
532
+ n_c = sum(1 for s in balanced if s["label"] == "clean")
533
+ logger.info(f"Saved {len(balanced)} samples (toxic={n_t}, clean={n_c}) to {output_path}")
534
+ return output_path
535
+
536
+
537
+ # ── Orchestrator ─────────────────────────────────────────────────
538
+
539
+ def prepare_all(output_dir: Path, n_per_class: int = 15000):
540
+ """Download and prepare all three datasets."""
541
+ output_dir = Path(output_dir)
542
+ logger.info(f"Preparing all datasets in {output_dir}")
543
+
544
+ paths = {}
545
+ paths["doc_type"] = prepare_doc_type_dataset(output_dir, n_per_class)
546
+ paths["ai_detect"] = prepare_ai_detect_dataset(output_dir, n_per_class)
547
+ paths["toxicity"] = prepare_toxicity_dataset(output_dir, n_per_class)
548
+
549
+ logger.info("All datasets prepared!")
550
+ return paths