File size: 6,915 Bytes
5e21013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
#!/usr/bin/env python3
"""scripts/load_eval_cases.py — load Bee Security Eval Harness YAMLs into Postgres.

Reads every `eval/bee_security_harness/cases/*.yaml`, validates the
schema, and UPSERTs each case into `public.eval_cases` keyed on
`case_key`. Idempotent: re-run after editing YAMLs and only changed
rows are touched.

Usage:
    python3 scripts/load_eval_cases.py
    python3 scripts/load_eval_cases.py --dry-run

Reads POSTGRES_URL_NON_POOLING from `.env`.
"""

from __future__ import annotations

import argparse
import json
import os
import sys
from pathlib import Path

try:
    from dotenv import load_dotenv

    load_dotenv(Path(__file__).resolve().parent.parent / ".env")
except ImportError:
    pass

import psycopg
import yaml

CASES_DIR = Path(__file__).resolve().parent.parent / "eval" / "bee_security_harness" / "cases"

CATEGORIES = {
    "insecure_code_generation",
    "prompt_injection",
    "agent_tool_abuse",
    "tenant_isolation",
    "authz_authn_failures",
    "cloud_iam_misconfigurations",
    "dependency_vulnerabilities",
    "secret_leakage",
    "unsafe_cyber_responses",
    "hallucinated_security_claims",
}

SEVERITIES = {"blocking", "warning", "info"}

GRADER_RULE_KINDS = {
    "must_contain",
    "must_not_contain",
    "must_refuse",
    "must_redirect",
    "min_length",
    "max_length",
}


def validate_case(case: dict, file: str, idx: int) -> list[str]:
    """Return list of error strings for one case (empty if valid)."""
    errors: list[str] = []
    where = f"{file}[{idx}]"

    for required in ("case_key", "severity", "prompt", "grader"):
        if required not in case:
            errors.append(f"{where}: missing required field `{required}`")

    sev = case.get("severity")
    if sev and sev not in SEVERITIES:
        errors.append(f"{where}: severity `{sev}` not in {sorted(SEVERITIES)}")

    prompt = case.get("prompt")
    if prompt is not None and not isinstance(prompt, dict):
        errors.append(f"{where}: prompt must be a dict (system + user)")
    elif isinstance(prompt, dict) and "user" not in prompt:
        errors.append(f"{where}: prompt missing `user` field")

    grader = case.get("grader") or {}
    for rule_kind in grader.keys():
        if rule_kind not in GRADER_RULE_KINDS:
            errors.append(
                f"{where}: grader rule `{rule_kind}` not recognised. "
                f"Valid: {sorted(GRADER_RULE_KINDS)}"
            )
    if "must_contain" in grader and not isinstance(grader["must_contain"], list):
        errors.append(f"{where}: grader.must_contain must be a list of regex strings")
    if "must_not_contain" in grader and not isinstance(grader["must_not_contain"], list):
        errors.append(f"{where}: grader.must_not_contain must be a list of regex strings")

    return errors


def serialize_prompt(prompt: dict) -> str:
    """Flatten {system, user} to a single text used by the runner."""
    sys_part = (prompt.get("system") or "").strip()
    usr_part = (prompt.get("user") or "").strip()
    if sys_part:
        return f"[SYSTEM]\n{sys_part}\n\n[USER]\n{usr_part}"
    return usr_part


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--dry-run", action="store_true")
    args = parser.parse_args()

    if not CASES_DIR.exists():
        print(f"ERROR: cases dir not found: {CASES_DIR}", file=sys.stderr)
        return 1

    pg_url = (os.environ.get("POSTGRES_URL_NON_POOLING") or "").strip()
    if not pg_url and not args.dry_run:
        print("ERROR: POSTGRES_URL_NON_POOLING not set", file=sys.stderr)
        return 1

    yaml_files = sorted(CASES_DIR.glob("*.yaml"))
    if not yaml_files:
        print(f"ERROR: no YAML files in {CASES_DIR}", file=sys.stderr)
        return 1

    all_cases: list[dict] = []
    all_errors: list[str] = []
    expected_categories: set[str] = set()

    for yf in yaml_files:
        with open(yf, encoding="utf-8") as f:
            doc = yaml.safe_load(f)
        if not isinstance(doc, dict) or "category" not in doc or "cases" not in doc:
            all_errors.append(f"{yf.name}: top-level must have `category` + `cases`")
            continue
        category = doc["category"]
        if category not in CATEGORIES:
            all_errors.append(
                f"{yf.name}: category `{category}` not in {sorted(CATEGORIES)}"
            )
            continue
        expected_categories.add(category)
        cases = doc.get("cases") or []
        for idx, case in enumerate(cases):
            errs = validate_case(case, yf.name, idx)
            all_errors.extend(errs)
            if errs:
                continue
            case["__category"] = category
            case["__file"] = yf.name
            all_cases.append(case)

    if all_errors:
        print("Validation errors:")
        for e in all_errors:
            print(f"  - {e}")
        return 1

    print(f"Validated {len(all_cases)} cases across {len(expected_categories)} categories.")
    missing = CATEGORIES - expected_categories
    if missing:
        print(f"  ! categories with NO cases yet: {sorted(missing)}")

    if args.dry_run:
        for c in all_cases[:3]:
            print(f"  example: {c['case_key']}  severity={c['severity']}  "
                  f"category={c['__category']}")
        print("dry-run; not writing to DB")
        return 0

    upserted = 0
    with psycopg.connect(pg_url, autocommit=False) as conn:
        with conn.cursor() as cur:
            for c in all_cases:
                tags = c.get("tags") or []
                cur.execute(
                    """
                    INSERT INTO public.eval_cases
                      (case_key, category, severity, prompt_text, grader,
                       rationale, tags, enabled, updated_at)
                    VALUES (%s, %s, %s, %s, %s::jsonb, %s, %s, TRUE, now())
                    ON CONFLICT (case_key) DO UPDATE
                      SET category = EXCLUDED.category,
                          severity = EXCLUDED.severity,
                          prompt_text = EXCLUDED.prompt_text,
                          grader = EXCLUDED.grader,
                          rationale = EXCLUDED.rationale,
                          tags = EXCLUDED.tags,
                          enabled = EXCLUDED.enabled,
                          updated_at = now()
                    """,
                    (
                        c["case_key"],
                        c["__category"],
                        c["severity"],
                        serialize_prompt(c["prompt"]),
                        json.dumps(c["grader"] or {}),
                        c.get("rationale"),
                        tags,
                    ),
                )
                upserted += 1
        conn.commit()

    print(f"Upserted {upserted} cases into public.eval_cases.")
    return 0


if __name__ == "__main__":
    sys.exit(main())