File size: 4,110 Bytes
fba140f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# -*- coding: utf-8 -*-
"""

augment_ops_coverage.py

Create op-augmented plans from a cleaned AttackPlan JSONL.



Usage:

  %run scripts/augment_ops_coverage.py --src scripts/train_attackplan.filtered.jsonl --sem kb/specs/point_semantics.json --out scripts/train_attackplan.aug.jsonl

"""
from __future__ import annotations
import argparse, json, random
from pathlib import Path

def load_sem(p): return json.loads(Path(p).read_text(encoding="utf-8"))

def coerce_num(x):
    if isinstance(x, (int, float)): return float(x)
    try: return float(str(x))
    except Exception: return None

def make_numeric_variants(it, spec, rng):
    base = []
    v = it.get("attack_value")
    num = coerce_num(v)
    if num is None: return base
    ops = spec.get("ops", [])
    inc = spec.get("numeric_increase_pct", 0.10)
    dec = spec.get("numeric_decrease_pct", 0.10)
    scales = spec.get("numeric_scale_factors", [0.5, 1.5])

    if "increase" in ops: 
        j = dict(it); j["op"] = "increase"; j["attack_value"] = round(num * (1 + inc), 3); base.append(j)
    if "decrease" in ops: 
        j = dict(it); j["op"] = "decrease"; j["attack_value"] = round(num * (1 - dec), 3); base.append(j)
    if "scale" in ops:
        for s in scales:
            j = dict(it); j["op"] = "scale"; j["attack_value"] = round(num * s, 3); base.append(j)
    return base

def make_enum_variants(it, spec):
    base = []
    vals = [str(v).upper() for v in spec.get("values", [])]
    if not vals: return base
    cur = str(it.get("attack_value", "")).upper()
    if "open" in spec.get("ops", []):
        j = dict(it); j["op"] = "open";  j["attack_value"] = "OPEN";  base.append(j)
    if "close" in spec.get("ops", []):
        j = dict(it); j["op"] = "close"; j["attack_value"] = "CLOSED"; base.append(j)
    if "trip" in spec.get("ops", []) and "OPEN" in vals:
        j = dict(it); j["op"] = "trip";  j["attack_value"] = "OPEN";  base.append(j)
    return base

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--src", required=True)
    ap.add_argument("--sem", required=True)
    ap.add_argument("--out", required=True)
    ap.add_argument("--seed", type=int, default=7)
    ap.add_argument("--max_aug_per_item", type=int, default=2, help="cap how many variants per original item")
    args = ap.parse_args()

    rng = random.Random(args.seed)
    sem = load_sem(args.sem)
    props = sem.get("properties", {})
    defaults = sem.get("defaults", {})

    src_lines = Path(args.src).read_text(encoding="utf-8-sig").splitlines()
    out = []
    n_src = 0; n_aug_plans = 0

    for ln in src_lines:
        if not ln.strip(): continue
        plan = json.loads(ln)
        n_src += 1
        out.append(ln)  # keep original

        # Make a simple one-item plan per augmented variant for clarity
        for it in plan.get("plan", []):
            point = it.get("point","")
            spec = dict(defaults); spec.update(props.get(point, {}))
            t = spec.get("type")
            variants = []
            if t in {"number","number_or_complex"}:
                variants = make_numeric_variants(it, spec, rng)
            elif t == "enum":
                variants = make_enum_variants(it, spec)

            rng.shuffle(variants)
            for v in variants[:args.max_aug_per_item]:
                out.append(json.dumps({
                    "version": plan.get("version","1.1"),
                    "time": plan.get("time", {"start_s":0,"end_s":60}),
                    "mim": plan.get("mim", {"active":True,"selected":["MIM1","MIM2","MIM3","MIM4"]}),
                    "plan": [v],
                    "compile_hints": plan.get("compile_hints", {"scenario_id":"a"})
                }, ensure_ascii=False))
                n_aug_plans += 1

    Path(args.out).write_text("\n".join(out) + "\n", encoding="utf-8")
    print(f"[done] plans_in={n_src}, aug_plans_added={n_aug_plans}, total_out={len(out)}")
    print(f"[wrote] {Path(args.out).resolve()}")

if __name__ == "__main__":
    main()