Corp_AI_Auditor / scripts /run_baseline.py
minato1718's picture
Fix: Serve Gradio UI + lightweight runtime deps
e882ca4 verified
"""
AuditEnv Baseline Runner β€” supports LLM (Groq/OpenAI) and signal-aware heuristic policies.
Usage:
# Heuristic (no API key needed):
python scripts/run_baseline.py --policy heuristic
# LLM via Groq:
$env:OPENAI_API_KEY="gsk_..."
python scripts/run_baseline.py --policy openai --model llama-3.3-70b-versatile
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import textwrap
from typing import Any
import httpx
from openai import OpenAI
# ---------------------------------------------------------------------------
# LLM System Prompt β€” much richer than before
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = textwrap.dedent("""\
You are an expert compliance auditor AI. You are reviewing documents inside an
automated audit environment. Your goal is to find genuine policy violations and
fraud β€” but ONLY when the evidence is clear.
## CRITICAL RULES
1. **Think before acting.** For EVERY document batch, first ask yourself:
"Is there a SPECIFIC fraud signal present in the text?" If no clear signal
exists, output `noop`. Do NOT guess or assume fraud.
1.1 **Default to noop.** `submit_finding` is allowed only when the exact fraud
signal appears in document text for this task.
2. **Confidence threshold:** If your confidence is below 0.70, output `noop`.
3. **One action per step.** Return exactly one JSON object per step.
4. **No markdown.** Return ONLY raw JSON, no fences, no explanation.
## STEP-BY-STEP REASONING (follow this every step)
1. Read every document in the batch carefully.
2. For each document, check: does it contain a known fraud signal? (see list below)
3. If YES β†’ build a `submit_finding` with the matching document_id & violation_type.
4. If UNSURE β†’ use `flag_human_review` with a note explaining the concern.
5. If NO signals found β†’ use `noop`.
## KNOWN FRAUD SIGNALS
Easy tasks:
- "DUPLICATE_FLAG=true" or "matches_receipt=" β†’ violation_type: "duplicate_receipt"
- "alcohol_amount" exceeding "policy_limit" β†’ violation_type: "alcohol_over_limit"
- "LATE=true" or submission_date past policy_deadline β†’ violation_type: "late_submission"
Medium tasks:
- "sod_conflict" or "segregation_of_duties" β†’ violation_type: "sod_conflict"
- "dormant_account" or "DORMANT=true" β†’ violation_type: "dormant_account_reactivation"
- "temporal_anomaly" or "off_hours" or "suspicious_hour" β†’ violation_type: "temporal_anomaly"
Hard tasks:
- "shell_company" or "SHELL=true" β†’ violation_type: "shell_company"
- "invoice_splitting" or "split_invoice" β†’ violation_type: "invoice_splitting"
- "round_tripping" or "ROUND_TRIP=true" β†’ violation_type: "round_tripping"
## EVIDENCE FORMAT
Evidence should be a list of document IDs that support the finding.
Use format: ["<flagged_doc_id>", "<neighboring_doc_id>"]
## FEW-SHOT EXAMPLES
### Example 1 β€” True positive (correct finding)
Documents: [{"id": "e-DOC-003", "text": "...amount=$185.00; DUPLICATE_FLAG=true; matches_receipt=RCT-12345..."}]
Correct output:
{"action_type": "submit_finding", "document_id": "e-DOC-003", "violation_type": "duplicate_receipt", "evidence": ["e-DOC-003", "e-DOC-002"], "confidence": 0.92, "note": "DUPLICATE_FLAG=true present with matching receipt reference"}
### Example 2 β€” No signal present (correct noop)
Documents: [{"id": "e-DOC-007", "text": "employee=Alice; amount=$45.00; expense_type=meals; description=Business lunch"}]
Correct output:
{"action_type": "noop", "note": "No fraud signals detected in document batch"}
### Example 3 β€” Incorrect false positive (DO NOT DO THIS)
Documents: [{"id": "e-DOC-010", "text": "employee=Bob; amount=$200.00; expense_type=travel"}]
WRONG output: {"action_type": "submit_finding", "document_id": "e-DOC-010", "violation_type": "duplicate_receipt", ...}
WHY WRONG: No DUPLICATE_FLAG, no alcohol signal, no LATE signal. This is a clean document.
## ACTION FORMATS
submit_finding:
{"action_type": "submit_finding", "document_id": "<doc_id>", "violation_type": "<type>", "evidence": ["<doc_id>", "<neighbor_id>"], "confidence": 0.85, "note": "explanation"}
flag_human_review:
{"action_type": "flag_human_review", "note": "explanation of concern"}
noop:
{"action_type": "noop", "note": "reason no finding"}
""")
# ---------------------------------------------------------------------------
# Signal-aware heuristic policy (Task 1.1 β€” FUTURE_PLAN.md)
# ---------------------------------------------------------------------------
_EASY_SIGNALS: list[tuple[str, str]] = [
("DUPLICATE_FLAG=true", "duplicate_receipt"),
("is_duplicate_invoice_id=true", "duplicate_receipt"),
("duplicate_invoice_group_size", "duplicate_receipt"),
("matches_receipt=", "duplicate_receipt"),
("alcohol_amount", "alcohol_over_limit"),
("alcohol_over_limit", "alcohol_over_limit"),
("LATE=true", "late_submission"),
("policy_deadline", "late_submission"),
]
_MEDIUM_SIGNALS: list[tuple[str, str]] = [
("sod_conflict", "sod_conflict"),
("segregation_of_duties", "sod_conflict"),
("dormant_account", "dormant_account_reactivation"),
("DORMANT=true", "dormant_account_reactivation"),
("temporal_anomaly", "temporal_anomaly"),
("suspicious_hour", "temporal_anomaly"),
("off_hours", "temporal_anomaly"),
]
_HARD_SIGNALS: list[tuple[str, str]] = [
("shell_company", "shell_company"),
("SHELL=true", "shell_company"),
("vendor_registration_age_days=", "shell_company"),
("invoice_splitting", "invoice_splitting"),
("split_invoice", "invoice_splitting"),
("round_tripping", "round_tripping"),
("ROUND_TRIP=true", "round_tripping"),
]
_TASK_SIGNALS: dict[str, list[tuple[str, str]]] = {
"easy": _EASY_SIGNALS,
"medium": _MEDIUM_SIGNALS,
"hard": _HARD_SIGNALS,
}
_DEFAULT_VIOLATION: dict[str, str] = {
"easy": "duplicate_receipt",
"medium": "sod_conflict",
"hard": "shell_company",
}
def _detect_violation(text: str, task_id: str) -> str | None:
"""Return the first matched violation type for the given document text."""
for signal, vtype in _TASK_SIGNALS.get(task_id, []):
if signal.lower() in text.lower():
return vtype
return None
def _build_heuristic_action(task_id: str, observation: dict[str, Any]) -> dict[str, Any]:
"""Signal-aware heuristic β€” reads embedded fraud clues from document text.
1. Scans all visible documents for known fraud signals.
2. Falls back to noop if no signals found.
"""
documents = observation.get("documents", [])
if not documents:
return {"action_type": "noop", "task_id": task_id, "note": "no_documents"}
# Scan all visible documents for fraud signals
for doc in documents:
doc_id = doc.get("id", "UNKNOWN")
text = doc.get("text", "")
vtype = _detect_violation(text, task_id)
if vtype:
idx = documents.index(doc)
neighbor_id = documents[max(0, idx - 1)]["id"]
evidence = [doc_id] if neighbor_id == doc_id else [doc_id, neighbor_id]
return {
"action_type": "submit_finding",
"task_id": task_id,
"finding": {
"document_id": doc_id,
"violation_type": vtype,
"evidence": evidence,
"confidence": 0.85,
},
"note": f"signal_detected:{vtype}",
}
# Fallback β€” safe abstention when no explicit signal is present
return {
"action_type": "noop",
"task_id": task_id,
"note": "heuristic_no_signal_noop",
}
# ---------------------------------------------------------------------------
# LLM policy β€” sends full document context to Groq/OpenAI
# ---------------------------------------------------------------------------
def _build_llm_action(task_id: str, observation: dict[str, Any], client: OpenAI, model: str) -> dict[str, Any]:
"""Build an action using chat completions (Groq, OpenAI, etc)."""
documents = observation.get("documents", [])
# Build rich document context (limit to 10 docs for token budget)
doc_lines = []
for doc in documents[:10]:
doc_lines.append(
f" - ID: {doc.get('id', 'N/A')}, Type: {doc.get('type', 'N/A')}, "
f"Text: {doc.get('text', '')[:300]}"
)
docs_text = "\n".join(doc_lines) if doc_lines else " (no documents)"
findings_submitted = observation.get("findings_submitted", 0)
steps_remaining = observation.get("steps_remaining", "?")
current_score = observation.get("current_partial_score", 0.0)
user_prompt = textwrap.dedent(f"""\
Task difficulty: {task_id}
Findings submitted so far: {findings_submitted}
Steps remaining: {steps_remaining}
Current score: {current_score:.2f}
Documents to review:
{docs_text}
Analyze ALL documents carefully. Look for violations matching {task_id} difficulty.
Return a single JSON action object.
""")
try:
completion = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
temperature=0.2,
max_tokens=400,
)
text = (completion.choices[0].message.content or "").strip()
# Strip markdown fences if present
if text.startswith("```"):
lines = text.split("\n")
lines = [l for l in lines if not l.strip().startswith("```")]
text = "\n".join(lines).strip()
if not text.startswith("{"):
return _build_heuristic_action(task_id, observation)
payload = json.loads(text)
except (json.JSONDecodeError, Exception) as exc:
print(f" [WARN] LLM parse/request failed: {exc}", file=sys.stderr)
return _build_heuristic_action(task_id, observation)
# Sanitize the LLM output
action_type = payload.get("action_type", "noop")
if action_type not in {"submit_finding", "flag_human_review", "noop"}:
action_type = "noop"
if action_type != "submit_finding":
return {"action_type": action_type, "task_id": task_id, "note": str(payload.get("note", ""))[:200]}
# Build structured finding from LLM output
doc_id = payload.get("document_id", documents[0]["id"] if documents else "UNKNOWN")
violation_type = payload.get("violation_type", _DEFAULT_VIOLATION.get(task_id, "duplicate_receipt"))
evidence = payload.get("evidence", [doc_id])
if not isinstance(evidence, list):
evidence = [evidence]
confidence = float(payload.get("confidence", 0.5))
confidence = max(0.0, min(1.0, confidence))
return {
"action_type": "submit_finding",
"task_id": task_id,
"finding": {
"document_id": doc_id,
"violation_type": violation_type,
"evidence": evidence,
"confidence": confidence,
},
"note": str(payload.get("note", "llm_action"))[:200],
}
def _safe_reward_fields(result: dict[str, Any]) -> tuple[float, str]:
"""Extract normalized reward and reason without raising on malformed payloads."""
reward = result.get("reward")
if not isinstance(reward, dict):
return 0.0, "missing_reward_payload"
reason = str(reward.get("reason", ""))
try:
reward_norm = float(reward.get("normalized", 0.0))
except (TypeError, ValueError):
return 0.0, f"{reason}|invalid_reward_value" if reason else "invalid_reward_value"
return reward_norm, reason
# ---------------------------------------------------------------------------
# Episode runner
# ---------------------------------------------------------------------------
def run_task(
env_url: str,
task_id: str,
client: OpenAI | None,
model: str,
seed: int,
policy: str,
) -> dict[str, Any]:
"""Run a complete episode. Returns a result dict with score, steps, and per-step details."""
with httpx.Client(timeout=30.0) as http:
try:
reset_resp = http.post(f"{env_url}/reset", json={"task_id": task_id, "seed": seed})
reset_resp.raise_for_status()
obs = reset_resp.json()
session_id = obs.get("session_id") if isinstance(obs, dict) else None
except Exception as exc:
print(f" [WARN] reset failed for task={task_id}: {exc}", file=sys.stderr)
return {
"task_id": task_id,
"score": 0.0,
"steps": 0,
"log": [],
"completed": False,
"error": f"reset_failed:{type(exc).__name__}",
}
total_reward = 0.0
steps = 0
step_log: list[dict[str, Any]] = []
done = False
hard_step_cap = 40
if isinstance(obs, dict):
raw_cap = obs.get("steps_remaining")
if isinstance(raw_cap, int):
# Keep a bounded safety margin while allowing full hard episodes to finish.
hard_step_cap = max(8, min(64, raw_cap + 4))
task_error = ""
while not done and steps < hard_step_cap:
if policy == "heuristic":
action = _build_heuristic_action(task_id=task_id, observation=obs)
else:
if client is None:
raise RuntimeError("OPENAI_API_KEY is required for policy=openai")
action = _build_llm_action(task_id=task_id, observation=obs, client=client, model=model)
if session_id and "session_id" not in action:
action["session_id"] = session_id
try:
step_resp = http.post(f"{env_url}/step", json=action)
step_resp.raise_for_status()
result = step_resp.json()
except Exception as exc:
task_error = f"step_failed:{type(exc).__name__}"
step_log.append(
{
"step": steps + 1,
"action_type": action.get("action_type"),
"reward_norm": 0.0,
"reward_reason": task_error,
"done": False,
}
)
print(f" [WARN] step failed for task={task_id}: {exc}", file=sys.stderr)
break
reward_norm, reward_reason = _safe_reward_fields(result)
total_reward += reward_norm
steps += 1
done = bool(result.get("done", False))
obs = result.get("observation", obs)
# Log this step
entry = {
"step": steps,
"action_type": action.get("action_type"),
"reward_norm": reward_norm,
"reward_reason": reward_reason,
"done": done,
}
if action.get("finding"):
entry["doc_id"] = action["finding"]["document_id"]
entry["violation"] = action["finding"]["violation_type"]
step_log.append(entry)
print(
f" Step {steps:2d} β”‚ {action.get('action_type'):18s} β”‚ "
f"reward={reward_norm:.3f} β”‚ reason={reward_reason} β”‚ "
f"done={done}"
)
if not done and not task_error and steps >= hard_step_cap:
task_error = "max_steps_reached"
mean_score = round(total_reward / steps, 6) if steps else 0.0
return {
"task_id": task_id,
"score": mean_score,
"steps": steps,
"log": step_log,
"completed": done,
"error": task_error,
}
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
parser = argparse.ArgumentParser(description="Run reproducible baseline scores on all AuditEnv tasks.")
parser.add_argument("--env-url", default=os.getenv("AUDITENV_BASE_URL", "http://127.0.0.1:8000"))
parser.add_argument("--model", default=os.getenv("AUDITENV_BASELINE_MODEL", "llama-3.3-70b-versatile"))
parser.add_argument("--base-url", default=os.getenv("OPENAI_BASE_URL", "https://api.groq.com/openai/v1"))
parser.add_argument(
"--policy",
choices=["openai", "heuristic"],
default="heuristic",
help="Action policy: 'openai' uses Groq/OpenAI API, 'heuristic' is free local fallback.",
)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--save-log", default="", help="Path to save per-step JSONL log.")
parser.add_argument(
"--include-partial-log",
action="store_true",
help="Include incomplete task episodes in --save-log output.",
)
args = parser.parse_args()
print(f"╔══════════════════════════════════════════════╗")
print(f"β•‘ AuditEnv Baseline Runner β•‘")
print(f"β•‘ Policy: {args.policy:10s} Seed: {args.seed:<10d} β•‘")
print(f"β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•")
client: OpenAI | None = None
if args.policy == "openai":
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
print("ERROR: Set OPENAI_API_KEY env var for --policy openai", file=sys.stderr)
sys.exit(1)
client = OpenAI(api_key=api_key, base_url=args.base_url)
print(f" Model: {args.model}")
print(f" API: {args.base_url}")
print()
results: list[dict[str, Any]] = []
for task_id in ["easy", "medium", "hard"]:
print(f"━━━ Task: {task_id} ━━━")
res = run_task(args.env_url, task_id, client, args.model, args.seed, args.policy)
results.append(res)
if res.get("completed"):
print(f" β†’ Score: {res['score']:.6f} ({res['steps']} steps)\n")
else:
print(
f" β†’ Score: {res['score']:.6f} ({res['steps']} steps) [INCOMPLETE: {res.get('error','')}]\n"
)
# Summary
print("β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”")
print("β”‚ BASELINE SCORE SUMMARY β”‚")
print("β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€")
print("β”‚ Task β”‚ Score β”‚ Steps β”‚ Status β”‚")
print("β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€")
for r in results:
status = "ok" if r.get("completed") else "incomplete"
print(f"β”‚ {r['task_id']:9s} β”‚ {r['score']:.6f} β”‚ {r['steps']:4d} β”‚ {status:10s} β”‚")
avg = sum(r["score"] for r in results) / len(results) if results else 0.0
print("β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€")
print(f"β”‚ AVERAGE β”‚ {avg:.6f} β”‚ β”‚ β”‚")
print("β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜")
# Optionally save step log
if args.save_log:
skipped = 0
written = 0
with open(args.save_log, "w", encoding="utf-8") as f:
for r in results:
if not args.include_partial_log and not r.get("completed"):
skipped += 1
continue
for entry in r["log"]:
payload = dict(entry)
payload["task_id"] = r["task_id"]
f.write(json.dumps(payload) + "\n")
written += 1
print(f"\nStep log saved to: {args.save_log} ({written} rows, skipped {skipped} incomplete task logs)")
if __name__ == "__main__":
main()