Corp_AI / inference.py
Arpit Deep
feat: initial AuditEnv submission
a617acd
"""
AuditEnv — Hackathon inference entrypoint.
Mandatory environment variables:
- API_BASE_URL : The API endpoint for the LLM.
- MODEL_NAME : The model identifier to use for inference.
- HF_TOKEN : Your Hugging Face / API key.
This script emits only three stdout line types, in order:
[START], [STEP], [END]
"""
from __future__ import annotations
import json
import os
import textwrap
from typing import Any, Dict, List, Optional
import httpx
from openai import OpenAI
# ---------------------------------------------------------------------------
# Environment variables (hackathon mandatory)
# ---------------------------------------------------------------------------
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "")
# Environment server URL (where AuditEnv FastAPI is running)
ENV_BASE_URL = os.getenv("AUDITENV_BASE_URL", "http://127.0.0.1:8000")
# Inference config
TASK_IDS = ["easy", "medium", "hard"]
SEED = 42
MAX_STEPS_MAP = {"easy": 12, "medium": 20, "hard": 28}
TEMPERATURE = 0.3
MAX_TOKENS = 400
BENCHMARK = "auditenv"
# ---------------------------------------------------------------------------
# Logging helpers — strict [START], [STEP], [END] format
# ---------------------------------------------------------------------------
def _bool_str(value: bool) -> str:
return "true" if value else "false"
def _single_line(value: str) -> str:
return " ".join(str(value).splitlines()).strip()
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = _single_line(error) if error else "null"
action_val = _single_line(action)
print(
f"[STEP] step={step} action={action_val} reward={reward:.2f} "
f"done={_bool_str(done)} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={_bool_str(success)} steps={steps} score={score:.2f} rewards={rewards_str}",
flush=True,
)
# ---------------------------------------------------------------------------
# System prompt for the audit agent
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = textwrap.dedent("""\
You are an expert compliance auditor AI agent. You are reviewing documents in a
simulated audit environment. Your job is to identify compliance violations,
fraud, and policy breaches.
For each step, you will receive a list of documents. Analyze them and decide on
one of three actions:
1. **submit_finding** — You found a violation. Return JSON with:
- "action_type": "submit_finding"
- "document_id": the ID of the suspicious document
- "violation_type": the type of violation you detected
- "evidence": list of document IDs as supporting evidence
- "confidence": float between 0.0 and 1.0
- "note": brief explanation
2. **flag_human_review** — You want a human to look at something. Return JSON with:
- "action_type": "flag_human_review"
- "note": explanation of concern
3. **noop** — Nothing suspicious in the current batch. Return JSON with:
- "action_type": "noop"
- "note": reason for no finding
VIOLATION TYPES by task difficulty:
- Easy: "duplicate_receipt", "alcohol_over_limit", "late_submission"
- Medium: "sod_conflict", "dormant_account_reactivation", "temporal_anomaly"
- Hard: "shell_company", "invoice_splitting", "round_tripping"
IMPORTANT: Respond with ONLY a valid JSON object. No markdown, no explanation.
""")
# ---------------------------------------------------------------------------
# LLM interaction
# ---------------------------------------------------------------------------
def build_user_prompt(task_id: str, step: int, observation: Dict[str, Any], history: List[str]) -> str:
docs = observation.get("documents", [])
docs_text = ""
for doc in docs[:10]: # limit to 10 docs for context window
docs_text += f" - ID: {doc.get('id', 'N/A')}, Type: {doc.get('type', 'N/A')}, Text: {doc.get('text', '')[:200]}\n"
findings_submitted = observation.get("findings_submitted", 0)
steps_remaining = observation.get("steps_remaining", 0)
current_score = observation.get("current_partial_score", 0.0)
history_block = "\n".join(history[-5:]) if history else "None"
return textwrap.dedent(f"""\
Task: {task_id} (Step {step})
Findings submitted so far: {findings_submitted}
Steps remaining: {steps_remaining}
Current partial score: {current_score:.2f}
Documents to review:
{docs_text}
Recent history:
{history_block}
Analyze the documents and return a JSON action. Look for violations relevant to this task difficulty.
""")
def get_model_action(
client: OpenAI,
task_id: str,
step: int,
observation: Dict[str, Any],
history: List[str],
) -> Dict[str, Any]:
user_prompt = build_user_prompt(task_id, step, observation, history)
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
stream=False,
)
text = (completion.choices[0].message.content or "").strip()
# Strip markdown fences if present
if text.startswith("```"):
lines = text.split("\n")
lines = [l for l in lines if not l.strip().startswith("```")]
text = "\n".join(lines).strip()
payload = json.loads(text)
return _build_action_from_llm(task_id, payload, observation)
except (json.JSONDecodeError, Exception) as exc:
print(f"[DEBUG] Model parse/request failed: {exc}", flush=True)
return _build_heuristic_action(task_id, observation)
def _build_action_from_llm(task_id: str, payload: Dict[str, Any], observation: Dict[str, Any]) -> Dict[str, Any]:
action_type = payload.get("action_type", "noop")
if action_type not in {"submit_finding", "flag_human_review", "noop"}:
action_type = "noop"
action: Dict[str, Any] = {
"action_type": action_type,
"task_id": task_id,
"note": str(payload.get("note", ""))[:200],
}
if action_type == "submit_finding":
documents = observation.get("documents", [])
doc_id = payload.get("document_id", documents[0]["id"] if documents else "UNKNOWN")
violation_type = payload.get("violation_type", "duplicate_receipt")
evidence = payload.get("evidence", [doc_id])
confidence = float(payload.get("confidence", 0.5))
confidence = max(0.0, min(1.0, confidence))
action["finding"] = {
"document_id": doc_id,
"violation_type": violation_type,
"evidence": evidence if isinstance(evidence, list) else [evidence],
"confidence": confidence,
}
return action
def _build_heuristic_action(task_id: str, observation: Dict[str, Any]) -> Dict[str, Any]:
"""Fallback heuristic policy when LLM call fails."""
documents = observation.get("documents", [])
doc_id = documents[0]["id"] if documents else "UNKNOWN"
violation_map = {
"easy": "duplicate_receipt",
"medium": "sod_conflict",
"hard": "shell_company",
}
return {
"action_type": "submit_finding",
"task_id": task_id,
"finding": {
"document_id": doc_id,
"violation_type": violation_map.get(task_id, "duplicate_receipt"),
"evidence": [doc_id],
"confidence": 0.5,
},
"note": "heuristic_fallback",
}
# ---------------------------------------------------------------------------
# Main inference loop
# ---------------------------------------------------------------------------
def run_task(task_id: str, client: OpenAI, http: httpx.Client) -> tuple[float, bool]:
"""Run a single task and return (score, success)."""
max_steps = MAX_STEPS_MAP.get(task_id, 12)
history: List[str] = []
rewards: List[float] = []
steps_taken = 0
score = 0.0
success = False
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
try:
# Reset environment
reset_resp = http.post(f"{ENV_BASE_URL}/reset", json={"task_id": task_id, "seed": SEED})
reset_resp.raise_for_status()
observation = reset_resp.json()
done = False
for step in range(1, max_steps + 1):
if done:
break
# Get action from LLM
action = get_model_action(client, task_id, step, observation, history)
action_summary = f"{action['action_type']}"
if action.get("finding"):
action_summary += f"({action['finding']['document_id']}:{action['finding']['violation_type']})"
# Step the environment
step_resp = http.post(f"{ENV_BASE_URL}/step", json=action)
step_resp.raise_for_status()
result = step_resp.json()
# Extract results
reward_obj = result.get("reward", {})
reward = float(reward_obj.get("normalized", 0.0)) if isinstance(reward_obj, dict) else 0.0
done = bool(result.get("done", False))
observation = result.get("observation", {})
error = result.get("info", {}).get("reason", None)
rewards.append(reward)
steps_taken = step
log_step(step=step, action=action_summary, reward=reward, done=done, error=error)
history.append(f"Step {step}: {action_summary} -> reward {reward:.2f} reason={error}")
# Compute final score
score = sum(rewards) / len(rewards) if rewards else 0.0
score = min(max(score, 0.0), 1.0)
success = score >= 0.3
except Exception as exc:
print(f"[DEBUG] Task {task_id} failed: {exc}", flush=True)
success = False
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
return score, success
def main() -> None:
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
with httpx.Client(timeout=30.0) as http:
# Verify environment is running
try:
health = http.get(f"{ENV_BASE_URL}/health")
health.raise_for_status()
print(f"[DEBUG] Environment healthy: {health.json()}", flush=True)
except Exception as exc:
print(f"[DEBUG] Environment health check failed: {exc}", flush=True)
print("[DEBUG] Make sure the AuditEnv server is running.", flush=True)
return
all_scores = {}
for task_id in TASK_IDS:
score, success = run_task(task_id, client, http)
all_scores[task_id] = score
print(f"[DEBUG] Task {task_id}: score={score:.4f} success={success}", flush=True)
print("\n--- Final Scores ---", flush=True)
for tid, sc in all_scores.items():
print(f" {tid}: {sc:.4f}", flush=True)
if __name__ == "__main__":
main()