Corp_AI / scripts /run_baseline.py
Arpit Deep
feat: initial AuditEnv submission
a617acd
from __future__ import annotations
import argparse
import json
import os
from typing import Any
import httpx
from openai import OpenAI
SYSTEM_PROMPT = (
"You are an audit agent. Return strict JSON with keys: action_type, violation_type, confidence, note. "
"Choose action_type from submit_finding, flag_human_review, noop."
)
def _build_action(task_id: str, observation: dict[str, Any], client: OpenAI, model: str) -> dict[str, Any]:
"""Build an action using the OpenAI Chat Completions API."""
documents = observation.get("documents", [])
doc_id = documents[0]["id"] if documents else "UNKNOWN"
user_prompt = (
"Task: " + task_id + "\n"
"Given this sample document id, propose one conservative action.\n"
f"document_id: {doc_id}\n"
"Return JSON only."
)
completion = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
temperature=0,
max_tokens=200,
)
text = (completion.choices[0].message.content or "").strip()
# Strip markdown fences if present
if text.startswith("```"):
lines = text.split("\n")
lines = [l for l in lines if not l.strip().startswith("```")]
text = "\n".join(lines).strip()
# Safe fallback if model output is not parseable JSON.
if not text.startswith("{"):
return {"action_type": "noop", "task_id": task_id, "note": "fallback_no_json"}
try:
payload = json.loads(text)
except Exception:
return {"action_type": "noop", "task_id": task_id, "note": "fallback_parse_error"}
action_type = payload.get("action_type", "noop")
if action_type not in {"submit_finding", "flag_human_review", "noop"}:
action_type = "noop"
if action_type != "submit_finding":
return {"action_type": action_type, "task_id": task_id, "note": payload.get("note", "")}
violation_type = payload.get("violation_type", "duplicate_receipt")
confidence = float(payload.get("confidence", 0.5))
confidence = max(0.0, min(1.0, confidence))
return {
"action_type": "submit_finding",
"task_id": task_id,
"finding": {
"document_id": doc_id,
"violation_type": violation_type,
"evidence": [doc_id],
"confidence": confidence,
},
"note": payload.get("note", "baseline_action"),
}
def _build_heuristic_action(task_id: str, observation: dict[str, Any]) -> dict[str, Any]:
"""Free fallback policy for local validation when API credits are unavailable."""
documents = observation.get("documents", [])
doc_id = documents[0]["id"] if documents else "UNKNOWN"
violation_map = {
"easy": "duplicate_receipt",
"medium": "sod_conflict",
"hard": "shell_company",
}
return {
"action_type": "submit_finding",
"task_id": task_id,
"finding": {
"document_id": doc_id,
"violation_type": violation_map.get(task_id, "duplicate_receipt"),
"evidence": [doc_id],
"confidence": 0.5,
},
"note": "heuristic_fallback_policy",
}
def run_task(
env_url: str,
task_id: str,
client: OpenAI | None,
model: str,
seed: int,
policy: str,
) -> float:
with httpx.Client(timeout=20.0) as http:
obs = http.post(f"{env_url}/reset", json={"task_id": task_id, "seed": seed}).json()
total = 0.0
steps = 0
done = False
while not done:
if policy == "heuristic":
action = _build_heuristic_action(task_id=task_id, observation=obs)
else:
if client is None:
raise RuntimeError("OPENAI_API_KEY is required for policy=openai")
action = _build_action(task_id=task_id, observation=obs, client=client, model=model)
result = http.post(f"{env_url}/step", json=action).json()
total += float(result["reward"]["normalized"])
steps += 1
done = bool(result["done"])
obs = result["observation"]
# Mean normalized reward per step (bounded [0,1] by construction)
return round(total / steps, 6)
def main() -> None:
parser = argparse.ArgumentParser(description="Run reproducible baseline scores on all AuditEnv tasks.")
parser.add_argument("--env-url", default=os.getenv("AUDITENV_BASE_URL", "http://127.0.0.1:8000"))
parser.add_argument("--model", default=os.getenv("AUDITENV_BASELINE_MODEL", "gpt-4.1-mini"))
parser.add_argument(
"--policy",
choices=["openai", "heuristic"],
default="openai",
help="Action policy: 'openai' uses API key, 'heuristic' is free local fallback.",
)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
client: OpenAI | None = None
if args.policy == "openai":
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError("OPENAI_API_KEY is required for --policy openai")
client = OpenAI(api_key=api_key)
scores = {}
for task_id in ["easy", "medium", "hard"]:
scores[task_id] = run_task(args.env_url, task_id, client, args.model, args.seed, args.policy)
print("Baseline scores (normalized):")
for task_id in ["easy", "medium", "hard"]:
print(f"- {task_id}: {scores[task_id]:.6f}")
if __name__ == "__main__":
main()