File size: 6,187 Bytes
007fbdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from __future__ import annotations

import json
import os
from pathlib import Path
from typing import Any

from .case_factory import generate_benign_twin, generate_case_batch, generate_holdout_suite
from .schema import normalize_id, normalize_text

BASE_DIR = Path(__file__).resolve().parent
FIXTURE_DIR = BASE_DIR / "fixtures"


def load_json(name: str) -> Any:
    path = FIXTURE_DIR / name
    with path.open("r", encoding="utf-8") as file_obj:
        return json.load(file_obj)


def _index_by(rows: list[dict[str, Any]], key: str) -> dict[str, dict[str, Any]]:
    output: dict[str, dict[str, Any]] = {}
    for row in rows:
        value = row.get(key)
        if value is None:
            continue
        output[str(value)] = row
    return output


def _vendor_index(vendors: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
    output: dict[str, dict[str, Any]] = {}
    for vendor in vendors:
        keys = {
            normalize_text(vendor.get("vendor_key")),
            normalize_text(vendor.get("canonical_name")),
            normalize_text(vendor.get("vendor_name")),
        }
        for key in keys:
            if key:
                output[key] = vendor
    return output


def _ledger_vendor_index(ledger_index: list[dict[str, Any]]) -> dict[str, list[dict[str, Any]]]:
    output: dict[str, list[dict[str, Any]]] = {}
    for row in ledger_index:
        vendor_key = normalize_text(row.get("vendor_key"))
        output.setdefault(vendor_key, []).append(row)
    return output


def _case_index(cases: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
    return {str(case["case_id"]): case for case in cases if "case_id" in case}


def _case_defaults(case: dict[str, Any]) -> dict[str, Any]:
    cloned = dict(case)
    cloned.setdefault("budget_total", 15.0)
    cloned.setdefault("max_steps", 20)
    cloned.setdefault("difficulty", "medium")
    cloned.setdefault("benchmark_split", "benchmark")
    difficulty = normalize_text(cloned.get("difficulty"))
    if "due_date_days" not in cloned:
        if difficulty == "easy":
            cloned["due_date_days"] = 3
        elif difficulty in {"hard", "expert"}:
            cloned["due_date_days"] = 30
        else:
            cloned["due_date_days"] = 14
    cloned.setdefault("documents", [])
    cloned.setdefault("gold", {})
    cloned.setdefault("task_label", cloned.get("task_type", ""))
    cloned.setdefault("contrastive_pair_id", "")
    cloned.setdefault("contrastive_role", "")
    cloned.setdefault("initial_visible_doc_ids", [doc.get("doc_id") for doc in cloned.get("documents", []) if doc.get("doc_id")])
    return cloned


def _env_flag(name: str, default: bool) -> bool:
    value = os.getenv(name)
    if value is None:
        return default
    return normalize_text(value) in {"1", "true", "yes", "on"}


def load_all() -> dict[str, Any]:
    vendors = load_json("vendors.json")
    vendors_by_key = _vendor_index(vendors)
    vendor_history = load_json("vendor_history.json")
    base_cases = [_case_defaults(case) for case in load_json("cases.json")]
    hard_cases = [case for case in base_cases if normalize_text(case.get("task_type")) in {"task_c", "task_d", "task_e"}]
    include_challenge = _env_flag("LEDGERSHIELD_INCLUDE_CHALLENGE", True)
    include_holdout = _env_flag("LEDGERSHIELD_INCLUDE_HOLDOUT", False)
    include_twins = _env_flag("LEDGERSHIELD_INCLUDE_TWINS", False)
    challenge_variants = max(0, int(os.getenv("LEDGERSHIELD_CHALLENGE_VARIANTS", "2") or 2))
    challenge_seed = int(os.getenv("LEDGERSHIELD_CHALLENGE_SEED", "2026") or 2026)
    holdout_variants = max(0, int(os.getenv("LEDGERSHIELD_HOLDOUT_VARIANTS", "1") or 1))
    holdout_seed = int(os.getenv("LEDGERSHIELD_HOLDOUT_SEED", "31415") or 31415)

    cases = list(base_cases)
    if include_challenge and hard_cases and challenge_variants > 0:
        challenge_cases = generate_case_batch(
            base_cases=hard_cases,
            variants_per_case=challenge_variants,
            seed=challenge_seed,
            split="challenge",
        )
        cases.extend(_case_defaults(case) for case in challenge_cases)

    if include_holdout and hard_cases and holdout_variants > 0:
        holdout_cases = generate_holdout_suite(
            base_cases=hard_cases,
            variants_per_case=holdout_variants,
            seed=holdout_seed,
        )
        cases.extend(_case_defaults(case) for case in holdout_cases)

    if include_twins:
        for idx, case in enumerate(base_cases):
            gold = case.get("gold", {}) or {}
            if normalize_text(case.get("task_type")) not in {"task_d", "task_e"} or not gold.get("unsafe_if_pay"):
                continue
            approved_bank_account = None
            for vendor_key_candidate in {
                normalize_text(case.get("vendor_key")),
                normalize_text(gold.get("vendor_key")),
            }:
                if vendor_key_candidate and vendor_key_candidate in vendors_by_key:
                    approved_bank_account = vendors_by_key[vendor_key_candidate].get("bank_account")
                    break
            twin = generate_benign_twin(case, seed=holdout_seed + idx, approved_bank_account=approved_bank_account)
            cases.append(_case_defaults(twin))

    po_records = load_json("po_records.json")
    receipts = load_json("receipts.json")
    ledger_index = load_json("ledger_index.json")
    email_threads = load_json("email_threads.json")
    policy_rules = load_json("policy_rules.json")

    return {
        "vendors": vendors,
        "vendor_history": vendor_history,
        "cases": cases,
        "po_records": po_records,
        "receipts": receipts,
        "ledger_index": ledger_index,
        "email_threads": email_threads,
        "policy_rules": policy_rules,
        "cases_by_id": _case_index(cases),
        "vendors_by_key": vendors_by_key,
        "po_by_id": _index_by(po_records, "po_id"),
        "receipt_by_id": _index_by(receipts, "receipt_id"),
        "thread_by_id": _index_by(email_threads, "thread_id"),
        "policy_by_id": _index_by(policy_rules, "rule_id"),
        "ledger_by_vendor": _ledger_vendor_index(ledger_index),
    }