Spaces:
Sleeping
Sleeping
File size: 5,531 Bytes
0b6a889 916c16e 0b6a889 8d728cc 0b6a889 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | """FastAPI application for the FinePrint OpenEnv environment.
Exposes the FinePrint policy compliance environment over HTTP so that
remote agents (or a HuggingFace Spaces front-end) can interact with it
via the standard OpenEnv ``/reset``, ``/step``, ``/state`` endpoints.
"""
from __future__ import annotations
from pathlib import Path
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
# ---------------------------------------------------------------------------
# Dual-import pattern -- works as a sub-package *and* when run directly.
# ---------------------------------------------------------------------------
try:
from .fineprint_environment import FinePrintEnvironment
except ImportError:
from fineprint_environment import FinePrintEnvironment # type: ignore[no-redef]
try:
from ..models import Action, Observation, State
except ImportError:
from models import Action, Observation, State # type: ignore[no-redef]
try:
from .tasks import TASK_IDS
except ImportError:
from tasks import TASK_IDS # type: ignore[no-redef]
# ---------------------------------------------------------------------------
# Application setup
# ---------------------------------------------------------------------------
app = FastAPI(
title="FinePrint-Env",
description="Consumer Policy Drift Detection Environment for OpenEnv",
version="0.1.0",
)
# Allow all origins so HuggingFace Spaces (and other frontends) can call us.
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------------------------------
# Session management
# ---------------------------------------------------------------------------
environments: dict[str, FinePrintEnvironment] = {}
def get_env(session_id: str = "default") -> FinePrintEnvironment:
"""Return the environment for *session_id*, creating one if needed."""
if session_id not in environments:
environments[session_id] = FinePrintEnvironment()
return environments[session_id]
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
_STATIC_DIR = Path(__file__).resolve().parent / "static"
@app.get("/", response_class=HTMLResponse)
async def root():
"""Landing page served from static/index.html."""
return FileResponse(_STATIC_DIR / "index.html", media_type="text/html")
@app.get("/blog", response_class=HTMLResponse)
async def blog():
"""Blog page served from static/blog.html."""
return FileResponse(_STATIC_DIR / "blog.html", media_type="text/html")
@app.get("/health")
async def health() -> dict:
"""Liveness / readiness probe."""
return {"status": "ok"}
@app.post("/reset")
async def reset(request: dict = {}) -> dict: # noqa: B006
"""Reset the environment and start a new episode.
Optional body fields
--------------------
session_id : str
Identifies the caller's session (default ``"default"``).
seed : int | None
RNG seed for reproducibility.
episode_id : str | None
Caller-supplied episode identifier.
options : dict
May contain ``task_id`` (e.g. ``"quote_accuracy"``).
"""
session_id: str = request.get("session_id", "default")
env = get_env(session_id)
obs: Observation = env.reset(
seed=request.get("seed"),
episode_id=request.get("episode_id"),
options=request.get("options", {}),
)
return obs.model_dump()
@app.post("/step")
async def step(request: dict) -> dict:
"""Execute one agent action and return the observation.
Body fields
-----------
session_id : str
Session identifier (default ``"default"``).
action : dict
Must contain ``command`` (str) and optionally ``args`` (dict)
and ``metadata`` (dict). If the top-level dict already has a
``command`` key and no ``action`` wrapper, it is treated as the
action directly for convenience.
"""
session_id: str = request.get("session_id", "default")
# Accept either {"action": {…}} or a flat {command, args, …} body.
action_data = request.get("action", request)
if "command" not in action_data:
return Observation(
output=(
"Error: request must include 'command'. "
"Send either {\"command\": \"view_policies\", \"args\": {}} "
"or {\"action\": {\"command\": \"view_policies\", \"args\": {}}}. "
f"Available commands: view_policies, view_workflow, "
"check_compliance, request_verification, quote_policy, "
"respond_to_user, take_action, escalate, abort_workflow, "
"clarify, submit"
),
done=False,
).model_dump()
action = Action(**action_data)
env = get_env(session_id)
obs: Observation = env.step(action)
return obs.model_dump()
@app.get("/state")
async def get_state(session_id: str = "default") -> dict:
"""Return the current episode state (step count, task id, etc.)."""
env = get_env(session_id)
return env.state.model_dump()
@app.get("/tasks")
async def list_tasks() -> dict:
"""List the available task identifiers."""
return {"tasks": TASK_IDS}
|