File size: 2,183 Bytes
fba140f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# -*- coding: utf-8 -*-
"""

split_attackplan_jsonl.py

Shuffle and split AttackPlan JSONL into datasets/train|val|test.jsonl



Usage:

  %run scripts/split_attackplan_jsonl.py --src "C:/Users/adetu/Dropbox/Ire_Research/my_code/scripts/train_attackplan.jsonl"

"""
from __future__ import annotations
import argparse, json, random, hashlib
from pathlib import Path

def plan_sig(plan: dict) -> str:
  # crude signature to detect duplicates across splits
  blob = json.dumps(plan.get("plan", []), sort_keys=True)
  return hashlib.md5(blob.encode("utf-8")).hexdigest()

def main():
  ap = argparse.ArgumentParser()
  ap.add_argument("--src", type=str, required=True)
  ap.add_argument("--seed", type=int, default=7)
  ap.add_argument("--train", type=float, default=0.70)
  ap.add_argument("--val", type=float, default=0.15)
  args = ap.parse_args()

  src = Path(args.src)
  lines = [ln for ln in src.read_text(encoding="utf-8-sig").splitlines() if ln.strip()]
  random.Random(args.seed).shuffle(lines)

  n = len(lines)
  ntr = int(args.train * n)
  nv = int(args.val * n)
  test_start = ntr + nv

  outdir = Path("datasets"); outdir.mkdir(exist_ok=True)
  Path(outdir, "train.jsonl").write_text("\n".join(lines[:ntr]) + "\n", encoding="utf-8")
  Path(outdir, "val.jsonl").write_text("\n".join(lines[ntr:test_start]) + "\n", encoding="utf-8")
  Path(outdir, "test.jsonl").write_text("\n".join(lines[test_start:]) + "\n", encoding="utf-8")

  # duplicate signature report
  import json as _json
  buckets = {"train":[], "val":[], "test":[]}
  for name, chunk in [("train", lines[:ntr]), ("val", lines[ntr:test_start]), ("test", lines[test_start:])]:
    for ln in chunk:
      try:
        obj = _json.loads(ln)
        if isinstance(obj, dict) and "plan" in obj:
          buckets[name].append(plan_sig(obj))
      except Exception:
        pass
  inter = set(buckets["train"]) & set(buckets["val"]) | set(buckets["train"]) & set(buckets["test"]) | set(buckets["val"]) & set(buckets["test"])
  print(f"[done] train/val/test = {ntr}/{nv}/{n - ntr - nv}. duplicate_plans_across_splits={len(inter)}")

if __name__ == "__main__":
  main()