Spaces:
Running
Running
File size: 15,326 Bytes
d75cfa2 126bdbd d75cfa2 126bdbd d75cfa2 126bdbd d75cfa2 126bdbd d75cfa2 126bdbd d75cfa2 126bdbd d75cfa2 126bdbd d75cfa2 126bdbd d75cfa2 126bdbd d75cfa2 126bdbd d75cfa2 eecf965 d75cfa2 126bdbd d75cfa2 126bdbd d75cfa2 126bdbd d75cfa2 d2e4b6d d75cfa2 126bdbd d75cfa2 eecf965 d75cfa2 126bdbd d75cfa2 126bdbd d75cfa2 126bdbd d75cfa2 126bdbd d75cfa2 126bdbd d75cfa2 126bdbd d75cfa2 126bdbd d75cfa2 126bdbd d75cfa2 126bdbd d75cfa2 126bdbd d75cfa2 126bdbd d75cfa2 126bdbd d75cfa2 126bdbd d75cfa2 a16cc4e f76bbd6 2c804dd a16cc4e d75cfa2 126bdbd f76bbd6 a16cc4e d75cfa2 f76bbd6 a16cc4e 126bdbd 2c804dd 126bdbd d75cfa2 126bdbd d75cfa2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 | # Copyright (c) 2026. All rights reserved.
# Financial Audit Environment — FastAPI application.
#
# Exposes the environment over HTTP with:
# - Standard OpenEnv endpoints (reset, step, state, health)
# - Custom endpoints:
# GET /tasks → list of tasks with action schema
# GET /grader → F1 score for last completed episode
# POST /baseline → trigger baseline inference
# GET /leaderboard → best scores per model
# GET /metrics → basic usage statistics
# - Session-based multi-tenancy with global fallback
# - Security middleware
import logging
import os
import time
import uuid
from collections import defaultdict
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from ..models import AuditAction, AuditObservation
from .environment import FinancialAuditEnvironment
from .security import setup_security
from .tasks import TASKS, get_all_tasks_summary
logger = logging.getLogger("financial_audit_env.app")
# ---------------------------------------------------------------------------
# Global environment instance (fallback when no session_id provided)
# ---------------------------------------------------------------------------
_env = FinancialAuditEnvironment()
# Session-based environments for multi-tenancy
_sessions: Dict[str, Dict[str, Any]] = {}
_SESSION_TTL = 3600 # Sessions expire after 1 hour
# Leaderboard storage
_leaderboard: List[Dict[str, Any]] = []
# Metrics tracking
_metrics = {
"total_resets": 0,
"total_steps": 0,
"total_episodes_completed": 0,
"task_reset_counts": defaultdict(int),
"start_time": time.time(),
}
# ---------------------------------------------------------------------------
# Create the FastAPI app
# ---------------------------------------------------------------------------
app = FastAPI(
title="Financial Audit Environment",
description=(
"An OpenEnv-compatible RL environment for financial auditing tasks. "
"Agents audit synthetic financial documents to find planted errors. "
"Supports 4 tasks (easy→expert), investigation mode, and adaptive difficulty."
),
version="2.0.0",
docs_url="/docs",
)
logger.info("Financial Audit Environment v2.0 — standalone FastAPI mode")
# ---------------------------------------------------------------------------
# Apply security middleware
# ---------------------------------------------------------------------------
setup_security(app)
# ---------------------------------------------------------------------------
# Session management helpers
# ---------------------------------------------------------------------------
def _get_env(session_id: Optional[str] = None) -> FinancialAuditEnvironment:
"""Get environment instance — session-based or global fallback."""
if session_id and session_id in _sessions:
_sessions[session_id]["last_access"] = time.time()
return _sessions[session_id]["env"]
return _env
def _create_session() -> str:
"""Create a new session with its own environment instance."""
session_id = str(uuid.uuid4())
_sessions[session_id] = {
"env": FinancialAuditEnvironment(),
"created_at": time.time(),
"last_access": time.time(),
}
# Cleanup expired sessions
_cleanup_sessions()
return session_id
def _cleanup_sessions() -> None:
"""Remove expired sessions to prevent memory leak."""
now = time.time()
expired = [
sid for sid, data in _sessions.items()
if now - data["last_access"] > _SESSION_TTL
]
for sid in expired:
del _sessions[sid]
if expired:
logger.info(f"Cleaned up {len(expired)} expired sessions")
# ---------------------------------------------------------------------------
# Request/Response models
# ---------------------------------------------------------------------------
class ResetRequest(BaseModel):
task_id: Optional[str] = "expense_audit"
seed: Optional[int] = 42
episode_id: Optional[str] = None
session_id: Optional[str] = None
investigation_mode: Optional[bool] = False
model_config = {"extra": "allow"}
class StepRequest(BaseModel):
"""Request body for the /step endpoint."""
action: AuditAction
session_id: Optional[str] = None
class BaselineResponse(BaseModel):
"""Response from the /baseline endpoint."""
scores: Dict[str, Any]
model: str
status: str
class LeaderboardEntry(BaseModel):
"""A single leaderboard entry."""
model: str
task_id: str
score: float
weighted_score: float
risk_mitigation_pct: float
timestamp: float
# ---------------------------------------------------------------------------
# Standard endpoints
# ---------------------------------------------------------------------------
@app.get("/")
async def root():
"""Root endpoint — welcome page with API docs link."""
return {
"name": "Financial Audit Environment",
"version": "2.0.0",
"description": (
"An OpenEnv-compatible RL environment for financial auditing tasks. "
"Agents audit synthetic financial documents to find planted errors."
),
"tasks": ["expense_audit", "invoice_match", "gst_reconciliation", "fraud_detection"],
"endpoints": {
"docs": "/docs",
"health": "/health",
"tasks": "/tasks",
"reset": "POST /reset",
"step": "POST /step",
"grader": "/grader",
"leaderboard": "/leaderboard",
"metrics": "/metrics",
},
"quickstart": (
"1. POST /reset with {\"task_id\": \"expense_audit\", \"seed\": 42}\n"
"2. Read the documents in the observation\n"
"3. POST /step with your findings\n"
"4. GET /grader to see your score"
),
"github": "https://github.com/balloonmann/financial-audit-env",
}
@app.get("/health")
async def health_check():
"""Health check endpoint — required for HF Space deployment."""
return {
"status": "healthy",
"environment": "financial_audit_env",
"version": "2.0.0",
"tasks_available": len(TASKS),
"active_sessions": len(_sessions),
}
@app.post("/reset")
async def reset_endpoint(request: ResetRequest = ResetRequest()):
"""
Reset the environment for a new episode.
Args (JSON body):
task_id: "expense_audit" | "invoice_match" | "gst_reconciliation" | "fraud_detection"
seed: Random seed for reproducibility (default: 42)
episode_id: Optional custom episode ID
session_id: Optional session ID for multi-tenancy
investigation_mode: If true, start in drill-down mode
"""
try:
env = _get_env(request.session_id)
obs = env.reset(
seed=request.seed,
episode_id=request.episode_id,
task_id=request.task_id,
investigation_mode=request.investigation_mode or False,
)
# Track metrics
_metrics["total_resets"] += 1
_metrics["task_reset_counts"][request.task_id or "expense_audit"] += 1
response = {
"observation": obs.model_dump(),
"done": obs.done,
"reward": obs.reward,
}
if request.session_id:
response["session_id"] = request.session_id
return response
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/step")
async def step_endpoint(request: StepRequest):
"""
Execute one step in the environment.
Submit audit findings and receive feedback + reward.
Set submit_final=True to end the episode and get final grading.
"""
try:
env = _get_env(request.session_id)
obs = env.step(request.action)
_metrics["total_steps"] += 1
if obs.done:
_metrics["total_episodes_completed"] += 1
return {
"observation": obs.model_dump(),
"done": obs.done,
"reward": obs.reward,
}
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/state")
async def state_endpoint(session_id: Optional[str] = None):
"""Get current episode state (step count, found errors, etc.)."""
env = _get_env(session_id)
return env.state.model_dump()
# ---------------------------------------------------------------------------
# Session management endpoints
# ---------------------------------------------------------------------------
@app.post("/session")
async def create_session():
"""Create a new isolated session for multi-tenancy."""
session_id = _create_session()
return {
"session_id": session_id,
"ttl_seconds": _SESSION_TTL,
"message": "Session created. Include session_id in /reset and /step requests.",
}
# ---------------------------------------------------------------------------
# Contest-required custom endpoints
# ---------------------------------------------------------------------------
@app.get("/tasks")
async def get_tasks():
"""List all available tasks with their descriptions and action schemas."""
return {
"tasks": get_all_tasks_summary(),
"total_tasks": len(TASKS),
}
@app.get("/grader")
async def get_grader_score(session_id: Optional[str] = None):
"""
Get the grader score for the last completed episode.
Includes F1, weighted F1, confusion matrix, and risk scoring.
"""
env = _get_env(session_id)
result = env.last_grader_result
if result is None:
return {
"status": "no_completed_episode",
"message": "No episode completed. Call /reset then /step with submit_final=True.",
}
def final_clamp(val: float) -> float:
"""Keep score-like fields within a stable open interval."""
return max(0.01, min(0.99, val))
return {
"status": "completed",
"task_id": env.state.task_id,
# Primary score fields with a final stability clamp.
"score": final_clamp(result["score"]),
"precision": final_clamp(result["precision"]),
"recall": final_clamp(result["recall"]),
"true_positives": result["true_positives"],
"false_positives": result["false_positives"],
"false_negatives": result["false_negatives"],
"total_errors": result["total_errors"],
# Enhanced scoring fields with the same stability clamp.
"weighted_score": final_clamp(result.get("weighted_score", result["score"])),
"partial_credit_score": final_clamp(result.get("partial_credit_score", result["score"])),
"partial_matches": result.get("partial_matches", 0),
# Confusion matrix
"confusion_matrix": result.get("confusion_matrix", {}),
# Risk scoring
"risk_score": result.get("risk_score", {}),
}
# ---------------------------------------------------------------------------
# Leaderboard endpoint
# ---------------------------------------------------------------------------
@app.post("/leaderboard")
async def submit_to_leaderboard(
model: str,
task_id: str,
score: float,
weighted_score: float = 0.01,
risk_mitigation_pct: float = 0.01,
):
"""Submit a score to the leaderboard."""
entry = {
"model": model,
"task_id": task_id,
"score": score,
"weighted_score": weighted_score,
"risk_mitigation_pct": risk_mitigation_pct,
"timestamp": time.time(),
}
_leaderboard.append(entry)
# Keep top 100 entries
_leaderboard.sort(key=lambda x: x["score"], reverse=True)
while len(_leaderboard) > 100:
_leaderboard.pop()
return {"status": "submitted", "rank": next(
(i + 1 for i, e in enumerate(_leaderboard) if e == entry), -1
)}
@app.get("/leaderboard")
async def get_leaderboard(task_id: Optional[str] = None, limit: int = 20):
"""Get the leaderboard — best scores per model."""
entries = _leaderboard
if task_id:
entries = [e for e in entries if e["task_id"] == task_id]
return {
"leaderboard": entries[:limit],
"total_entries": len(entries),
}
# ---------------------------------------------------------------------------
# Metrics endpoint
# ---------------------------------------------------------------------------
@app.get("/metrics")
async def get_metrics():
"""Get basic usage statistics."""
uptime = time.time() - _metrics["start_time"]
return {
"uptime_seconds": round(uptime, 0),
"total_resets": _metrics["total_resets"],
"total_steps": _metrics["total_steps"],
"total_episodes_completed": _metrics["total_episodes_completed"],
"task_usage": dict(_metrics["task_reset_counts"]),
"active_sessions": len(_sessions),
}
# ---------------------------------------------------------------------------
# Adaptive difficulty endpoint
# ---------------------------------------------------------------------------
@app.get("/adaptive-difficulty")
async def get_adaptive_difficulty(session_id: Optional[str] = None):
"""Get adaptive difficulty recommendations based on score history."""
env = _get_env(session_id)
return env.get_adaptive_difficulty()
# ---------------------------------------------------------------------------
# Baseline endpoint
# ---------------------------------------------------------------------------
@app.post("/baseline")
async def run_baseline():
"""
Run the baseline agent on all 3 tasks and return scores.
Requires HF_TOKEN environment variable.
"""
try:
from ..baseline import run_baseline_all_tasks
except ImportError:
return JSONResponse(
status_code=501,
content={
"status": "error",
"message": "Baseline not available. Run baseline.py directly.",
},
)
hf_token = os.environ.get("HF_TOKEN", "")
if not hf_token:
return JSONResponse(
status_code=400,
content={
"status": "error",
"message": "HF_TOKEN not set. Get one at https://huggingface.co/settings/tokens",
},
)
try:
scores = run_baseline_all_tasks(env=_env, hf_token=hf_token)
return BaselineResponse(
scores=scores,
model="meta-llama/Llama-3.1-8B-Instruct",
status="completed",
)
except Exception as e:
logger.error(f"Baseline failed: {e}", exc_info=True)
return JSONResponse(
status_code=500,
content={"status": "error", "message": "Baseline failed. Check logs."},
)
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
def main():
"""Run the server directly: python -m financial_audit_env.server.app"""
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
main()
|