File size: 2,376 Bytes
3dac39e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#!/usr/bin/env python3
"""Merge all LLM annotations into enriched dataset. Re-run anytime new files appear."""
import json
from pathlib import Path
from collections import Counter
import shutil

DATA = Path("/home/ubuntu/alkyline/data/processed")

LLM_FILES = sorted(DATA.glob("llm_annotated_*.jsonl")) + sorted(DATA.glob("llm_generated_*.jsonl"))

LABEL_MAP_5 = {
    "MALWARE": "Malware", "THREAT_ACTOR": None, "TOOL": None,
    "VULNERABILITY": "Vulnerability", "SYSTEM": "System", "ORGANIZATION": "Organization",
    "IP_ADDRESS": "Indicator", "DOMAIN": "Indicator", "URL": "Indicator",
    "HASH": "Indicator", "EMAIL": "Indicator", "CVE_ID": "Vulnerability", "FILEPATH": None,
}

agg_13 = list(open(DATA / "aggregated_13class_train.jsonl"))
llm_lines = []
totals = Counter()

for f in LLM_FILES:
    n = 0
    for line in open(f):
        llm_lines.append(line.strip())
        for key, offsets in json.loads(line)["spans"].items():
            totals[key.split(": ", 1)[0]] += len(offsets)
        n += 1
    print(f"{f.name}: {n} examples")

print(f"\nTotal LLM: {len(llm_lines)} examples, {sum(totals.values())} spans")
for l, c in sorted(totals.items(), key=lambda x: -x[1]):
    print(f"  {l}: {c}")

with open(DATA / "enriched_13class_train.jsonl", "w") as f:
    for line in agg_13:
        f.write(line.rstrip("\n") + "\n")
    for line in llm_lines:
        f.write(line + "\n")

with open(DATA / "enriched_5class_train.jsonl", "w") as f:
    for line in open(DATA / "aggregated_5class_train.jsonl"):
        f.write(line)
    for line in llm_lines:
        rec = json.loads(line)
        new_spans = {}
        for key, offsets in rec["spans"].items():
            l5 = LABEL_MAP_5.get(key.split(": ", 1)[0])
            if l5:
                new_spans.setdefault(f"{l5}: {key.split(': ', 1)[1]}", []).extend(offsets)
        rec["spans"] = new_spans
        if rec["spans"]:
            f.write(json.dumps(rec, ensure_ascii=False) + "\n")

for split in ["valid", "test"]:
    shutil.copy(DATA / f"aggregated_13class_{split}.jsonl", DATA / f"enriched_13class_{split}.jsonl")
    shutil.copy(DATA / f"aggregated_5class_{split}.jsonl", DATA / f"enriched_5class_{split}.jsonl")

n13 = sum(1 for _ in open(DATA / "enriched_13class_train.jsonl"))
n5 = sum(1 for _ in open(DATA / "enriched_5class_train.jsonl"))
print(f"\nEnriched 13-class: {n13} | 5-class: {n5}")