from __future__ import annotations import argparse import json import os from typing import Any import httpx from openai import OpenAI SYSTEM_PROMPT = ( "You are an audit agent. Return strict JSON with keys: action_type, violation_type, confidence, note. " "Choose action_type from submit_finding, flag_human_review, noop." ) def _build_action(task_id: str, observation: dict[str, Any], client: OpenAI, model: str) -> dict[str, Any]: """Build an action using the OpenAI Chat Completions API.""" documents = observation.get("documents", []) doc_id = documents[0]["id"] if documents else "UNKNOWN" user_prompt = ( "Task: " + task_id + "\n" "Given this sample document id, propose one conservative action.\n" f"document_id: {doc_id}\n" "Return JSON only." ) completion = client.chat.completions.create( model=model, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ], temperature=0, max_tokens=200, ) text = (completion.choices[0].message.content or "").strip() # Strip markdown fences if present if text.startswith("```"): lines = text.split("\n") lines = [l for l in lines if not l.strip().startswith("```")] text = "\n".join(lines).strip() # Safe fallback if model output is not parseable JSON. if not text.startswith("{"): return {"action_type": "noop", "task_id": task_id, "note": "fallback_no_json"} try: payload = json.loads(text) except Exception: return {"action_type": "noop", "task_id": task_id, "note": "fallback_parse_error"} action_type = payload.get("action_type", "noop") if action_type not in {"submit_finding", "flag_human_review", "noop"}: action_type = "noop" if action_type != "submit_finding": return {"action_type": action_type, "task_id": task_id, "note": payload.get("note", "")} violation_type = payload.get("violation_type", "duplicate_receipt") confidence = float(payload.get("confidence", 0.5)) confidence = max(0.0, min(1.0, confidence)) return { "action_type": "submit_finding", "task_id": task_id, "finding": { "document_id": doc_id, "violation_type": violation_type, "evidence": [doc_id], "confidence": confidence, }, "note": payload.get("note", "baseline_action"), } def _build_heuristic_action(task_id: str, observation: dict[str, Any]) -> dict[str, Any]: """Free fallback policy for local validation when API credits are unavailable.""" documents = observation.get("documents", []) doc_id = documents[0]["id"] if documents else "UNKNOWN" violation_map = { "easy": "duplicate_receipt", "medium": "sod_conflict", "hard": "shell_company", } return { "action_type": "submit_finding", "task_id": task_id, "finding": { "document_id": doc_id, "violation_type": violation_map.get(task_id, "duplicate_receipt"), "evidence": [doc_id], "confidence": 0.5, }, "note": "heuristic_fallback_policy", } def run_task( env_url: str, task_id: str, client: OpenAI | None, model: str, seed: int, policy: str, ) -> float: with httpx.Client(timeout=20.0) as http: obs = http.post(f"{env_url}/reset", json={"task_id": task_id, "seed": seed}).json() total = 0.0 steps = 0 done = False while not done: if policy == "heuristic": action = _build_heuristic_action(task_id=task_id, observation=obs) else: if client is None: raise RuntimeError("OPENAI_API_KEY is required for policy=openai") action = _build_action(task_id=task_id, observation=obs, client=client, model=model) result = http.post(f"{env_url}/step", json=action).json() total += float(result["reward"]["normalized"]) steps += 1 done = bool(result["done"]) obs = result["observation"] # Mean normalized reward per step (bounded [0,1] by construction) return round(total / steps, 6) def main() -> None: parser = argparse.ArgumentParser(description="Run reproducible baseline scores on all AuditEnv tasks.") parser.add_argument("--env-url", default=os.getenv("AUDITENV_BASE_URL", "http://127.0.0.1:8000")) parser.add_argument("--model", default=os.getenv("AUDITENV_BASELINE_MODEL", "gpt-4.1-mini")) parser.add_argument( "--policy", choices=["openai", "heuristic"], default="openai", help="Action policy: 'openai' uses API key, 'heuristic' is free local fallback.", ) parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() client: OpenAI | None = None if args.policy == "openai": api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise RuntimeError("OPENAI_API_KEY is required for --policy openai") client = OpenAI(api_key=api_key) scores = {} for task_id in ["easy", "medium", "hard"]: scores[task_id] = run_task(args.env_url, task_id, client, args.model, args.seed, args.policy) print("Baseline scores (normalized):") for task_id in ["easy", "medium", "hard"]: print(f"- {task_id}: {scores[task_id]:.6f}") if __name__ == "__main__": main()