File size: 15,326 Bytes
d75cfa2
 
 
126bdbd
d75cfa2
126bdbd
 
 
 
 
 
 
 
d75cfa2
 
 
126bdbd
 
 
 
d75cfa2
 
 
 
 
 
 
 
 
 
 
 
 
126bdbd
d75cfa2
 
 
126bdbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d75cfa2
 
 
 
 
 
 
126bdbd
 
d75cfa2
126bdbd
d75cfa2
 
126bdbd
d75cfa2
 
 
 
 
 
 
126bdbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d75cfa2
 
 
 
 
 
 
 
126bdbd
 
d75cfa2
eecf965
 
d75cfa2
 
 
 
126bdbd
d75cfa2
 
 
 
 
 
 
 
 
126bdbd
 
 
 
 
 
 
 
 
 
d75cfa2
126bdbd
d75cfa2
 
d2e4b6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d75cfa2
 
 
126bdbd
 
 
 
 
 
 
d75cfa2
 
 
eecf965
d75cfa2
 
 
 
126bdbd
d75cfa2
 
126bdbd
 
d75cfa2
 
126bdbd
 
d75cfa2
 
 
126bdbd
d75cfa2
126bdbd
 
 
 
 
d75cfa2
 
 
 
126bdbd
 
 
d75cfa2
 
 
 
 
 
 
 
 
 
 
 
 
126bdbd
 
 
 
 
d75cfa2
 
 
 
 
 
 
 
 
 
 
 
126bdbd
d75cfa2
126bdbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d75cfa2
 
 
 
 
 
 
 
126bdbd
d75cfa2
 
 
 
 
 
 
126bdbd
d75cfa2
 
126bdbd
d75cfa2
126bdbd
 
d75cfa2
 
 
 
 
 
a16cc4e
f76bbd6
2c804dd
a16cc4e
d75cfa2
 
126bdbd
f76bbd6
a16cc4e
 
 
d75cfa2
 
 
 
f76bbd6
a16cc4e
 
126bdbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c804dd
 
126bdbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d75cfa2
 
 
126bdbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d75cfa2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
# Copyright (c) 2026. All rights reserved.
# Financial Audit Environment — FastAPI application.
#
# Exposes the environment over HTTP with:
# - Standard OpenEnv endpoints (reset, step, state, health)
# - Custom endpoints:
#   GET  /tasks       → list of tasks with action schema
#   GET  /grader      → F1 score for last completed episode
#   POST /baseline    → trigger baseline inference
#   GET  /leaderboard → best scores per model
#   GET  /metrics     → basic usage statistics
# - Session-based multi-tenancy with global fallback
# - Security middleware

import logging
import os
import time
import uuid
from collections import defaultdict
from typing import Any, Dict, List, Optional

from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel

from ..models import AuditAction, AuditObservation
from .environment import FinancialAuditEnvironment
from .security import setup_security
from .tasks import TASKS, get_all_tasks_summary

logger = logging.getLogger("financial_audit_env.app")

# ---------------------------------------------------------------------------
# Global environment instance (fallback when no session_id provided)
# ---------------------------------------------------------------------------
_env = FinancialAuditEnvironment()

# Session-based environments for multi-tenancy
_sessions: Dict[str, Dict[str, Any]] = {}
_SESSION_TTL = 3600  # Sessions expire after 1 hour

# Leaderboard storage
_leaderboard: List[Dict[str, Any]] = []

# Metrics tracking
_metrics = {
    "total_resets": 0,
    "total_steps": 0,
    "total_episodes_completed": 0,
    "task_reset_counts": defaultdict(int),
    "start_time": time.time(),
}

# ---------------------------------------------------------------------------
# Create the FastAPI app
# ---------------------------------------------------------------------------
app = FastAPI(
    title="Financial Audit Environment",
    description=(
        "An OpenEnv-compatible RL environment for financial auditing tasks. "
        "Agents audit synthetic financial documents to find planted errors. "
        "Supports 4 tasks (easy→expert), investigation mode, and adaptive difficulty."
    ),
    version="2.0.0",
    docs_url="/docs",
)
logger.info("Financial Audit Environment v2.0 — standalone FastAPI mode")

# ---------------------------------------------------------------------------
# Apply security middleware
# ---------------------------------------------------------------------------
setup_security(app)


# ---------------------------------------------------------------------------
# Session management helpers
# ---------------------------------------------------------------------------

def _get_env(session_id: Optional[str] = None) -> FinancialAuditEnvironment:
    """Get environment instance — session-based or global fallback."""
    if session_id and session_id in _sessions:
        _sessions[session_id]["last_access"] = time.time()
        return _sessions[session_id]["env"]
    return _env


def _create_session() -> str:
    """Create a new session with its own environment instance."""
    session_id = str(uuid.uuid4())
    _sessions[session_id] = {
        "env": FinancialAuditEnvironment(),
        "created_at": time.time(),
        "last_access": time.time(),
    }
    # Cleanup expired sessions
    _cleanup_sessions()
    return session_id


def _cleanup_sessions() -> None:
    """Remove expired sessions to prevent memory leak."""
    now = time.time()
    expired = [
        sid for sid, data in _sessions.items()
        if now - data["last_access"] > _SESSION_TTL
    ]
    for sid in expired:
        del _sessions[sid]
    if expired:
        logger.info(f"Cleaned up {len(expired)} expired sessions")


# ---------------------------------------------------------------------------
# Request/Response models
# ---------------------------------------------------------------------------

class ResetRequest(BaseModel):
    task_id: Optional[str] = "expense_audit"
    seed: Optional[int] = 42
    episode_id: Optional[str] = None
    session_id: Optional[str] = None
    investigation_mode: Optional[bool] = False

    model_config = {"extra": "allow"}


class StepRequest(BaseModel):
    """Request body for the /step endpoint."""
    action: AuditAction
    session_id: Optional[str] = None


class BaselineResponse(BaseModel):
    """Response from the /baseline endpoint."""
    scores: Dict[str, Any]
    model: str
    status: str


class LeaderboardEntry(BaseModel):
    """A single leaderboard entry."""
    model: str
    task_id: str
    score: float
    weighted_score: float
    risk_mitigation_pct: float
    timestamp: float


# ---------------------------------------------------------------------------
# Standard endpoints
# ---------------------------------------------------------------------------

@app.get("/")
async def root():
    """Root endpoint — welcome page with API docs link."""
    return {
        "name": "Financial Audit Environment",
        "version": "2.0.0",
        "description": (
            "An OpenEnv-compatible RL environment for financial auditing tasks. "
            "Agents audit synthetic financial documents to find planted errors."
        ),
        "tasks": ["expense_audit", "invoice_match", "gst_reconciliation", "fraud_detection"],
        "endpoints": {
            "docs": "/docs",
            "health": "/health",
            "tasks": "/tasks",
            "reset": "POST /reset",
            "step": "POST /step",
            "grader": "/grader",
            "leaderboard": "/leaderboard",
            "metrics": "/metrics",
        },
        "quickstart": (
            "1. POST /reset with {\"task_id\": \"expense_audit\", \"seed\": 42}\n"
            "2. Read the documents in the observation\n"
            "3. POST /step with your findings\n"
            "4. GET /grader to see your score"
        ),
        "github": "https://github.com/balloonmann/financial-audit-env",
    }


@app.get("/health")
async def health_check():
    """Health check endpoint — required for HF Space deployment."""
    return {
        "status": "healthy",
        "environment": "financial_audit_env",
        "version": "2.0.0",
        "tasks_available": len(TASKS),
        "active_sessions": len(_sessions),
    }


@app.post("/reset")
async def reset_endpoint(request: ResetRequest = ResetRequest()):
    """
    Reset the environment for a new episode.

    Args (JSON body):
        task_id: "expense_audit" | "invoice_match" | "gst_reconciliation" | "fraud_detection"
        seed: Random seed for reproducibility (default: 42)
        episode_id: Optional custom episode ID
        session_id: Optional session ID for multi-tenancy
        investigation_mode: If true, start in drill-down mode
    """
    try:
        env = _get_env(request.session_id)
        obs = env.reset(
            seed=request.seed,
            episode_id=request.episode_id,
            task_id=request.task_id,
            investigation_mode=request.investigation_mode or False,
        )
        # Track metrics
        _metrics["total_resets"] += 1
        _metrics["task_reset_counts"][request.task_id or "expense_audit"] += 1

        response = {
            "observation": obs.model_dump(),
            "done": obs.done,
            "reward": obs.reward,
        }
        if request.session_id:
            response["session_id"] = request.session_id
        return response
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))


@app.post("/step")
async def step_endpoint(request: StepRequest):
    """
    Execute one step in the environment.

    Submit audit findings and receive feedback + reward.
    Set submit_final=True to end the episode and get final grading.
    """
    try:
        env = _get_env(request.session_id)
        obs = env.step(request.action)
        _metrics["total_steps"] += 1
        if obs.done:
            _metrics["total_episodes_completed"] += 1
        return {
            "observation": obs.model_dump(),
            "done": obs.done,
            "reward": obs.reward,
        }
    except RuntimeError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))


@app.get("/state")
async def state_endpoint(session_id: Optional[str] = None):
    """Get current episode state (step count, found errors, etc.)."""
    env = _get_env(session_id)
    return env.state.model_dump()


# ---------------------------------------------------------------------------
# Session management endpoints
# ---------------------------------------------------------------------------

@app.post("/session")
async def create_session():
    """Create a new isolated session for multi-tenancy."""
    session_id = _create_session()
    return {
        "session_id": session_id,
        "ttl_seconds": _SESSION_TTL,
        "message": "Session created. Include session_id in /reset and /step requests.",
    }


# ---------------------------------------------------------------------------
# Contest-required custom endpoints
# ---------------------------------------------------------------------------

@app.get("/tasks")
async def get_tasks():
    """List all available tasks with their descriptions and action schemas."""
    return {
        "tasks": get_all_tasks_summary(),
        "total_tasks": len(TASKS),
    }


@app.get("/grader")
async def get_grader_score(session_id: Optional[str] = None):
    """
    Get the grader score for the last completed episode.
    Includes F1, weighted F1, confusion matrix, and risk scoring.
    """
    env = _get_env(session_id)
    result = env.last_grader_result
    if result is None:
        return {
            "status": "no_completed_episode",
            "message": "No episode completed. Call /reset then /step with submit_final=True.",
        }

    def final_clamp(val: float) -> float:
        """Keep score-like fields within a stable open interval."""
        return max(0.01, min(0.99, val))

    return {
        "status": "completed",
        "task_id": env.state.task_id,
        # Primary score fields with a final stability clamp.
        "score": final_clamp(result["score"]),
        "precision": final_clamp(result["precision"]),
        "recall": final_clamp(result["recall"]),
        "true_positives": result["true_positives"],
        "false_positives": result["false_positives"],
        "false_negatives": result["false_negatives"],
        "total_errors": result["total_errors"],
        # Enhanced scoring fields with the same stability clamp.
        "weighted_score": final_clamp(result.get("weighted_score", result["score"])),
        "partial_credit_score": final_clamp(result.get("partial_credit_score", result["score"])),
        "partial_matches": result.get("partial_matches", 0),
        # Confusion matrix
        "confusion_matrix": result.get("confusion_matrix", {}),
        # Risk scoring
        "risk_score": result.get("risk_score", {}),
    }


# ---------------------------------------------------------------------------
# Leaderboard endpoint
# ---------------------------------------------------------------------------

@app.post("/leaderboard")
async def submit_to_leaderboard(
    model: str,
    task_id: str,
    score: float,
    weighted_score: float = 0.01,
    risk_mitigation_pct: float = 0.01,
):
    """Submit a score to the leaderboard."""
    entry = {
        "model": model,
        "task_id": task_id,
        "score": score,
        "weighted_score": weighted_score,
        "risk_mitigation_pct": risk_mitigation_pct,
        "timestamp": time.time(),
    }
    _leaderboard.append(entry)
    # Keep top 100 entries
    _leaderboard.sort(key=lambda x: x["score"], reverse=True)
    while len(_leaderboard) > 100:
        _leaderboard.pop()
    return {"status": "submitted", "rank": next(
        (i + 1 for i, e in enumerate(_leaderboard) if e == entry), -1
    )}


@app.get("/leaderboard")
async def get_leaderboard(task_id: Optional[str] = None, limit: int = 20):
    """Get the leaderboard — best scores per model."""
    entries = _leaderboard
    if task_id:
        entries = [e for e in entries if e["task_id"] == task_id]
    return {
        "leaderboard": entries[:limit],
        "total_entries": len(entries),
    }


# ---------------------------------------------------------------------------
# Metrics endpoint
# ---------------------------------------------------------------------------

@app.get("/metrics")
async def get_metrics():
    """Get basic usage statistics."""
    uptime = time.time() - _metrics["start_time"]
    return {
        "uptime_seconds": round(uptime, 0),
        "total_resets": _metrics["total_resets"],
        "total_steps": _metrics["total_steps"],
        "total_episodes_completed": _metrics["total_episodes_completed"],
        "task_usage": dict(_metrics["task_reset_counts"]),
        "active_sessions": len(_sessions),
    }


# ---------------------------------------------------------------------------
# Adaptive difficulty endpoint
# ---------------------------------------------------------------------------

@app.get("/adaptive-difficulty")
async def get_adaptive_difficulty(session_id: Optional[str] = None):
    """Get adaptive difficulty recommendations based on score history."""
    env = _get_env(session_id)
    return env.get_adaptive_difficulty()


# ---------------------------------------------------------------------------
# Baseline endpoint
# ---------------------------------------------------------------------------

@app.post("/baseline")
async def run_baseline():
    """
    Run the baseline agent on all 3 tasks and return scores.
    Requires HF_TOKEN environment variable.
    """
    try:
        from ..baseline import run_baseline_all_tasks
    except ImportError:
        return JSONResponse(
            status_code=501,
            content={
                "status": "error",
                "message": "Baseline not available. Run baseline.py directly.",
            },
        )

    hf_token = os.environ.get("HF_TOKEN", "")
    if not hf_token:
        return JSONResponse(
            status_code=400,
            content={
                "status": "error",
                "message": "HF_TOKEN not set. Get one at https://huggingface.co/settings/tokens",
            },
        )

    try:
        scores = run_baseline_all_tasks(env=_env, hf_token=hf_token)
        return BaselineResponse(
            scores=scores,
            model="meta-llama/Llama-3.1-8B-Instruct",
            status="completed",
        )
    except Exception as e:
        logger.error(f"Baseline failed: {e}", exc_info=True)
        return JSONResponse(
            status_code=500,
            content={"status": "error", "message": "Baseline failed. Check logs."},
        )


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

def main():
    """Run the server directly: python -m financial_audit_env.server.app"""
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)


if __name__ == "__main__":
    main()