Spaces:
Sleeping
Sleeping
| """ | |
| AuditEnv Baseline Runner β supports LLM (Groq/OpenAI) and signal-aware heuristic policies. | |
| Usage: | |
| # Heuristic (no API key needed): | |
| python scripts/run_baseline.py --policy heuristic | |
| # LLM via Groq: | |
| $env:OPENAI_API_KEY="gsk_..." | |
| python scripts/run_baseline.py --policy openai --model llama-3.3-70b-versatile | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| import textwrap | |
| from typing import Any | |
| import httpx | |
| from openai import OpenAI | |
| # --------------------------------------------------------------------------- | |
| # LLM System Prompt β much richer than before | |
| # --------------------------------------------------------------------------- | |
| SYSTEM_PROMPT = textwrap.dedent("""\ | |
| You are an expert compliance auditor AI. You are reviewing documents inside an | |
| automated audit environment. Your goal is to find genuine policy violations and | |
| fraud β but ONLY when the evidence is clear. | |
| ## CRITICAL RULES | |
| 1. **Think before acting.** For EVERY document batch, first ask yourself: | |
| "Is there a SPECIFIC fraud signal present in the text?" If no clear signal | |
| exists, output `noop`. Do NOT guess or assume fraud. | |
| 1.1 **Default to noop.** `submit_finding` is allowed only when the exact fraud | |
| signal appears in document text for this task. | |
| 2. **Confidence threshold:** If your confidence is below 0.70, output `noop`. | |
| 3. **One action per step.** Return exactly one JSON object per step. | |
| 4. **No markdown.** Return ONLY raw JSON, no fences, no explanation. | |
| ## STEP-BY-STEP REASONING (follow this every step) | |
| 1. Read every document in the batch carefully. | |
| 2. For each document, check: does it contain a known fraud signal? (see list below) | |
| 3. If YES β build a `submit_finding` with the matching document_id & violation_type. | |
| 4. If UNSURE β use `flag_human_review` with a note explaining the concern. | |
| 5. If NO signals found β use `noop`. | |
| ## KNOWN FRAUD SIGNALS | |
| Easy tasks: | |
| - "DUPLICATE_FLAG=true" or "matches_receipt=" β violation_type: "duplicate_receipt" | |
| - "alcohol_amount" exceeding "policy_limit" β violation_type: "alcohol_over_limit" | |
| - "LATE=true" or submission_date past policy_deadline β violation_type: "late_submission" | |
| Medium tasks: | |
| - "sod_conflict" or "segregation_of_duties" β violation_type: "sod_conflict" | |
| - "dormant_account" or "DORMANT=true" β violation_type: "dormant_account_reactivation" | |
| - "temporal_anomaly" or "off_hours" or "suspicious_hour" β violation_type: "temporal_anomaly" | |
| Hard tasks: | |
| - "shell_company" or "SHELL=true" β violation_type: "shell_company" | |
| - "invoice_splitting" or "split_invoice" β violation_type: "invoice_splitting" | |
| - "round_tripping" or "ROUND_TRIP=true" β violation_type: "round_tripping" | |
| ## EVIDENCE FORMAT | |
| Evidence should be a list of document IDs that support the finding. | |
| Use format: ["<flagged_doc_id>", "<neighboring_doc_id>"] | |
| ## FEW-SHOT EXAMPLES | |
| ### Example 1 β True positive (correct finding) | |
| Documents: [{"id": "e-DOC-003", "text": "...amount=$185.00; DUPLICATE_FLAG=true; matches_receipt=RCT-12345..."}] | |
| Correct output: | |
| {"action_type": "submit_finding", "document_id": "e-DOC-003", "violation_type": "duplicate_receipt", "evidence": ["e-DOC-003", "e-DOC-002"], "confidence": 0.92, "note": "DUPLICATE_FLAG=true present with matching receipt reference"} | |
| ### Example 2 β No signal present (correct noop) | |
| Documents: [{"id": "e-DOC-007", "text": "employee=Alice; amount=$45.00; expense_type=meals; description=Business lunch"}] | |
| Correct output: | |
| {"action_type": "noop", "note": "No fraud signals detected in document batch"} | |
| ### Example 3 β Incorrect false positive (DO NOT DO THIS) | |
| Documents: [{"id": "e-DOC-010", "text": "employee=Bob; amount=$200.00; expense_type=travel"}] | |
| WRONG output: {"action_type": "submit_finding", "document_id": "e-DOC-010", "violation_type": "duplicate_receipt", ...} | |
| WHY WRONG: No DUPLICATE_FLAG, no alcohol signal, no LATE signal. This is a clean document. | |
| ## ACTION FORMATS | |
| submit_finding: | |
| {"action_type": "submit_finding", "document_id": "<doc_id>", "violation_type": "<type>", "evidence": ["<doc_id>", "<neighbor_id>"], "confidence": 0.85, "note": "explanation"} | |
| flag_human_review: | |
| {"action_type": "flag_human_review", "note": "explanation of concern"} | |
| noop: | |
| {"action_type": "noop", "note": "reason no finding"} | |
| """) | |
| # --------------------------------------------------------------------------- | |
| # Signal-aware heuristic policy (Task 1.1 β FUTURE_PLAN.md) | |
| # --------------------------------------------------------------------------- | |
| _EASY_SIGNALS: list[tuple[str, str]] = [ | |
| ("DUPLICATE_FLAG=true", "duplicate_receipt"), | |
| ("is_duplicate_invoice_id=true", "duplicate_receipt"), | |
| ("duplicate_invoice_group_size", "duplicate_receipt"), | |
| ("matches_receipt=", "duplicate_receipt"), | |
| ("alcohol_amount", "alcohol_over_limit"), | |
| ("alcohol_over_limit", "alcohol_over_limit"), | |
| ("LATE=true", "late_submission"), | |
| ("policy_deadline", "late_submission"), | |
| ] | |
| _MEDIUM_SIGNALS: list[tuple[str, str]] = [ | |
| ("sod_conflict", "sod_conflict"), | |
| ("segregation_of_duties", "sod_conflict"), | |
| ("dormant_account", "dormant_account_reactivation"), | |
| ("DORMANT=true", "dormant_account_reactivation"), | |
| ("temporal_anomaly", "temporal_anomaly"), | |
| ("suspicious_hour", "temporal_anomaly"), | |
| ("off_hours", "temporal_anomaly"), | |
| ] | |
| _HARD_SIGNALS: list[tuple[str, str]] = [ | |
| ("shell_company", "shell_company"), | |
| ("SHELL=true", "shell_company"), | |
| ("vendor_registration_age_days=", "shell_company"), | |
| ("invoice_splitting", "invoice_splitting"), | |
| ("split_invoice", "invoice_splitting"), | |
| ("round_tripping", "round_tripping"), | |
| ("ROUND_TRIP=true", "round_tripping"), | |
| ] | |
| _TASK_SIGNALS: dict[str, list[tuple[str, str]]] = { | |
| "easy": _EASY_SIGNALS, | |
| "medium": _MEDIUM_SIGNALS, | |
| "hard": _HARD_SIGNALS, | |
| } | |
| _DEFAULT_VIOLATION: dict[str, str] = { | |
| "easy": "duplicate_receipt", | |
| "medium": "sod_conflict", | |
| "hard": "shell_company", | |
| } | |
| def _detect_violation(text: str, task_id: str) -> str | None: | |
| """Return the first matched violation type for the given document text.""" | |
| for signal, vtype in _TASK_SIGNALS.get(task_id, []): | |
| if signal.lower() in text.lower(): | |
| return vtype | |
| return None | |
| def _build_heuristic_action(task_id: str, observation: dict[str, Any]) -> dict[str, Any]: | |
| """Signal-aware heuristic β reads embedded fraud clues from document text. | |
| 1. Scans all visible documents for known fraud signals. | |
| 2. Falls back to noop if no signals found. | |
| """ | |
| documents = observation.get("documents", []) | |
| if not documents: | |
| return {"action_type": "noop", "task_id": task_id, "note": "no_documents"} | |
| # Scan all visible documents for fraud signals | |
| for doc in documents: | |
| doc_id = doc.get("id", "UNKNOWN") | |
| text = doc.get("text", "") | |
| vtype = _detect_violation(text, task_id) | |
| if vtype: | |
| idx = documents.index(doc) | |
| neighbor_id = documents[max(0, idx - 1)]["id"] | |
| evidence = [doc_id] if neighbor_id == doc_id else [doc_id, neighbor_id] | |
| return { | |
| "action_type": "submit_finding", | |
| "task_id": task_id, | |
| "finding": { | |
| "document_id": doc_id, | |
| "violation_type": vtype, | |
| "evidence": evidence, | |
| "confidence": 0.85, | |
| }, | |
| "note": f"signal_detected:{vtype}", | |
| } | |
| # Fallback β safe abstention when no explicit signal is present | |
| return { | |
| "action_type": "noop", | |
| "task_id": task_id, | |
| "note": "heuristic_no_signal_noop", | |
| } | |
| # --------------------------------------------------------------------------- | |
| # LLM policy β sends full document context to Groq/OpenAI | |
| # --------------------------------------------------------------------------- | |
| def _build_llm_action(task_id: str, observation: dict[str, Any], client: OpenAI, model: str) -> dict[str, Any]: | |
| """Build an action using chat completions (Groq, OpenAI, etc).""" | |
| documents = observation.get("documents", []) | |
| # Build rich document context (limit to 10 docs for token budget) | |
| doc_lines = [] | |
| for doc in documents[:10]: | |
| doc_lines.append( | |
| f" - ID: {doc.get('id', 'N/A')}, Type: {doc.get('type', 'N/A')}, " | |
| f"Text: {doc.get('text', '')[:300]}" | |
| ) | |
| docs_text = "\n".join(doc_lines) if doc_lines else " (no documents)" | |
| findings_submitted = observation.get("findings_submitted", 0) | |
| steps_remaining = observation.get("steps_remaining", "?") | |
| current_score = observation.get("current_partial_score", 0.0) | |
| user_prompt = textwrap.dedent(f"""\ | |
| Task difficulty: {task_id} | |
| Findings submitted so far: {findings_submitted} | |
| Steps remaining: {steps_remaining} | |
| Current score: {current_score:.2f} | |
| Documents to review: | |
| {docs_text} | |
| Analyze ALL documents carefully. Look for violations matching {task_id} difficulty. | |
| Return a single JSON action object. | |
| """) | |
| try: | |
| completion = client.chat.completions.create( | |
| model=model, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| temperature=0.2, | |
| max_tokens=400, | |
| ) | |
| text = (completion.choices[0].message.content or "").strip() | |
| # Strip markdown fences if present | |
| if text.startswith("```"): | |
| lines = text.split("\n") | |
| lines = [l for l in lines if not l.strip().startswith("```")] | |
| text = "\n".join(lines).strip() | |
| if not text.startswith("{"): | |
| return _build_heuristic_action(task_id, observation) | |
| payload = json.loads(text) | |
| except (json.JSONDecodeError, Exception) as exc: | |
| print(f" [WARN] LLM parse/request failed: {exc}", file=sys.stderr) | |
| return _build_heuristic_action(task_id, observation) | |
| # Sanitize the LLM output | |
| action_type = payload.get("action_type", "noop") | |
| if action_type not in {"submit_finding", "flag_human_review", "noop"}: | |
| action_type = "noop" | |
| if action_type != "submit_finding": | |
| return {"action_type": action_type, "task_id": task_id, "note": str(payload.get("note", ""))[:200]} | |
| # Build structured finding from LLM output | |
| doc_id = payload.get("document_id", documents[0]["id"] if documents else "UNKNOWN") | |
| violation_type = payload.get("violation_type", _DEFAULT_VIOLATION.get(task_id, "duplicate_receipt")) | |
| evidence = payload.get("evidence", [doc_id]) | |
| if not isinstance(evidence, list): | |
| evidence = [evidence] | |
| confidence = float(payload.get("confidence", 0.5)) | |
| confidence = max(0.0, min(1.0, confidence)) | |
| return { | |
| "action_type": "submit_finding", | |
| "task_id": task_id, | |
| "finding": { | |
| "document_id": doc_id, | |
| "violation_type": violation_type, | |
| "evidence": evidence, | |
| "confidence": confidence, | |
| }, | |
| "note": str(payload.get("note", "llm_action"))[:200], | |
| } | |
| def _safe_reward_fields(result: dict[str, Any]) -> tuple[float, str]: | |
| """Extract normalized reward and reason without raising on malformed payloads.""" | |
| reward = result.get("reward") | |
| if not isinstance(reward, dict): | |
| return 0.0, "missing_reward_payload" | |
| reason = str(reward.get("reason", "")) | |
| try: | |
| reward_norm = float(reward.get("normalized", 0.0)) | |
| except (TypeError, ValueError): | |
| return 0.0, f"{reason}|invalid_reward_value" if reason else "invalid_reward_value" | |
| return reward_norm, reason | |
| # --------------------------------------------------------------------------- | |
| # Episode runner | |
| # --------------------------------------------------------------------------- | |
| def run_task( | |
| env_url: str, | |
| task_id: str, | |
| client: OpenAI | None, | |
| model: str, | |
| seed: int, | |
| policy: str, | |
| ) -> dict[str, Any]: | |
| """Run a complete episode. Returns a result dict with score, steps, and per-step details.""" | |
| with httpx.Client(timeout=30.0) as http: | |
| try: | |
| reset_resp = http.post(f"{env_url}/reset", json={"task_id": task_id, "seed": seed}) | |
| reset_resp.raise_for_status() | |
| obs = reset_resp.json() | |
| session_id = obs.get("session_id") if isinstance(obs, dict) else None | |
| except Exception as exc: | |
| print(f" [WARN] reset failed for task={task_id}: {exc}", file=sys.stderr) | |
| return { | |
| "task_id": task_id, | |
| "score": 0.0, | |
| "steps": 0, | |
| "log": [], | |
| "completed": False, | |
| "error": f"reset_failed:{type(exc).__name__}", | |
| } | |
| total_reward = 0.0 | |
| steps = 0 | |
| step_log: list[dict[str, Any]] = [] | |
| done = False | |
| hard_step_cap = 40 | |
| if isinstance(obs, dict): | |
| raw_cap = obs.get("steps_remaining") | |
| if isinstance(raw_cap, int): | |
| # Keep a bounded safety margin while allowing full hard episodes to finish. | |
| hard_step_cap = max(8, min(64, raw_cap + 4)) | |
| task_error = "" | |
| while not done and steps < hard_step_cap: | |
| if policy == "heuristic": | |
| action = _build_heuristic_action(task_id=task_id, observation=obs) | |
| else: | |
| if client is None: | |
| raise RuntimeError("OPENAI_API_KEY is required for policy=openai") | |
| action = _build_llm_action(task_id=task_id, observation=obs, client=client, model=model) | |
| if session_id and "session_id" not in action: | |
| action["session_id"] = session_id | |
| try: | |
| step_resp = http.post(f"{env_url}/step", json=action) | |
| step_resp.raise_for_status() | |
| result = step_resp.json() | |
| except Exception as exc: | |
| task_error = f"step_failed:{type(exc).__name__}" | |
| step_log.append( | |
| { | |
| "step": steps + 1, | |
| "action_type": action.get("action_type"), | |
| "reward_norm": 0.0, | |
| "reward_reason": task_error, | |
| "done": False, | |
| } | |
| ) | |
| print(f" [WARN] step failed for task={task_id}: {exc}", file=sys.stderr) | |
| break | |
| reward_norm, reward_reason = _safe_reward_fields(result) | |
| total_reward += reward_norm | |
| steps += 1 | |
| done = bool(result.get("done", False)) | |
| obs = result.get("observation", obs) | |
| # Log this step | |
| entry = { | |
| "step": steps, | |
| "action_type": action.get("action_type"), | |
| "reward_norm": reward_norm, | |
| "reward_reason": reward_reason, | |
| "done": done, | |
| } | |
| if action.get("finding"): | |
| entry["doc_id"] = action["finding"]["document_id"] | |
| entry["violation"] = action["finding"]["violation_type"] | |
| step_log.append(entry) | |
| print( | |
| f" Step {steps:2d} β {action.get('action_type'):18s} β " | |
| f"reward={reward_norm:.3f} β reason={reward_reason} β " | |
| f"done={done}" | |
| ) | |
| if not done and not task_error and steps >= hard_step_cap: | |
| task_error = "max_steps_reached" | |
| mean_score = round(total_reward / steps, 6) if steps else 0.0 | |
| return { | |
| "task_id": task_id, | |
| "score": mean_score, | |
| "steps": steps, | |
| "log": step_log, | |
| "completed": done, | |
| "error": task_error, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Run reproducible baseline scores on all AuditEnv tasks.") | |
| parser.add_argument("--env-url", default=os.getenv("AUDITENV_BASE_URL", "http://127.0.0.1:8000")) | |
| parser.add_argument("--model", default=os.getenv("AUDITENV_BASELINE_MODEL", "llama-3.3-70b-versatile")) | |
| parser.add_argument("--base-url", default=os.getenv("OPENAI_BASE_URL", "https://api.groq.com/openai/v1")) | |
| parser.add_argument( | |
| "--policy", | |
| choices=["openai", "heuristic"], | |
| default="heuristic", | |
| help="Action policy: 'openai' uses Groq/OpenAI API, 'heuristic' is free local fallback.", | |
| ) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--save-log", default="", help="Path to save per-step JSONL log.") | |
| parser.add_argument( | |
| "--include-partial-log", | |
| action="store_true", | |
| help="Include incomplete task episodes in --save-log output.", | |
| ) | |
| args = parser.parse_args() | |
| print(f"ββββββββββββββββββββββββββββββββββββββββββββββββ") | |
| print(f"β AuditEnv Baseline Runner β") | |
| print(f"β Policy: {args.policy:10s} Seed: {args.seed:<10d} β") | |
| print(f"ββββββββββββββββββββββββββββββββββββββββββββββββ") | |
| client: OpenAI | None = None | |
| if args.policy == "openai": | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| print("ERROR: Set OPENAI_API_KEY env var for --policy openai", file=sys.stderr) | |
| sys.exit(1) | |
| client = OpenAI(api_key=api_key, base_url=args.base_url) | |
| print(f" Model: {args.model}") | |
| print(f" API: {args.base_url}") | |
| print() | |
| results: list[dict[str, Any]] = [] | |
| for task_id in ["easy", "medium", "hard"]: | |
| print(f"βββ Task: {task_id} βββ") | |
| res = run_task(args.env_url, task_id, client, args.model, args.seed, args.policy) | |
| results.append(res) | |
| if res.get("completed"): | |
| print(f" β Score: {res['score']:.6f} ({res['steps']} steps)\n") | |
| else: | |
| print( | |
| f" β Score: {res['score']:.6f} ({res['steps']} steps) [INCOMPLETE: {res.get('error','')}]\n" | |
| ) | |
| # Summary | |
| print("ββββββββββββββββββββββββββββββββββββββββ") | |
| print("β BASELINE SCORE SUMMARY β") | |
| print("βββββββββββββ¬βββββββββββββ¬βββββββββββββββ¬βββββββββββββ€") | |
| print("β Task β Score β Steps β Status β") | |
| print("βββββββββββββΌβββββββββββββΌβββββββββββββββΌβββββββββββββ€") | |
| for r in results: | |
| status = "ok" if r.get("completed") else "incomplete" | |
| print(f"β {r['task_id']:9s} β {r['score']:.6f} β {r['steps']:4d} β {status:10s} β") | |
| avg = sum(r["score"] for r in results) / len(results) if results else 0.0 | |
| print("βββββββββββββΌβββββββββββββΌβββββββββββββββΌβββββββββββββ€") | |
| print(f"β AVERAGE β {avg:.6f} β β β") | |
| print("βββββββββββββ΄βββββββββββββ΄βββββββββββββββ΄βββββββββββββ") | |
| # Optionally save step log | |
| if args.save_log: | |
| skipped = 0 | |
| written = 0 | |
| with open(args.save_log, "w", encoding="utf-8") as f: | |
| for r in results: | |
| if not args.include_partial_log and not r.get("completed"): | |
| skipped += 1 | |
| continue | |
| for entry in r["log"]: | |
| payload = dict(entry) | |
| payload["task_id"] = r["task_id"] | |
| f.write(json.dumps(payload) + "\n") | |
| written += 1 | |
| print(f"\nStep log saved to: {args.save_log} ({written} rows, skipped {skipped} incomplete task logs)") | |
| if __name__ == "__main__": | |
| main() | |