arcspan / scripts /audit_leakage.py
chairulridjal's picture
Add files using upload-large-folder tool
3dac39e verified
#!/usr/bin/env python3
"""Audit train-test leakage, internal duplicates, and entity memorization."""
import json, sys
from collections import defaultdict, Counter
from pathlib import Path
DATA = Path("/home/ubuntu/alkyline/data/processed")
TRAIN_FILES = {
"enriched_train": DATA / "enriched_5class_train_cleaned.jsonl",
"aptner_train": DATA / "aptner_5class_train.jsonl",
"securebert2_train": DATA / "securebert2_5class_train.jsonl",
"defanged_aug": DATA / "defanged_augmented.jsonl",
}
TEST_FILES = {
"enriched_test": DATA / "enriched_5class_test.jsonl",
"enriched_valid": DATA / "enriched_5class_valid_cleaned.jsonl",
"cyner_test": DATA / "cyner_test.jsonl",
"sb2_test": DATA / "securebert2_5class_test.jsonl",
"aptner_test": DATA / "aptner_5class_test.jsonl",
"aptner_dev": DATA / "aptner_5class_dev.jsonl",
}
def load_jsonl(path):
if not path.exists():
return []
with open(path) as f:
return [json.loads(line) for line in f if line.strip()]
def extract_entities(record):
ents = []
for key in (record.get("spans") or {}):
if ": " in key:
cls, surface = key.split(": ", 1)
ents.append((cls, surface))
return ents
# ── Load data ──
print("Loading data...")
train_data = {}
for name, path in TRAIN_FILES.items():
recs = load_jsonl(path)
if recs:
train_data[name] = recs
print(f" {name}: {len(recs)} records")
test_data = {}
for name, path in TEST_FILES.items():
recs = load_jsonl(path)
if recs:
test_data[name] = recs
print(f" {name}: {len(recs)} records")
# ── 1. Exact text duplicates (trainβ†’test) ──
print("\n" + "="*80)
print("1. EXACT TEXT DUPLICATES (train β†’ test)")
print("="*80)
test_text_map = {}
for tname, recs in test_data.items():
for r in recs:
test_text_map.setdefault(r["text"], set()).add(tname)
leak_counts = defaultdict(int)
leak_examples = []
for tname, recs in train_data.items():
for r in recs:
if r["text"] in test_text_map:
for tsrc in test_text_map[r["text"]]:
leak_counts[(tname, tsrc)] += 1
if len(leak_examples) < 20:
leak_examples.append((tname, tsrc, r["text"][:100]))
total_leaks = sum(leak_counts.values())
if total_leaks == 0:
print("βœ“ No exact text leakage found.")
else:
print(f"βœ— {total_leaks} exact leaks found!")
for (tr, te), c in sorted(leak_counts.items(), key=lambda x: -x[1]):
print(f" {tr} β†’ {te}: {c} duplicates")
print("\nExamples:")
for tr, te, txt in leak_examples[:15]:
print(f" [{tr}β†’{te}] {txt}")
# ── 2. Fuzzy (prefix-80) duplicates ──
print("\n" + "="*80)
print("2. NEAR-DUPLICATES (text[:80] match, train β†’ test, excluding exact)")
print("="*80)
test_prefix_map = {}
for tname, recs in test_data.items():
for r in recs:
pfx = r["text"][:80]
test_prefix_map.setdefault(pfx, set()).add(tname)
fuzzy_counts = defaultdict(int)
fuzzy_examples = []
for tname, recs in train_data.items():
for r in recs:
pfx = r["text"][:80]
if pfx in test_prefix_map and r["text"] not in test_text_map:
for tsrc in test_prefix_map[pfx]:
fuzzy_counts[(tname, tsrc)] += 1
if len(fuzzy_examples) < 10:
fuzzy_examples.append((tname, tsrc, pfx))
total_fuzzy = sum(fuzzy_counts.values())
if total_fuzzy == 0:
print("βœ“ No additional near-duplicates beyond exact matches.")
else:
print(f"βœ— {total_fuzzy} near-duplicate leaks (not counting exact)!")
for (tr, te), c in sorted(fuzzy_counts.items(), key=lambda x: -x[1]):
print(f" {tr} β†’ {te}: {c}")
print("\nExamples:")
for tr, te, pfx in fuzzy_examples[:10]:
print(f" [{tr}β†’{te}] {pfx}...")
# ── 3. Internal train duplicates ──
print("\n" + "="*80)
print("3. INTERNAL TRAIN DUPLICATES")
print("="*80)
train_text_sources = defaultdict(list)
for tname, recs in train_data.items():
for r in recs:
train_text_sources[r["text"]].append(tname)
for tname, recs in train_data.items():
texts = [r["text"] for r in recs]
n_unique = len(set(texts))
n_dup = len(texts) - n_unique
pct = n_dup / len(texts) * 100 if texts else 0
print(f" {tname}: {n_dup} within-source dups ({len(texts)} total, {n_unique} unique, {pct:.1f}%)")
cross_dup = 0
cross_examples = []
for txt, sources in train_text_sources.items():
if len(set(sources)) > 1:
cross_dup += 1
if len(cross_examples) < 5:
cross_examples.append((sources, txt[:100]))
print(f"\n Cross-source duplicates: {cross_dup} texts in multiple train files")
for sources, txt in cross_examples:
print(f" {dict(Counter(sources))}: {txt}")
total_train = sum(len(r) for r in train_data.values())
total_unique = len(train_text_sources)
print(f"\n Total: {total_train} records, {total_unique} unique, "
f"dup rate: {(total_train - total_unique)/total_train*100:.1f}%")
# ── 4. Cross-test-set overlap ──
print("\n" + "="*80)
print("4. CROSS-TEST-SET OVERLAP")
print("="*80)
test_names = list(test_data.keys())
for i in range(len(test_names)):
for j in range(i+1, len(test_names)):
n1, n2 = test_names[i], test_names[j]
t1 = {r["text"] for r in test_data[n1]}
t2 = {r["text"] for r in test_data[n2]}
overlap = t1 & t2
if overlap:
print(f" βœ— {n1} ∩ {n2}: {len(overlap)} shared texts")
for txt in list(overlap)[:3]:
print(f" {txt[:100]}")
else:
print(f" βœ“ {n1} ∩ {n2}: 0")
# ── 5. Entity memorization risk ──
print("\n" + "="*80)
print("5. ENTITY MEMORIZATION RISK (% test entity forms seen in train)")
print("="*80)
train_entities = defaultdict(set)
for recs in train_data.values():
for r in recs:
for cls, surface in extract_entities(r):
train_entities[cls].add(surface.lower())
print(f"Train entity vocab: {sum(len(v) for v in train_entities.values())} unique forms, {len(train_entities)} classes")
for cls in sorted(train_entities):
print(f" {cls}: {len(train_entities[cls])}")
print()
for tname, recs in test_data.items():
test_ents = defaultdict(set)
for r in recs:
for cls, surface in extract_entities(r):
test_ents[cls].add(surface.lower())
if not test_ents:
continue
print(f" {tname}:")
total_t, total_s = 0, 0
for cls in sorted(test_ents):
forms = test_ents[cls]
seen = forms & train_entities.get(cls, set())
pct = len(seen)/len(forms)*100 if forms else 0
total_t += len(forms)
total_s += len(seen)
flag = " ⚠" if pct > 80 else ""
print(f" {cls}: {len(seen)}/{len(forms)} ({pct:.0f}%){flag}")
# Show unseen entities for low-overlap classes
overall = total_s/total_t*100 if total_t else 0
flag = " ⚠⚠" if overall > 70 else ""
print(f" OVERALL: {total_s}/{total_t} ({overall:.0f}%){flag}")
print()
print("="*80)
print("AUDIT COMPLETE")