Spaces:
Sleeping
Sleeping
File size: 7,466 Bytes
81aa69d 6c591d0 81aa69d 11c71eb 81aa69d 6c591d0 81aa69d 6c591d0 81aa69d 6c591d0 81aa69d 6c591d0 3932d4b 6c591d0 81aa69d 6c591d0 81aa69d 40b0e9f 81aa69d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | """
FastAPI application exposing the Customer Support Environment
via HTTP endpoints compatible with OpenEnv specification.
Endpoints:
POST /reset β Reset environment, returns initial observation
POST /step β Execute an action, returns (obs, reward, done, info)
GET /state β Get current internal state
GET /health β Health check
GET /tasks β List available tasks
GET / β Environment info
"""
import sys
import os
# Ensure project root is on the path
_project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if _project_root not in sys.path:
sys.path.insert(0, _project_root)
from typing import Any, Dict, Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, field_validator
from models import SupportAction, SupportObservation, SupportState, safe_score # type: ignore
from server.environment import CustomerSupportEnvironment # type: ignore
from tasks import TASK_IDS, TASKS # type: ignore
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Request / Response schemas
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class ResetRequest(BaseModel):
task_id: Optional[str] = Field(default="easy_faq", description="Task ID to load")
seed: Optional[int] = Field(default=None, description="Random seed (unused)")
class StepRequest(BaseModel):
action: SupportAction = Field(..., description="Agent action")
class StepResponse(BaseModel):
"""Response from the /step endpoint.
Uses an auto-clamping validator instead of gt/lt constraints.
This prevents Pydantic from raising ValidationError on boundary
values and ensures the evaluator NEVER receives 0.0 or 1.0.
"""
observation: SupportObservation
reward: float = Field(default=0.01, description="Step reward in strict (0, 1)")
done: bool
info: Dict[str, Any]
@field_validator("reward", mode="before")
@classmethod
def _clamp_reward(cls, v: Any) -> float:
"""Auto-clamp reward to strict (0, 1)."""
return safe_score(v)
class TaskInfo(BaseModel):
task_id: str
name: str
description: str
difficulty: str
max_steps: int
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# App factory
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
app = FastAPI(
title="Customer Support Environment β OpenEnv",
description=(
"AI-Powered Customer Support Ticket Resolution Environment. "
"Train agents to handle real customer issues using step/reset/state APIs."
),
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global environment instance (single-agent mode for simplicity)
env = CustomerSupportEnvironment()
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Endpoints
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.get("/", tags=["info"])
def root():
"""Environment info and available endpoints."""
return {
"name": "customer_support_env",
"version": "1.0.0",
"description": "AI-Powered Customer Support Ticket Resolution Environment",
"endpoints": {
"POST /reset": "Reset environment with a task_id",
"POST /step": "Execute an action",
"GET /state": "Get current state",
"GET /health": "Health check",
"GET /tasks": "List available tasks",
},
"available_tasks": TASK_IDS,
}
@app.get("/health", tags=["health"])
def health():
"""Health check endpoint."""
return {"status": "healthy", "environment": "customer_support_env"}
@app.get("/tasks", response_model=list[TaskInfo], tags=["tasks"])
def list_tasks():
"""List all available tasks with metadata."""
result = []
for tid, task in TASKS.items():
result.append(
TaskInfo(
task_id=tid,
name=task["ticket"]["subject"],
description=f"{task['difficulty'].value.upper()} β {task['ticket']['subject']}",
difficulty=task["difficulty"].value,
max_steps=task["max_steps"],
)
)
return result
@app.post("/reset", response_model=SupportObservation, tags=["environment"])
def reset(request: ResetRequest = ResetRequest()):
"""Reset the environment and return the initial observation."""
try:
obs = env.reset(task_id=request.task_id, seed=request.seed)
return obs
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.post("/step", response_model=StepResponse, tags=["environment"])
def step(request: StepRequest):
"""Execute an agent action and return the result."""
try:
obs, reward, done, info = env.step(action=request.action)
# Triple-safe: clamp reward via safe_score before passing to StepResponse
# (StepResponse also has its own auto-clamping validator)
clamped_reward = safe_score(reward)
# Also clamp all scores inside reward_breakdown in info
if "reward_breakdown" in info and isinstance(info["reward_breakdown"], dict):
rb = info["reward_breakdown"]
for key in ["correctness", "tone", "completeness", "efficiency", "total"]:
if key in rb:
rb[key] = safe_score(rb[key])
return StepResponse(
observation=obs,
reward=clamped_reward,
done=done,
info=info,
)
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/state", response_model=SupportState, tags=["environment"])
def get_state():
"""Get the current internal state of the environment."""
return env.state()
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Entry point
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def main():
"""Run the server directly."""
import uvicorn
port = int(os.environ.get("PORT", 7860))
host = os.environ.get("HOST", "0.0.0.0")
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
main()
|