"""Flask REST API server for Email Triage OpenEnv.""" import os from pathlib import Path from flask import Flask, request, jsonify, render_template, send_file from environment.env import EmailTriageEnv from environment.types import Action app = Flask(__name__, template_folder=os.path.join(os.path.dirname(__file__), 'templates')) # Global environment instances (one per task) environments = {} def get_env(task_name: str = "spam_detection") -> EmailTriageEnv: """Get or create environment for task""" if task_name not in environments: environments[task_name] = EmailTriageEnv(task_name=task_name) return environments[task_name] @app.route("/", methods=["GET"]) def index(): """Root endpoint - Interactive dashboard""" try: return render_template("index.html") except Exception: # Fallback: read HTML file directly html_path = os.path.join(os.path.dirname(__file__), 'templates', 'index.html') if os.path.exists(html_path): with open(html_path, 'r', encoding='utf-8') as f: return f.read() return jsonify({ "status": "ok", "message": "Email Triage OpenEnv API", "version": "1.0.0", "note": "Dashboard not available. Use API endpoints directly.", "endpoints": { "health": "GET /health", "tasks": "GET /tasks", "reset": "POST /reset?task=spam_detection", "step": "POST /step?task=spam_detection", "state": "GET /state?task=spam_detection", "state-describe": "GET /state-describe?task=spam_detection" } }), 200 @app.route("/health", methods=["GET"]) def health(): """Health check endpoint""" return jsonify({"status": "ok"}), 200 @app.route("/reset", methods=["POST"]) def reset(): """Reset environment - POST /reset?task=spam_detection""" task_name = request.args.get("task", "spam_detection") env = get_env(task_name) obs = env.reset() return jsonify({ "observation": obs.model_dump(mode="json"), "task": task_name }), 200 @app.route("/step", methods=["POST"]) def step(): """Step environment - POST /step with JSON action""" task_name = request.args.get("task", "spam_detection") env = get_env(task_name) data = request.get_json() if not data: return jsonify({"error": "No action provided"}), 400 try: action = Action( classification=data.get("classification"), team=data.get("team", "none"), priority=int(data.get("priority", 1)) ) except Exception as e: return jsonify({"error": f"Invalid action: {str(e)}"}), 400 obs, reward, done, info = env.step(action) return jsonify({ "observation": obs.model_dump(mode="json"), "reward": reward.model_dump(mode="json"), "done": done, "info": info }), 200 @app.route("/state", methods=["GET"]) def state(): """Get current state - GET /state?task=spam_detection""" task_name = request.args.get("task", "spam_detection") env = get_env(task_name) state = env.state() return jsonify(state.model_dump(mode="json")), 200 @app.route("/state-describe", methods=["GET"]) def state_describe(): """Describe observation and action spaces""" task_name = request.args.get("task", "spam_detection") env = get_env(task_name) return jsonify({ "observation_space": env.describe_observation_space(), "action_space": env.describe_action_space() }), 200 @app.route("/tasks", methods=["GET"]) def tasks(): """List available tasks""" return jsonify({ "tasks": [ { "name": "spam_detection", "description": "Binary spam/non-spam classification", "difficulty": "easy", "num_emails": 10 }, { "name": "multi_class_routing", "description": "Multi-class classification with routing", "difficulty": "medium", "num_emails": 12 }, { "name": "context_aware_triage", "description": "Complex context-aware triage with escalation", "difficulty": "hard", "num_emails": 20 } ] }), 200 def main(): """Main entry point for the server""" port = int(os.environ.get("PORT", 7860)) app.run(host="0.0.0.0", port=port, debug=False) if __name__ == "__main__": main()