| |
| """Unified dataset builder for Arcspan cybersecurity NER. |
| |
| Usage: |
| python scripts/build_dataset.py --output-dir data/processed --tag r7 |
| python scripts/build_dataset.py --output-dir data/processed --tag r8 --include-stucco --stucco-limit 5000 |
| |
| Outputs: |
| data/processed/{tag}_5class_train.jsonl |
| data/processed/{tag}_5class_valid.jsonl |
| data/processed/{tag}_stats.json |
| """ |
| import argparse |
| import json |
| import random |
| import sys |
| from collections import defaultdict |
| from pathlib import Path |
|
|
| |
| |
| |
| DATA_DIR = Path(__file__).resolve().parent.parent / "data" / "processed" |
|
|
| SOURCES = { |
| "base": DATA_DIR / "enriched_5class_train_cleaned_deleaked.jsonl", |
| "base_valid": DATA_DIR / "enriched_5class_valid_cleaned_trimmed.jsonl", |
| "aptner_train": DATA_DIR / "aptner_5class_train_deleaked.jsonl", |
| "aptner_dev": DATA_DIR / "aptner_5class_dev.jsonl", |
| "defanged": DATA_DIR / "defanged_augmented.jsonl", |
| "securebert2": DATA_DIR / "securebert2_5class_train_deleaked.jsonl", |
| "stucco": DATA_DIR / "stucco_nvd_5class.jsonl", |
| } |
|
|
| ENTITY_PROPAGATION_SCRIPT = ( |
| Path(__file__).resolve().parent / "entity_propagation.py" |
| ) |
|
|
|
|
| |
| |
| |
| def load_jsonl(path: Path) -> list[dict]: |
| """Load a JSONL file, skipping blank lines.""" |
| records = [] |
| with open(path) as f: |
| for line in f: |
| line = line.strip() |
| if line: |
| records.append(json.loads(line)) |
| return records |
|
|
|
|
| def write_jsonl(path: Path, records: list[dict]) -> None: |
| with open(path, "w") as f: |
| for r in records: |
| f.write(json.dumps(r, ensure_ascii=False) + "\n") |
|
|
|
|
| def dedup_key(text: str) -> str: |
| """Key for deduplication: first 80 chars (fuzzy) + full hash.""" |
| return text[:80] + "||" + text |
|
|
|
|
| def deduplicate(records: list[dict]) -> list[dict]: |
| """Remove duplicates by exact text + fuzzy first-80-char match.""" |
| seen: set[str] = set() |
| out = [] |
| for r in records: |
| key = dedup_key(r["text"]) |
| if key not in seen: |
| seen.add(key) |
| out.append(r) |
| return out |
|
|
|
|
| def count_entities(records: list[dict]) -> tuple[int, dict[str, int]]: |
| """Return (total_entities, {label: count}).""" |
| counts: dict[str, int] = defaultdict(int) |
| total = 0 |
| for r in records: |
| for key, positions in r.get("spans", {}).items(): |
| label = key.split(":")[0].strip() |
| n = len(positions) |
| counts[label] += n |
| total += n |
| return total, dict(counts) |
|
|
|
|
| def source_breakdown(records: list[dict]) -> dict[str, int]: |
| """Count records by info.source.""" |
| counts: dict[str, int] = defaultdict(int) |
| for r in records: |
| src = r.get("info", {}).get("source", "unknown") |
| counts[src] += 1 |
| return dict(counts) |
|
|
|
|
| def train_valid_split( |
| records: list[dict], valid_frac: float = 0.1, seed: int = 42 |
| ) -> tuple[list[dict], list[dict]]: |
| """Random 90/10 split with deterministic seed.""" |
| rng = random.Random(seed) |
| shuffled = list(records) |
| rng.shuffle(shuffled) |
| split_idx = max(1, int(len(shuffled) * (1 - valid_frac))) |
| return shuffled[:split_idx], shuffled[split_idx:] |
|
|
|
|
| |
| |
| |
| def build_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser( |
| description="Build merged Arcspan cybersecurity NER dataset." |
| ) |
| p.add_argument("--output-dir", required=True, type=Path) |
| p.add_argument("--tag", required=True, help="Release tag, e.g. r7, r8") |
| p.add_argument( |
| "--base-data", |
| type=Path, |
| default=None, |
| help="Override base training data path", |
| ) |
| |
| p.add_argument( |
| "--include-aptner", |
| action=argparse.BooleanOptionalAction, |
| default=True, |
| ) |
| p.add_argument( |
| "--include-defanged", |
| action=argparse.BooleanOptionalAction, |
| default=True, |
| ) |
| p.add_argument("--include-securebert2", action="store_true", default=False) |
| p.add_argument("--include-stucco", action="store_true", default=False) |
| p.add_argument("--stucco-limit", type=int, default=5000) |
| p.add_argument("--apply-propagation", action="store_true", default=False) |
| p.add_argument("--seed", type=int, default=42) |
| return p.parse_args() |
|
|
|
|
| def main() -> None: |
| args = build_args() |
| args.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| all_train: list[dict] = [] |
| all_valid: list[dict] = [] |
|
|
| |
| base_path = args.base_data or SOURCES["base"] |
| print(f"[base] Loading {base_path}") |
| all_train.extend(load_jsonl(base_path)) |
|
|
| base_valid_path = SOURCES["base_valid"] |
| if base_valid_path.exists(): |
| print(f"[base] Loading validation split {base_valid_path}") |
| all_valid.extend(load_jsonl(base_valid_path)) |
|
|
| |
| if args.include_aptner: |
| for split, key in [("train", "aptner_train"), ("dev", "aptner_dev")]: |
| p = SOURCES[key] |
| if not p.exists(): |
| print(f"[aptner] WARNING: {p} not found, skipping") |
| continue |
| print(f"[aptner] Loading {split}: {p}") |
| data = load_jsonl(p) |
| if split == "train": |
| all_train.extend(data) |
| else: |
| all_valid.extend(data) |
|
|
| |
| if args.include_defanged: |
| p = SOURCES["defanged"] |
| if not p.exists(): |
| print(f"[defanged] WARNING: {p} not found, skipping") |
| else: |
| print(f"[defanged] Loading {p}") |
| all_train.extend(load_jsonl(p)) |
|
|
| |
| if args.include_securebert2: |
| p = SOURCES["securebert2"] |
| if not p.exists(): |
| print(f"[securebert2] WARNING: {p} not found, skipping") |
| else: |
| print(f"[securebert2] Loading {p}") |
| all_train.extend(load_jsonl(p)) |
|
|
| |
| if args.include_stucco: |
| p = SOURCES["stucco"] |
| if not p.exists(): |
| print(f"[stucco] WARNING: {p} not found, skipping") |
| else: |
| print(f"[stucco] Loading {p} (limit={args.stucco_limit})") |
| data = load_jsonl(p) |
| rng = random.Random(args.seed) |
| if len(data) > args.stucco_limit: |
| data = rng.sample(data, args.stucco_limit) |
| all_train.extend(data) |
|
|
| |
| pre_dedup = len(all_train) |
| all_train = deduplicate(all_train) |
| print( |
| f"\n[dedup] Train: {pre_dedup} → {len(all_train)} " |
| f"({pre_dedup - len(all_train)} removed)" |
| ) |
|
|
| pre_dedup_v = len(all_valid) |
| all_valid = deduplicate(all_valid) |
| print( |
| f"[dedup] Valid: {pre_dedup_v} → {len(all_valid)} " |
| f"({pre_dedup_v - len(all_valid)} removed)" |
| ) |
|
|
| |
| if args.apply_propagation: |
| if not ENTITY_PROPAGATION_SCRIPT.exists(): |
| print( |
| f"[propagation] WARNING: {ENTITY_PROPAGATION_SCRIPT} not found, " |
| "skipping" |
| ) |
| else: |
| print("[propagation] Applying entity propagation...") |
| |
| sys.path.insert(0, str(ENTITY_PROPAGATION_SCRIPT.parent)) |
| from entity_propagation import propagate_entities |
|
|
| all_train = propagate_entities(all_train) |
| all_valid = propagate_entities(all_valid) |
|
|
| |
| if not all_valid: |
| print( |
| f"\n[split] No pre-existing valid data — splitting train 90/10 " |
| f"(seed={args.seed})" |
| ) |
| all_train, all_valid = train_valid_split( |
| all_train, valid_frac=0.1, seed=args.seed |
| ) |
|
|
| |
| rng = random.Random(args.seed) |
| rng.shuffle(all_train) |
|
|
| |
| train_path = args.output_dir / f"{args.tag}_5class_train.jsonl" |
| valid_path = args.output_dir / f"{args.tag}_5class_valid.jsonl" |
| write_jsonl(train_path, all_train) |
| write_jsonl(valid_path, all_valid) |
|
|
| |
| t_total, t_by_class = count_entities(all_train) |
| v_total, v_by_class = count_entities(all_valid) |
| t_sources = source_breakdown(all_train) |
| v_sources = source_breakdown(all_valid) |
|
|
| all_labels = sorted(set(list(t_by_class.keys()) + list(v_by_class.keys()))) |
|
|
| print(f"\n{'='*60}") |
| print(f" {args.tag.upper()} DATASET STATISTICS") |
| print(f"{'='*60}") |
| print(f"\n Train: {len(all_train):>7} examples, {t_total:>7} entities") |
| print(f" Valid: {len(all_valid):>7} examples, {v_total:>7} entities") |
| print(f" Total: {len(all_train)+len(all_valid):>7} examples") |
|
|
| print(f"\n --- Entity counts by class ---") |
| print(f" {'Class':<20} {'Train':>8} {'Valid':>8} {'Total':>8}") |
| for label in all_labels: |
| t = t_by_class.get(label, 0) |
| v = v_by_class.get(label, 0) |
| print(f" {label:<20} {t:>8} {v:>8} {t+v:>8}") |
|
|
| print(f"\n --- Source breakdown (train) ---") |
| for src, n in sorted(t_sources.items(), key=lambda x: -x[1]): |
| print(f" {src:<30} {n:>7}") |
|
|
| print(f"\n --- Source breakdown (valid) ---") |
| for src, n in sorted(v_sources.items(), key=lambda x: -x[1]): |
| print(f" {src:<30} {n:>7}") |
|
|
| |
| stats = { |
| "tag": args.tag, |
| "seed": args.seed, |
| "train_examples": len(all_train), |
| "valid_examples": len(all_valid), |
| "train_entities": t_total, |
| "valid_entities": v_total, |
| "entity_counts_train": t_by_class, |
| "entity_counts_valid": v_by_class, |
| "source_breakdown_train": t_sources, |
| "source_breakdown_valid": v_sources, |
| } |
| stats_path = args.output_dir / f"{args.tag}_stats.json" |
| with open(stats_path, "w") as f: |
| json.dump(stats, f, indent=2) |
|
|
| print(f"\n Written: {train_path}") |
| print(f" Written: {valid_path}") |
| print(f" Written: {stats_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|