File size: 5,531 Bytes
0b6a889
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
916c16e
 
 
 
 
 
0b6a889
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d728cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b6a889
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""FastAPI application for the FinePrint OpenEnv environment.

Exposes the FinePrint policy compliance environment over HTTP so that
remote agents (or a HuggingFace Spaces front-end) can interact with it
via the standard OpenEnv ``/reset``, ``/step``, ``/state`` endpoints.
"""

from __future__ import annotations

from pathlib import Path

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles

# ---------------------------------------------------------------------------
# Dual-import pattern -- works as a sub-package *and* when run directly.
# ---------------------------------------------------------------------------

try:
    from .fineprint_environment import FinePrintEnvironment
except ImportError:
    from fineprint_environment import FinePrintEnvironment  # type: ignore[no-redef]

try:
    from ..models import Action, Observation, State
except ImportError:
    from models import Action, Observation, State  # type: ignore[no-redef]

try:
    from .tasks import TASK_IDS
except ImportError:
    from tasks import TASK_IDS  # type: ignore[no-redef]

# ---------------------------------------------------------------------------
# Application setup
# ---------------------------------------------------------------------------

app = FastAPI(
    title="FinePrint-Env",
    description="Consumer Policy Drift Detection Environment for OpenEnv",
    version="0.1.0",
)

# Allow all origins so HuggingFace Spaces (and other frontends) can call us.
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# ---------------------------------------------------------------------------
# Session management
# ---------------------------------------------------------------------------

environments: dict[str, FinePrintEnvironment] = {}


def get_env(session_id: str = "default") -> FinePrintEnvironment:
    """Return the environment for *session_id*, creating one if needed."""
    if session_id not in environments:
        environments[session_id] = FinePrintEnvironment()
    return environments[session_id]


# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------

_STATIC_DIR = Path(__file__).resolve().parent / "static"


@app.get("/", response_class=HTMLResponse)
async def root():
    """Landing page served from static/index.html."""
    return FileResponse(_STATIC_DIR / "index.html", media_type="text/html")


@app.get("/blog", response_class=HTMLResponse)
async def blog():
    """Blog page served from static/blog.html."""
    return FileResponse(_STATIC_DIR / "blog.html", media_type="text/html")


@app.get("/health")
async def health() -> dict:
    """Liveness / readiness probe."""
    return {"status": "ok"}


@app.post("/reset")
async def reset(request: dict = {}) -> dict:  # noqa: B006
    """Reset the environment and start a new episode.

    Optional body fields
    --------------------
    session_id : str
        Identifies the caller's session (default ``"default"``).
    seed : int | None
        RNG seed for reproducibility.
    episode_id : str | None
        Caller-supplied episode identifier.
    options : dict
        May contain ``task_id`` (e.g. ``"quote_accuracy"``).
    """
    session_id: str = request.get("session_id", "default")
    env = get_env(session_id)
    obs: Observation = env.reset(
        seed=request.get("seed"),
        episode_id=request.get("episode_id"),
        options=request.get("options", {}),
    )
    return obs.model_dump()


@app.post("/step")
async def step(request: dict) -> dict:
    """Execute one agent action and return the observation.

    Body fields
    -----------
    session_id : str
        Session identifier (default ``"default"``).
    action : dict
        Must contain ``command`` (str) and optionally ``args`` (dict)
        and ``metadata`` (dict).  If the top-level dict already has a
        ``command`` key and no ``action`` wrapper, it is treated as the
        action directly for convenience.
    """
    session_id: str = request.get("session_id", "default")

    # Accept either {"action": {…}} or a flat {command, args, …} body.
    action_data = request.get("action", request)

    if "command" not in action_data:
        return Observation(
            output=(
                "Error: request must include 'command'. "
                "Send either {\"command\": \"view_policies\", \"args\": {}} "
                "or {\"action\": {\"command\": \"view_policies\", \"args\": {}}}. "
                f"Available commands: view_policies, view_workflow, "
                "check_compliance, request_verification, quote_policy, "
                "respond_to_user, take_action, escalate, abort_workflow, "
                "clarify, submit"
            ),
            done=False,
        ).model_dump()

    action = Action(**action_data)

    env = get_env(session_id)
    obs: Observation = env.step(action)
    return obs.model_dump()


@app.get("/state")
async def get_state(session_id: str = "default") -> dict:
    """Return the current episode state (step count, task id, etc.)."""
    env = get_env(session_id)
    return env.state.model_dump()


@app.get("/tasks")
async def list_tasks() -> dict:
    """List the available task identifiers."""
    return {"tasks": TASK_IDS}