ledgershield / server /data_loader.py
king673134's picture
Upload folder using huggingface_hub
007fbdd verified
Raw
History Blame Contribute Delete
6.19 kB
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Any
from .case_factory import generate_benign_twin, generate_case_batch, generate_holdout_suite
from .schema import normalize_id, normalize_text
BASE_DIR = Path(__file__).resolve().parent
FIXTURE_DIR = BASE_DIR / "fixtures"
def load_json(name: str) -> Any:
path = FIXTURE_DIR / name
with path.open("r", encoding="utf-8") as file_obj:
return json.load(file_obj)
def _index_by(rows: list[dict[str, Any]], key: str) -> dict[str, dict[str, Any]]:
output: dict[str, dict[str, Any]] = {}
for row in rows:
value = row.get(key)
if value is None:
continue
output[str(value)] = row
return output
def _vendor_index(vendors: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
output: dict[str, dict[str, Any]] = {}
for vendor in vendors:
keys = {
normalize_text(vendor.get("vendor_key")),
normalize_text(vendor.get("canonical_name")),
normalize_text(vendor.get("vendor_name")),
}
for key in keys:
if key:
output[key] = vendor
return output
def _ledger_vendor_index(ledger_index: list[dict[str, Any]]) -> dict[str, list[dict[str, Any]]]:
output: dict[str, list[dict[str, Any]]] = {}
for row in ledger_index:
vendor_key = normalize_text(row.get("vendor_key"))
output.setdefault(vendor_key, []).append(row)
return output
def _case_index(cases: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
return {str(case["case_id"]): case for case in cases if "case_id" in case}
def _case_defaults(case: dict[str, Any]) -> dict[str, Any]:
cloned = dict(case)
cloned.setdefault("budget_total", 15.0)
cloned.setdefault("max_steps", 20)
cloned.setdefault("difficulty", "medium")
cloned.setdefault("benchmark_split", "benchmark")
difficulty = normalize_text(cloned.get("difficulty"))
if "due_date_days" not in cloned:
if difficulty == "easy":
cloned["due_date_days"] = 3
elif difficulty in {"hard", "expert"}:
cloned["due_date_days"] = 30
else:
cloned["due_date_days"] = 14
cloned.setdefault("documents", [])
cloned.setdefault("gold", {})
cloned.setdefault("task_label", cloned.get("task_type", ""))
cloned.setdefault("contrastive_pair_id", "")
cloned.setdefault("contrastive_role", "")
cloned.setdefault("initial_visible_doc_ids", [doc.get("doc_id") for doc in cloned.get("documents", []) if doc.get("doc_id")])
return cloned
def _env_flag(name: str, default: bool) -> bool:
value = os.getenv(name)
if value is None:
return default
return normalize_text(value) in {"1", "true", "yes", "on"}
def load_all() -> dict[str, Any]:
vendors = load_json("vendors.json")
vendors_by_key = _vendor_index(vendors)
vendor_history = load_json("vendor_history.json")
base_cases = [_case_defaults(case) for case in load_json("cases.json")]
hard_cases = [case for case in base_cases if normalize_text(case.get("task_type")) in {"task_c", "task_d", "task_e"}]
include_challenge = _env_flag("LEDGERSHIELD_INCLUDE_CHALLENGE", True)
include_holdout = _env_flag("LEDGERSHIELD_INCLUDE_HOLDOUT", False)
include_twins = _env_flag("LEDGERSHIELD_INCLUDE_TWINS", False)
challenge_variants = max(0, int(os.getenv("LEDGERSHIELD_CHALLENGE_VARIANTS", "2") or 2))
challenge_seed = int(os.getenv("LEDGERSHIELD_CHALLENGE_SEED", "2026") or 2026)
holdout_variants = max(0, int(os.getenv("LEDGERSHIELD_HOLDOUT_VARIANTS", "1") or 1))
holdout_seed = int(os.getenv("LEDGERSHIELD_HOLDOUT_SEED", "31415") or 31415)
cases = list(base_cases)
if include_challenge and hard_cases and challenge_variants > 0:
challenge_cases = generate_case_batch(
base_cases=hard_cases,
variants_per_case=challenge_variants,
seed=challenge_seed,
split="challenge",
)
cases.extend(_case_defaults(case) for case in challenge_cases)
if include_holdout and hard_cases and holdout_variants > 0:
holdout_cases = generate_holdout_suite(
base_cases=hard_cases,
variants_per_case=holdout_variants,
seed=holdout_seed,
)
cases.extend(_case_defaults(case) for case in holdout_cases)
if include_twins:
for idx, case in enumerate(base_cases):
gold = case.get("gold", {}) or {}
if normalize_text(case.get("task_type")) not in {"task_d", "task_e"} or not gold.get("unsafe_if_pay"):
continue
approved_bank_account = None
for vendor_key_candidate in {
normalize_text(case.get("vendor_key")),
normalize_text(gold.get("vendor_key")),
}:
if vendor_key_candidate and vendor_key_candidate in vendors_by_key:
approved_bank_account = vendors_by_key[vendor_key_candidate].get("bank_account")
break
twin = generate_benign_twin(case, seed=holdout_seed + idx, approved_bank_account=approved_bank_account)
cases.append(_case_defaults(twin))
po_records = load_json("po_records.json")
receipts = load_json("receipts.json")
ledger_index = load_json("ledger_index.json")
email_threads = load_json("email_threads.json")
policy_rules = load_json("policy_rules.json")
return {
"vendors": vendors,
"vendor_history": vendor_history,
"cases": cases,
"po_records": po_records,
"receipts": receipts,
"ledger_index": ledger_index,
"email_threads": email_threads,
"policy_rules": policy_rules,
"cases_by_id": _case_index(cases),
"vendors_by_key": vendors_by_key,
"po_by_id": _index_by(po_records, "po_id"),
"receipt_by_id": _index_by(receipts, "receipt_id"),
"thread_by_id": _index_by(email_threads, "thread_id"),
"policy_by_id": _index_by(policy_rules, "rule_id"),
"ledger_by_vendor": _ledger_vendor_index(ledger_index),
}