arcspan / scripts /build_dataset.py
chairulridjal's picture
Add files using upload-large-folder tool
3dac39e verified
#!/usr/bin/env python3
"""Unified dataset builder for Arcspan cybersecurity NER.
Usage:
python scripts/build_dataset.py --output-dir data/processed --tag r7
python scripts/build_dataset.py --output-dir data/processed --tag r8 --include-stucco --stucco-limit 5000
Outputs:
data/processed/{tag}_5class_train.jsonl
data/processed/{tag}_5class_valid.jsonl
data/processed/{tag}_stats.json
"""
import argparse
import json
import random
import sys
from collections import defaultdict
from pathlib import Path
# ---------------------------------------------------------------------------
# Default source paths (relative to repo root)
# ---------------------------------------------------------------------------
DATA_DIR = Path(__file__).resolve().parent.parent / "data" / "processed"
SOURCES = {
"base": DATA_DIR / "enriched_5class_train_cleaned_deleaked.jsonl",
"base_valid": DATA_DIR / "enriched_5class_valid_cleaned_trimmed.jsonl",
"aptner_train": DATA_DIR / "aptner_5class_train_deleaked.jsonl",
"aptner_dev": DATA_DIR / "aptner_5class_dev.jsonl",
"defanged": DATA_DIR / "defanged_augmented.jsonl",
"securebert2": DATA_DIR / "securebert2_5class_train_deleaked.jsonl",
"stucco": DATA_DIR / "stucco_nvd_5class.jsonl",
}
ENTITY_PROPAGATION_SCRIPT = (
Path(__file__).resolve().parent / "entity_propagation.py"
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def load_jsonl(path: Path) -> list[dict]:
"""Load a JSONL file, skipping blank lines."""
records = []
with open(path) as f:
for line in f:
line = line.strip()
if line:
records.append(json.loads(line))
return records
def write_jsonl(path: Path, records: list[dict]) -> None:
with open(path, "w") as f:
for r in records:
f.write(json.dumps(r, ensure_ascii=False) + "\n")
def dedup_key(text: str) -> str:
"""Key for deduplication: first 80 chars (fuzzy) + full hash."""
return text[:80] + "||" + text
def deduplicate(records: list[dict]) -> list[dict]:
"""Remove duplicates by exact text + fuzzy first-80-char match."""
seen: set[str] = set()
out = []
for r in records:
key = dedup_key(r["text"])
if key not in seen:
seen.add(key)
out.append(r)
return out
def count_entities(records: list[dict]) -> tuple[int, dict[str, int]]:
"""Return (total_entities, {label: count})."""
counts: dict[str, int] = defaultdict(int)
total = 0
for r in records:
for key, positions in r.get("spans", {}).items():
label = key.split(":")[0].strip()
n = len(positions)
counts[label] += n
total += n
return total, dict(counts)
def source_breakdown(records: list[dict]) -> dict[str, int]:
"""Count records by info.source."""
counts: dict[str, int] = defaultdict(int)
for r in records:
src = r.get("info", {}).get("source", "unknown")
counts[src] += 1
return dict(counts)
def train_valid_split(
records: list[dict], valid_frac: float = 0.1, seed: int = 42
) -> tuple[list[dict], list[dict]]:
"""Random 90/10 split with deterministic seed."""
rng = random.Random(seed)
shuffled = list(records)
rng.shuffle(shuffled)
split_idx = max(1, int(len(shuffled) * (1 - valid_frac)))
return shuffled[:split_idx], shuffled[split_idx:]
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def build_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Build merged Arcspan cybersecurity NER dataset."
)
p.add_argument("--output-dir", required=True, type=Path)
p.add_argument("--tag", required=True, help="Release tag, e.g. r7, r8")
p.add_argument(
"--base-data",
type=Path,
default=None,
help="Override base training data path",
)
# Source toggles (aptner & defanged on by default)
p.add_argument(
"--include-aptner",
action=argparse.BooleanOptionalAction,
default=True,
)
p.add_argument(
"--include-defanged",
action=argparse.BooleanOptionalAction,
default=True,
)
p.add_argument("--include-securebert2", action="store_true", default=False)
p.add_argument("--include-stucco", action="store_true", default=False)
p.add_argument("--stucco-limit", type=int, default=5000)
p.add_argument("--apply-propagation", action="store_true", default=False)
p.add_argument("--seed", type=int, default=42)
return p.parse_args()
def main() -> None:
args = build_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
all_train: list[dict] = []
all_valid: list[dict] = []
# --- 1. Base data (already has a separate valid split) -----------------
base_path = args.base_data or SOURCES["base"]
print(f"[base] Loading {base_path}")
all_train.extend(load_jsonl(base_path))
base_valid_path = SOURCES["base_valid"]
if base_valid_path.exists():
print(f"[base] Loading validation split {base_valid_path}")
all_valid.extend(load_jsonl(base_valid_path))
# --- 2. APTNER (has its own train/dev split) ---------------------------
if args.include_aptner:
for split, key in [("train", "aptner_train"), ("dev", "aptner_dev")]:
p = SOURCES[key]
if not p.exists():
print(f"[aptner] WARNING: {p} not found, skipping")
continue
print(f"[aptner] Loading {split}: {p}")
data = load_jsonl(p)
if split == "train":
all_train.extend(data)
else:
all_valid.extend(data)
# --- 3. Defanged augmentation ------------------------------------------
if args.include_defanged:
p = SOURCES["defanged"]
if not p.exists():
print(f"[defanged] WARNING: {p} not found, skipping")
else:
print(f"[defanged] Loading {p}")
all_train.extend(load_jsonl(p))
# --- 4. SecureBERT2 (no separate valid split — merged into train) ------
if args.include_securebert2:
p = SOURCES["securebert2"]
if not p.exists():
print(f"[securebert2] WARNING: {p} not found, skipping")
else:
print(f"[securebert2] Loading {p}")
all_train.extend(load_jsonl(p))
# --- 5. Stucco NVD (capped) -------------------------------------------
if args.include_stucco:
p = SOURCES["stucco"]
if not p.exists():
print(f"[stucco] WARNING: {p} not found, skipping")
else:
print(f"[stucco] Loading {p} (limit={args.stucco_limit})")
data = load_jsonl(p)
rng = random.Random(args.seed)
if len(data) > args.stucco_limit:
data = rng.sample(data, args.stucco_limit)
all_train.extend(data)
# --- 6. Deduplication --------------------------------------------------
pre_dedup = len(all_train)
all_train = deduplicate(all_train)
print(
f"\n[dedup] Train: {pre_dedup}{len(all_train)} "
f"({pre_dedup - len(all_train)} removed)"
)
pre_dedup_v = len(all_valid)
all_valid = deduplicate(all_valid)
print(
f"[dedup] Valid: {pre_dedup_v}{len(all_valid)} "
f"({pre_dedup_v - len(all_valid)} removed)"
)
# --- 7. Entity propagation (optional post-processing) ------------------
if args.apply_propagation:
if not ENTITY_PROPAGATION_SCRIPT.exists():
print(
f"[propagation] WARNING: {ENTITY_PROPAGATION_SCRIPT} not found, "
"skipping"
)
else:
print("[propagation] Applying entity propagation...")
# Import and run the propagation function
sys.path.insert(0, str(ENTITY_PROPAGATION_SCRIPT.parent))
from entity_propagation import propagate_entities # type: ignore
all_train = propagate_entities(all_train)
all_valid = propagate_entities(all_valid)
# --- 8. If no pre-existing valid split, create one from train ----------
if not all_valid:
print(
f"\n[split] No pre-existing valid data — splitting train 90/10 "
f"(seed={args.seed})"
)
all_train, all_valid = train_valid_split(
all_train, valid_frac=0.1, seed=args.seed
)
# --- 9. Shuffle train --------------------------------------------------
rng = random.Random(args.seed)
rng.shuffle(all_train)
# --- 10. Write outputs -------------------------------------------------
train_path = args.output_dir / f"{args.tag}_5class_train.jsonl"
valid_path = args.output_dir / f"{args.tag}_5class_valid.jsonl"
write_jsonl(train_path, all_train)
write_jsonl(valid_path, all_valid)
# --- 11. Compute & print stats -----------------------------------------
t_total, t_by_class = count_entities(all_train)
v_total, v_by_class = count_entities(all_valid)
t_sources = source_breakdown(all_train)
v_sources = source_breakdown(all_valid)
all_labels = sorted(set(list(t_by_class.keys()) + list(v_by_class.keys())))
print(f"\n{'='*60}")
print(f" {args.tag.upper()} DATASET STATISTICS")
print(f"{'='*60}")
print(f"\n Train: {len(all_train):>7} examples, {t_total:>7} entities")
print(f" Valid: {len(all_valid):>7} examples, {v_total:>7} entities")
print(f" Total: {len(all_train)+len(all_valid):>7} examples")
print(f"\n --- Entity counts by class ---")
print(f" {'Class':<20} {'Train':>8} {'Valid':>8} {'Total':>8}")
for label in all_labels:
t = t_by_class.get(label, 0)
v = v_by_class.get(label, 0)
print(f" {label:<20} {t:>8} {v:>8} {t+v:>8}")
print(f"\n --- Source breakdown (train) ---")
for src, n in sorted(t_sources.items(), key=lambda x: -x[1]):
print(f" {src:<30} {n:>7}")
print(f"\n --- Source breakdown (valid) ---")
for src, n in sorted(v_sources.items(), key=lambda x: -x[1]):
print(f" {src:<30} {n:>7}")
# --- 12. Save stats JSON -----------------------------------------------
stats = {
"tag": args.tag,
"seed": args.seed,
"train_examples": len(all_train),
"valid_examples": len(all_valid),
"train_entities": t_total,
"valid_entities": v_total,
"entity_counts_train": t_by_class,
"entity_counts_valid": v_by_class,
"source_breakdown_train": t_sources,
"source_breakdown_valid": v_sources,
}
stats_path = args.output_dir / f"{args.tag}_stats.json"
with open(stats_path, "w") as f:
json.dump(stats, f, indent=2)
print(f"\n Written: {train_path}")
print(f" Written: {valid_path}")
print(f" Written: {stats_path}")
if __name__ == "__main__":
main()