#!/usr/bin/env python3 """Remove train-test leakage from Arcspan cybersecurity NER training data.""" import json from pathlib import Path DATA = Path("/home/ubuntu/alkyline/data/processed") # --- Test/valid files (DO NOT modify) --- TEST_VALID_FILES = [ "enriched_5class_test.jsonl", "enriched_5class_valid_cleaned.jsonl", "enriched_5class_valid.jsonl", "cyner_test.jsonl", "securebert2_5class_test.jsonl", "securebert2_test.jsonl", ] # --- Training files to clean --- TRAIN_FILES = [ "enriched_5class_train_cleaned.jsonl", "aptner_5class_train.jsonl", "securebert2_5class_train.jsonl", ] def load_texts(path: Path): """Yield text field from each line of a JSONL file.""" with open(path) as f: for line in f: line = line.strip() if line: yield json.loads(line)["text"] def main(): # 1. Collect all test/valid texts and prefixes exact_texts = set() prefix_texts = set() # text[:80] for fname in TEST_VALID_FILES: p = DATA / fname if not p.exists(): print(f" SKIP (not found): {fname}") continue count = 0 for text in load_texts(p): exact_texts.add(text) prefix_texts.add(text[:80]) count += 1 print(f" Loaded {count:,} test/valid examples from {fname}") print(f"\nTotal unique test/valid texts: {len(exact_texts):,}") print(f"Total unique test/valid prefixes: {len(prefix_texts):,}\n") # 2. Clean each training file for fname in TRAIN_FILES: p = DATA / fname if not p.exists(): print(f" SKIP (not found): {fname}") continue kept = [] removed_exact = 0 removed_prefix = 0 removed_trivial = 0 with open(p) as f: for line in f: line = line.strip() if not line: continue rec = json.loads(line) text = rec["text"] # Trivial if len(text.strip()) <= 2: removed_trivial += 1 continue # Exact match if text in exact_texts: removed_exact += 1 continue # Prefix match if text[:80] in prefix_texts: removed_prefix += 1 continue kept.append(line) total_removed = removed_exact + removed_prefix + removed_trivial stem = fname.replace(".jsonl", "") out_path = DATA / f"{stem}_deleaked.jsonl" with open(out_path, "w") as f: for line in kept: f.write(line + "\n") print(f"{fname}:") print(f" Removed: {total_removed:,} (exact={removed_exact}, prefix={removed_prefix}, trivial={removed_trivial})") print(f" Kept: {len(kept):,}") print(f" Written: {out_path.name}\n") if __name__ == "__main__": main()