arcspan / scripts /fix_leakage.py
chairulridjal's picture
Add files using upload-large-folder tool
3dac39e verified
#!/usr/bin/env python3
"""Remove train-test leakage from Arcspan cybersecurity NER training data."""
import json
from pathlib import Path
DATA = Path("/home/ubuntu/alkyline/data/processed")
# --- Test/valid files (DO NOT modify) ---
TEST_VALID_FILES = [
"enriched_5class_test.jsonl",
"enriched_5class_valid_cleaned.jsonl",
"enriched_5class_valid.jsonl",
"cyner_test.jsonl",
"securebert2_5class_test.jsonl",
"securebert2_test.jsonl",
]
# --- Training files to clean ---
TRAIN_FILES = [
"enriched_5class_train_cleaned.jsonl",
"aptner_5class_train.jsonl",
"securebert2_5class_train.jsonl",
]
def load_texts(path: Path):
"""Yield text field from each line of a JSONL file."""
with open(path) as f:
for line in f:
line = line.strip()
if line:
yield json.loads(line)["text"]
def main():
# 1. Collect all test/valid texts and prefixes
exact_texts = set()
prefix_texts = set() # text[:80]
for fname in TEST_VALID_FILES:
p = DATA / fname
if not p.exists():
print(f" SKIP (not found): {fname}")
continue
count = 0
for text in load_texts(p):
exact_texts.add(text)
prefix_texts.add(text[:80])
count += 1
print(f" Loaded {count:,} test/valid examples from {fname}")
print(f"\nTotal unique test/valid texts: {len(exact_texts):,}")
print(f"Total unique test/valid prefixes: {len(prefix_texts):,}\n")
# 2. Clean each training file
for fname in TRAIN_FILES:
p = DATA / fname
if not p.exists():
print(f" SKIP (not found): {fname}")
continue
kept = []
removed_exact = 0
removed_prefix = 0
removed_trivial = 0
with open(p) as f:
for line in f:
line = line.strip()
if not line:
continue
rec = json.loads(line)
text = rec["text"]
# Trivial
if len(text.strip()) <= 2:
removed_trivial += 1
continue
# Exact match
if text in exact_texts:
removed_exact += 1
continue
# Prefix match
if text[:80] in prefix_texts:
removed_prefix += 1
continue
kept.append(line)
total_removed = removed_exact + removed_prefix + removed_trivial
stem = fname.replace(".jsonl", "")
out_path = DATA / f"{stem}_deleaked.jsonl"
with open(out_path, "w") as f:
for line in kept:
f.write(line + "\n")
print(f"{fname}:")
print(f" Removed: {total_removed:,} (exact={removed_exact}, prefix={removed_prefix}, trivial={removed_trivial})")
print(f" Kept: {len(kept):,}")
print(f" Written: {out_path.name}\n")
if __name__ == "__main__":
main()