Spaces:
Sleeping
Sleeping
File size: 6,187 Bytes
007fbdd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Any
from .case_factory import generate_benign_twin, generate_case_batch, generate_holdout_suite
from .schema import normalize_id, normalize_text
BASE_DIR = Path(__file__).resolve().parent
FIXTURE_DIR = BASE_DIR / "fixtures"
def load_json(name: str) -> Any:
path = FIXTURE_DIR / name
with path.open("r", encoding="utf-8") as file_obj:
return json.load(file_obj)
def _index_by(rows: list[dict[str, Any]], key: str) -> dict[str, dict[str, Any]]:
output: dict[str, dict[str, Any]] = {}
for row in rows:
value = row.get(key)
if value is None:
continue
output[str(value)] = row
return output
def _vendor_index(vendors: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
output: dict[str, dict[str, Any]] = {}
for vendor in vendors:
keys = {
normalize_text(vendor.get("vendor_key")),
normalize_text(vendor.get("canonical_name")),
normalize_text(vendor.get("vendor_name")),
}
for key in keys:
if key:
output[key] = vendor
return output
def _ledger_vendor_index(ledger_index: list[dict[str, Any]]) -> dict[str, list[dict[str, Any]]]:
output: dict[str, list[dict[str, Any]]] = {}
for row in ledger_index:
vendor_key = normalize_text(row.get("vendor_key"))
output.setdefault(vendor_key, []).append(row)
return output
def _case_index(cases: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
return {str(case["case_id"]): case for case in cases if "case_id" in case}
def _case_defaults(case: dict[str, Any]) -> dict[str, Any]:
cloned = dict(case)
cloned.setdefault("budget_total", 15.0)
cloned.setdefault("max_steps", 20)
cloned.setdefault("difficulty", "medium")
cloned.setdefault("benchmark_split", "benchmark")
difficulty = normalize_text(cloned.get("difficulty"))
if "due_date_days" not in cloned:
if difficulty == "easy":
cloned["due_date_days"] = 3
elif difficulty in {"hard", "expert"}:
cloned["due_date_days"] = 30
else:
cloned["due_date_days"] = 14
cloned.setdefault("documents", [])
cloned.setdefault("gold", {})
cloned.setdefault("task_label", cloned.get("task_type", ""))
cloned.setdefault("contrastive_pair_id", "")
cloned.setdefault("contrastive_role", "")
cloned.setdefault("initial_visible_doc_ids", [doc.get("doc_id") for doc in cloned.get("documents", []) if doc.get("doc_id")])
return cloned
def _env_flag(name: str, default: bool) -> bool:
value = os.getenv(name)
if value is None:
return default
return normalize_text(value) in {"1", "true", "yes", "on"}
def load_all() -> dict[str, Any]:
vendors = load_json("vendors.json")
vendors_by_key = _vendor_index(vendors)
vendor_history = load_json("vendor_history.json")
base_cases = [_case_defaults(case) for case in load_json("cases.json")]
hard_cases = [case for case in base_cases if normalize_text(case.get("task_type")) in {"task_c", "task_d", "task_e"}]
include_challenge = _env_flag("LEDGERSHIELD_INCLUDE_CHALLENGE", True)
include_holdout = _env_flag("LEDGERSHIELD_INCLUDE_HOLDOUT", False)
include_twins = _env_flag("LEDGERSHIELD_INCLUDE_TWINS", False)
challenge_variants = max(0, int(os.getenv("LEDGERSHIELD_CHALLENGE_VARIANTS", "2") or 2))
challenge_seed = int(os.getenv("LEDGERSHIELD_CHALLENGE_SEED", "2026") or 2026)
holdout_variants = max(0, int(os.getenv("LEDGERSHIELD_HOLDOUT_VARIANTS", "1") or 1))
holdout_seed = int(os.getenv("LEDGERSHIELD_HOLDOUT_SEED", "31415") or 31415)
cases = list(base_cases)
if include_challenge and hard_cases and challenge_variants > 0:
challenge_cases = generate_case_batch(
base_cases=hard_cases,
variants_per_case=challenge_variants,
seed=challenge_seed,
split="challenge",
)
cases.extend(_case_defaults(case) for case in challenge_cases)
if include_holdout and hard_cases and holdout_variants > 0:
holdout_cases = generate_holdout_suite(
base_cases=hard_cases,
variants_per_case=holdout_variants,
seed=holdout_seed,
)
cases.extend(_case_defaults(case) for case in holdout_cases)
if include_twins:
for idx, case in enumerate(base_cases):
gold = case.get("gold", {}) or {}
if normalize_text(case.get("task_type")) not in {"task_d", "task_e"} or not gold.get("unsafe_if_pay"):
continue
approved_bank_account = None
for vendor_key_candidate in {
normalize_text(case.get("vendor_key")),
normalize_text(gold.get("vendor_key")),
}:
if vendor_key_candidate and vendor_key_candidate in vendors_by_key:
approved_bank_account = vendors_by_key[vendor_key_candidate].get("bank_account")
break
twin = generate_benign_twin(case, seed=holdout_seed + idx, approved_bank_account=approved_bank_account)
cases.append(_case_defaults(twin))
po_records = load_json("po_records.json")
receipts = load_json("receipts.json")
ledger_index = load_json("ledger_index.json")
email_threads = load_json("email_threads.json")
policy_rules = load_json("policy_rules.json")
return {
"vendors": vendors,
"vendor_history": vendor_history,
"cases": cases,
"po_records": po_records,
"receipts": receipts,
"ledger_index": ledger_index,
"email_threads": email_threads,
"policy_rules": policy_rules,
"cases_by_id": _case_index(cases),
"vendors_by_key": vendors_by_key,
"po_by_id": _index_by(po_records, "po_id"),
"receipt_by_id": _index_by(receipts, "receipt_id"),
"thread_by_id": _index_by(email_threads, "thread_id"),
"policy_by_id": _index_by(policy_rules, "rule_id"),
"ledger_by_vendor": _ledger_vendor_index(ledger_index),
}
|