esctr-environment / ablation.py
musharraf7's picture
Upload folder using huggingface_hub
08a3b81 verified
#!/usr/bin/env python3
"""
Run ablation experiments for ESCTR reward/risk analysis.
Variants:
1) base_env -> no distractors, no risk shaping
2) distractors_only -> distractors enabled, risk shaping off
3) distractors_risk_shaping -> distractors enabled, risk shaping on
"""
import json
import os
import re
from statistics import mean
from server.environment import ESCTREnvironment
from server.models import ESCTRAction
LINE_RE = re.compile(
r"^(LI-\d+)\s+.*?\s+(\d+)\s+\$([0-9,]+\.\d{2})\s+\$([0-9,]+\.\d{2})$",
re.MULTILINE,
)
def _to_float(text: str) -> float:
return float(text.replace(",", ""))
def scripted_procurement_episode(seed: int) -> tuple[float, dict]:
env = ESCTREnvironment()
env.reset(task_name="procurement_reconciliation", seed=seed)
po_summary = env.step(
ESCTRAction(action_type="query_database", query_parameters={"table": "purchase_orders"})
)
inv_summary = env.step(
ESCTRAction(action_type="query_database", query_parameters={"table": "invoices"})
)
po_id = re.search(r"\[PRIMARY\] PO Number: ([A-Z]+-\d{4}-\d{4})", po_summary.system_response)
inv_id = re.search(r"\[PRIMARY\] Invoice: ([A-Z]+-\d{4}-\d{4})", inv_summary.system_response)
if not po_id or not inv_id:
raise RuntimeError("Could not parse primary PO/Invoice IDs from query output.")
po_doc = env.step(
ESCTRAction(action_type="read_document", document_id=po_id.group(1))
)
inv_doc = env.step(
ESCTRAction(action_type="read_document", document_id=inv_id.group(1))
)
po_rows = {m.group(1): _to_float(m.group(4)) for m in LINE_RE.finditer(po_doc.system_response)}
inv_rows = {m.group(1): _to_float(m.group(4)) for m in LINE_RE.finditer(inv_doc.system_response)}
# Slightly biased adjustment to simulate realistic model error and expose risk shaping effects.
oracle_adjustment = env._scenario.correct_adjustment # noqa: SLF001
target_adjustment = round(oracle_adjustment * 0.90, 2)
final = env.step(
ESCTRAction(
action_type="submit_financial_decision",
adjustment_amount=target_adjustment,
adjustment_reason=(
"Noisy adjustment for ablation measurement after full investigation path "
f"(oracle={oracle_adjustment:.2f}, submitted={target_adjustment:.2f})"
),
)
)
return final.reward, final.metadata
def run_variant(name: str, distractors: bool, risk_shaping: bool, seeds: range) -> dict:
os.environ["ESCTR_ENABLE_DISTRACTORS"] = "1" if distractors else "0"
os.environ["ESCTR_ENABLE_RISK_SHAPING"] = "1" if risk_shaping else "0"
rewards = []
over = []
under = []
shortcut = []
reliance = []
for seed in seeds:
score, meta = scripted_procurement_episode(seed)
rewards.append(score)
over.append(float(meta.get("risk_over_penalization", 0.0)))
under.append(float(meta.get("risk_under_penalization", 0.0)))
shortcut.append(1.0 if meta.get("risk_procedural_shortcut", False) else 0.0)
reliance.append(1.0 if meta.get("risk_vendor_reliance", False) else 0.0)
return {
"variant": name,
"episodes": len(rewards),
"mean_reward": round(mean(rewards), 4),
"mean_over_penalization_risk": round(mean(over), 4),
"mean_under_penalization_risk": round(mean(under), 4),
"procedural_shortcut_rate": round(mean(shortcut), 4),
"vendor_reliance_rate": round(mean(reliance), 4),
}
def main():
seeds = range(0, 30)
results = [
run_variant("base_env", distractors=False, risk_shaping=False, seeds=seeds),
run_variant("distractors_only", distractors=True, risk_shaping=False, seeds=seeds),
run_variant("distractors_risk_shaping", distractors=True, risk_shaping=True, seeds=seeds),
]
os.makedirs("artifacts", exist_ok=True)
with open("artifacts/ablation_results.json", "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
print(json.dumps(results, indent=2))
if __name__ == "__main__":
main()