deps-0001's picture
bug fix
8d728cc
"""FastAPI application for the FinePrint OpenEnv environment.
Exposes the FinePrint policy compliance environment over HTTP so that
remote agents (or a HuggingFace Spaces front-end) can interact with it
via the standard OpenEnv ``/reset``, ``/step``, ``/state`` endpoints.
"""
from __future__ import annotations
from pathlib import Path
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
# ---------------------------------------------------------------------------
# Dual-import pattern -- works as a sub-package *and* when run directly.
# ---------------------------------------------------------------------------
try:
from .fineprint_environment import FinePrintEnvironment
except ImportError:
from fineprint_environment import FinePrintEnvironment # type: ignore[no-redef]
try:
from ..models import Action, Observation, State
except ImportError:
from models import Action, Observation, State # type: ignore[no-redef]
try:
from .tasks import TASK_IDS
except ImportError:
from tasks import TASK_IDS # type: ignore[no-redef]
# ---------------------------------------------------------------------------
# Application setup
# ---------------------------------------------------------------------------
app = FastAPI(
title="FinePrint-Env",
description="Consumer Policy Drift Detection Environment for OpenEnv",
version="0.1.0",
)
# Allow all origins so HuggingFace Spaces (and other frontends) can call us.
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------------------------------
# Session management
# ---------------------------------------------------------------------------
environments: dict[str, FinePrintEnvironment] = {}
def get_env(session_id: str = "default") -> FinePrintEnvironment:
"""Return the environment for *session_id*, creating one if needed."""
if session_id not in environments:
environments[session_id] = FinePrintEnvironment()
return environments[session_id]
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
_STATIC_DIR = Path(__file__).resolve().parent / "static"
@app.get("/", response_class=HTMLResponse)
async def root():
"""Landing page served from static/index.html."""
return FileResponse(_STATIC_DIR / "index.html", media_type="text/html")
@app.get("/blog", response_class=HTMLResponse)
async def blog():
"""Blog page served from static/blog.html."""
return FileResponse(_STATIC_DIR / "blog.html", media_type="text/html")
@app.get("/health")
async def health() -> dict:
"""Liveness / readiness probe."""
return {"status": "ok"}
@app.post("/reset")
async def reset(request: dict = {}) -> dict: # noqa: B006
"""Reset the environment and start a new episode.
Optional body fields
--------------------
session_id : str
Identifies the caller's session (default ``"default"``).
seed : int | None
RNG seed for reproducibility.
episode_id : str | None
Caller-supplied episode identifier.
options : dict
May contain ``task_id`` (e.g. ``"quote_accuracy"``).
"""
session_id: str = request.get("session_id", "default")
env = get_env(session_id)
obs: Observation = env.reset(
seed=request.get("seed"),
episode_id=request.get("episode_id"),
options=request.get("options", {}),
)
return obs.model_dump()
@app.post("/step")
async def step(request: dict) -> dict:
"""Execute one agent action and return the observation.
Body fields
-----------
session_id : str
Session identifier (default ``"default"``).
action : dict
Must contain ``command`` (str) and optionally ``args`` (dict)
and ``metadata`` (dict). If the top-level dict already has a
``command`` key and no ``action`` wrapper, it is treated as the
action directly for convenience.
"""
session_id: str = request.get("session_id", "default")
# Accept either {"action": {…}} or a flat {command, args, …} body.
action_data = request.get("action", request)
if "command" not in action_data:
return Observation(
output=(
"Error: request must include 'command'. "
"Send either {\"command\": \"view_policies\", \"args\": {}} "
"or {\"action\": {\"command\": \"view_policies\", \"args\": {}}}. "
f"Available commands: view_policies, view_workflow, "
"check_compliance, request_verification, quote_policy, "
"respond_to_user, take_action, escalate, abort_workflow, "
"clarify, submit"
),
done=False,
).model_dump()
action = Action(**action_data)
env = get_env(session_id)
obs: Observation = env.step(action)
return obs.model_dump()
@app.get("/state")
async def get_state(session_id: str = "default") -> dict:
"""Return the current episode state (step count, task id, etc.)."""
env = get_env(session_id)
return env.state.model_dump()
@app.get("/tasks")
async def list_tasks() -> dict:
"""List the available task identifiers."""
return {"tasks": TASK_IDS}