File size: 3,682 Bytes
3dac39e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/env python3
"""Fix label consistency issues in Arcspan cybersecurity NER datasets."""

import json
import re
import sys
from collections import defaultdict
from pathlib import Path

# --- Rules ---
ENTITY_TO_LABEL = {}

apt_groups = [
    "apt28", "apt29", "apt30", "apt32", "apt33", "apt34", "apt37", "apt38", "apt41",
    "fin7", "fin8", "turla", "lazarus", "lazarus group", "kimsuky",
    "oceanlotus", "ocean lotus", "winnti", "fancy bear", "cozy bear",
    "equation group", "sandworm", "darkhotel", "pawn storm", "sofacy",
    "carbanak group", "cobalt group", "ta505", "ta551", "muddywater", "charming kitten",
]
companies = [
    "facebook", "github", "vmware", "cisco", "apple", "google", "microsoft",
    "amazon", "oracle", "ibm", "samsung", "huawei", "intel", "adobe", "citrix",
    "fortinet", "palo alto", "palo alto networks", "fireeye", "mandiant",
    "crowdstrike", "kaspersky", "symantec", "mcafee", "trend micro", "sophos", "eset",
]
products = [
    "powershell", "windows", "linux", "macos", "ios", "android",
    "chrome", "firefox", "safari", "office", "outlook", "exchange",
    "iis", "apache", "nginx", "docker", "kubernetes",
]

for name in apt_groups + companies:
    ENTITY_TO_LABEL[name] = "Organization"
for name in products:
    ENTITY_TO_LABEL[name] = "System"

CVE_RE = re.compile(r"^CVE-\d{4}-\d+$", re.IGNORECASE)

def get_correct_label(surface_text):
    key = surface_text.strip().lower()
    if key in ENTITY_TO_LABEL:
        return ENTITY_TO_LABEL[key]
    if CVE_RE.match(key):
        return "Vulnerability"
    return None

def fix_file(filepath):
    path = Path(filepath)
    lines = path.read_text().strip().split("\n")
    stats = defaultdict(int)
    total_relabeled = 0
    fixed_lines = []

    for line in lines:
        rec = json.loads(line)
        spans = rec.get("spans", {})
        new_spans = {}
        changed = False

        for span_key, offsets in spans.items():
            # Parse "Label: entity_text"
            colon_idx = span_key.index(":")
            old_label = span_key[:colon_idx]
            entity_text = span_key[colon_idx + 1:].strip()

            correct_label = get_correct_label(entity_text)

            if correct_label and correct_label != old_label:
                new_key = f"{correct_label}: {entity_text}"
                stats[f"{old_label}{correct_label}"] += len(offsets)
                total_relabeled += len(offsets)
                changed = True
            else:
                new_key = span_key

            # Merge if key already exists
            if new_key in new_spans:
                new_spans[new_key].extend(offsets)
            else:
                new_spans[new_key] = list(offsets)

        if changed:
            rec["spans"] = new_spans
        fixed_lines.append(json.dumps(rec, ensure_ascii=False))

    path.write_text("\n".join(fixed_lines) + "\n")
    return total_relabeled, dict(stats)

FILES = [
    "/home/ubuntu/alkyline/data/processed/enriched_5class_train_cleaned.jsonl",
    "/home/ubuntu/alkyline/data/processed/enriched_5class_valid_cleaned.jsonl",
    "/home/ubuntu/alkyline/data/processed/aptner_5class_train.jsonl",
    "/home/ubuntu/alkyline/data/processed/defanged_augmented.jsonl",
]

if __name__ == "__main__":
    for f in FILES:
        p = Path(f)
        if not p.exists():
            print(f"SKIP (not found): {f}")
            continue
        total, breakdown = fix_file(f)
        print(f"\n{'='*60}")
        print(f"FILE: {p.name}")
        print(f"Total span relabelings: {total}")
        for transition, count in sorted(breakdown.items(), key=lambda x: -x[1]):
            print(f"  {transition}: {count}")