# Copyright (c) 2026. All rights reserved. # Financial Audit Environment — FastAPI application. # # Exposes the environment over HTTP with: # - Standard OpenEnv endpoints (reset, step, state, health) # - Custom endpoints: # GET /tasks → list of tasks with action schema # GET /grader → F1 score for last completed episode # POST /baseline → trigger baseline inference # GET /leaderboard → best scores per model # GET /metrics → basic usage statistics # - Session-based multi-tenancy with global fallback # - Security middleware import logging import os import time import uuid from collections import defaultdict from typing import Any, Dict, List, Optional from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel from ..models import AuditAction, AuditObservation from .environment import FinancialAuditEnvironment from .security import setup_security from .tasks import TASKS, get_all_tasks_summary logger = logging.getLogger("financial_audit_env.app") # --------------------------------------------------------------------------- # Global environment instance (fallback when no session_id provided) # --------------------------------------------------------------------------- _env = FinancialAuditEnvironment() # Session-based environments for multi-tenancy _sessions: Dict[str, Dict[str, Any]] = {} _SESSION_TTL = 3600 # Sessions expire after 1 hour # Leaderboard storage _leaderboard: List[Dict[str, Any]] = [] # Metrics tracking _metrics = { "total_resets": 0, "total_steps": 0, "total_episodes_completed": 0, "task_reset_counts": defaultdict(int), "start_time": time.time(), } # --------------------------------------------------------------------------- # Create the FastAPI app # --------------------------------------------------------------------------- app = FastAPI( title="Financial Audit Environment", description=( "An OpenEnv-compatible RL environment for financial auditing tasks. " "Agents audit synthetic financial documents to find planted errors. " "Supports 4 tasks (easy→expert), investigation mode, and adaptive difficulty." ), version="2.0.0", docs_url="/docs", ) logger.info("Financial Audit Environment v2.0 — standalone FastAPI mode") # --------------------------------------------------------------------------- # Apply security middleware # --------------------------------------------------------------------------- setup_security(app) # --------------------------------------------------------------------------- # Session management helpers # --------------------------------------------------------------------------- def _get_env(session_id: Optional[str] = None) -> FinancialAuditEnvironment: """Get environment instance — session-based or global fallback.""" if session_id and session_id in _sessions: _sessions[session_id]["last_access"] = time.time() return _sessions[session_id]["env"] return _env def _create_session() -> str: """Create a new session with its own environment instance.""" session_id = str(uuid.uuid4()) _sessions[session_id] = { "env": FinancialAuditEnvironment(), "created_at": time.time(), "last_access": time.time(), } # Cleanup expired sessions _cleanup_sessions() return session_id def _cleanup_sessions() -> None: """Remove expired sessions to prevent memory leak.""" now = time.time() expired = [ sid for sid, data in _sessions.items() if now - data["last_access"] > _SESSION_TTL ] for sid in expired: del _sessions[sid] if expired: logger.info(f"Cleaned up {len(expired)} expired sessions") # --------------------------------------------------------------------------- # Request/Response models # --------------------------------------------------------------------------- class ResetRequest(BaseModel): task_id: Optional[str] = "expense_audit" seed: Optional[int] = 42 episode_id: Optional[str] = None session_id: Optional[str] = None investigation_mode: Optional[bool] = False model_config = {"extra": "allow"} class StepRequest(BaseModel): """Request body for the /step endpoint.""" action: AuditAction session_id: Optional[str] = None class BaselineResponse(BaseModel): """Response from the /baseline endpoint.""" scores: Dict[str, Any] model: str status: str class LeaderboardEntry(BaseModel): """A single leaderboard entry.""" model: str task_id: str score: float weighted_score: float risk_mitigation_pct: float timestamp: float # --------------------------------------------------------------------------- # Standard endpoints # --------------------------------------------------------------------------- @app.get("/") async def root(): """Root endpoint — welcome page with API docs link.""" return { "name": "Financial Audit Environment", "version": "2.0.0", "description": ( "An OpenEnv-compatible RL environment for financial auditing tasks. " "Agents audit synthetic financial documents to find planted errors." ), "tasks": ["expense_audit", "invoice_match", "gst_reconciliation", "fraud_detection"], "endpoints": { "docs": "/docs", "health": "/health", "tasks": "/tasks", "reset": "POST /reset", "step": "POST /step", "grader": "/grader", "leaderboard": "/leaderboard", "metrics": "/metrics", }, "quickstart": ( "1. POST /reset with {\"task_id\": \"expense_audit\", \"seed\": 42}\n" "2. Read the documents in the observation\n" "3. POST /step with your findings\n" "4. GET /grader to see your score" ), "github": "https://github.com/balloonmann/financial-audit-env", } @app.get("/health") async def health_check(): """Health check endpoint — required for HF Space deployment.""" return { "status": "healthy", "environment": "financial_audit_env", "version": "2.0.0", "tasks_available": len(TASKS), "active_sessions": len(_sessions), } @app.post("/reset") async def reset_endpoint(request: ResetRequest = ResetRequest()): """ Reset the environment for a new episode. Args (JSON body): task_id: "expense_audit" | "invoice_match" | "gst_reconciliation" | "fraud_detection" seed: Random seed for reproducibility (default: 42) episode_id: Optional custom episode ID session_id: Optional session ID for multi-tenancy investigation_mode: If true, start in drill-down mode """ try: env = _get_env(request.session_id) obs = env.reset( seed=request.seed, episode_id=request.episode_id, task_id=request.task_id, investigation_mode=request.investigation_mode or False, ) # Track metrics _metrics["total_resets"] += 1 _metrics["task_reset_counts"][request.task_id or "expense_audit"] += 1 response = { "observation": obs.model_dump(), "done": obs.done, "reward": obs.reward, } if request.session_id: response["session_id"] = request.session_id return response except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @app.post("/step") async def step_endpoint(request: StepRequest): """ Execute one step in the environment. Submit audit findings and receive feedback + reward. Set submit_final=True to end the episode and get final grading. """ try: env = _get_env(request.session_id) obs = env.step(request.action) _metrics["total_steps"] += 1 if obs.done: _metrics["total_episodes_completed"] += 1 return { "observation": obs.model_dump(), "done": obs.done, "reward": obs.reward, } except RuntimeError as e: raise HTTPException(status_code=400, detail=str(e)) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @app.get("/state") async def state_endpoint(session_id: Optional[str] = None): """Get current episode state (step count, found errors, etc.).""" env = _get_env(session_id) return env.state.model_dump() # --------------------------------------------------------------------------- # Session management endpoints # --------------------------------------------------------------------------- @app.post("/session") async def create_session(): """Create a new isolated session for multi-tenancy.""" session_id = _create_session() return { "session_id": session_id, "ttl_seconds": _SESSION_TTL, "message": "Session created. Include session_id in /reset and /step requests.", } # --------------------------------------------------------------------------- # Contest-required custom endpoints # --------------------------------------------------------------------------- @app.get("/tasks") async def get_tasks(): """List all available tasks with their descriptions and action schemas.""" return { "tasks": get_all_tasks_summary(), "total_tasks": len(TASKS), } @app.get("/grader") async def get_grader_score(session_id: Optional[str] = None): """ Get the grader score for the last completed episode. Includes F1, weighted F1, confusion matrix, and risk scoring. """ env = _get_env(session_id) result = env.last_grader_result if result is None: return { "status": "no_completed_episode", "message": "No episode completed. Call /reset then /step with submit_final=True.", } def final_clamp(val: float) -> float: """Keep score-like fields within a stable open interval.""" return max(0.01, min(0.99, val)) return { "status": "completed", "task_id": env.state.task_id, # Primary score fields with a final stability clamp. "score": final_clamp(result["score"]), "precision": final_clamp(result["precision"]), "recall": final_clamp(result["recall"]), "true_positives": result["true_positives"], "false_positives": result["false_positives"], "false_negatives": result["false_negatives"], "total_errors": result["total_errors"], # Enhanced scoring fields with the same stability clamp. "weighted_score": final_clamp(result.get("weighted_score", result["score"])), "partial_credit_score": final_clamp(result.get("partial_credit_score", result["score"])), "partial_matches": result.get("partial_matches", 0), # Confusion matrix "confusion_matrix": result.get("confusion_matrix", {}), # Risk scoring "risk_score": result.get("risk_score", {}), } # --------------------------------------------------------------------------- # Leaderboard endpoint # --------------------------------------------------------------------------- @app.post("/leaderboard") async def submit_to_leaderboard( model: str, task_id: str, score: float, weighted_score: float = 0.01, risk_mitigation_pct: float = 0.01, ): """Submit a score to the leaderboard.""" entry = { "model": model, "task_id": task_id, "score": score, "weighted_score": weighted_score, "risk_mitigation_pct": risk_mitigation_pct, "timestamp": time.time(), } _leaderboard.append(entry) # Keep top 100 entries _leaderboard.sort(key=lambda x: x["score"], reverse=True) while len(_leaderboard) > 100: _leaderboard.pop() return {"status": "submitted", "rank": next( (i + 1 for i, e in enumerate(_leaderboard) if e == entry), -1 )} @app.get("/leaderboard") async def get_leaderboard(task_id: Optional[str] = None, limit: int = 20): """Get the leaderboard — best scores per model.""" entries = _leaderboard if task_id: entries = [e for e in entries if e["task_id"] == task_id] return { "leaderboard": entries[:limit], "total_entries": len(entries), } # --------------------------------------------------------------------------- # Metrics endpoint # --------------------------------------------------------------------------- @app.get("/metrics") async def get_metrics(): """Get basic usage statistics.""" uptime = time.time() - _metrics["start_time"] return { "uptime_seconds": round(uptime, 0), "total_resets": _metrics["total_resets"], "total_steps": _metrics["total_steps"], "total_episodes_completed": _metrics["total_episodes_completed"], "task_usage": dict(_metrics["task_reset_counts"]), "active_sessions": len(_sessions), } # --------------------------------------------------------------------------- # Adaptive difficulty endpoint # --------------------------------------------------------------------------- @app.get("/adaptive-difficulty") async def get_adaptive_difficulty(session_id: Optional[str] = None): """Get adaptive difficulty recommendations based on score history.""" env = _get_env(session_id) return env.get_adaptive_difficulty() # --------------------------------------------------------------------------- # Baseline endpoint # --------------------------------------------------------------------------- @app.post("/baseline") async def run_baseline(): """ Run the baseline agent on all 3 tasks and return scores. Requires HF_TOKEN environment variable. """ try: from ..baseline import run_baseline_all_tasks except ImportError: return JSONResponse( status_code=501, content={ "status": "error", "message": "Baseline not available. Run baseline.py directly.", }, ) hf_token = os.environ.get("HF_TOKEN", "") if not hf_token: return JSONResponse( status_code=400, content={ "status": "error", "message": "HF_TOKEN not set. Get one at https://huggingface.co/settings/tokens", }, ) try: scores = run_baseline_all_tasks(env=_env, hf_token=hf_token) return BaselineResponse( scores=scores, model="meta-llama/Llama-3.1-8B-Instruct", status="completed", ) except Exception as e: logger.error(f"Baseline failed: {e}", exc_info=True) return JSONResponse( status_code=500, content={"status": "error", "message": "Baseline failed. Check logs."}, ) # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- def main(): """Run the server directly: python -m financial_audit_env.server.app""" import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000) if __name__ == "__main__": main()