#!/usr/bin/env python3 """Audit train-test leakage, internal duplicates, and entity memorization.""" import json, sys from collections import defaultdict, Counter from pathlib import Path DATA = Path("/home/ubuntu/alkyline/data/processed") TRAIN_FILES = { "enriched_train": DATA / "enriched_5class_train_cleaned.jsonl", "aptner_train": DATA / "aptner_5class_train.jsonl", "securebert2_train": DATA / "securebert2_5class_train.jsonl", "defanged_aug": DATA / "defanged_augmented.jsonl", } TEST_FILES = { "enriched_test": DATA / "enriched_5class_test.jsonl", "enriched_valid": DATA / "enriched_5class_valid_cleaned.jsonl", "cyner_test": DATA / "cyner_test.jsonl", "sb2_test": DATA / "securebert2_5class_test.jsonl", "aptner_test": DATA / "aptner_5class_test.jsonl", "aptner_dev": DATA / "aptner_5class_dev.jsonl", } def load_jsonl(path): if not path.exists(): return [] with open(path) as f: return [json.loads(line) for line in f if line.strip()] def extract_entities(record): ents = [] for key in (record.get("spans") or {}): if ": " in key: cls, surface = key.split(": ", 1) ents.append((cls, surface)) return ents # ── Load data ── print("Loading data...") train_data = {} for name, path in TRAIN_FILES.items(): recs = load_jsonl(path) if recs: train_data[name] = recs print(f" {name}: {len(recs)} records") test_data = {} for name, path in TEST_FILES.items(): recs = load_jsonl(path) if recs: test_data[name] = recs print(f" {name}: {len(recs)} records") # ── 1. Exact text duplicates (train→test) ── print("\n" + "="*80) print("1. EXACT TEXT DUPLICATES (train → test)") print("="*80) test_text_map = {} for tname, recs in test_data.items(): for r in recs: test_text_map.setdefault(r["text"], set()).add(tname) leak_counts = defaultdict(int) leak_examples = [] for tname, recs in train_data.items(): for r in recs: if r["text"] in test_text_map: for tsrc in test_text_map[r["text"]]: leak_counts[(tname, tsrc)] += 1 if len(leak_examples) < 20: leak_examples.append((tname, tsrc, r["text"][:100])) total_leaks = sum(leak_counts.values()) if total_leaks == 0: print("✓ No exact text leakage found.") else: print(f"✗ {total_leaks} exact leaks found!") for (tr, te), c in sorted(leak_counts.items(), key=lambda x: -x[1]): print(f" {tr} → {te}: {c} duplicates") print("\nExamples:") for tr, te, txt in leak_examples[:15]: print(f" [{tr}→{te}] {txt}") # ── 2. Fuzzy (prefix-80) duplicates ── print("\n" + "="*80) print("2. NEAR-DUPLICATES (text[:80] match, train → test, excluding exact)") print("="*80) test_prefix_map = {} for tname, recs in test_data.items(): for r in recs: pfx = r["text"][:80] test_prefix_map.setdefault(pfx, set()).add(tname) fuzzy_counts = defaultdict(int) fuzzy_examples = [] for tname, recs in train_data.items(): for r in recs: pfx = r["text"][:80] if pfx in test_prefix_map and r["text"] not in test_text_map: for tsrc in test_prefix_map[pfx]: fuzzy_counts[(tname, tsrc)] += 1 if len(fuzzy_examples) < 10: fuzzy_examples.append((tname, tsrc, pfx)) total_fuzzy = sum(fuzzy_counts.values()) if total_fuzzy == 0: print("✓ No additional near-duplicates beyond exact matches.") else: print(f"✗ {total_fuzzy} near-duplicate leaks (not counting exact)!") for (tr, te), c in sorted(fuzzy_counts.items(), key=lambda x: -x[1]): print(f" {tr} → {te}: {c}") print("\nExamples:") for tr, te, pfx in fuzzy_examples[:10]: print(f" [{tr}→{te}] {pfx}...") # ── 3. Internal train duplicates ── print("\n" + "="*80) print("3. INTERNAL TRAIN DUPLICATES") print("="*80) train_text_sources = defaultdict(list) for tname, recs in train_data.items(): for r in recs: train_text_sources[r["text"]].append(tname) for tname, recs in train_data.items(): texts = [r["text"] for r in recs] n_unique = len(set(texts)) n_dup = len(texts) - n_unique pct = n_dup / len(texts) * 100 if texts else 0 print(f" {tname}: {n_dup} within-source dups ({len(texts)} total, {n_unique} unique, {pct:.1f}%)") cross_dup = 0 cross_examples = [] for txt, sources in train_text_sources.items(): if len(set(sources)) > 1: cross_dup += 1 if len(cross_examples) < 5: cross_examples.append((sources, txt[:100])) print(f"\n Cross-source duplicates: {cross_dup} texts in multiple train files") for sources, txt in cross_examples: print(f" {dict(Counter(sources))}: {txt}") total_train = sum(len(r) for r in train_data.values()) total_unique = len(train_text_sources) print(f"\n Total: {total_train} records, {total_unique} unique, " f"dup rate: {(total_train - total_unique)/total_train*100:.1f}%") # ── 4. Cross-test-set overlap ── print("\n" + "="*80) print("4. CROSS-TEST-SET OVERLAP") print("="*80) test_names = list(test_data.keys()) for i in range(len(test_names)): for j in range(i+1, len(test_names)): n1, n2 = test_names[i], test_names[j] t1 = {r["text"] for r in test_data[n1]} t2 = {r["text"] for r in test_data[n2]} overlap = t1 & t2 if overlap: print(f" ✗ {n1} ∩ {n2}: {len(overlap)} shared texts") for txt in list(overlap)[:3]: print(f" {txt[:100]}") else: print(f" ✓ {n1} ∩ {n2}: 0") # ── 5. Entity memorization risk ── print("\n" + "="*80) print("5. ENTITY MEMORIZATION RISK (% test entity forms seen in train)") print("="*80) train_entities = defaultdict(set) for recs in train_data.values(): for r in recs: for cls, surface in extract_entities(r): train_entities[cls].add(surface.lower()) print(f"Train entity vocab: {sum(len(v) for v in train_entities.values())} unique forms, {len(train_entities)} classes") for cls in sorted(train_entities): print(f" {cls}: {len(train_entities[cls])}") print() for tname, recs in test_data.items(): test_ents = defaultdict(set) for r in recs: for cls, surface in extract_entities(r): test_ents[cls].add(surface.lower()) if not test_ents: continue print(f" {tname}:") total_t, total_s = 0, 0 for cls in sorted(test_ents): forms = test_ents[cls] seen = forms & train_entities.get(cls, set()) pct = len(seen)/len(forms)*100 if forms else 0 total_t += len(forms) total_s += len(seen) flag = " ⚠" if pct > 80 else "" print(f" {cls}: {len(seen)}/{len(forms)} ({pct:.0f}%){flag}") # Show unseen entities for low-overlap classes overall = total_s/total_t*100 if total_t else 0 flag = " ⚠⚠" if overall > 70 else "" print(f" OVERALL: {total_s}/{total_t} ({overall:.0f}%){flag}") print() print("="*80) print("AUDIT COMPLETE")