#!/usr/bin/env python3 """Unified dataset builder for Arcspan cybersecurity NER. Usage: python scripts/build_dataset.py --output-dir data/processed --tag r7 python scripts/build_dataset.py --output-dir data/processed --tag r8 --include-stucco --stucco-limit 5000 Outputs: data/processed/{tag}_5class_train.jsonl data/processed/{tag}_5class_valid.jsonl data/processed/{tag}_stats.json """ import argparse import json import random import sys from collections import defaultdict from pathlib import Path # --------------------------------------------------------------------------- # Default source paths (relative to repo root) # --------------------------------------------------------------------------- DATA_DIR = Path(__file__).resolve().parent.parent / "data" / "processed" SOURCES = { "base": DATA_DIR / "enriched_5class_train_cleaned_deleaked.jsonl", "base_valid": DATA_DIR / "enriched_5class_valid_cleaned_trimmed.jsonl", "aptner_train": DATA_DIR / "aptner_5class_train_deleaked.jsonl", "aptner_dev": DATA_DIR / "aptner_5class_dev.jsonl", "defanged": DATA_DIR / "defanged_augmented.jsonl", "securebert2": DATA_DIR / "securebert2_5class_train_deleaked.jsonl", "stucco": DATA_DIR / "stucco_nvd_5class.jsonl", } ENTITY_PROPAGATION_SCRIPT = ( Path(__file__).resolve().parent / "entity_propagation.py" ) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def load_jsonl(path: Path) -> list[dict]: """Load a JSONL file, skipping blank lines.""" records = [] with open(path) as f: for line in f: line = line.strip() if line: records.append(json.loads(line)) return records def write_jsonl(path: Path, records: list[dict]) -> None: with open(path, "w") as f: for r in records: f.write(json.dumps(r, ensure_ascii=False) + "\n") def dedup_key(text: str) -> str: """Key for deduplication: first 80 chars (fuzzy) + full hash.""" return text[:80] + "||" + text def deduplicate(records: list[dict]) -> list[dict]: """Remove duplicates by exact text + fuzzy first-80-char match.""" seen: set[str] = set() out = [] for r in records: key = dedup_key(r["text"]) if key not in seen: seen.add(key) out.append(r) return out def count_entities(records: list[dict]) -> tuple[int, dict[str, int]]: """Return (total_entities, {label: count}).""" counts: dict[str, int] = defaultdict(int) total = 0 for r in records: for key, positions in r.get("spans", {}).items(): label = key.split(":")[0].strip() n = len(positions) counts[label] += n total += n return total, dict(counts) def source_breakdown(records: list[dict]) -> dict[str, int]: """Count records by info.source.""" counts: dict[str, int] = defaultdict(int) for r in records: src = r.get("info", {}).get("source", "unknown") counts[src] += 1 return dict(counts) def train_valid_split( records: list[dict], valid_frac: float = 0.1, seed: int = 42 ) -> tuple[list[dict], list[dict]]: """Random 90/10 split with deterministic seed.""" rng = random.Random(seed) shuffled = list(records) rng.shuffle(shuffled) split_idx = max(1, int(len(shuffled) * (1 - valid_frac))) return shuffled[:split_idx], shuffled[split_idx:] # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def build_args() -> argparse.Namespace: p = argparse.ArgumentParser( description="Build merged Arcspan cybersecurity NER dataset." ) p.add_argument("--output-dir", required=True, type=Path) p.add_argument("--tag", required=True, help="Release tag, e.g. r7, r8") p.add_argument( "--base-data", type=Path, default=None, help="Override base training data path", ) # Source toggles (aptner & defanged on by default) p.add_argument( "--include-aptner", action=argparse.BooleanOptionalAction, default=True, ) p.add_argument( "--include-defanged", action=argparse.BooleanOptionalAction, default=True, ) p.add_argument("--include-securebert2", action="store_true", default=False) p.add_argument("--include-stucco", action="store_true", default=False) p.add_argument("--stucco-limit", type=int, default=5000) p.add_argument("--apply-propagation", action="store_true", default=False) p.add_argument("--seed", type=int, default=42) return p.parse_args() def main() -> None: args = build_args() args.output_dir.mkdir(parents=True, exist_ok=True) all_train: list[dict] = [] all_valid: list[dict] = [] # --- 1. Base data (already has a separate valid split) ----------------- base_path = args.base_data or SOURCES["base"] print(f"[base] Loading {base_path}") all_train.extend(load_jsonl(base_path)) base_valid_path = SOURCES["base_valid"] if base_valid_path.exists(): print(f"[base] Loading validation split {base_valid_path}") all_valid.extend(load_jsonl(base_valid_path)) # --- 2. APTNER (has its own train/dev split) --------------------------- if args.include_aptner: for split, key in [("train", "aptner_train"), ("dev", "aptner_dev")]: p = SOURCES[key] if not p.exists(): print(f"[aptner] WARNING: {p} not found, skipping") continue print(f"[aptner] Loading {split}: {p}") data = load_jsonl(p) if split == "train": all_train.extend(data) else: all_valid.extend(data) # --- 3. Defanged augmentation ------------------------------------------ if args.include_defanged: p = SOURCES["defanged"] if not p.exists(): print(f"[defanged] WARNING: {p} not found, skipping") else: print(f"[defanged] Loading {p}") all_train.extend(load_jsonl(p)) # --- 4. SecureBERT2 (no separate valid split — merged into train) ------ if args.include_securebert2: p = SOURCES["securebert2"] if not p.exists(): print(f"[securebert2] WARNING: {p} not found, skipping") else: print(f"[securebert2] Loading {p}") all_train.extend(load_jsonl(p)) # --- 5. Stucco NVD (capped) ------------------------------------------- if args.include_stucco: p = SOURCES["stucco"] if not p.exists(): print(f"[stucco] WARNING: {p} not found, skipping") else: print(f"[stucco] Loading {p} (limit={args.stucco_limit})") data = load_jsonl(p) rng = random.Random(args.seed) if len(data) > args.stucco_limit: data = rng.sample(data, args.stucco_limit) all_train.extend(data) # --- 6. Deduplication -------------------------------------------------- pre_dedup = len(all_train) all_train = deduplicate(all_train) print( f"\n[dedup] Train: {pre_dedup} → {len(all_train)} " f"({pre_dedup - len(all_train)} removed)" ) pre_dedup_v = len(all_valid) all_valid = deduplicate(all_valid) print( f"[dedup] Valid: {pre_dedup_v} → {len(all_valid)} " f"({pre_dedup_v - len(all_valid)} removed)" ) # --- 7. Entity propagation (optional post-processing) ------------------ if args.apply_propagation: if not ENTITY_PROPAGATION_SCRIPT.exists(): print( f"[propagation] WARNING: {ENTITY_PROPAGATION_SCRIPT} not found, " "skipping" ) else: print("[propagation] Applying entity propagation...") # Import and run the propagation function sys.path.insert(0, str(ENTITY_PROPAGATION_SCRIPT.parent)) from entity_propagation import propagate_entities # type: ignore all_train = propagate_entities(all_train) all_valid = propagate_entities(all_valid) # --- 8. If no pre-existing valid split, create one from train ---------- if not all_valid: print( f"\n[split] No pre-existing valid data — splitting train 90/10 " f"(seed={args.seed})" ) all_train, all_valid = train_valid_split( all_train, valid_frac=0.1, seed=args.seed ) # --- 9. Shuffle train -------------------------------------------------- rng = random.Random(args.seed) rng.shuffle(all_train) # --- 10. Write outputs ------------------------------------------------- train_path = args.output_dir / f"{args.tag}_5class_train.jsonl" valid_path = args.output_dir / f"{args.tag}_5class_valid.jsonl" write_jsonl(train_path, all_train) write_jsonl(valid_path, all_valid) # --- 11. Compute & print stats ----------------------------------------- t_total, t_by_class = count_entities(all_train) v_total, v_by_class = count_entities(all_valid) t_sources = source_breakdown(all_train) v_sources = source_breakdown(all_valid) all_labels = sorted(set(list(t_by_class.keys()) + list(v_by_class.keys()))) print(f"\n{'='*60}") print(f" {args.tag.upper()} DATASET STATISTICS") print(f"{'='*60}") print(f"\n Train: {len(all_train):>7} examples, {t_total:>7} entities") print(f" Valid: {len(all_valid):>7} examples, {v_total:>7} entities") print(f" Total: {len(all_train)+len(all_valid):>7} examples") print(f"\n --- Entity counts by class ---") print(f" {'Class':<20} {'Train':>8} {'Valid':>8} {'Total':>8}") for label in all_labels: t = t_by_class.get(label, 0) v = v_by_class.get(label, 0) print(f" {label:<20} {t:>8} {v:>8} {t+v:>8}") print(f"\n --- Source breakdown (train) ---") for src, n in sorted(t_sources.items(), key=lambda x: -x[1]): print(f" {src:<30} {n:>7}") print(f"\n --- Source breakdown (valid) ---") for src, n in sorted(v_sources.items(), key=lambda x: -x[1]): print(f" {src:<30} {n:>7}") # --- 12. Save stats JSON ----------------------------------------------- stats = { "tag": args.tag, "seed": args.seed, "train_examples": len(all_train), "valid_examples": len(all_valid), "train_entities": t_total, "valid_entities": v_total, "entity_counts_train": t_by_class, "entity_counts_valid": v_by_class, "source_breakdown_train": t_sources, "source_breakdown_valid": v_sources, } stats_path = args.output_dir / f"{args.tag}_stats.json" with open(stats_path, "w") as f: json.dump(stats, f, indent=2) print(f"\n Written: {train_path}") print(f" Written: {valid_path}") print(f" Written: {stats_path}") if __name__ == "__main__": main()