Spaces:
Sleeping
Sleeping
File size: 4,609 Bytes
1e1ca31 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | """Flask REST API server for Email Triage OpenEnv."""
import os
from pathlib import Path
from flask import Flask, request, jsonify, render_template, send_file
from environment.env import EmailTriageEnv
from environment.types import Action
app = Flask(__name__, template_folder=os.path.join(os.path.dirname(__file__), 'templates'))
# Global environment instances (one per task)
environments = {}
def get_env(task_name: str = "spam_detection") -> EmailTriageEnv:
"""Get or create environment for task"""
if task_name not in environments:
environments[task_name] = EmailTriageEnv(task_name=task_name)
return environments[task_name]
@app.route("/", methods=["GET"])
def index():
"""Root endpoint - Interactive dashboard"""
try:
return render_template("index.html")
except Exception:
# Fallback: read HTML file directly
html_path = os.path.join(os.path.dirname(__file__), 'templates', 'index.html')
if os.path.exists(html_path):
with open(html_path, 'r', encoding='utf-8') as f:
return f.read()
return jsonify({
"status": "ok",
"message": "Email Triage OpenEnv API",
"version": "1.0.0",
"note": "Dashboard not available. Use API endpoints directly.",
"endpoints": {
"health": "GET /health",
"tasks": "GET /tasks",
"reset": "POST /reset?task=spam_detection",
"step": "POST /step?task=spam_detection",
"state": "GET /state?task=spam_detection",
"state-describe": "GET /state-describe?task=spam_detection"
}
}), 200
@app.route("/health", methods=["GET"])
def health():
"""Health check endpoint"""
return jsonify({"status": "ok"}), 200
@app.route("/reset", methods=["POST"])
def reset():
"""Reset environment - POST /reset?task=spam_detection"""
task_name = request.args.get("task", "spam_detection")
env = get_env(task_name)
obs = env.reset()
return jsonify({
"observation": obs.model_dump(mode="json"),
"task": task_name
}), 200
@app.route("/step", methods=["POST"])
def step():
"""Step environment - POST /step with JSON action"""
task_name = request.args.get("task", "spam_detection")
env = get_env(task_name)
data = request.get_json()
if not data:
return jsonify({"error": "No action provided"}), 400
try:
action = Action(
classification=data.get("classification"),
team=data.get("team", "none"),
priority=int(data.get("priority", 1))
)
except Exception as e:
return jsonify({"error": f"Invalid action: {str(e)}"}), 400
obs, reward, done, info = env.step(action)
return jsonify({
"observation": obs.model_dump(mode="json"),
"reward": reward.model_dump(mode="json"),
"done": done,
"info": info
}), 200
@app.route("/state", methods=["GET"])
def state():
"""Get current state - GET /state?task=spam_detection"""
task_name = request.args.get("task", "spam_detection")
env = get_env(task_name)
state = env.state()
return jsonify(state.model_dump(mode="json")), 200
@app.route("/state-describe", methods=["GET"])
def state_describe():
"""Describe observation and action spaces"""
task_name = request.args.get("task", "spam_detection")
env = get_env(task_name)
return jsonify({
"observation_space": env.describe_observation_space(),
"action_space": env.describe_action_space()
}), 200
@app.route("/tasks", methods=["GET"])
def tasks():
"""List available tasks"""
return jsonify({
"tasks": [
{
"name": "spam_detection",
"description": "Binary spam/non-spam classification",
"difficulty": "easy",
"num_emails": 10
},
{
"name": "multi_class_routing",
"description": "Multi-class classification with routing",
"difficulty": "medium",
"num_emails": 12
},
{
"name": "context_aware_triage",
"description": "Complex context-aware triage with escalation",
"difficulty": "hard",
"num_emails": 20
}
]
}), 200
def main():
"""Main entry point for the server"""
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port, debug=False)
if __name__ == "__main__":
main()
|