File size: 9,721 Bytes
2930dae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
"""
SupportEnv — FastAPI server

Endpoints
---------
POST   /reset          Create a new episode
POST   /step           Advance the episode
GET    /state          Current episode state
GET    /tasks          List tasks and action schema
POST   /grader         Grade a finished episode
POST   /baseline       Run the built-in baseline agent on all tasks
GET    /health         Liveness check
GET    /               Info / spec link
"""
from __future__ import annotations

import os
import subprocess
import sys
import tempfile
from typing import Any, Dict, List, Optional

from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel

import environment as env
from data import TASK_META
from models import (
    Action,
    BaselineResult,
    GraderResponse,
    Observation,
    State,
    StepResult,
    TaskInfo,
)

app = FastAPI(
    title="SupportEnv",
    description=(
        "An OpenEnv-compliant customer-support triage environment. "
        "Agents learn to classify, extract information from, and resolve "
        "real-world SaaS support tickets."
    ),
    version="1.0.0",
    docs_url="/docs",
    redoc_url="/redoc",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)


# ---------------------------------------------------------------------------
# Request / response shapes for endpoints not covered by models.py
# ---------------------------------------------------------------------------

class ResetRequest(BaseModel):
    task_id: str
    ticket_index: Optional[int] = None


class StepRequest(BaseModel):
    episode_id: str
    action: Action


class GraderRequest(BaseModel):
    episode_id: str


# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------

@app.get("/", tags=["meta"])
def root():
    return {
        "name": "SupportEnv",
        "version": "1.0.0",
        "description": "OpenEnv customer-support ticket triage environment",
        "openenv_spec": "https://github.com/openenv/openenv",
        "tasks": list(TASK_META.keys()),
        "endpoints": {
            "reset": "POST /reset",
            "step": "POST /step",
            "state": "GET /state?episode_id=...",
            "tasks": "GET /tasks",
            "grader": "POST /grader",
            "baseline": "POST /baseline",
            "health": "GET /health",
            "docs": "GET /docs",
        },
    }


@app.get("/health", tags=["meta"])
def health():
    return {"status": "ok"}


# ---------------------------------------------------------------------------
# Core OpenEnv endpoints
# ---------------------------------------------------------------------------

@app.post("/reset", response_model=Observation, tags=["openenv"])
def reset(request: ResetRequest) -> Observation:
    """
    Start a new episode.

    - **task_id**: `task1` | `task2` | `task3`
    - **ticket_index**: 0-indexed ticket to use (optional; default 0)
    """
    try:
        return env.reset(request.task_id, request.ticket_index)
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))


@app.post("/step", response_model=StepResult, tags=["openenv"])
def step(request: StepRequest) -> StepResult:
    """
    Submit an action and advance the episode.

    The `action` object must include `action_type` and the fields relevant
    to that action type (see GET /tasks for the schema).
    """
    try:
        return env.step(request.episode_id, request.action)
    except KeyError:
        raise HTTPException(
            status_code=404,
            detail=f"Episode '{request.episode_id}' not found. Call POST /reset first.",
        )
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))


@app.get("/state", response_model=State, tags=["openenv"])
def state(episode_id: str = Query(..., description="Episode UUID from POST /reset")) -> State:
    """Return the current state of an episode."""
    try:
        return env.state(episode_id)
    except KeyError:
        raise HTTPException(
            status_code=404,
            detail=f"Episode '{episode_id}' not found.",
        )


# ---------------------------------------------------------------------------
# /tasks — task listing + action schema
# ---------------------------------------------------------------------------

# JSON Schema for the Action model (subset used in each task)
_BASE_ACTION_SCHEMA = {
    "type": "object",
    "required": ["action_type"],
    "properties": {
        "action_type": {
            "type": "string",
            "description": "One of the available_actions listed in the Observation",
        },
    },
}

_ACTION_SCHEMAS: Dict[str, Dict[str, Any]] = {
    "task1": {
        **_BASE_ACTION_SCHEMA,
        "description": "classify action: set category + priority; then submit",
        "properties": {
            **_BASE_ACTION_SCHEMA["properties"],
            "category": {
                "type": "string",
                "enum": [
                    "billing", "technical", "account",
                    "feature_request", "complaint", "general",
                ],
            },
            "priority": {
                "type": "string",
                "enum": ["low", "medium", "high", "critical"],
            },
        },
    },
    "task2": {
        **_BASE_ACTION_SCHEMA,
        "description": "extract action: populate extracted_entities + required_actions; then submit",
        "properties": {
            **_BASE_ACTION_SCHEMA["properties"],
            "extracted_entities": {
                "type": "object",
                "additionalProperties": True,
                "description": "Key-value pairs extracted from the ticket text",
            },
            "required_actions": {
                "type": "array",
                "items": {"type": "string"},
                "description": "List of action identifiers (snake_case) needed to close the ticket",
            },
        },
    },
    "task3": {
        **_BASE_ACTION_SCHEMA,
        "description": (
            "respond or resolve action: write response_text + resolution_steps; "
            "optionally escalate; then submit"
        ),
        "properties": {
            **_BASE_ACTION_SCHEMA["properties"],
            "response_text": {
                "type": "string",
                "description": "Full professional response to send to the customer",
            },
            "resolution_steps": {
                "type": "array",
                "items": {"type": "string"},
                "description": "Ordered steps for support staff to resolve the ticket",
            },
            "escalation_team": {
                "type": "string",
                "enum": ["billing_team", "engineering", "account_management", "legal"],
            },
            "escalation_reason": {"type": "string"},
        },
    },
}


@app.get("/tasks", response_model=List[TaskInfo], tags=["openenv"])
def list_tasks() -> List[TaskInfo]:
    """Return metadata and action schema for all tasks."""
    result = []
    for task_id, meta in TASK_META.items():
        result.append(
            TaskInfo(
                task_id=task_id,
                name=meta["name"],
                description=meta["description"],
                difficulty=meta["difficulty"],
                max_steps=meta["max_steps"],
                action_schema=_ACTION_SCHEMAS[task_id],
            )
        )
    return result


# ---------------------------------------------------------------------------
# /grader — grade a finished episode
# ---------------------------------------------------------------------------

@app.post("/grader", response_model=GraderResponse, tags=["openenv"])
def grader(request: GraderRequest) -> GraderResponse:
    """
    Grade a finished episode.

    The episode must have reached `done=True` (either via a `submit` action
    or by exhausting `max_steps`).
    """
    try:
        return env.grade(request.episode_id)
    except KeyError:
        raise HTTPException(
            status_code=404,
            detail=f"Episode '{request.episode_id}' not found.",
        )
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))


# ---------------------------------------------------------------------------
# /baseline — run the built-in baseline agent
# ---------------------------------------------------------------------------

class BaselineRequest(BaseModel):
    model: str = "gpt-4o-mini"
    ticket_index: Optional[int] = 0


@app.post("/baseline", response_model=BaselineResult, tags=["openenv"])
def run_baseline(request: BaselineRequest) -> BaselineResult:
    """
    Run the heuristic baseline agent against all three tasks.

    The built-in baseline does NOT require an OpenAI key — it uses the
    deterministic heuristic baseline from `baseline.py`.
    If you want to run the LLM baseline, call `baseline.py` directly.
    """
    try:
        from baseline import run_heuristic_baseline
        scores = run_heuristic_baseline(
            ticket_index=request.ticket_index or 0
        )
        avg = round(sum(s["score"] for s in scores) / len(scores), 4)
        return BaselineResult(
            model="heuristic-baseline",
            scores=[
                {"task_id": s["task_id"], "score": s["score"], "details": s}
                for s in scores
            ],
            average_score=avg,
        )
    except Exception as exc:
        raise HTTPException(status_code=500, detail=str(exc))