""" AuditEnv Baseline Runner — supports LLM (Groq/OpenAI) and signal-aware heuristic policies. Usage: # Heuristic (no API key needed): python scripts/run_baseline.py --policy heuristic # LLM via Groq: $env:OPENAI_API_KEY="gsk_..." python scripts/run_baseline.py --policy openai --model llama-3.3-70b-versatile """ from __future__ import annotations import argparse import json import os import sys import textwrap from typing import Any import httpx from openai import OpenAI # --------------------------------------------------------------------------- # LLM System Prompt — much richer than before # --------------------------------------------------------------------------- SYSTEM_PROMPT = textwrap.dedent("""\ You are an expert compliance auditor AI. You are reviewing documents inside an automated audit environment. Your goal is to find genuine policy violations and fraud — but ONLY when the evidence is clear. ## CRITICAL RULES 1. **Think before acting.** For EVERY document batch, first ask yourself: "Is there a SPECIFIC fraud signal present in the text?" If no clear signal exists, output `noop`. Do NOT guess or assume fraud. 1.1 **Default to noop.** `submit_finding` is allowed only when the exact fraud signal appears in document text for this task. 2. **Confidence threshold:** If your confidence is below 0.70, output `noop`. 3. **One action per step.** Return exactly one JSON object per step. 4. **No markdown.** Return ONLY raw JSON, no fences, no explanation. ## STEP-BY-STEP REASONING (follow this every step) 1. Read every document in the batch carefully. 2. For each document, check: does it contain a known fraud signal? (see list below) 3. If YES → build a `submit_finding` with the matching document_id & violation_type. 4. If UNSURE → use `flag_human_review` with a note explaining the concern. 5. If NO signals found → use `noop`. ## KNOWN FRAUD SIGNALS Easy tasks: - "DUPLICATE_FLAG=true" or "matches_receipt=" → violation_type: "duplicate_receipt" - "alcohol_amount" exceeding "policy_limit" → violation_type: "alcohol_over_limit" - "LATE=true" or submission_date past policy_deadline → violation_type: "late_submission" Medium tasks: - "sod_conflict" or "segregation_of_duties" → violation_type: "sod_conflict" - "dormant_account" or "DORMANT=true" → violation_type: "dormant_account_reactivation" - "temporal_anomaly" or "off_hours" or "suspicious_hour" → violation_type: "temporal_anomaly" Hard tasks: - "shell_company" or "SHELL=true" → violation_type: "shell_company" - "invoice_splitting" or "split_invoice" → violation_type: "invoice_splitting" - "round_tripping" or "ROUND_TRIP=true" → violation_type: "round_tripping" ## EVIDENCE FORMAT Evidence should be a list of document IDs that support the finding. Use format: ["", ""] ## FEW-SHOT EXAMPLES ### Example 1 — True positive (correct finding) Documents: [{"id": "e-DOC-003", "text": "...amount=$185.00; DUPLICATE_FLAG=true; matches_receipt=RCT-12345..."}] Correct output: {"action_type": "submit_finding", "document_id": "e-DOC-003", "violation_type": "duplicate_receipt", "evidence": ["e-DOC-003", "e-DOC-002"], "confidence": 0.92, "note": "DUPLICATE_FLAG=true present with matching receipt reference"} ### Example 2 — No signal present (correct noop) Documents: [{"id": "e-DOC-007", "text": "employee=Alice; amount=$45.00; expense_type=meals; description=Business lunch"}] Correct output: {"action_type": "noop", "note": "No fraud signals detected in document batch"} ### Example 3 — Incorrect false positive (DO NOT DO THIS) Documents: [{"id": "e-DOC-010", "text": "employee=Bob; amount=$200.00; expense_type=travel"}] WRONG output: {"action_type": "submit_finding", "document_id": "e-DOC-010", "violation_type": "duplicate_receipt", ...} WHY WRONG: No DUPLICATE_FLAG, no alcohol signal, no LATE signal. This is a clean document. ## ACTION FORMATS submit_finding: {"action_type": "submit_finding", "document_id": "", "violation_type": "", "evidence": ["", ""], "confidence": 0.85, "note": "explanation"} flag_human_review: {"action_type": "flag_human_review", "note": "explanation of concern"} noop: {"action_type": "noop", "note": "reason no finding"} """) # --------------------------------------------------------------------------- # Signal-aware heuristic policy (Task 1.1 — FUTURE_PLAN.md) # --------------------------------------------------------------------------- _EASY_SIGNALS: list[tuple[str, str]] = [ ("DUPLICATE_FLAG=true", "duplicate_receipt"), ("is_duplicate_invoice_id=true", "duplicate_receipt"), ("duplicate_invoice_group_size", "duplicate_receipt"), ("matches_receipt=", "duplicate_receipt"), ("alcohol_amount", "alcohol_over_limit"), ("alcohol_over_limit", "alcohol_over_limit"), ("LATE=true", "late_submission"), ("policy_deadline", "late_submission"), ] _MEDIUM_SIGNALS: list[tuple[str, str]] = [ ("sod_conflict", "sod_conflict"), ("segregation_of_duties", "sod_conflict"), ("dormant_account", "dormant_account_reactivation"), ("DORMANT=true", "dormant_account_reactivation"), ("temporal_anomaly", "temporal_anomaly"), ("suspicious_hour", "temporal_anomaly"), ("off_hours", "temporal_anomaly"), ] _HARD_SIGNALS: list[tuple[str, str]] = [ ("shell_company", "shell_company"), ("SHELL=true", "shell_company"), ("vendor_registration_age_days=", "shell_company"), ("invoice_splitting", "invoice_splitting"), ("split_invoice", "invoice_splitting"), ("round_tripping", "round_tripping"), ("ROUND_TRIP=true", "round_tripping"), ] _TASK_SIGNALS: dict[str, list[tuple[str, str]]] = { "easy": _EASY_SIGNALS, "medium": _MEDIUM_SIGNALS, "hard": _HARD_SIGNALS, } _DEFAULT_VIOLATION: dict[str, str] = { "easy": "duplicate_receipt", "medium": "sod_conflict", "hard": "shell_company", } def _detect_violation(text: str, task_id: str) -> str | None: """Return the first matched violation type for the given document text.""" for signal, vtype in _TASK_SIGNALS.get(task_id, []): if signal.lower() in text.lower(): return vtype return None def _build_heuristic_action(task_id: str, observation: dict[str, Any]) -> dict[str, Any]: """Signal-aware heuristic — reads embedded fraud clues from document text. 1. Scans all visible documents for known fraud signals. 2. Falls back to noop if no signals found. """ documents = observation.get("documents", []) if not documents: return {"action_type": "noop", "task_id": task_id, "note": "no_documents"} # Scan all visible documents for fraud signals for doc in documents: doc_id = doc.get("id", "UNKNOWN") text = doc.get("text", "") vtype = _detect_violation(text, task_id) if vtype: idx = documents.index(doc) neighbor_id = documents[max(0, idx - 1)]["id"] evidence = [doc_id] if neighbor_id == doc_id else [doc_id, neighbor_id] return { "action_type": "submit_finding", "task_id": task_id, "finding": { "document_id": doc_id, "violation_type": vtype, "evidence": evidence, "confidence": 0.85, }, "note": f"signal_detected:{vtype}", } # Fallback — safe abstention when no explicit signal is present return { "action_type": "noop", "task_id": task_id, "note": "heuristic_no_signal_noop", } # --------------------------------------------------------------------------- # LLM policy — sends full document context to Groq/OpenAI # --------------------------------------------------------------------------- def _build_llm_action(task_id: str, observation: dict[str, Any], client: OpenAI, model: str) -> dict[str, Any]: """Build an action using chat completions (Groq, OpenAI, etc).""" documents = observation.get("documents", []) # Build rich document context (limit to 10 docs for token budget) doc_lines = [] for doc in documents[:10]: doc_lines.append( f" - ID: {doc.get('id', 'N/A')}, Type: {doc.get('type', 'N/A')}, " f"Text: {doc.get('text', '')[:300]}" ) docs_text = "\n".join(doc_lines) if doc_lines else " (no documents)" findings_submitted = observation.get("findings_submitted", 0) steps_remaining = observation.get("steps_remaining", "?") current_score = observation.get("current_partial_score", 0.0) user_prompt = textwrap.dedent(f"""\ Task difficulty: {task_id} Findings submitted so far: {findings_submitted} Steps remaining: {steps_remaining} Current score: {current_score:.2f} Documents to review: {docs_text} Analyze ALL documents carefully. Look for violations matching {task_id} difficulty. Return a single JSON action object. """) try: completion = client.chat.completions.create( model=model, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ], temperature=0.2, max_tokens=400, ) text = (completion.choices[0].message.content or "").strip() # Strip markdown fences if present if text.startswith("```"): lines = text.split("\n") lines = [l for l in lines if not l.strip().startswith("```")] text = "\n".join(lines).strip() if not text.startswith("{"): return _build_heuristic_action(task_id, observation) payload = json.loads(text) except (json.JSONDecodeError, Exception) as exc: print(f" [WARN] LLM parse/request failed: {exc}", file=sys.stderr) return _build_heuristic_action(task_id, observation) # Sanitize the LLM output action_type = payload.get("action_type", "noop") if action_type not in {"submit_finding", "flag_human_review", "noop"}: action_type = "noop" if action_type != "submit_finding": return {"action_type": action_type, "task_id": task_id, "note": str(payload.get("note", ""))[:200]} # Build structured finding from LLM output doc_id = payload.get("document_id", documents[0]["id"] if documents else "UNKNOWN") violation_type = payload.get("violation_type", _DEFAULT_VIOLATION.get(task_id, "duplicate_receipt")) evidence = payload.get("evidence", [doc_id]) if not isinstance(evidence, list): evidence = [evidence] confidence = float(payload.get("confidence", 0.5)) confidence = max(0.0, min(1.0, confidence)) return { "action_type": "submit_finding", "task_id": task_id, "finding": { "document_id": doc_id, "violation_type": violation_type, "evidence": evidence, "confidence": confidence, }, "note": str(payload.get("note", "llm_action"))[:200], } def _safe_reward_fields(result: dict[str, Any]) -> tuple[float, str]: """Extract normalized reward and reason without raising on malformed payloads.""" reward = result.get("reward") if not isinstance(reward, dict): return 0.0, "missing_reward_payload" reason = str(reward.get("reason", "")) try: reward_norm = float(reward.get("normalized", 0.0)) except (TypeError, ValueError): return 0.0, f"{reason}|invalid_reward_value" if reason else "invalid_reward_value" return reward_norm, reason # --------------------------------------------------------------------------- # Episode runner # --------------------------------------------------------------------------- def run_task( env_url: str, task_id: str, client: OpenAI | None, model: str, seed: int, policy: str, ) -> dict[str, Any]: """Run a complete episode. Returns a result dict with score, steps, and per-step details.""" with httpx.Client(timeout=30.0) as http: try: reset_resp = http.post(f"{env_url}/reset", json={"task_id": task_id, "seed": seed}) reset_resp.raise_for_status() obs = reset_resp.json() session_id = obs.get("session_id") if isinstance(obs, dict) else None except Exception as exc: print(f" [WARN] reset failed for task={task_id}: {exc}", file=sys.stderr) return { "task_id": task_id, "score": 0.0, "steps": 0, "log": [], "completed": False, "error": f"reset_failed:{type(exc).__name__}", } total_reward = 0.0 steps = 0 step_log: list[dict[str, Any]] = [] done = False hard_step_cap = 40 if isinstance(obs, dict): raw_cap = obs.get("steps_remaining") if isinstance(raw_cap, int): # Keep a bounded safety margin while allowing full hard episodes to finish. hard_step_cap = max(8, min(64, raw_cap + 4)) task_error = "" while not done and steps < hard_step_cap: if policy == "heuristic": action = _build_heuristic_action(task_id=task_id, observation=obs) else: if client is None: raise RuntimeError("OPENAI_API_KEY is required for policy=openai") action = _build_llm_action(task_id=task_id, observation=obs, client=client, model=model) if session_id and "session_id" not in action: action["session_id"] = session_id try: step_resp = http.post(f"{env_url}/step", json=action) step_resp.raise_for_status() result = step_resp.json() except Exception as exc: task_error = f"step_failed:{type(exc).__name__}" step_log.append( { "step": steps + 1, "action_type": action.get("action_type"), "reward_norm": 0.0, "reward_reason": task_error, "done": False, } ) print(f" [WARN] step failed for task={task_id}: {exc}", file=sys.stderr) break reward_norm, reward_reason = _safe_reward_fields(result) total_reward += reward_norm steps += 1 done = bool(result.get("done", False)) obs = result.get("observation", obs) # Log this step entry = { "step": steps, "action_type": action.get("action_type"), "reward_norm": reward_norm, "reward_reason": reward_reason, "done": done, } if action.get("finding"): entry["doc_id"] = action["finding"]["document_id"] entry["violation"] = action["finding"]["violation_type"] step_log.append(entry) print( f" Step {steps:2d} │ {action.get('action_type'):18s} │ " f"reward={reward_norm:.3f} │ reason={reward_reason} │ " f"done={done}" ) if not done and not task_error and steps >= hard_step_cap: task_error = "max_steps_reached" mean_score = round(total_reward / steps, 6) if steps else 0.0 return { "task_id": task_id, "score": mean_score, "steps": steps, "log": step_log, "completed": done, "error": task_error, } # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: parser = argparse.ArgumentParser(description="Run reproducible baseline scores on all AuditEnv tasks.") parser.add_argument("--env-url", default=os.getenv("AUDITENV_BASE_URL", "http://127.0.0.1:8000")) parser.add_argument("--model", default=os.getenv("AUDITENV_BASELINE_MODEL", "llama-3.3-70b-versatile")) parser.add_argument("--base-url", default=os.getenv("OPENAI_BASE_URL", "https://api.groq.com/openai/v1")) parser.add_argument( "--policy", choices=["openai", "heuristic"], default="heuristic", help="Action policy: 'openai' uses Groq/OpenAI API, 'heuristic' is free local fallback.", ) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--save-log", default="", help="Path to save per-step JSONL log.") parser.add_argument( "--include-partial-log", action="store_true", help="Include incomplete task episodes in --save-log output.", ) args = parser.parse_args() print(f"╔══════════════════════════════════════════════╗") print(f"║ AuditEnv Baseline Runner ║") print(f"║ Policy: {args.policy:10s} Seed: {args.seed:<10d} ║") print(f"╚══════════════════════════════════════════════╝") client: OpenAI | None = None if args.policy == "openai": api_key = os.getenv("OPENAI_API_KEY") if not api_key: print("ERROR: Set OPENAI_API_KEY env var for --policy openai", file=sys.stderr) sys.exit(1) client = OpenAI(api_key=api_key, base_url=args.base_url) print(f" Model: {args.model}") print(f" API: {args.base_url}") print() results: list[dict[str, Any]] = [] for task_id in ["easy", "medium", "hard"]: print(f"━━━ Task: {task_id} ━━━") res = run_task(args.env_url, task_id, client, args.model, args.seed, args.policy) results.append(res) if res.get("completed"): print(f" → Score: {res['score']:.6f} ({res['steps']} steps)\n") else: print( f" → Score: {res['score']:.6f} ({res['steps']} steps) [INCOMPLETE: {res.get('error','')}]\n" ) # Summary print("┌──────────────────────────────────────┐") print("│ BASELINE SCORE SUMMARY │") print("├───────────┬────────────┬──────────────┬────────────┤") print("│ Task │ Score │ Steps │ Status │") print("├───────────┼────────────┼──────────────┼────────────┤") for r in results: status = "ok" if r.get("completed") else "incomplete" print(f"│ {r['task_id']:9s} │ {r['score']:.6f} │ {r['steps']:4d} │ {status:10s} │") avg = sum(r["score"] for r in results) / len(results) if results else 0.0 print("├───────────┼────────────┼──────────────┼────────────┤") print(f"│ AVERAGE │ {avg:.6f} │ │ │") print("└───────────┴────────────┴──────────────┴────────────┘") # Optionally save step log if args.save_log: skipped = 0 written = 0 with open(args.save_log, "w", encoding="utf-8") as f: for r in results: if not args.include_partial_log and not r.get("completed"): skipped += 1 continue for entry in r["log"]: payload = dict(entry) payload["task_id"] = r["task_id"] f.write(json.dumps(payload) + "\n") written += 1 print(f"\nStep log saved to: {args.save_log} ({written} rows, skipped {skipped} incomplete task logs)") if __name__ == "__main__": main()