arcspan / scripts /cleanup_data.py
chairulridjal's picture
Add files using upload-large-folder tool
038e086 verified
#!/usr/bin/env python3
"""
Comprehensive data cleanup for Arcspan cybersecurity NER.
Fixes all P0/P1 issues from the audit. Idempotent β€” safe to run multiple times.
Usage: python scripts/cleanup_data.py
"""
import json
import re
import shutil
from pathlib import Path
from collections import Counter, defaultdict
from copy import deepcopy
DATA = Path("/home/ubuntu/alkyline/data/processed")
BACKUP = DATA / "backup"
# ─── Constants ───────────────────────────────────────────────────────────────
SECURITY_VENDORS = {
"ESET", "Trend Micro", "Kaspersky", "Symantec", "SentinelOne",
"Avast", "Fortinet", "Bitdefender", "Sophos", "Palo Alto", "McAfee",
}
# False positive "at" context patterns
AT_FALSE_POSITIVE_RE = re.compile(
r'\bat\s+(least|the|a|an|this|that|once|any|all|one|times?|some|which|various)\b'
r'|(?:aimed|look(?:ing)?|looked|arrive[ds]?|point(?:ed|ing)?|direct(?:ed|ing)?)\s+at\b'
r'|\bat\b(?!\s+command|\s+utility|\s+scheduler|\s+job)',
re.IGNORECASE
)
FILEPATH_DATE_RE = re.compile(r'^/\d{1,2}/\d{2,4}$')
# HTML tags to strip (real markup, not cybersec terms like <payload>)
HTML_TAG_RE = re.compile(
r'</?(?:p|br|div|span|a|b|i|em|strong|ul|ol|li|td|tr|th|table|thead|tbody|'
r'h[1-6]|img|hr|blockquote|pre|code|dl|dt|dd|sup|sub|font|center|'
r'section|article|header|footer|nav|main|aside|figure|figcaption|caption|'
r'small|big|u|s|strike|del|ins|abbr|cite|q|mark|ruby|rt|rp|wbr)'
r'(?:\s[^>]*)?\s*/?>',
re.IGNORECASE
)
# Also strip HTML entities
HTML_ENTITY_RE = re.compile(r'&(?:nbsp|amp|lt|gt|quot|apos|#\d+|#x[0-9a-fA-F]+);')
LABEL_MAP_5 = {
"MALWARE": "Malware", "THREAT_ACTOR": None, "TOOL": None,
"VULNERABILITY": "Vulnerability", "SYSTEM": "System", "ORGANIZATION": "Organization",
"IP_ADDRESS": "Indicator", "DOMAIN": "Indicator", "URL": "Indicator",
"HASH": "Indicator", "EMAIL": "Indicator", "CVE_ID": "Vulnerability", "FILEPATH": None,
}
# ─── Stats tracker ───────────────────────────────────────────────────────────
stats = Counter()
# ─── Helpers ─────────────────────────────────────────────────────────────────
def load_jsonl(path):
records = []
with open(path) as f:
for line in f:
line = line.strip()
if line:
records.append(json.loads(line))
return records
def save_jsonl(path, records):
with open(path, "w") as f:
for rec in records:
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
def backup_file(path):
if path.exists():
dst = BACKUP / path.name
if not dst.exists():
shutil.copy2(path, dst)
def get_span_entity(key):
"""Extract (label, entity) from 'LABEL: entity'."""
parts = key.split(": ", 1)
return (parts[0], parts[1]) if len(parts) == 2 else (parts[0], "")
def spans_overlap(a_start, a_end, b_start, b_end):
return a_start < b_end and b_start < a_end
# ─── Fix functions ───────────────────────────────────────────────────────────
def fix_tool_at(rec):
"""Remove 'TOOL: at' false positives (P0-2)."""
text = rec["text"]
spans = rec.get("spans", {})
key = "TOOL: at"
if key not in spans:
return 0
offsets = spans[key]
kept = []
removed = 0
for off in offsets:
start, end = off[0], off[1]
# Get context: 40 chars before and after
ctx_start = max(0, start - 40)
ctx_end = min(len(text), end + 40)
context = text[ctx_start:ctx_end].lower()
# Check if context clearly refers to Unix at command
if any(p in context for p in ["at command", "at utility", "at scheduler",
"the at tool", "using at to schedule",
"at job", "/usr/bin/at"]):
kept.append(off)
else:
removed += 1
if removed:
if kept:
spans[key] = kept
else:
del spans[key]
return removed
def fix_filepath_dates(rec):
"""Remove FILEPATH spans matching date patterns (P0-3)."""
spans = rec.get("spans", {})
removed = 0
to_delete = []
for key in list(spans.keys()):
label, entity = get_span_entity(key)
if label != "FILEPATH":
continue
if FILEPATH_DATE_RE.match(entity):
to_delete.append(key)
removed += len(spans[key])
for key in to_delete:
del spans[key]
return removed
def fix_overlapping_spans(rec):
"""Resolve overlapping spans β€” keep longest; remove MALWARE:Play overlapping SYSTEM:Google Play (P0-4)."""
spans = rec.get("spans", {})
if not spans:
return 0
# Flatten all spans into a list of (start, end, key, offset_idx)
flat = []
for key, offsets in spans.items():
for i, off in enumerate(offsets):
flat.append((off[0], off[1], key, i))
if len(flat) < 2:
return 0
# Sort by start, then by length descending
flat.sort(key=lambda x: (x[0], -(x[1] - x[0])))
to_remove = set() # (key, offset_idx)
removed = 0
for i in range(len(flat)):
if (flat[i][2], flat[i][3]) in to_remove:
continue
for j in range(i + 1, len(flat)):
if flat[j][0] >= flat[i][1]:
break # no more overlaps possible
if (flat[j][2], flat[j][3]) in to_remove:
continue
if not spans_overlap(flat[i][0], flat[i][1], flat[j][0], flat[j][1]):
continue
# Overlap found β€” decide which to remove
i_key, j_key = flat[i][2], flat[j][2]
i_len = flat[i][1] - flat[i][0]
j_len = flat[j][1] - flat[j][0]
# Special case: MALWARE: Play overlapping SYSTEM: Google Play
if i_key.startswith("MALWARE: Play") and "Google Play" in j_key:
to_remove.add((flat[i][2], flat[i][3]))
elif j_key.startswith("MALWARE: Play") and "Google Play" in i_key:
to_remove.add((flat[j][2], flat[j][3]))
elif i_len >= j_len:
to_remove.add((flat[j][2], flat[j][3]))
else:
to_remove.add((flat[i][2], flat[i][3]))
if not to_remove:
return 0
# Rebuild spans, removing flagged offsets
new_spans = {}
for key, offsets in spans.items():
kept = [off for i, off in enumerate(offsets) if (key, i) not in to_remove]
if kept:
new_spans[key] = kept
else:
removed += 1
removed_count = len(to_remove)
rec["spans"] = new_spans
return removed_count
def fix_vendor_labels(rec):
"""Relabel security vendors from SYSTEM β†’ ORGANIZATION (P1-6)."""
spans = rec.get("spans", {})
fixed = 0
for vendor in SECURITY_VENDORS:
old_key = f"SYSTEM: {vendor}"
if old_key in spans:
new_key = f"ORGANIZATION: {vendor}"
offsets = spans.pop(old_key)
spans.setdefault(new_key, []).extend(offsets)
fixed += len(offsets)
return fixed
def clean_html_str(s):
"""Strip HTML tags and entities from a string."""
s = HTML_TAG_RE.sub("", s)
s = HTML_ENTITY_RE.sub("", s)
return s
def fix_html(rec):
"""Strip HTML tags from text and recalculate span offsets (P1-8)."""
text = rec["text"]
if not HTML_TAG_RE.search(text) and not HTML_ENTITY_RE.search(text):
return 0
cleaned = clean_html_str(text)
if cleaned == text:
return 0
# Re-find each entity in the cleaned text
spans = rec.get("spans", {})
new_spans = {}
for key, offsets in spans.items():
label, entity = get_span_entity(key)
# Clean the entity in the key too
clean_entity = clean_html_str(entity)
if not clean_entity.strip():
continue
clean_key = f"{label}: {clean_entity}" if clean_entity != entity else key
new_offsets = []
for off in offsets:
orig_entity = text[off[0]:off[1]]
ce = clean_html_str(orig_entity)
if not ce.strip():
continue
# Find in cleaned text
idx = cleaned.find(ce)
if idx == -1:
idx = cleaned.lower().find(ce.lower())
if idx != -1:
new_offsets.append([idx, idx + len(ce)])
if new_offsets:
new_spans.setdefault(clean_key, []).extend(new_offsets)
rec["text"] = cleaned
rec["spans"] = new_spans
return 1
def fix_dirty_span_keys(rec):
"""Clean HTML remnants from span keys and fix key↔offset mismatches (post-HTML-strip)."""
text = rec["text"]
spans = rec.get("spans", {})
new_spans = {}
fixed = 0
for key, offsets in spans.items():
label, entity = get_span_entity(key)
clean_entity = clean_html_str(entity)
if not clean_entity.strip():
continue
# Only remap if HTML was actually removed from the entity
if clean_entity == entity:
new_spans.setdefault(key, []).extend(offsets)
continue
clean_key = f"{label}: {clean_entity}"
new_offsets = []
for off in offsets:
actual = text[off[0]:off[1]]
if actual == clean_entity:
new_offsets.append(off)
else:
# Try to find entity near the offset
search_start = max(0, off[0] - 10)
search_end = min(len(text), off[1] + 10)
window = text[search_start:search_end]
idx = window.find(clean_entity)
if idx != -1:
abs_start = search_start + idx
new_offsets.append([abs_start, abs_start + len(clean_entity)])
fixed += 1
if new_offsets:
new_spans.setdefault(clean_key, []).extend(new_offsets)
rec["spans"] = new_spans
return fixed
def verify_offsets(rec):
"""Return list of offset errors."""
text = rec.get("text", "")
errors = []
for key, offsets in rec.get("spans", {}).items():
_, entity = get_span_entity(key)
for off in offsets:
if off[0] < 0 or off[1] > len(text) or off[0] >= off[1]:
errors.append(f"{key}: [{off[0]},{off[1]}] out of bounds (len={len(text)})")
else:
actual = text[off[0]:off[1]]
if actual != entity:
# Allow minor mismatches (whitespace, case)
if actual.strip().lower() != entity.strip().lower():
errors.append(f"{key}: expected '{entity}' got '{actual}' at [{off[0]},{off[1]}]")
return errors
def dedup_offsets(rec):
"""Remove duplicate offsets within each span key."""
spans = rec.get("spans", {})
for key in spans:
seen = set()
unique = []
for off in spans[key]:
t = (off[0], off[1])
if t not in seen:
seen.add(t)
unique.append(off)
spans[key] = unique
# ─── Main cleanup pipeline ──────────────────────────────────────────────────
def main():
print("=" * 70)
print("ARCSPAN DATA CLEANUP")
print("=" * 70)
# ── Backup ───────────────────────────────────────────────────────────
BACKUP.mkdir(exist_ok=True)
all_files = sorted(DATA.glob("*.jsonl"))
for f in all_files:
backup_file(f)
print(f"\nβœ“ Backed up {len(all_files)} files to {BACKUP}/")
# ── Phase 1: Clean LLM files (P0-2,3,4 + P1-5,6,7,8) ───────────────
print("\n" + "─" * 70)
print("PHASE 1: Clean LLM annotation/generation files")
print("─" * 70)
llm_files = sorted(DATA.glob("llm_annotated_*.jsonl")) + sorted(DATA.glob("llm_generated_*.jsonl"))
# P1-5: Deduplicate LLM files
# Load mitre_v2 and nvd_v2 texts for dedup
mitre_v2_texts = set()
nvd_v2_texts = set()
if (DATA / "llm_annotated_mitre_v2.jsonl").exists():
for rec in load_jsonl(DATA / "llm_annotated_mitre_v2.jsonl"):
mitre_v2_texts.add(rec["text"])
mitre_v2_texts.add(clean_html_str(rec["text"]))
if (DATA / "llm_annotated_nvd_v2.jsonl").exists():
for rec in load_jsonl(DATA / "llm_annotated_nvd_v2.jsonl"):
nvd_v2_texts.add(rec["text"])
nvd_v2_texts.add(clean_html_str(rec["text"]))
for fpath in llm_files:
records = load_jsonl(fpath)
orig_count = len(records)
fname = fpath.name
# P1-5a: Remove texts that exist in v2 files (pre-fix pass)
if fname == "llm_annotated_mitre.jsonl":
records = [r for r in records if r["text"] not in mitre_v2_texts]
stats["mitre_deduped"] += orig_count - len(records)
elif fname == "llm_annotated_apt.jsonl":
records = [r for r in records if r["text"] not in mitre_v2_texts]
stats["apt_deduped_vs_mitre"] += orig_count - len(records)
elif fname == "llm_annotated_nvd.jsonl":
records = [r for r in records if r["text"] not in nvd_v2_texts]
stats["nvd_deduped"] += orig_count - len(records)
# Apply per-record fixes BEFORE dedup (HTML strip can create new dupes)
for rec in records:
# P0-2: Remove TOOL: at false positives
n = fix_tool_at(rec)
stats["tool_at_removed"] += n
# P0-3: Remove FILEPATH date false positives
n = fix_filepath_dates(rec)
stats["filepath_date_removed"] += n
# P1-6: Relabel security vendors
n = fix_vendor_labels(rec)
stats["vendor_relabeled"] += n
# P1-8: Strip HTML
n = fix_html(rec)
stats["html_stripped"] += n
# Post-fix: clean dirty span keys (HTML remnants in keys)
fix_dirty_span_keys(rec)
dedup_offsets(rec)
# P0-4: Fix overlapping spans LAST (after all transforms)
while True:
n = fix_overlapping_spans(rec)
if n == 0:
break
stats["overlaps_fixed"] += n
# P1-5a (post-fix): re-check against v2 texts after HTML strip
if fname == "llm_annotated_mitre.jsonl":
before = len(records)
records = [r for r in records if r["text"] not in mitre_v2_texts]
stats["mitre_deduped"] += before - len(records)
elif fname == "llm_annotated_apt.jsonl":
before = len(records)
records = [r for r in records if r["text"] not in mitre_v2_texts]
stats["apt_deduped_vs_mitre"] += before - len(records)
elif fname == "llm_annotated_nvd.jsonl":
before = len(records)
records = [r for r in records if r["text"] not in nvd_v2_texts]
stats["nvd_deduped"] += before - len(records)
# P1-5b: Remove exact duplicate texts within file (after fixes)
seen_texts = set()
deduped = []
for r in records:
if r["text"] not in seen_texts:
seen_texts.add(r["text"])
deduped.append(r)
stats[f"intra_dedup_{fname}"] += len(records) - len(deduped)
records = deduped
# P1-7: Remove short texts
before = len(records)
records = [r for r in records if len(r["text"]) >= 20]
stats["short_removed"] += before - len(records)
# Remove records with no spans
before = len(records)
records = [r for r in records if r.get("spans")]
stats["empty_spans_removed"] += before - len(records)
save_jsonl(fpath, records)
print(f" {fname}: {orig_count} β†’ {len(records)}")
# ── Phase 2: Clean aggregated files (P0-1,4 + P1-6,7,8) ────────────
print("\n" + "─" * 70)
print("PHASE 2: Clean aggregated files & fix train/test leakage")
print("─" * 70)
# Load all aggregated files
agg_data = {}
for variant in ["13class", "5class"]:
for split in ["test", "valid", "train"]:
key = f"aggregated_{variant}_{split}.jsonl"
fpath = DATA / key
if fpath.exists():
agg_data[key] = load_jsonl(fpath)
# P0-1: Deduplicate across splits (priority: test > valid > train)
for variant in ["13class", "5class"]:
seen_texts = set()
total_removed = 0
for split in ["test", "valid", "train"]:
key = f"aggregated_{variant}_{split}.jsonl"
if key not in agg_data:
continue
records = agg_data[key]
deduped = []
for rec in records:
if rec["text"] not in seen_texts:
seen_texts.add(rec["text"])
deduped.append(rec)
else:
total_removed += 1
agg_data[key] = deduped
stats[f"leakage_removed_{variant}"] += total_removed
# Apply per-record fixes to aggregated data
for key, records in agg_data.items():
orig_count = len(records)
for rec in records:
fix_vendor_labels(rec)
fix_html(rec)
fix_filepath_dates(rec)
fix_tool_at(rec)
fix_dirty_span_keys(rec)
dedup_offsets(rec)
while fix_overlapping_spans(rec): pass
# Remove short texts
records = [r for r in records if len(r["text"]) >= 20]
agg_data[key] = records
print(f" {key}: {orig_count} β†’ {len(records)}")
# Save aggregated files
for key, records in agg_data.items():
save_jsonl(DATA / key, records)
# ── Phase 3: Regenerate enriched files ──────────────────────────────
print("\n" + "─" * 70)
print("PHASE 3: Regenerate enriched files")
print("─" * 70)
# Reload cleaned LLM files
llm_records = []
for f in sorted(DATA.glob("llm_annotated_*.jsonl")) + sorted(DATA.glob("llm_generated_*.jsonl")):
llm_records.extend(load_jsonl(f))
print(f" LLM records: {len(llm_records)}")
# Enriched 13-class train = aggregated 13-class train + all LLM
agg_13_train = load_jsonl(DATA / "aggregated_13class_train.jsonl")
enriched_13_train = agg_13_train + llm_records
save_jsonl(DATA / "enriched_13class_train.jsonl", enriched_13_train)
print(f" enriched_13class_train: {len(enriched_13_train)}")
# Enriched 5-class train = aggregated 5-class train + LLM (mapped)
agg_5_train = load_jsonl(DATA / "aggregated_5class_train.jsonl")
llm_5class = []
for rec in llm_records:
new_rec = deepcopy(rec)
new_spans = {}
for key, offsets in rec["spans"].items():
label, entity = get_span_entity(key)
l5 = LABEL_MAP_5.get(label)
if l5:
new_spans.setdefault(f"{l5}: {entity}", []).extend(offsets)
new_rec["spans"] = new_spans
if new_spans:
llm_5class.append(new_rec)
enriched_5_train = agg_5_train + llm_5class
save_jsonl(DATA / "enriched_5class_train.jsonl", enriched_5_train)
print(f" enriched_5class_train: {len(enriched_5_train)}")
# Valid/test: copy from aggregated
for split in ["valid", "test"]:
for variant in ["13class", "5class"]:
src = DATA / f"aggregated_{variant}_{split}.jsonl"
dst = DATA / f"enriched_{variant}_{split}.jsonl"
shutil.copy2(src, dst)
n = sum(1 for _ in open(dst))
print(f" enriched_{variant}_{split}: {n}")
# ── Phase 4: Verification ───────────────────────────────────────────
print("\n" + "─" * 70)
print("PHASE 4: Offset verification")
print("─" * 70)
total_checked = 0
total_errors = 0
for fpath in sorted(DATA.glob("*.jsonl")):
if fpath.parent.name == "backup":
continue
errors_in_file = 0
records = load_jsonl(fpath)
for rec in records:
errs = verify_offsets(rec)
errors_in_file += len(errs)
total_checked += len(records)
if errors_in_file:
print(f" ⚠ {fpath.name}: {errors_in_file} offset errors")
total_errors += errors_in_file
if total_errors == 0:
print(f" βœ“ All {total_checked} records pass offset verification")
else:
print(f" ⚠ {total_errors} total offset errors across {total_checked} records")
# ── Summary ─────────────────────────────────────────────────────────
print("\n" + "=" * 70)
print("CLEANUP SUMMARY")
print("=" * 70)
for k, v in sorted(stats.items()):
if v > 0:
print(f" {k}: {v}")
print("=" * 70)
print("Done.")
if __name__ == "__main__":
main()