File size: 3,881 Bytes
d34f0ce
 
 
 
 
 
 
 
 
 
 
 
 
 
5f2ce8f
d34f0ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f2ce8f
 
d34f0ce
5f2ce8f
 
d34f0ce
 
 
5f2ce8f
 
d34f0ce
 
 
 
 
 
 
 
5f2ce8f
 
 
 
d34f0ce
 
 
 
5f2ce8f
 
d34f0ce
 
 
 
 
 
 
 
 
 
 
5f2ce8f
 
 
 
 
 
 
d34f0ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f2ce8f
d34f0ce
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""
FastAPI server for Customer Support Email Triage Environment.
Exposes OpenEnv-compliant API endpoints.
"""

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from typing import Dict, Any
import sys
import os

# Add parent directory to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from models import EmailAction, EmailObservation, EmailState, StepReturn, ResetReturn
from .environment import CustomerSupportEnv

# Initialize FastAPI app
app = FastAPI(
    title="Customer Support Email Triage Environment",
    description="OpenEnv-compliant environment for email classification and response generation",
    version="1.0.0"
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize environment
env = CustomerSupportEnv()


@app.get("/health")
def health_check() -> Dict[str, str]:
    """
    Health check endpoint.
    
    Returns:
        Status indicator
    """
    return {"status": "healthy"}


@app.get("/info")
def info() -> Dict[str, Any]:
    """
    Get environment information.
    
    Returns:
        Environment metadata
    """
    return {
        "name": "customer_support_env",
        "version": "1.0.0",
        "description": "Customer Support Email Triage and Response System",
        "action_space": "EmailAction",
        "observation_space": "EmailObservation",
        "reward_range": [0.0, 1.0],
        "tasks": 12,
        "episode_type": "multi-step"
    }


@app.post("/reset", response_model=ResetReturn)
def reset() -> ResetReturn:
    """
    Reset the environment and return initial observation.
    
    Returns:
        Dict with observation and info
    """
    try:
        result = env.reset()
        return ResetReturn(
            observation=result["observation"],
            info=result["info"]
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/step", response_model=StepReturn)
def step(action: EmailAction) -> StepReturn:
    """
    Execute one step in the environment.
    
    Args:
        action: EmailAction with category, priority, response
    
    Returns:
        Dict with observation, reward, done, info
    """
    try:
        result = env.step(action)
        return StepReturn(
            observation=result["observation"],
            reward=result["reward"],
            done=result["done"],
            info=result["info"],
            step_reward_breakdown=result.get("step_reward_breakdown", {})
        )
    except RuntimeError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/state")
def get_state() -> Dict[str, Any]:
    """
    Get current environment state.
    
    Returns:
        Current state dictionary
    """
    try:
        return env.get_state()
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/stats")
def get_stats() -> Dict[str, Any]:
    """
    Get environment statistics.
    
    Returns:
        Statistics dictionary
    """
    try:
        return env.get_stats()
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/")
def root() -> Dict[str, str]:
    """
    Root endpoint with API documentation link.
    
    Returns:
        API info
    """
    return {
        "name": "Customer Support Email Triage Environment",
        "version": "1.0.0",
        "docs": "/docs",
        "openapi": "/openapi.json"
    }


def main():
    """Main entry point for running the server."""
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=5001)


if __name__ == "__main__":
    main()