"""FastAPI server exposing the FraudShield OpenEnv API.""" from __future__ import annotations import json import logging import os from contextlib import asynccontextmanager from pathlib import Path from typing import Any, Dict from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse, JSONResponse from fraudshield_env import FraudShieldEnvironment, TASK_CONFIG from llm_agent import SnapshotCalibratedFraudDetectionAgent from models import ( ActionTypeEnum, CaseScreenEnum, EpisodeState, FraudCheckAction, FraudCheckObservation, ResetResult, Reward, StepResult, TaskDifficulty, ) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) ROOT_DIR = Path(__file__).resolve().parents[1] DATA_PATH = ROOT_DIR / "data" APP_VERSION = "0.6.0" env = FraudShieldEnvironment(data_path=str(DATA_PATH), seed=42) @asynccontextmanager async def lifespan(_: FastAPI): """Load the frozen snapshot on startup.""" if not env.load_data(): logger.error("FraudShield failed to load bundled data from %s", DATA_PATH) yield app = FastAPI( title="FraudShield", description=( "Simulated fraud-investigation environment for OpenEnv. Agents operate under partial " "observability, reveal evidence with investigation tools, and route cases under limited budgets." ), version=APP_VERSION, docs_url="/docs", openapi_url="/openapi.json", lifespan=lifespan, ) def _explorer_html() -> str: return """ FraudShield Explorer

FraudShield Explorer

FraudShield is a simulated fraud review workflow. You start with a small amount of triage information, choose what evidence to inspect, and then make a final case decision.

1. Pick a task

Easy, medium, and hard tasks reveal different amounts of ambiguity and linked-case complexity.

2. Investigate

Use the workflow actions to reveal customer, merchant, network, payment, or policy evidence.

3. Decide

Add a note when required, then approve, hold, request documents, block, or escalate.

Current Case View

Everything below is what the model or analyst can currently see. Hidden evidence appears only after the matching investigation step.

Visible Hints

Revealed Evidence

{}

Current State Snapshot

No active episode yet.

Baseline Walkthrough

A plain reference flow that shows how the current rule-based baseline handles the same task.
""" def _ensure_data_loaded() -> None: if not env.data_loaded and not env.load_data(): raise RuntimeError(f"FraudShield failed to load data from {DATA_PATH}") def _task_payload() -> Dict[str, Any]: return { task.value: { "difficulty": task.value, "description": TASK_CONFIG[task]["description"], "num_cases": TASK_CONFIG[task]["num_cases"], "max_steps": TASK_CONFIG[task]["max_steps"], "sla_limit": TASK_CONFIG[task]["sla_limit"], "investigation_budget": TASK_CONFIG[task]["investigation_budget"], } for task in TaskDifficulty } def _workflow_views() -> list[str]: return [screen.value for screen in CaseScreenEnum] def _metadata_payload() -> Dict[str, Any]: _ensure_data_loaded() return { "name": "fraudshield", "title": "FraudShield", "version": APP_VERSION, "description": app.description, "transport": { "rest": { "health": "/health", "reset": "/reset", "step": "/step", "state": "/state", "info": "/info", "tasks": "/tasks", "metadata": "/metadata", "schema": "/schema", }, "mcp": "/mcp", "openapi": "/openapi.json", }, "action_families": [action.value for action in ActionTypeEnum], "workflow_views": _workflow_views(), "tasks": _task_payload(), "data_snapshot": env.data_loader.get_bundle_summary(), } def _schema_payload() -> Dict[str, Any]: return { "name": "fraudshield", "version": APP_VERSION, "action": FraudCheckAction.model_json_schema(), "observation": FraudCheckObservation.model_json_schema(), "reward": Reward.model_json_schema(), "state": EpisodeState.model_json_schema(), "reset_result": ResetResult.model_json_schema(), "step_result": StepResult.model_json_schema(), "tasks": _task_payload(), } def _demo_trace_payload(task: TaskDifficulty) -> Dict[str, Any]: demo_env = FraudShieldEnvironment(data_path=str(DATA_PATH), seed=42) demo_env.load_data() reset_result = demo_env.reset(task.value) agent = SnapshotCalibratedFraudDetectionAgent() observation = reset_result.observation action_trace: list[Dict[str, Any]] = [] max_steps = TASK_CONFIG[task]["max_steps"] while not demo_env.is_done and demo_env.step_count < max_steps: action = agent.decide(observation) result = demo_env.step(action) action_trace.append( { "step": demo_env.step_count, "action": action.model_dump(mode="json"), "reward": result.reward.model_dump(mode="json"), "done": result.done, } ) observation = result.observation return { "task": task.value, "agent_name": agent.name, "initial_observation": reset_result.observation.model_dump(mode="json"), "action_trace": action_trace, "episode_report": demo_env.get_episode_report(), } def _mcp_success(request_id: Any, result: Dict[str, Any]) -> JSONResponse: return JSONResponse({"jsonrpc": "2.0", "id": request_id, "result": result}) def _mcp_error(request_id: Any, code: int, message: str) -> JSONResponse: return JSONResponse({"jsonrpc": "2.0", "id": request_id, "error": {"code": code, "message": message}}) def _mcp_tool_result(payload: Dict[str, Any]) -> Dict[str, Any]: return { "content": [{"type": "text", "text": json.dumps(payload, ensure_ascii=True)}], "structuredContent": payload, "isError": False, } def _mcp_tool_descriptors() -> list[Dict[str, Any]]: task_values = [task.value for task in TaskDifficulty] return [ { "name": "environment.reset", "description": "Start a new easy, medium, or hard FraudShield episode.", "inputSchema": { "type": "object", "properties": { "task": {"type": "string", "enum": task_values, "default": TaskDifficulty.EASY.value} }, }, }, { "name": "environment.step", "description": "Submit one investigation or resolution action for the active case.", "inputSchema": FraudCheckAction.model_json_schema(), }, { "name": "environment.state", "description": "Read the full current episode state.", "inputSchema": {"type": "object", "properties": {}}, }, { "name": "environment.info", "description": "Read static environment information and dataset metadata.", "inputSchema": {"type": "object", "properties": {}}, }, { "name": "environment.tasks", "description": "List the available graded tasks.", "inputSchema": {"type": "object", "properties": {}}, }, { "name": "environment.metadata", "description": "Read runtime metadata for OpenEnv clients.", "inputSchema": {"type": "object", "properties": {}}, }, { "name": "environment.schema", "description": "Read the JSON schema for the typed models.", "inputSchema": {"type": "object", "properties": {}}, }, ] def _run_mcp_tool(name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: _ensure_data_loaded() if name == "environment.reset": task = arguments.get("task", TaskDifficulty.EASY.value) result = env.reset(str(task)) return {"observation": result.observation.model_dump(mode="json"), "info": result.info} if name == "environment.step": action = FraudCheckAction.model_validate(arguments) result = env.step(action) return { "observation": result.observation.model_dump(mode="json"), "reward": result.reward.model_dump(mode="json"), "done": result.done, "info": result.info, } if name == "environment.state": return env.state().model_dump(mode="json") if name == "environment.info": return { "name": "fraudshield", "version": APP_VERSION, "tasks": _task_payload(), "workflow_views": _workflow_views(), "data_snapshot": env.data_loader.get_bundle_summary(), } if name == "environment.tasks": return _task_payload() if name == "environment.metadata": return _metadata_payload() if name == "environment.schema": return _schema_payload() raise ValueError(f"Unknown MCP tool: {name}") @app.get("/", response_class=HTMLResponse) async def explorer() -> HTMLResponse: return HTMLResponse(_explorer_html()) @app.get("/health") async def health_check() -> Dict[str, Any]: if not env.data_loaded: env.load_data() return { "status": "healthy" if env.data_loaded else "degraded", "service": "fraudshield", "data_loaded": env.data_loaded, "workflow_views": _workflow_views(), } @app.post("/reset") async def reset(task: TaskDifficulty = TaskDifficulty.EASY) -> Dict[str, Any]: try: _ensure_data_loaded() result = env.reset(task.value) return {"observation": result.observation.model_dump(mode="json"), "info": result.info} except Exception as exc: logger.exception("Reset error") raise HTTPException(status_code=500, detail=str(exc)) from exc @app.post("/step") async def step(action: FraudCheckAction) -> Dict[str, Any]: try: _ensure_data_loaded() result = env.step(action) return { "observation": result.observation.model_dump(mode="json"), "reward": result.reward.model_dump(mode="json"), "done": result.done, "info": result.info, } except Exception as exc: logger.exception("Step error") raise HTTPException(status_code=500, detail=str(exc)) from exc @app.get("/state") async def get_state() -> Dict[str, Any]: try: _ensure_data_loaded() return env.state().model_dump(mode="json") except Exception as exc: logger.exception("State error") raise HTTPException(status_code=500, detail=str(exc)) from exc @app.get("/info") async def get_info() -> Dict[str, Any]: _ensure_data_loaded() return { "name": "fraudshield", "version": APP_VERSION, "description": app.description, "tasks": _task_payload(), "workflow_views": _workflow_views(), "data_snapshot": env.data_loader.get_bundle_summary(), } @app.get("/tasks") async def get_tasks() -> Dict[str, Any]: _ensure_data_loaded() return _task_payload() @app.get("/metadata") async def get_metadata() -> Dict[str, Any]: return _metadata_payload() @app.get("/schema") async def get_schema() -> Dict[str, Any]: _ensure_data_loaded() return _schema_payload() @app.get("/demo/trace") async def demo_trace(task: TaskDifficulty = TaskDifficulty.MEDIUM) -> Dict[str, Any]: try: return _demo_trace_payload(task) except Exception as exc: logger.exception("Demo trace error") raise HTTPException(status_code=500, detail=str(exc)) from exc @app.post("/mcp") async def mcp_endpoint(request: Dict[str, Any]) -> JSONResponse: request_id = request.get("id") method = request.get("method") params = request.get("params", {}) or {} try: if method == "initialize": return _mcp_success( request_id, { "protocolVersion": "2025-03-26", "capabilities": {"tools": {}, "prompts": {}, "resources": {}}, "serverInfo": {"name": "fraudshield", "version": APP_VERSION}, }, ) if method in {"notifications/initialized", "initialized", "ping"}: return _mcp_success(request_id, {}) if method == "tools/list": return _mcp_success(request_id, {"tools": _mcp_tool_descriptors()}) if method == "tools/call": tool_name = params.get("name") if not tool_name: return _mcp_error(request_id, -32602, "tools/call requires a tool name") arguments = params.get("arguments", {}) or {} return _mcp_success(request_id, _mcp_tool_result(_run_mcp_tool(tool_name, arguments))) if method == "resources/list": return _mcp_success(request_id, {"resources": []}) if method == "prompts/list": return _mcp_success(request_id, {"prompts": []}) return _mcp_error(request_id, -32601, f"Method not found: {method}") except Exception as exc: logger.exception("MCP error") return _mcp_error(request_id, -32000, str(exc)) @app.exception_handler(Exception) async def global_exception_handler(_: Any, exc: Exception) -> JSONResponse: logger.exception("Unhandled exception") return JSONResponse(status_code=500, content={"detail": str(exc)}) def main() -> None: import uvicorn port = int(os.getenv("PORT", "7860")) logger.info("Launching FraudShield server on port %d", port) uvicorn.run(app, host="0.0.0.0", port=port, workers=1) if __name__ == "__main__": # pragma: no cover main()