File size: 4,609 Bytes
1e1ca31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""Flask REST API server for Email Triage OpenEnv."""

import os
from pathlib import Path

from flask import Flask, request, jsonify, render_template, send_file

from environment.env import EmailTriageEnv
from environment.types import Action

app = Flask(__name__, template_folder=os.path.join(os.path.dirname(__file__), 'templates'))

# Global environment instances (one per task)
environments = {}

def get_env(task_name: str = "spam_detection") -> EmailTriageEnv:
    """Get or create environment for task"""
    if task_name not in environments:
        environments[task_name] = EmailTriageEnv(task_name=task_name)
    return environments[task_name]

@app.route("/", methods=["GET"])
def index():
    """Root endpoint - Interactive dashboard"""
    try:
        return render_template("index.html")
    except Exception:
        # Fallback: read HTML file directly
        html_path = os.path.join(os.path.dirname(__file__), 'templates', 'index.html')
        if os.path.exists(html_path):
            with open(html_path, 'r', encoding='utf-8') as f:
                return f.read()
        return jsonify({
            "status": "ok",
            "message": "Email Triage OpenEnv API",
            "version": "1.0.0",
            "note": "Dashboard not available. Use API endpoints directly.",
            "endpoints": {
                "health": "GET /health",
                "tasks": "GET /tasks",
                "reset": "POST /reset?task=spam_detection",
                "step": "POST /step?task=spam_detection",
                "state": "GET /state?task=spam_detection",
                "state-describe": "GET /state-describe?task=spam_detection"
            }
        }), 200

@app.route("/health", methods=["GET"])
def health():
    """Health check endpoint"""
    return jsonify({"status": "ok"}), 200

@app.route("/reset", methods=["POST"])
def reset():
    """Reset environment - POST /reset?task=spam_detection"""
    task_name = request.args.get("task", "spam_detection")
    env = get_env(task_name)
    obs = env.reset()
    return jsonify({
        "observation": obs.model_dump(mode="json"),
        "task": task_name
    }), 200

@app.route("/step", methods=["POST"])
def step():
    """Step environment - POST /step with JSON action"""
    task_name = request.args.get("task", "spam_detection")
    env = get_env(task_name)

    data = request.get_json()
    if not data:
        return jsonify({"error": "No action provided"}), 400

    try:
        action = Action(
            classification=data.get("classification"),
            team=data.get("team", "none"),
            priority=int(data.get("priority", 1))
        )
    except Exception as e:
        return jsonify({"error": f"Invalid action: {str(e)}"}), 400

    obs, reward, done, info = env.step(action)

    return jsonify({
        "observation": obs.model_dump(mode="json"),
        "reward": reward.model_dump(mode="json"),
        "done": done,
        "info": info
    }), 200

@app.route("/state", methods=["GET"])
def state():
    """Get current state - GET /state?task=spam_detection"""
    task_name = request.args.get("task", "spam_detection")
    env = get_env(task_name)
    state = env.state()
    return jsonify(state.model_dump(mode="json")), 200

@app.route("/state-describe", methods=["GET"])
def state_describe():
    """Describe observation and action spaces"""
    task_name = request.args.get("task", "spam_detection")
    env = get_env(task_name)
    return jsonify({
        "observation_space": env.describe_observation_space(),
        "action_space": env.describe_action_space()
    }), 200

@app.route("/tasks", methods=["GET"])
def tasks():
    """List available tasks"""
    return jsonify({
        "tasks": [
            {
                "name": "spam_detection",
                "description": "Binary spam/non-spam classification",
                "difficulty": "easy",
                "num_emails": 10
            },
            {
                "name": "multi_class_routing",
                "description": "Multi-class classification with routing",
                "difficulty": "medium",
                "num_emails": 12
            },
            {
                "name": "context_aware_triage",
                "description": "Complex context-aware triage with escalation",
                "difficulty": "hard",
                "num_emails": 20
            }
        ]
    }), 200


def main():
    """Main entry point for the server"""
    port = int(os.environ.get("PORT", 7860))
    app.run(host="0.0.0.0", port=port, debug=False)


if __name__ == "__main__":
    main()