arcspan / scripts /fix_label_consistency.py
chairulridjal's picture
Add files using upload-large-folder tool
3dac39e verified
#!/usr/bin/env python3
"""Fix label consistency issues in Arcspan cybersecurity NER datasets."""
import json
import re
import sys
from collections import defaultdict
from pathlib import Path
# --- Rules ---
ENTITY_TO_LABEL = {}
apt_groups = [
"apt28", "apt29", "apt30", "apt32", "apt33", "apt34", "apt37", "apt38", "apt41",
"fin7", "fin8", "turla", "lazarus", "lazarus group", "kimsuky",
"oceanlotus", "ocean lotus", "winnti", "fancy bear", "cozy bear",
"equation group", "sandworm", "darkhotel", "pawn storm", "sofacy",
"carbanak group", "cobalt group", "ta505", "ta551", "muddywater", "charming kitten",
]
companies = [
"facebook", "github", "vmware", "cisco", "apple", "google", "microsoft",
"amazon", "oracle", "ibm", "samsung", "huawei", "intel", "adobe", "citrix",
"fortinet", "palo alto", "palo alto networks", "fireeye", "mandiant",
"crowdstrike", "kaspersky", "symantec", "mcafee", "trend micro", "sophos", "eset",
]
products = [
"powershell", "windows", "linux", "macos", "ios", "android",
"chrome", "firefox", "safari", "office", "outlook", "exchange",
"iis", "apache", "nginx", "docker", "kubernetes",
]
for name in apt_groups + companies:
ENTITY_TO_LABEL[name] = "Organization"
for name in products:
ENTITY_TO_LABEL[name] = "System"
CVE_RE = re.compile(r"^CVE-\d{4}-\d+$", re.IGNORECASE)
def get_correct_label(surface_text):
key = surface_text.strip().lower()
if key in ENTITY_TO_LABEL:
return ENTITY_TO_LABEL[key]
if CVE_RE.match(key):
return "Vulnerability"
return None
def fix_file(filepath):
path = Path(filepath)
lines = path.read_text().strip().split("\n")
stats = defaultdict(int)
total_relabeled = 0
fixed_lines = []
for line in lines:
rec = json.loads(line)
spans = rec.get("spans", {})
new_spans = {}
changed = False
for span_key, offsets in spans.items():
# Parse "Label: entity_text"
colon_idx = span_key.index(":")
old_label = span_key[:colon_idx]
entity_text = span_key[colon_idx + 1:].strip()
correct_label = get_correct_label(entity_text)
if correct_label and correct_label != old_label:
new_key = f"{correct_label}: {entity_text}"
stats[f"{old_label}{correct_label}"] += len(offsets)
total_relabeled += len(offsets)
changed = True
else:
new_key = span_key
# Merge if key already exists
if new_key in new_spans:
new_spans[new_key].extend(offsets)
else:
new_spans[new_key] = list(offsets)
if changed:
rec["spans"] = new_spans
fixed_lines.append(json.dumps(rec, ensure_ascii=False))
path.write_text("\n".join(fixed_lines) + "\n")
return total_relabeled, dict(stats)
FILES = [
"/home/ubuntu/alkyline/data/processed/enriched_5class_train_cleaned.jsonl",
"/home/ubuntu/alkyline/data/processed/enriched_5class_valid_cleaned.jsonl",
"/home/ubuntu/alkyline/data/processed/aptner_5class_train.jsonl",
"/home/ubuntu/alkyline/data/processed/defanged_augmented.jsonl",
]
if __name__ == "__main__":
for f in FILES:
p = Path(f)
if not p.exists():
print(f"SKIP (not found): {f}")
continue
total, breakdown = fix_file(f)
print(f"\n{'='*60}")
print(f"FILE: {p.name}")
print(f"Total span relabelings: {total}")
for transition, count in sorted(breakdown.items(), key=lambda x: -x[1]):
print(f" {transition}: {count}")