#!/usr/bin/env python3 """ Comprehensive data cleanup for Arcspan cybersecurity NER. Fixes all P0/P1 issues from the audit. Idempotent — safe to run multiple times. Usage: python scripts/cleanup_data.py """ import json import re import shutil from pathlib import Path from collections import Counter, defaultdict from copy import deepcopy DATA = Path("/home/ubuntu/alkyline/data/processed") BACKUP = DATA / "backup" # ─── Constants ─────────────────────────────────────────────────────────────── SECURITY_VENDORS = { "ESET", "Trend Micro", "Kaspersky", "Symantec", "SentinelOne", "Avast", "Fortinet", "Bitdefender", "Sophos", "Palo Alto", "McAfee", } # False positive "at" context patterns AT_FALSE_POSITIVE_RE = re.compile( r'\bat\s+(least|the|a|an|this|that|once|any|all|one|times?|some|which|various)\b' r'|(?:aimed|look(?:ing)?|looked|arrive[ds]?|point(?:ed|ing)?|direct(?:ed|ing)?)\s+at\b' r'|\bat\b(?!\s+command|\s+utility|\s+scheduler|\s+job)', re.IGNORECASE ) FILEPATH_DATE_RE = re.compile(r'^/\d{1,2}/\d{2,4}$') # HTML tags to strip (real markup, not cybersec terms like ) HTML_TAG_RE = re.compile( r']*)?\s*/?>', re.IGNORECASE ) # Also strip HTML entities HTML_ENTITY_RE = re.compile(r'&(?:nbsp|amp|lt|gt|quot|apos|#\d+|#x[0-9a-fA-F]+);') LABEL_MAP_5 = { "MALWARE": "Malware", "THREAT_ACTOR": None, "TOOL": None, "VULNERABILITY": "Vulnerability", "SYSTEM": "System", "ORGANIZATION": "Organization", "IP_ADDRESS": "Indicator", "DOMAIN": "Indicator", "URL": "Indicator", "HASH": "Indicator", "EMAIL": "Indicator", "CVE_ID": "Vulnerability", "FILEPATH": None, } # ─── Stats tracker ─────────────────────────────────────────────────────────── stats = Counter() # ─── Helpers ───────────────────────────────────────────────────────────────── def load_jsonl(path): records = [] with open(path) as f: for line in f: line = line.strip() if line: records.append(json.loads(line)) return records def save_jsonl(path, records): with open(path, "w") as f: for rec in records: f.write(json.dumps(rec, ensure_ascii=False) + "\n") def backup_file(path): if path.exists(): dst = BACKUP / path.name if not dst.exists(): shutil.copy2(path, dst) def get_span_entity(key): """Extract (label, entity) from 'LABEL: entity'.""" parts = key.split(": ", 1) return (parts[0], parts[1]) if len(parts) == 2 else (parts[0], "") def spans_overlap(a_start, a_end, b_start, b_end): return a_start < b_end and b_start < a_end # ─── Fix functions ─────────────────────────────────────────────────────────── def fix_tool_at(rec): """Remove 'TOOL: at' false positives (P0-2).""" text = rec["text"] spans = rec.get("spans", {}) key = "TOOL: at" if key not in spans: return 0 offsets = spans[key] kept = [] removed = 0 for off in offsets: start, end = off[0], off[1] # Get context: 40 chars before and after ctx_start = max(0, start - 40) ctx_end = min(len(text), end + 40) context = text[ctx_start:ctx_end].lower() # Check if context clearly refers to Unix at command if any(p in context for p in ["at command", "at utility", "at scheduler", "the at tool", "using at to schedule", "at job", "/usr/bin/at"]): kept.append(off) else: removed += 1 if removed: if kept: spans[key] = kept else: del spans[key] return removed def fix_filepath_dates(rec): """Remove FILEPATH spans matching date patterns (P0-3).""" spans = rec.get("spans", {}) removed = 0 to_delete = [] for key in list(spans.keys()): label, entity = get_span_entity(key) if label != "FILEPATH": continue if FILEPATH_DATE_RE.match(entity): to_delete.append(key) removed += len(spans[key]) for key in to_delete: del spans[key] return removed def fix_overlapping_spans(rec): """Resolve overlapping spans — keep longest; remove MALWARE:Play overlapping SYSTEM:Google Play (P0-4).""" spans = rec.get("spans", {}) if not spans: return 0 # Flatten all spans into a list of (start, end, key, offset_idx) flat = [] for key, offsets in spans.items(): for i, off in enumerate(offsets): flat.append((off[0], off[1], key, i)) if len(flat) < 2: return 0 # Sort by start, then by length descending flat.sort(key=lambda x: (x[0], -(x[1] - x[0]))) to_remove = set() # (key, offset_idx) removed = 0 for i in range(len(flat)): if (flat[i][2], flat[i][3]) in to_remove: continue for j in range(i + 1, len(flat)): if flat[j][0] >= flat[i][1]: break # no more overlaps possible if (flat[j][2], flat[j][3]) in to_remove: continue if not spans_overlap(flat[i][0], flat[i][1], flat[j][0], flat[j][1]): continue # Overlap found — decide which to remove i_key, j_key = flat[i][2], flat[j][2] i_len = flat[i][1] - flat[i][0] j_len = flat[j][1] - flat[j][0] # Special case: MALWARE: Play overlapping SYSTEM: Google Play if i_key.startswith("MALWARE: Play") and "Google Play" in j_key: to_remove.add((flat[i][2], flat[i][3])) elif j_key.startswith("MALWARE: Play") and "Google Play" in i_key: to_remove.add((flat[j][2], flat[j][3])) elif i_len >= j_len: to_remove.add((flat[j][2], flat[j][3])) else: to_remove.add((flat[i][2], flat[i][3])) if not to_remove: return 0 # Rebuild spans, removing flagged offsets new_spans = {} for key, offsets in spans.items(): kept = [off for i, off in enumerate(offsets) if (key, i) not in to_remove] if kept: new_spans[key] = kept else: removed += 1 removed_count = len(to_remove) rec["spans"] = new_spans return removed_count def fix_vendor_labels(rec): """Relabel security vendors from SYSTEM → ORGANIZATION (P1-6).""" spans = rec.get("spans", {}) fixed = 0 for vendor in SECURITY_VENDORS: old_key = f"SYSTEM: {vendor}" if old_key in spans: new_key = f"ORGANIZATION: {vendor}" offsets = spans.pop(old_key) spans.setdefault(new_key, []).extend(offsets) fixed += len(offsets) return fixed def clean_html_str(s): """Strip HTML tags and entities from a string.""" s = HTML_TAG_RE.sub("", s) s = HTML_ENTITY_RE.sub("", s) return s def fix_html(rec): """Strip HTML tags from text and recalculate span offsets (P1-8).""" text = rec["text"] if not HTML_TAG_RE.search(text) and not HTML_ENTITY_RE.search(text): return 0 cleaned = clean_html_str(text) if cleaned == text: return 0 # Re-find each entity in the cleaned text spans = rec.get("spans", {}) new_spans = {} for key, offsets in spans.items(): label, entity = get_span_entity(key) # Clean the entity in the key too clean_entity = clean_html_str(entity) if not clean_entity.strip(): continue clean_key = f"{label}: {clean_entity}" if clean_entity != entity else key new_offsets = [] for off in offsets: orig_entity = text[off[0]:off[1]] ce = clean_html_str(orig_entity) if not ce.strip(): continue # Find in cleaned text idx = cleaned.find(ce) if idx == -1: idx = cleaned.lower().find(ce.lower()) if idx != -1: new_offsets.append([idx, idx + len(ce)]) if new_offsets: new_spans.setdefault(clean_key, []).extend(new_offsets) rec["text"] = cleaned rec["spans"] = new_spans return 1 def fix_dirty_span_keys(rec): """Clean HTML remnants from span keys and fix key↔offset mismatches (post-HTML-strip).""" text = rec["text"] spans = rec.get("spans", {}) new_spans = {} fixed = 0 for key, offsets in spans.items(): label, entity = get_span_entity(key) clean_entity = clean_html_str(entity) if not clean_entity.strip(): continue # Only remap if HTML was actually removed from the entity if clean_entity == entity: new_spans.setdefault(key, []).extend(offsets) continue clean_key = f"{label}: {clean_entity}" new_offsets = [] for off in offsets: actual = text[off[0]:off[1]] if actual == clean_entity: new_offsets.append(off) else: # Try to find entity near the offset search_start = max(0, off[0] - 10) search_end = min(len(text), off[1] + 10) window = text[search_start:search_end] idx = window.find(clean_entity) if idx != -1: abs_start = search_start + idx new_offsets.append([abs_start, abs_start + len(clean_entity)]) fixed += 1 if new_offsets: new_spans.setdefault(clean_key, []).extend(new_offsets) rec["spans"] = new_spans return fixed def verify_offsets(rec): """Return list of offset errors.""" text = rec.get("text", "") errors = [] for key, offsets in rec.get("spans", {}).items(): _, entity = get_span_entity(key) for off in offsets: if off[0] < 0 or off[1] > len(text) or off[0] >= off[1]: errors.append(f"{key}: [{off[0]},{off[1]}] out of bounds (len={len(text)})") else: actual = text[off[0]:off[1]] if actual != entity: # Allow minor mismatches (whitespace, case) if actual.strip().lower() != entity.strip().lower(): errors.append(f"{key}: expected '{entity}' got '{actual}' at [{off[0]},{off[1]}]") return errors def dedup_offsets(rec): """Remove duplicate offsets within each span key.""" spans = rec.get("spans", {}) for key in spans: seen = set() unique = [] for off in spans[key]: t = (off[0], off[1]) if t not in seen: seen.add(t) unique.append(off) spans[key] = unique # ─── Main cleanup pipeline ────────────────────────────────────────────────── def main(): print("=" * 70) print("ARCSPAN DATA CLEANUP") print("=" * 70) # ── Backup ─────────────────────────────────────────────────────────── BACKUP.mkdir(exist_ok=True) all_files = sorted(DATA.glob("*.jsonl")) for f in all_files: backup_file(f) print(f"\n✓ Backed up {len(all_files)} files to {BACKUP}/") # ── Phase 1: Clean LLM files (P0-2,3,4 + P1-5,6,7,8) ─────────────── print("\n" + "─" * 70) print("PHASE 1: Clean LLM annotation/generation files") print("─" * 70) llm_files = sorted(DATA.glob("llm_annotated_*.jsonl")) + sorted(DATA.glob("llm_generated_*.jsonl")) # P1-5: Deduplicate LLM files # Load mitre_v2 and nvd_v2 texts for dedup mitre_v2_texts = set() nvd_v2_texts = set() if (DATA / "llm_annotated_mitre_v2.jsonl").exists(): for rec in load_jsonl(DATA / "llm_annotated_mitre_v2.jsonl"): mitre_v2_texts.add(rec["text"]) mitre_v2_texts.add(clean_html_str(rec["text"])) if (DATA / "llm_annotated_nvd_v2.jsonl").exists(): for rec in load_jsonl(DATA / "llm_annotated_nvd_v2.jsonl"): nvd_v2_texts.add(rec["text"]) nvd_v2_texts.add(clean_html_str(rec["text"])) for fpath in llm_files: records = load_jsonl(fpath) orig_count = len(records) fname = fpath.name # P1-5a: Remove texts that exist in v2 files (pre-fix pass) if fname == "llm_annotated_mitre.jsonl": records = [r for r in records if r["text"] not in mitre_v2_texts] stats["mitre_deduped"] += orig_count - len(records) elif fname == "llm_annotated_apt.jsonl": records = [r for r in records if r["text"] not in mitre_v2_texts] stats["apt_deduped_vs_mitre"] += orig_count - len(records) elif fname == "llm_annotated_nvd.jsonl": records = [r for r in records if r["text"] not in nvd_v2_texts] stats["nvd_deduped"] += orig_count - len(records) # Apply per-record fixes BEFORE dedup (HTML strip can create new dupes) for rec in records: # P0-2: Remove TOOL: at false positives n = fix_tool_at(rec) stats["tool_at_removed"] += n # P0-3: Remove FILEPATH date false positives n = fix_filepath_dates(rec) stats["filepath_date_removed"] += n # P1-6: Relabel security vendors n = fix_vendor_labels(rec) stats["vendor_relabeled"] += n # P1-8: Strip HTML n = fix_html(rec) stats["html_stripped"] += n # Post-fix: clean dirty span keys (HTML remnants in keys) fix_dirty_span_keys(rec) dedup_offsets(rec) # P0-4: Fix overlapping spans LAST (after all transforms) while True: n = fix_overlapping_spans(rec) if n == 0: break stats["overlaps_fixed"] += n # P1-5a (post-fix): re-check against v2 texts after HTML strip if fname == "llm_annotated_mitre.jsonl": before = len(records) records = [r for r in records if r["text"] not in mitre_v2_texts] stats["mitre_deduped"] += before - len(records) elif fname == "llm_annotated_apt.jsonl": before = len(records) records = [r for r in records if r["text"] not in mitre_v2_texts] stats["apt_deduped_vs_mitre"] += before - len(records) elif fname == "llm_annotated_nvd.jsonl": before = len(records) records = [r for r in records if r["text"] not in nvd_v2_texts] stats["nvd_deduped"] += before - len(records) # P1-5b: Remove exact duplicate texts within file (after fixes) seen_texts = set() deduped = [] for r in records: if r["text"] not in seen_texts: seen_texts.add(r["text"]) deduped.append(r) stats[f"intra_dedup_{fname}"] += len(records) - len(deduped) records = deduped # P1-7: Remove short texts before = len(records) records = [r for r in records if len(r["text"]) >= 20] stats["short_removed"] += before - len(records) # Remove records with no spans before = len(records) records = [r for r in records if r.get("spans")] stats["empty_spans_removed"] += before - len(records) save_jsonl(fpath, records) print(f" {fname}: {orig_count} → {len(records)}") # ── Phase 2: Clean aggregated files (P0-1,4 + P1-6,7,8) ──────────── print("\n" + "─" * 70) print("PHASE 2: Clean aggregated files & fix train/test leakage") print("─" * 70) # Load all aggregated files agg_data = {} for variant in ["13class", "5class"]: for split in ["test", "valid", "train"]: key = f"aggregated_{variant}_{split}.jsonl" fpath = DATA / key if fpath.exists(): agg_data[key] = load_jsonl(fpath) # P0-1: Deduplicate across splits (priority: test > valid > train) for variant in ["13class", "5class"]: seen_texts = set() total_removed = 0 for split in ["test", "valid", "train"]: key = f"aggregated_{variant}_{split}.jsonl" if key not in agg_data: continue records = agg_data[key] deduped = [] for rec in records: if rec["text"] not in seen_texts: seen_texts.add(rec["text"]) deduped.append(rec) else: total_removed += 1 agg_data[key] = deduped stats[f"leakage_removed_{variant}"] += total_removed # Apply per-record fixes to aggregated data for key, records in agg_data.items(): orig_count = len(records) for rec in records: fix_vendor_labels(rec) fix_html(rec) fix_filepath_dates(rec) fix_tool_at(rec) fix_dirty_span_keys(rec) dedup_offsets(rec) while fix_overlapping_spans(rec): pass # Remove short texts records = [r for r in records if len(r["text"]) >= 20] agg_data[key] = records print(f" {key}: {orig_count} → {len(records)}") # Save aggregated files for key, records in agg_data.items(): save_jsonl(DATA / key, records) # ── Phase 3: Regenerate enriched files ────────────────────────────── print("\n" + "─" * 70) print("PHASE 3: Regenerate enriched files") print("─" * 70) # Reload cleaned LLM files llm_records = [] for f in sorted(DATA.glob("llm_annotated_*.jsonl")) + sorted(DATA.glob("llm_generated_*.jsonl")): llm_records.extend(load_jsonl(f)) print(f" LLM records: {len(llm_records)}") # Enriched 13-class train = aggregated 13-class train + all LLM agg_13_train = load_jsonl(DATA / "aggregated_13class_train.jsonl") enriched_13_train = agg_13_train + llm_records save_jsonl(DATA / "enriched_13class_train.jsonl", enriched_13_train) print(f" enriched_13class_train: {len(enriched_13_train)}") # Enriched 5-class train = aggregated 5-class train + LLM (mapped) agg_5_train = load_jsonl(DATA / "aggregated_5class_train.jsonl") llm_5class = [] for rec in llm_records: new_rec = deepcopy(rec) new_spans = {} for key, offsets in rec["spans"].items(): label, entity = get_span_entity(key) l5 = LABEL_MAP_5.get(label) if l5: new_spans.setdefault(f"{l5}: {entity}", []).extend(offsets) new_rec["spans"] = new_spans if new_spans: llm_5class.append(new_rec) enriched_5_train = agg_5_train + llm_5class save_jsonl(DATA / "enriched_5class_train.jsonl", enriched_5_train) print(f" enriched_5class_train: {len(enriched_5_train)}") # Valid/test: copy from aggregated for split in ["valid", "test"]: for variant in ["13class", "5class"]: src = DATA / f"aggregated_{variant}_{split}.jsonl" dst = DATA / f"enriched_{variant}_{split}.jsonl" shutil.copy2(src, dst) n = sum(1 for _ in open(dst)) print(f" enriched_{variant}_{split}: {n}") # ── Phase 4: Verification ─────────────────────────────────────────── print("\n" + "─" * 70) print("PHASE 4: Offset verification") print("─" * 70) total_checked = 0 total_errors = 0 for fpath in sorted(DATA.glob("*.jsonl")): if fpath.parent.name == "backup": continue errors_in_file = 0 records = load_jsonl(fpath) for rec in records: errs = verify_offsets(rec) errors_in_file += len(errs) total_checked += len(records) if errors_in_file: print(f" ⚠ {fpath.name}: {errors_in_file} offset errors") total_errors += errors_in_file if total_errors == 0: print(f" ✓ All {total_checked} records pass offset verification") else: print(f" ⚠ {total_errors} total offset errors across {total_checked} records") # ── Summary ───────────────────────────────────────────────────────── print("\n" + "=" * 70) print("CLEANUP SUMMARY") print("=" * 70) for k, v in sorted(stats.items()): if v > 0: print(f" {k}: {v}") print("=" * 70) print("Done.") if __name__ == "__main__": main()