Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Run ablation experiments for ESCTR reward/risk analysis. | |
| Variants: | |
| 1) base_env -> no distractors, no risk shaping | |
| 2) distractors_only -> distractors enabled, risk shaping off | |
| 3) distractors_risk_shaping -> distractors enabled, risk shaping on | |
| """ | |
| import json | |
| import os | |
| import re | |
| from statistics import mean | |
| from server.environment import ESCTREnvironment | |
| from server.models import ESCTRAction | |
| LINE_RE = re.compile( | |
| r"^(LI-\d+)\s+.*?\s+(\d+)\s+\$([0-9,]+\.\d{2})\s+\$([0-9,]+\.\d{2})$", | |
| re.MULTILINE, | |
| ) | |
| def _to_float(text: str) -> float: | |
| return float(text.replace(",", "")) | |
| def scripted_procurement_episode(seed: int) -> tuple[float, dict]: | |
| env = ESCTREnvironment() | |
| env.reset(task_name="procurement_reconciliation", seed=seed) | |
| po_summary = env.step( | |
| ESCTRAction(action_type="query_database", query_parameters={"table": "purchase_orders"}) | |
| ) | |
| inv_summary = env.step( | |
| ESCTRAction(action_type="query_database", query_parameters={"table": "invoices"}) | |
| ) | |
| po_id = re.search(r"\[PRIMARY\] PO Number: ([A-Z]+-\d{4}-\d{4})", po_summary.system_response) | |
| inv_id = re.search(r"\[PRIMARY\] Invoice: ([A-Z]+-\d{4}-\d{4})", inv_summary.system_response) | |
| if not po_id or not inv_id: | |
| raise RuntimeError("Could not parse primary PO/Invoice IDs from query output.") | |
| po_doc = env.step( | |
| ESCTRAction(action_type="read_document", document_id=po_id.group(1)) | |
| ) | |
| inv_doc = env.step( | |
| ESCTRAction(action_type="read_document", document_id=inv_id.group(1)) | |
| ) | |
| po_rows = {m.group(1): _to_float(m.group(4)) for m in LINE_RE.finditer(po_doc.system_response)} | |
| inv_rows = {m.group(1): _to_float(m.group(4)) for m in LINE_RE.finditer(inv_doc.system_response)} | |
| # Slightly biased adjustment to simulate realistic model error and expose risk shaping effects. | |
| oracle_adjustment = env._scenario.correct_adjustment # noqa: SLF001 | |
| target_adjustment = round(oracle_adjustment * 0.90, 2) | |
| final = env.step( | |
| ESCTRAction( | |
| action_type="submit_financial_decision", | |
| adjustment_amount=target_adjustment, | |
| adjustment_reason=( | |
| "Noisy adjustment for ablation measurement after full investigation path " | |
| f"(oracle={oracle_adjustment:.2f}, submitted={target_adjustment:.2f})" | |
| ), | |
| ) | |
| ) | |
| return final.reward, final.metadata | |
| def run_variant(name: str, distractors: bool, risk_shaping: bool, seeds: range) -> dict: | |
| os.environ["ESCTR_ENABLE_DISTRACTORS"] = "1" if distractors else "0" | |
| os.environ["ESCTR_ENABLE_RISK_SHAPING"] = "1" if risk_shaping else "0" | |
| rewards = [] | |
| over = [] | |
| under = [] | |
| shortcut = [] | |
| reliance = [] | |
| for seed in seeds: | |
| score, meta = scripted_procurement_episode(seed) | |
| rewards.append(score) | |
| over.append(float(meta.get("risk_over_penalization", 0.0))) | |
| under.append(float(meta.get("risk_under_penalization", 0.0))) | |
| shortcut.append(1.0 if meta.get("risk_procedural_shortcut", False) else 0.0) | |
| reliance.append(1.0 if meta.get("risk_vendor_reliance", False) else 0.0) | |
| return { | |
| "variant": name, | |
| "episodes": len(rewards), | |
| "mean_reward": round(mean(rewards), 4), | |
| "mean_over_penalization_risk": round(mean(over), 4), | |
| "mean_under_penalization_risk": round(mean(under), 4), | |
| "procedural_shortcut_rate": round(mean(shortcut), 4), | |
| "vendor_reliance_rate": round(mean(reliance), 4), | |
| } | |
| def main(): | |
| seeds = range(0, 30) | |
| results = [ | |
| run_variant("base_env", distractors=False, risk_shaping=False, seeds=seeds), | |
| run_variant("distractors_only", distractors=True, risk_shaping=False, seeds=seeds), | |
| run_variant("distractors_risk_shaping", distractors=True, risk_shaping=True, seeds=seeds), | |
| ] | |
| os.makedirs("artifacts", exist_ok=True) | |
| with open("artifacts/ablation_results.json", "w", encoding="utf-8") as f: | |
| json.dump(results, f, indent=2) | |
| print(json.dumps(results, indent=2)) | |
| if __name__ == "__main__": | |
| main() | |