File size: 7,466 Bytes
81aa69d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c591d0
81aa69d
11c71eb
 
 
81aa69d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c591d0
 
 
 
 
 
81aa69d
6c591d0
81aa69d
 
 
6c591d0
 
 
 
 
 
81aa69d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c591d0
 
 
 
 
3932d4b
 
 
 
 
6c591d0
 
81aa69d
 
6c591d0
81aa69d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40b0e9f
81aa69d
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""
FastAPI application exposing the Customer Support Environment
via HTTP endpoints compatible with OpenEnv specification.

Endpoints:
    POST /reset        β€” Reset environment, returns initial observation
    POST /step         β€” Execute an action, returns (obs, reward, done, info)
    GET  /state        β€” Get current internal state
    GET  /health       β€” Health check
    GET  /tasks        β€” List available tasks
    GET  /             β€” Environment info
"""

import sys
import os

# Ensure project root is on the path
_project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if _project_root not in sys.path:
    sys.path.insert(0, _project_root)

from typing import Any, Dict, Optional

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, field_validator

from models import SupportAction, SupportObservation, SupportState, safe_score  # type: ignore
from server.environment import CustomerSupportEnvironment  # type: ignore
from tasks import TASK_IDS, TASKS  # type: ignore


# ──────────────────────────────────────────────────────────────────
# Request / Response schemas
# ──────────────────────────────────────────────────────────────────

class ResetRequest(BaseModel):
    task_id: Optional[str] = Field(default="easy_faq", description="Task ID to load")
    seed: Optional[int] = Field(default=None, description="Random seed (unused)")


class StepRequest(BaseModel):
    action: SupportAction = Field(..., description="Agent action")


class StepResponse(BaseModel):
    """Response from the /step endpoint.

    Uses an auto-clamping validator instead of gt/lt constraints.
    This prevents Pydantic from raising ValidationError on boundary
    values and ensures the evaluator NEVER receives 0.0 or 1.0.
    """
    observation: SupportObservation
    reward: float = Field(default=0.01, description="Step reward in strict (0, 1)")
    done: bool
    info: Dict[str, Any]

    @field_validator("reward", mode="before")
    @classmethod
    def _clamp_reward(cls, v: Any) -> float:
        """Auto-clamp reward to strict (0, 1)."""
        return safe_score(v)


class TaskInfo(BaseModel):
    task_id: str
    name: str
    description: str
    difficulty: str
    max_steps: int


# ──────────────────────────────────────────────────────────────────
# App factory
# ──────────────────────────────────────────────────────────────────

app = FastAPI(
    title="Customer Support Environment β€” OpenEnv",
    description=(
        "AI-Powered Customer Support Ticket Resolution Environment. "
        "Train agents to handle real customer issues using step/reset/state APIs."
    ),
    version="1.0.0",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global environment instance (single-agent mode for simplicity)
env = CustomerSupportEnvironment()


# ──────────────────────────────────────────────────────────────────
# Endpoints
# ──────────────────────────────────────────────────────────────────

@app.get("/", tags=["info"])
def root():
    """Environment info and available endpoints."""
    return {
        "name": "customer_support_env",
        "version": "1.0.0",
        "description": "AI-Powered Customer Support Ticket Resolution Environment",
        "endpoints": {
            "POST /reset": "Reset environment with a task_id",
            "POST /step": "Execute an action",
            "GET /state": "Get current state",
            "GET /health": "Health check",
            "GET /tasks": "List available tasks",
        },
        "available_tasks": TASK_IDS,
    }


@app.get("/health", tags=["health"])
def health():
    """Health check endpoint."""
    return {"status": "healthy", "environment": "customer_support_env"}


@app.get("/tasks", response_model=list[TaskInfo], tags=["tasks"])
def list_tasks():
    """List all available tasks with metadata."""
    result = []
    for tid, task in TASKS.items():
        result.append(
            TaskInfo(
                task_id=tid,
                name=task["ticket"]["subject"],
                description=f"{task['difficulty'].value.upper()} β€” {task['ticket']['subject']}",
                difficulty=task["difficulty"].value,
                max_steps=task["max_steps"],
            )
        )
    return result


@app.post("/reset", response_model=SupportObservation, tags=["environment"])
def reset(request: ResetRequest = ResetRequest()):
    """Reset the environment and return the initial observation."""
    try:
        obs = env.reset(task_id=request.task_id, seed=request.seed)
        return obs
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))


@app.post("/step", response_model=StepResponse, tags=["environment"])
def step(request: StepRequest):
    """Execute an agent action and return the result."""
    try:
        obs, reward, done, info = env.step(action=request.action)

        # Triple-safe: clamp reward via safe_score before passing to StepResponse
        # (StepResponse also has its own auto-clamping validator)
        clamped_reward = safe_score(reward)

        # Also clamp all scores inside reward_breakdown in info
        if "reward_breakdown" in info and isinstance(info["reward_breakdown"], dict):
            rb = info["reward_breakdown"]
            for key in ["correctness", "tone", "completeness", "efficiency", "total"]:
                if key in rb:
                    rb[key] = safe_score(rb[key])

        return StepResponse(
            observation=obs,
            reward=clamped_reward,
            done=done,
            info=info,
        )
    except RuntimeError as e:
        raise HTTPException(status_code=400, detail=str(e))


@app.get("/state", response_model=SupportState, tags=["environment"])
def get_state():
    """Get the current internal state of the environment."""
    return env.state()


# ──────────────────────────────────────────────────────────────────
# Entry point
# ──────────────────────────────────────────────────────────────────

def main():
    """Run the server directly."""
    import uvicorn
    port = int(os.environ.get("PORT", 7860))
    host = os.environ.get("HOST", "0.0.0.0")
    uvicorn.run(app, host=host, port=port)


if __name__ == "__main__":
    main()