arcspan / scripts /merge_llm_annotations.py
chairulridjal's picture
Add files using upload-large-folder tool
3dac39e verified
#!/usr/bin/env python3
"""Merge all LLM annotations into enriched dataset. Re-run anytime new files appear."""
import json
from pathlib import Path
from collections import Counter
import shutil
DATA = Path("/home/ubuntu/alkyline/data/processed")
LLM_FILES = sorted(DATA.glob("llm_annotated_*.jsonl")) + sorted(DATA.glob("llm_generated_*.jsonl"))
LABEL_MAP_5 = {
"MALWARE": "Malware", "THREAT_ACTOR": None, "TOOL": None,
"VULNERABILITY": "Vulnerability", "SYSTEM": "System", "ORGANIZATION": "Organization",
"IP_ADDRESS": "Indicator", "DOMAIN": "Indicator", "URL": "Indicator",
"HASH": "Indicator", "EMAIL": "Indicator", "CVE_ID": "Vulnerability", "FILEPATH": None,
}
agg_13 = list(open(DATA / "aggregated_13class_train.jsonl"))
llm_lines = []
totals = Counter()
for f in LLM_FILES:
n = 0
for line in open(f):
llm_lines.append(line.strip())
for key, offsets in json.loads(line)["spans"].items():
totals[key.split(": ", 1)[0]] += len(offsets)
n += 1
print(f"{f.name}: {n} examples")
print(f"\nTotal LLM: {len(llm_lines)} examples, {sum(totals.values())} spans")
for l, c in sorted(totals.items(), key=lambda x: -x[1]):
print(f" {l}: {c}")
with open(DATA / "enriched_13class_train.jsonl", "w") as f:
for line in agg_13:
f.write(line.rstrip("\n") + "\n")
for line in llm_lines:
f.write(line + "\n")
with open(DATA / "enriched_5class_train.jsonl", "w") as f:
for line in open(DATA / "aggregated_5class_train.jsonl"):
f.write(line)
for line in llm_lines:
rec = json.loads(line)
new_spans = {}
for key, offsets in rec["spans"].items():
l5 = LABEL_MAP_5.get(key.split(": ", 1)[0])
if l5:
new_spans.setdefault(f"{l5}: {key.split(': ', 1)[1]}", []).extend(offsets)
rec["spans"] = new_spans
if rec["spans"]:
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
for split in ["valid", "test"]:
shutil.copy(DATA / f"aggregated_13class_{split}.jsonl", DATA / f"enriched_13class_{split}.jsonl")
shutil.copy(DATA / f"aggregated_5class_{split}.jsonl", DATA / f"enriched_5class_{split}.jsonl")
n13 = sum(1 for _ in open(DATA / "enriched_13class_train.jsonl"))
n5 = sum(1 for _ in open(DATA / "enriched_5class_train.jsonl"))
print(f"\nEnriched 13-class: {n13} | 5-class: {n5}")