Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| from typing import Any | |
| import httpx | |
| from openai import OpenAI | |
| SYSTEM_PROMPT = ( | |
| "You are an audit agent. Return strict JSON with keys: action_type, violation_type, confidence, note. " | |
| "Choose action_type from submit_finding, flag_human_review, noop." | |
| ) | |
| def _build_action(task_id: str, observation: dict[str, Any], client: OpenAI, model: str) -> dict[str, Any]: | |
| """Build an action using the OpenAI Chat Completions API.""" | |
| documents = observation.get("documents", []) | |
| doc_id = documents[0]["id"] if documents else "UNKNOWN" | |
| user_prompt = ( | |
| "Task: " + task_id + "\n" | |
| "Given this sample document id, propose one conservative action.\n" | |
| f"document_id: {doc_id}\n" | |
| "Return JSON only." | |
| ) | |
| completion = client.chat.completions.create( | |
| model=model, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| temperature=0, | |
| max_tokens=200, | |
| ) | |
| text = (completion.choices[0].message.content or "").strip() | |
| # Strip markdown fences if present | |
| if text.startswith("```"): | |
| lines = text.split("\n") | |
| lines = [l for l in lines if not l.strip().startswith("```")] | |
| text = "\n".join(lines).strip() | |
| # Safe fallback if model output is not parseable JSON. | |
| if not text.startswith("{"): | |
| return {"action_type": "noop", "task_id": task_id, "note": "fallback_no_json"} | |
| try: | |
| payload = json.loads(text) | |
| except Exception: | |
| return {"action_type": "noop", "task_id": task_id, "note": "fallback_parse_error"} | |
| action_type = payload.get("action_type", "noop") | |
| if action_type not in {"submit_finding", "flag_human_review", "noop"}: | |
| action_type = "noop" | |
| if action_type != "submit_finding": | |
| return {"action_type": action_type, "task_id": task_id, "note": payload.get("note", "")} | |
| violation_type = payload.get("violation_type", "duplicate_receipt") | |
| confidence = float(payload.get("confidence", 0.5)) | |
| confidence = max(0.0, min(1.0, confidence)) | |
| return { | |
| "action_type": "submit_finding", | |
| "task_id": task_id, | |
| "finding": { | |
| "document_id": doc_id, | |
| "violation_type": violation_type, | |
| "evidence": [doc_id], | |
| "confidence": confidence, | |
| }, | |
| "note": payload.get("note", "baseline_action"), | |
| } | |
| def _build_heuristic_action(task_id: str, observation: dict[str, Any]) -> dict[str, Any]: | |
| """Free fallback policy for local validation when API credits are unavailable.""" | |
| documents = observation.get("documents", []) | |
| doc_id = documents[0]["id"] if documents else "UNKNOWN" | |
| violation_map = { | |
| "easy": "duplicate_receipt", | |
| "medium": "sod_conflict", | |
| "hard": "shell_company", | |
| } | |
| return { | |
| "action_type": "submit_finding", | |
| "task_id": task_id, | |
| "finding": { | |
| "document_id": doc_id, | |
| "violation_type": violation_map.get(task_id, "duplicate_receipt"), | |
| "evidence": [doc_id], | |
| "confidence": 0.5, | |
| }, | |
| "note": "heuristic_fallback_policy", | |
| } | |
| def run_task( | |
| env_url: str, | |
| task_id: str, | |
| client: OpenAI | None, | |
| model: str, | |
| seed: int, | |
| policy: str, | |
| ) -> float: | |
| with httpx.Client(timeout=20.0) as http: | |
| obs = http.post(f"{env_url}/reset", json={"task_id": task_id, "seed": seed}).json() | |
| total = 0.0 | |
| steps = 0 | |
| done = False | |
| while not done: | |
| if policy == "heuristic": | |
| action = _build_heuristic_action(task_id=task_id, observation=obs) | |
| else: | |
| if client is None: | |
| raise RuntimeError("OPENAI_API_KEY is required for policy=openai") | |
| action = _build_action(task_id=task_id, observation=obs, client=client, model=model) | |
| result = http.post(f"{env_url}/step", json=action).json() | |
| total += float(result["reward"]["normalized"]) | |
| steps += 1 | |
| done = bool(result["done"]) | |
| obs = result["observation"] | |
| # Mean normalized reward per step (bounded [0,1] by construction) | |
| return round(total / steps, 6) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Run reproducible baseline scores on all AuditEnv tasks.") | |
| parser.add_argument("--env-url", default=os.getenv("AUDITENV_BASE_URL", "http://127.0.0.1:8000")) | |
| parser.add_argument("--model", default=os.getenv("AUDITENV_BASELINE_MODEL", "gpt-4.1-mini")) | |
| parser.add_argument( | |
| "--policy", | |
| choices=["openai", "heuristic"], | |
| default="openai", | |
| help="Action policy: 'openai' uses API key, 'heuristic' is free local fallback.", | |
| ) | |
| parser.add_argument("--seed", type=int, default=42) | |
| args = parser.parse_args() | |
| client: OpenAI | None = None | |
| if args.policy == "openai": | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| raise RuntimeError("OPENAI_API_KEY is required for --policy openai") | |
| client = OpenAI(api_key=api_key) | |
| scores = {} | |
| for task_id in ["easy", "medium", "hard"]: | |
| scores[task_id] = run_task(args.env_url, task_id, client, args.model, args.seed, args.policy) | |
| print("Baseline scores (normalized):") | |
| for task_id in ["easy", "medium", "hard"]: | |
| print(f"- {task_id}: {scores[task_id]:.6f}") | |
| if __name__ == "__main__": | |
| main() | |