financial_audit_env / inference.py
balloonmann's picture
Round 2 Implementation: Multi-period campaign, regulatory shocks, adversarial grading, and GRPO training infrastructure
52f5c27
#!/usr/bin/env python3
"""
Financial Audit Environment — Inference Script
==============================================
MANDATORY CONSTRAINTS MET:
- Uses API_BASE_URL, MODEL_NAME, and HF_TOKEN from environment.
- Placed in the root directory of the project.
- Uses the standard OpenAI Client for all LLM calls.
- Emits structured stdout logs: [START], [STEP], [END]
Usage:
export API_BASE_URL=https://router.huggingface.co/v1/
export MODEL_NAME=meta-llama/Llama-3.1-8B-Instruct
export HF_TOKEN=your_token_here
python inference.py --env-url http://localhost:8000
"""
import argparse
import json
import logging
import os
import re
import sys
import time
from typing import Any, Dict, List, Optional
import requests
try:
from openai import OpenAI
except ImportError:
print("ERROR: 'openai' package not installed. Run: pip install openai>=1.0.0", file=sys.stderr)
sys.exit(1)
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass # dotenv is optional — env vars can be set directly
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("inference")
# ---------------------------------------------------------------------------
# Configuration from Environment Variables (MANDATORY)
# ---------------------------------------------------------------------------
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1/")
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN is None:
raise ValueError("HF_TOKEN environment variable is required")
# Round-1 default: run the required 3 tasks (easy, medium, hard).
TASK_IDS = ["expense_audit", "invoice_match", "gst_reconciliation"]
SEED = 42
BENCHMARK = "financial_audit_env"
TEMPERATURE = 0.1
MAX_TOKENS = 4096
MAX_STEPS = 5 # Max steps per task (matches our env config)
SUCCESS_SCORE_THRESHOLD = 0.5
# ---------------------------------------------------------------------------
# MANDATORY Structured Logging: [START], [STEP], [END]
# ---------------------------------------------------------------------------
def strict_unit_interval(value: Any, default: float = 0.01) -> float:
"""Return a finite float constrained to a stable open interval."""
try:
num = float(value)
except (TypeError, ValueError):
num = default
# NaN check that works without importing math.
if num != num:
num = default
if num <= 0.01:
return 0.01
if num >= 0.99:
return 0.99
return num
def log_start(task: str, env: str, model: str) -> None:
"""Emit the [START] structured log line."""
print(
f"[START] task={task} env={env} model={model}",
flush=True,
)
def log_step(
step: int,
action: str,
reward: float,
done: bool,
error: Optional[str] = None,
) -> None:
"""Emit the [STEP] structured log line."""
error_str = error if error else "null"
done_str = "true" if done else "false"
clamped_reward = strict_unit_interval(reward)
print(
f"[STEP] step={step} action={action} reward={clamped_reward:.6f} done={done_str} error={error_str}",
flush=True,
)
def log_end(
success: bool,
steps: int,
score: float,
rewards: List[float],
) -> None:
"""Emit the [END] structured log line."""
success_str = "true" if success else "false"
clamped_score = strict_unit_interval(score)
safe_rewards = rewards if rewards else [clamped_score]
rewards_str = ",".join([f"{strict_unit_interval(r):.6f}" for r in safe_rewards])
print(
f"[END] success={success_str} steps={steps} score={clamped_score:.6f} rewards={rewards_str}",
flush=True,
)
# ---------------------------------------------------------------------------
# LLM Prompts & Parsing
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = """You are an expert financial auditor AI. You review financial documents and identify errors, violations, and discrepancies with extreme precision.
You MUST respond ONLY with a valid JSON array of findings. Each finding must have:
- "document_id": the ID of the document with the error
- "error_type": one of the allowed error types listed below
- "description": a clear explanation of the error
- "suggested_fix": what should be done to fix it
Do NOT include any text before or after the JSON array. Do NOT use markdown code blocks. Just output the raw JSON array."""
def build_task_prompt(task_description: str, documents: Dict[str, Any], error_types: List[str]) -> str:
prompt_parts = [
"# TASK",
task_description,
"",
"# ALLOWED ERROR TYPES",
json.dumps(error_types, indent=2),
"",
"# DOCUMENTS TO AUDIT",
]
for doc_type, doc_data in documents.items():
prompt_parts.append(f"\n## {doc_type.upper()}")
if isinstance(doc_data, list):
for i, row in enumerate(doc_data):
prompt_parts.append(f"Row {i+1}: {json.dumps(row)}")
elif isinstance(doc_data, dict):
prompt_parts.append(json.dumps(doc_data, indent=2))
prompt_parts.extend([
"",
"# INSTRUCTIONS",
"Analyze ALL documents carefully. Find EVERY error/violation/discrepancy.",
"Respond with ONLY a JSON array of findings. No other text.",
'Example: [{"document_id": "EXP-001", "error_type": "over_limit", "description": "...", "suggested_fix": "..."}]',
])
return "\n".join(prompt_parts)
def parse_llm_findings(response_text: str) -> List[Dict[str, str]]:
text = response_text.strip()
if text.startswith("```"):
lines = text.split("\n")
text = "\n".join(lines[1:-1]) # Strip opening/closing backticks
if text.endswith("```"):
text = text[:-3].strip()
start_idx = text.find("[")
end_idx = text.rfind("]")
if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
text = text[start_idx:end_idx + 1]
try:
findings = json.loads(text)
if isinstance(findings, list):
return findings
except json.JSONDecodeError:
logger.warning(f"Failed to parse LLM response as JSON: {text[:200]}...")
return []
# ---------------------------------------------------------------------------
# Core Inference Execution
# ---------------------------------------------------------------------------
def run_agent_single_task(
env_url: str,
task_id: str,
client: OpenAI,
seed: int = SEED,
) -> Dict[str, Any]:
session = requests.Session()
rewards: List[float] = []
steps_taken = 0
score = 0.01
success = False
# Emit [START] log
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
try:
# 1. Reset Environment
logger.info(f"[{task_id}] Resetting environment...")
reset_resp = session.post(
f"{env_url}/reset",
json={"task_id": task_id, "seed": seed},
)
reset_resp.raise_for_status()
obs = reset_resp.json().get("observation", reset_resp.json())
# 2. Extract Data & Allowed Errors
tasks_resp = session.get(f"{env_url}/tasks")
tasks_resp.raise_for_status()
task_info = next(t for t in tasks_resp.json()["tasks"] if t["id"] == task_id)
error_types = task_info["error_types"]
prompt = build_task_prompt(obs["task_description"], obs["documents"], error_types)
# 3. Call LLM (via OpenAI Client — MANDATORY)
logger.info(f"[{task_id}] Calling {MODEL_NAME} at {API_BASE_URL}...")
start_time = time.time()
try:
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
max_tokens=MAX_TOKENS,
temperature=TEMPERATURE,
stream=False,
)
response_text = response.choices[0].message.content or ""
except Exception as e:
logger.error(f"[{task_id}] Inference failed: {e}")
response_text = "[]"
logger.info(f"[{task_id}] LLM responded in {time.time() - start_time:.1f}s")
# 4. Parse & Submit Action
findings = parse_llm_findings(response_text)
logger.info(f"[{task_id}] Parsed {len(findings)} findings from LLM")
action = {
"findings": [
{
"document_id": str(f.get("document_id", "")),
"error_type": str(f.get("error_type", "")),
"description": str(f.get("description", "No description")),
"suggested_fix": str(f.get("suggested_fix", "")) if f.get("suggested_fix") else None,
}
for f in findings
],
"submit_final": True,
}
# Summarize the action for the log
action_summary = f"submit_{len(findings)}_findings"
step_resp = session.post(f"{env_url}/step", json={"action": action})
step_resp.raise_for_status()
step_data = step_resp.json()
step_reward = strict_unit_interval(step_data.get("reward", 0.01) or 0.01)
step_done = step_data.get("done", True)
rewards.append(step_reward)
steps_taken = 1
# Emit [STEP] log
log_step(
step=1,
action=action_summary,
reward=step_reward,
done=step_done,
error=None,
)
# 5. Get Grader Results
grader_resp = session.get(f"{env_url}/grader")
grader_resp.raise_for_status()
grader_data = grader_resp.json()
score = strict_unit_interval(grader_data.get("score", 0.01))
success = score >= SUCCESS_SCORE_THRESHOLD
result = {
"task_id": task_id,
"task_name": task_info["name"],
"difficulty": task_info["difficulty"],
"score": score,
"precision": strict_unit_interval(grader_data.get("precision", 0.01)),
"recall": strict_unit_interval(grader_data.get("recall", 0.01)),
}
logger.info(f"[{task_id}] Score: {result['score']:.4f} (P={result['precision']:.2f}, R={result['recall']:.2f})")
except Exception as exc:
logger.error(f"[{task_id}] Task failed: {exc}")
result = {
"task_id": task_id,
"task_name": task_id,
"difficulty": "unknown",
"score": 0.01,
"precision": 0.01,
"recall": 0.01,
}
finally:
# Keep output parseable for strict validators even on exceptions.
if not rewards:
rewards = [strict_unit_interval(score)]
if steps_taken <= 0:
steps_taken = 1
score = strict_unit_interval(score)
# Emit [END] log
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
return result
def run_campaign_round(env_url: str, seed: int = SEED, client=None) -> Dict[str, Any]:
"""
Run a full multi-agent, multi-period campaign (Round 2).
Flow per period:
1. Start each specialist task in dependency order
2. Use LLM to generate findings (if client available)
3. Submit findings
4. Handle any regulatory shocks returned
5. Overseer reviews all specialist findings
6. Advance to next period (world mutation)
Logs with [START]/[STEP]/[END] format for mandatory hackathon logging.
"""
session = requests.Session()
roles = ["expense_specialist", "invoice_specialist", "gst_specialist", "fraud_specialist"]
total_periods = 5
logger.info("[CAMPAIGN START] Initializing multi-agent campaign")
print(f"[START] campaign seed={seed} periods={total_periods}", flush=True)
# Start campaign
start = session.post(
f"{env_url}/campaign/start",
json={"seed": seed, "total_periods": total_periods},
)
start.raise_for_status()
start_data = start.json()
campaign_id = start_data["campaign_id"]
period_results = []
for period in range(1, total_periods + 1):
logger.info(f"[CAMPAIGN] Period {period}/{total_periods}")
print(f"[STEP] period={period} phase=start", flush=True)
period_findings_by_role: Dict[str, List[Dict]] = {}
# Run each specialist in order
for role in roles:
# Start task
task_start = session.post(
f"{env_url}/campaign/task/start",
json={"campaign_id": campaign_id, "role": role},
)
if task_start.status_code != 200:
logger.warning(f" [{role}] task/start failed: {task_start.text}")
continue
task_data = task_start.json()
# Get observation for prompt
obs = task_data.get("observation", {})
world_state = task_data.get("world_state", {})
history = task_data.get("findings_history", [])
instructions = task_data.get("active_instructions", [])
shocks = task_data.get("pending_regulatory_shocks", [])
# Generate findings using LLM (if available)
findings: List[Dict] = []
if client:
try:
prompt = _build_campaign_prompt(obs, world_state, history, instructions, shocks, role)
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": prompt}],
temperature=0.1,
max_tokens=2048,
)
findings = _parse_findings_response(response.choices[0].message.content)
except Exception as e:
logger.warning(f" [{role}] LLM call failed: {e}")
# Submit findings
submit = session.post(
f"{env_url}/campaign/task/submit",
json={
"campaign_id": campaign_id,
"role": role,
"action": {"findings": findings, "submit_final": True},
},
)
if submit.status_code == 200:
submit_data = submit.json()
# Check for regulatory shocks
obs_data = submit_data.get("observation", {})
reg_shocks = obs_data.get("pending_regulatory_shocks", [])
if reg_shocks:
logger.info(f" [{role}] REGULATORY SHOCK received: {len(reg_shocks)} new rule(s)")
print(f"[STEP] period={period} role={role} regulatory_shock=true", flush=True)
period_findings_by_role[role] = findings
print(f"[STEP] period={period} role={role} findings={len(findings)}", flush=True)
# Overseer review
all_decisions = []
for role, findings in period_findings_by_role.items():
for f in findings:
all_decisions.append({
"finding_ref": f"{f.get('document_id', '')}:{f.get('error_type', '')}",
"verdict": "approve",
"reason_code": "specialist_evidence",
"confidence": f.get("confidence", 0.7),
})
review = session.post(
f"{env_url}/overseer/review",
json={
"campaign_id": campaign_id,
"action": {
"audit_trail_id": f"trail-{campaign_id}-p{period}",
"decisions": all_decisions,
"conflicts_resolved": [],
"task_reassignments": {},
},
},
)
if review.status_code == 200:
review_data = review.json()
logger.info(f" [overseer] Review complete: {review_data.get('result', {})}")
print(f"[STEP] period={period} phase=overseer_review decisions={len(all_decisions)}", flush=True)
period_results.append({
"period": period,
"findings_by_role": {r: len(f) for r, f in period_findings_by_role.items()},
"overseer_decisions": len(all_decisions),
})
# Advance to next period (unless last)
if period < total_periods:
advance = session.post(
f"{env_url}/campaign/period/advance",
json={"campaign_id": campaign_id},
)
if advance.status_code == 200:
adv_data = advance.json()
new_obs = adv_data.get("observation", {})
ws = new_obs.get("world_state", {})
logger.info(f" Advanced to period {period + 1}, "
f"policy_version={ws.get('policy_version', '?')}, "
f"schema_version={ws.get('schema_version', '?')}")
# Get final campaign state
status = session.get(f"{env_url}/campaign/state", params={"campaign_id": campaign_id})
final_state = status.json() if status.status_code == 200 else {}
print(f"[END] campaign_id={campaign_id} periods={total_periods} success=true", flush=True)
# Print summary table
print(f"\n{'='*60}", file=sys.stderr)
print(f" CAMPAIGN RESULTS — {total_periods} periods", file=sys.stderr)
print(f"{'='*60}", file=sys.stderr)
for pr in period_results:
total_f = sum(pr["findings_by_role"].values())
print(f" Period {pr['period']}: {total_f} findings, {pr['overseer_decisions']} overseer decisions", file=sys.stderr)
print(f"{'='*60}", file=sys.stderr)
return {
"campaign_id": campaign_id,
"period_results": period_results,
"state": final_state,
}
def _build_campaign_prompt(obs, world_state, history, instructions, shocks, role):
"""Build a comprehensive prompt for a specialist agent."""
task_desc = obs.get("task_description", "")
docs = json.dumps(obs.get("documents", {}), indent=2, default=str)[:8000]
prompt = f"You are a {role} in a multi-agent financial audit team.\n\n"
prompt += f"TASK:\n{task_desc}\n\n"
if world_state.get("policy_updates"):
prompt += "ACTIVE POLICY CHANGES:\n"
for p in world_state["policy_updates"]:
prompt += f" - {p}\n"
prompt += "\n"
if shocks:
prompt += "⚠️ REGULATORY SHOCKS (apply immediately):\n"
for s in shocks:
prompt += f" - {s.get('text', '')}\n"
prompt += "\n"
if history:
prompt += f"PRIOR PERIOD FINDINGS (use for cross-period patterns):\n"
for h in history[-10:]:
prompt += f" - P{h.get('period')}: {h.get('document_id')} ({h.get('error_type')})\n"
prompt += "\n"
if instructions:
prompt += "INSTRUCTIONS TO FOLLOW:\n"
for inst in instructions[:10]:
prompt += f" - [{inst.get('id')}] {inst.get('text')}\n"
prompt += "\n"
prompt += f"DOCUMENTS (truncated):\n{docs}\n\n"
prompt += (
"Report findings as JSON array. Each finding needs: "
"document_id, error_type, description, confidence (0.0-1.0).\n"
"Be precise (no false positives) and thorough (find all errors).\n"
"Output ONLY the JSON array."
)
return prompt
def _parse_findings_response(text: str) -> List[Dict]:
"""Parse LLM response into findings dicts."""
if not text:
return []
try:
match = re.search(r"\[.*\]", text, re.DOTALL)
if match:
parsed = json.loads(match.group())
if isinstance(parsed, list):
valid = []
for item in parsed:
if isinstance(item, dict) and "document_id" in item and "error_type" in item:
valid.append({
"document_id": str(item["document_id"]),
"error_type": str(item["error_type"]).lower(),
"description": str(item.get("description", "Finding")),
"confidence": float(item["confidence"]) if "confidence" in item else None,
})
return valid
except (json.JSONDecodeError, ValueError, TypeError):
pass
return []
def main():
parser = argparse.ArgumentParser(description="Financial Audit Env Inference")
parser.add_argument("--env-url", default="http://localhost:8000", help="URL of the running environment")
parser.add_argument("--task", default=None, help="Specific task ID to run (optional)")
parser.add_argument("--campaign", action="store_true", help="Run Round 2 campaign flow instead of Round 1 tasks")
parser.add_argument("--seed", type=int, default=SEED, help="Seed for reproducibility")
args = parser.parse_args()
print(f"{'='*60}", file=sys.stderr)
print(f" OpenEnv Financial Audit - Inference Configuration", file=sys.stderr)
print(f" Model Identifier: {MODEL_NAME}", file=sys.stderr)
print(f" API Base URL: {API_BASE_URL}", file=sys.stderr)
print(f" Environment URL: {args.env_url}", file=sys.stderr)
print(f"{'='*60}\n", file=sys.stderr)
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
if args.campaign:
out = run_campaign_round(args.env_url, args.seed, client=client)
print(json.dumps({"campaign": out}, indent=2, default=str), file=sys.stderr)
return
tasks_to_run = [args.task] if args.task else TASK_IDS
results = {}
total_score = 0.0
for idx, task_id in enumerate(tasks_to_run):
res = run_agent_single_task(args.env_url, task_id, client, args.seed)
results[task_id] = res
total_score += res["score"]
if idx < len(tasks_to_run) - 1:
print("\n", file=sys.stderr)
print(f"\n{'='*60}", file=sys.stderr)
print(f" RESULTS", file=sys.stderr)
print(f"{'='*60}", file=sys.stderr)
print(f"{'Task':<30} {'Difficulty':<12} {'Score':<8} {'P':<8} {'R':<8}", file=sys.stderr)
print(f"{'-'*60}", file=sys.stderr)
for k, v in results.items():
print(f"{v['task_name']:<30} {v['difficulty']:<12} {v['score']:<8.4f} {v['precision']:<8.2f} {v['recall']:<8.2f}", file=sys.stderr)
print(f"{'-'*60}", file=sys.stderr)
print(f"{'AVERAGE':<42} {total_score / len(tasks_to_run):<8.4f}", file=sys.stderr)
print(f"{'='*60}", file=sys.stderr)
if __name__ == "__main__":
main()