Spaces:
Sleeping
Sleeping
File size: 4,146 Bytes
08a3b81 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | #!/usr/bin/env python3
"""
Run ablation experiments for ESCTR reward/risk analysis.
Variants:
1) base_env -> no distractors, no risk shaping
2) distractors_only -> distractors enabled, risk shaping off
3) distractors_risk_shaping -> distractors enabled, risk shaping on
"""
import json
import os
import re
from statistics import mean
from server.environment import ESCTREnvironment
from server.models import ESCTRAction
LINE_RE = re.compile(
r"^(LI-\d+)\s+.*?\s+(\d+)\s+\$([0-9,]+\.\d{2})\s+\$([0-9,]+\.\d{2})$",
re.MULTILINE,
)
def _to_float(text: str) -> float:
return float(text.replace(",", ""))
def scripted_procurement_episode(seed: int) -> tuple[float, dict]:
env = ESCTREnvironment()
env.reset(task_name="procurement_reconciliation", seed=seed)
po_summary = env.step(
ESCTRAction(action_type="query_database", query_parameters={"table": "purchase_orders"})
)
inv_summary = env.step(
ESCTRAction(action_type="query_database", query_parameters={"table": "invoices"})
)
po_id = re.search(r"\[PRIMARY\] PO Number: ([A-Z]+-\d{4}-\d{4})", po_summary.system_response)
inv_id = re.search(r"\[PRIMARY\] Invoice: ([A-Z]+-\d{4}-\d{4})", inv_summary.system_response)
if not po_id or not inv_id:
raise RuntimeError("Could not parse primary PO/Invoice IDs from query output.")
po_doc = env.step(
ESCTRAction(action_type="read_document", document_id=po_id.group(1))
)
inv_doc = env.step(
ESCTRAction(action_type="read_document", document_id=inv_id.group(1))
)
po_rows = {m.group(1): _to_float(m.group(4)) for m in LINE_RE.finditer(po_doc.system_response)}
inv_rows = {m.group(1): _to_float(m.group(4)) for m in LINE_RE.finditer(inv_doc.system_response)}
# Slightly biased adjustment to simulate realistic model error and expose risk shaping effects.
oracle_adjustment = env._scenario.correct_adjustment # noqa: SLF001
target_adjustment = round(oracle_adjustment * 0.90, 2)
final = env.step(
ESCTRAction(
action_type="submit_financial_decision",
adjustment_amount=target_adjustment,
adjustment_reason=(
"Noisy adjustment for ablation measurement after full investigation path "
f"(oracle={oracle_adjustment:.2f}, submitted={target_adjustment:.2f})"
),
)
)
return final.reward, final.metadata
def run_variant(name: str, distractors: bool, risk_shaping: bool, seeds: range) -> dict:
os.environ["ESCTR_ENABLE_DISTRACTORS"] = "1" if distractors else "0"
os.environ["ESCTR_ENABLE_RISK_SHAPING"] = "1" if risk_shaping else "0"
rewards = []
over = []
under = []
shortcut = []
reliance = []
for seed in seeds:
score, meta = scripted_procurement_episode(seed)
rewards.append(score)
over.append(float(meta.get("risk_over_penalization", 0.0)))
under.append(float(meta.get("risk_under_penalization", 0.0)))
shortcut.append(1.0 if meta.get("risk_procedural_shortcut", False) else 0.0)
reliance.append(1.0 if meta.get("risk_vendor_reliance", False) else 0.0)
return {
"variant": name,
"episodes": len(rewards),
"mean_reward": round(mean(rewards), 4),
"mean_over_penalization_risk": round(mean(over), 4),
"mean_under_penalization_risk": round(mean(under), 4),
"procedural_shortcut_rate": round(mean(shortcut), 4),
"vendor_reliance_rate": round(mean(reliance), 4),
}
def main():
seeds = range(0, 30)
results = [
run_variant("base_env", distractors=False, risk_shaping=False, seeds=seeds),
run_variant("distractors_only", distractors=True, risk_shaping=False, seeds=seeds),
run_variant("distractors_risk_shaping", distractors=True, risk_shaping=True, seeds=seeds),
]
os.makedirs("artifacts", exist_ok=True)
with open("artifacts/ablation_results.json", "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
print(json.dumps(results, indent=2))
if __name__ == "__main__":
main()
|