| |
| """Annotate MITRE ATT&CK descriptions for cybersecurity NER training. |
| |
| Reads intrusion-set, malware, and tool STIX JSON files, strips markdown, |
| and produces OPF-format JSONL with exact character offsets. |
| """ |
|
|
| import json, os, re, glob |
| from pathlib import Path |
| from collections import defaultdict |
|
|
| BASE = Path("/home/ubuntu/alkyline/data/raw/mitre-cti/enterprise-attack") |
| OUT = Path("/home/ubuntu/alkyline/data/processed/llm_annotated_mitre.jsonl") |
|
|
| |
|
|
| def load_stix_objects(subdir): |
| """Load all STIX objects from a subdirectory.""" |
| results = [] |
| for fp in glob.glob(str(BASE / subdir / "*.json")): |
| with open(fp) as f: |
| data = json.load(f) |
| for obj in data.get("objects", []): |
| if obj.get("revoked") or obj.get("x_mitre_deprecated"): |
| continue |
| results.append(obj) |
| return results |
|
|
| print("Loading STIX data...") |
| intrusion_sets = load_stix_objects("intrusion-set") |
| malware_objs = load_stix_objects("malware") |
| tool_objs = load_stix_objects("tool") |
|
|
| |
| entity_db = {} |
|
|
| for obj in intrusion_sets: |
| entity_db[obj["name"]] = "THREAT_ACTOR" |
| for alias in obj.get("x_mitre_aliases", obj.get("aliases", [])): |
| entity_db[alias] = "THREAT_ACTOR" |
|
|
| for obj in malware_objs: |
| entity_db[obj["name"]] = "MALWARE" |
| for alias in obj.get("x_mitre_aliases", obj.get("aliases", [])): |
| entity_db[alias] = "MALWARE" |
|
|
| for obj in tool_objs: |
| entity_db[obj["name"]] = "TOOL" |
| for alias in obj.get("x_mitre_aliases", obj.get("aliases", [])): |
| entity_db[alias] = "TOOL" |
|
|
| |
| KNOWN_ORGS = { |
| "FireEye", "Mandiant", "CrowdStrike", "Kaspersky", "Symantec", "Microsoft", |
| "Palo Alto Networks", "ESET", "Trend Micro", "Cisco", "Recorded Future", |
| "Proofpoint", "SentinelOne", "Carbon Black", "McAfee", "Secureworks", |
| "Cylance", "Fortinet", "Sophos", "F-Secure", "Bitdefender", "Avast", |
| "Malwarebytes", "Check Point", "Zscaler", "Unit 42", "Dragos", |
| "CISA", "FBI", "NSA", "GCHQ", "NCSC", "DOJ", "Accenture", |
| "NCC Group", "Volexity", "Google", "Alphabet", "Samsung", "Apple", |
| "Red Canary", "Elastic", "Splunk", "IBM", "Dell", "Cisco Talos", |
| "Group-IB", "ThreatConnect", "DHS", "US-CERT", "CERT-UA", |
| "NATO", "European Union", "United Nations", |
| |
| "SWIFT", "Bank of Bangladesh", "Bancomext", "Banco de Chile", |
| |
| "Reconnaissance General Bureau", "Ministry of State Security", |
| "General Staff Main Intelligence Directorate", "GRU", |
| "Federal Security Service", "FSB", |
| } |
|
|
| |
| KNOWN_SYSTEMS = { |
| "Windows", "Linux", "macOS", "Android", "iOS", |
| "Microsoft Office", "Microsoft Exchange", "Microsoft Outlook", |
| "Active Directory", "PowerShell", "Windows Management Instrumentation", |
| "IIS", "Apache", "Nginx", "Docker", "Kubernetes", |
| "VMware", "VirtualBox", "Hyper-V", "Citrix", |
| "SolarWinds", "SolarWinds Orion", |
| "Outlook Web Access", "OWA", |
| "Internet Explorer", "Chrome", "Firefox", |
| "SharePoint", "OneDrive", "Dropbox", "Google Drive", |
| "GitHub", "GitLab", "Jira", "Confluence", |
| "Telegram", "WhatsApp", "Signal", |
| "VPN", "RDP", "SSH", "SMB", "DNS", "HTTP", "HTTPS", "FTP", |
| "Visual Basic", "JavaScript", "Python", "Perl", "Lua", |
| "cmd.exe", "cmd", "rundll32", "regsvr32", "mshta", "certutil", |
| "schtasks", "at.exe", "wmic", "bitsadmin", "msiexec", |
| ".NET", "WMI", "COM", "DCOM", |
| } |
|
|
| print(f"Knowledge base: {len(entity_db)} malware/tool/actor names, " |
| f"{len(KNOWN_ORGS)} orgs, {len(KNOWN_SYSTEMS)} systems") |
|
|
| |
|
|
| def strip_markdown(text): |
| """Strip markdown links [text](url) β text, and remove citation refs. |
| Returns (clean_text, mapping) where mapping maps clean positions back.""" |
| if not text: |
| return "", [] |
|
|
| |
| text = re.sub(r'\(Citation:[^)]+\)', '', text) |
|
|
| |
| text = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', text) |
|
|
| |
| text = re.sub(r'</?code>', '', text) |
|
|
| |
| text = re.sub(r' +', ' ', text) |
|
|
| |
| text = text.strip() |
|
|
| return text |
|
|
| |
|
|
| CVE_RE = re.compile(r'CVE-\d{4}-\d{4,7}') |
| FILEPATH_RE = re.compile(r'(?:[A-Z]:\\[\w\\._-]+|/(?:usr|etc|tmp|var|opt|home|proc|dev|bin|sbin)[\w/._-]+|%[A-Z_]+%[\w\\._-]*)') |
|
|
| def find_all_occurrences(text, pattern): |
| """Find all non-overlapping occurrences of pattern in text.""" |
| spans = [] |
| start = 0 |
| plen = len(pattern) |
| while True: |
| idx = text.find(pattern, start) |
| if idx == -1: |
| break |
| |
| before_ok = (idx == 0 or not text[idx-1].isalnum()) |
| after_ok = (idx + plen >= len(text) or not text[idx + plen].isalnum()) |
| |
| if plen <= 2 and not (before_ok and after_ok): |
| start = idx + 1 |
| continue |
| |
| if plen >= 3 and not (before_ok and after_ok): |
| |
| if plen < 4: |
| start = idx + 1 |
| continue |
| spans.append((idx, idx + plen)) |
| start = idx + plen |
| return spans |
|
|
| def annotate_text(text, self_name=None, self_label=None): |
| """Find all entity spans in text. Returns spans dict.""" |
| spans = defaultdict(list) |
|
|
| |
| occupied = set() |
|
|
| def add_span(label, entity_text, start, end): |
| |
| assert text[start:end] == entity_text, \ |
| f"Offset mismatch: text[{start}:{end}]={text[start:end]!r} != {entity_text!r}" |
| |
| span_range = set(range(start, end)) |
| if span_range & occupied: |
| return |
| occupied.update(span_range) |
| key = f"{label}: {entity_text}" |
| spans[key].append([start, end]) |
|
|
| |
| for m in CVE_RE.finditer(text): |
| add_span("CVE_ID", m.group(), m.start(), m.end()) |
|
|
| |
| for m in FILEPATH_RE.finditer(text): |
| add_span("FILEPATH", m.group(), m.start(), m.end()) |
|
|
| |
| all_entities = [] |
| for name, label in entity_db.items(): |
| if len(name) >= 3: |
| all_entities.append((name, label)) |
| |
| for name in KNOWN_ORGS: |
| all_entities.append((name, "ORGANIZATION")) |
| for name in KNOWN_SYSTEMS: |
| all_entities.append((name, "SYSTEM")) |
|
|
| |
| all_entities.sort(key=lambda x: -len(x[0])) |
|
|
| for name, label in all_entities: |
| for start, end in find_all_occurrences(text, name): |
| add_span(label, name, start, end) |
|
|
| return dict(spans) |
|
|
| |
|
|
| def process_objects(objects, obj_type, id_prefix, max_count=None): |
| """Process a list of STIX objects into annotated JSONL records.""" |
| records = [] |
| for i, obj in enumerate(objects): |
| if max_count and i >= max_count: |
| break |
| desc = obj.get("description", "") |
| if not desc or len(desc) < 50: |
| continue |
|
|
| clean = strip_markdown(desc) |
| if len(clean) < 30: |
| continue |
|
|
| name = obj["name"] |
| label = {"intrusion-set": "THREAT_ACTOR", "malware": "MALWARE", "tool": "TOOL"}[obj_type] |
|
|
| spans = annotate_text(clean, self_name=name, self_label=label) |
|
|
| if len(spans) < 1: |
| continue |
|
|
| mitre_id = "" |
| for ref in obj.get("external_references", []): |
| if ref.get("source_name") == "mitre-attack": |
| mitre_id = ref.get("external_id", "") |
| break |
|
|
| record = { |
| "text": clean, |
| "spans": spans, |
| "info": { |
| "id": f"mitre_{id_prefix}_{i:04d}", |
| "source": "mitre_attack", |
| "mitre_id": mitre_id, |
| "name": name, |
| "type": obj_type |
| } |
| } |
| records.append(record) |
| return records |
|
|
| print("Processing intrusion-sets...") |
| records_is = process_objects(intrusion_sets, "intrusion-set", "is") |
| print(f" β {len(records_is)} records") |
|
|
| print("Processing malware...") |
| records_mw = process_objects(malware_objs, "malware", "mw") |
| print(f" β {len(records_mw)} records") |
|
|
| print("Processing tools...") |
| records_tl = process_objects(tool_objs, "tool", "tl") |
| print(f" β {len(records_tl)} records") |
|
|
| all_records = records_is + records_mw + records_tl |
| print(f"\nTotal records: {len(all_records)}") |
|
|
| |
| total_spans = sum( |
| sum(len(v) for v in r["spans"].values()) |
| for r in all_records |
| ) |
| print(f"Total entity spans: {total_spans}") |
|
|
| |
| label_counts = defaultdict(int) |
| for r in all_records: |
| for key, positions in r["spans"].items(): |
| label = key.split(":")[0] |
| label_counts[label] += len(positions) |
| print("Spans by label:") |
| for label, count in sorted(label_counts.items(), key=lambda x: -x[1]): |
| print(f" {label}: {count}") |
|
|
| |
|
|
| OUT.parent.mkdir(parents=True, exist_ok=True) |
| with open(OUT, "w") as f: |
| for r in all_records: |
| f.write(json.dumps(r, ensure_ascii=False) + "\n") |
|
|
| print(f"\nWritten to {OUT}") |
|
|
| |
|
|
| print("\nββ Validation sample ββ") |
| import random |
| random.seed(42) |
| samples = random.sample(all_records, min(5, len(all_records))) |
| for rec in samples: |
| print(f"\n[{rec['info']['name']}] ({rec['info']['type']})") |
| print(f" Text: {rec['text'][:120]}...") |
| for key, positions in list(rec["spans"].items())[:5]: |
| for s, e in positions[:2]: |
| actual = rec["text"][s:e] |
| label, entity = key.split(": ", 1) |
| ok = "β" if actual == entity else f"β got={actual!r}" |
| print(f" {label}: '{entity}' [{s}:{e}] {ok}") |
|
|