Spaces:
Sleeping
Sleeping
File size: 6,776 Bytes
5c3a0f6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 | """
openenv_api.py
──────────────────────────────────────────────────────────────────────────────
HTTP API wrapper for JaamCTRL environment.
Exposes Gymnasium reset/step/state endpoints as POST routes.
Supports OpenEnv hackathon submission checker.
Usage:
python openenv_api.py # Start server on 0.0.0.0:5000
python openenv_api.py --port 8000 # Custom port
FLASK_DEBUG=1 python openenv_api.py # Debug mode
Environment variables:
FLASK_PORT Server port (default: 5000)
MOCK_SUMO Force mock mode (default: "1")
──────────────────────────────────────────────────────────────────────────────
"""
from __future__ import annotations
import json
import os
import sys
import argparse
from typing import Any, Dict, Tuple
import numpy as np
from flask import Flask, request, jsonify
# ── Setup paths ─────────────────────────────────────────────────────────────
ROOT = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, ROOT)
# Force mock SUMO for API mode (unless explicitly disabled)
if "MOCK_SUMO" not in os.environ:
os.environ["MOCK_SUMO"] = "1"
# ── Imports ─────────────────────────────────────────────────────────────────
from env import JaamCTRLTrafficEnv, register_envs
register_envs()
# ── Flask app ───────────────────────────────────────────────────────────────
app = Flask(__name__)
app.config["JSON_SORT_KEYS"] = False
# Global environment instance (reset on each request)
_env = None
_current_task = 1
def _to_serializable(obj: Any) -> Any:
"""Convert numpy arrays and other non-JSON types to native Python types."""
if isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, (np.integer, np.floating)):
return obj.item()
elif isinstance(obj, dict):
return {k: _to_serializable(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return [_to_serializable(v) for v in obj]
return obj
@app.route("/health", methods=["GET"])
def health():
"""Health check endpoint."""
return jsonify({"status": "ok", "version": "1.0.0"}), 200
@app.route("/reset", methods=["POST"])
def reset():
"""
Reset the environment.
Request body (JSON):
{
"task_id": 1,
"seed": 42,
"options": {}
}
Response:
{
"observation": {...},
"info": {...}
}
"""
global _env, _current_task
try:
data = request.get_json() or {}
task_id = data.get("task_id", 1)
seed = data.get("seed")
options = data.get("options", {})
# Create/recreate environment if task changed
if _env is None or task_id != _current_task:
if _env:
_env.close()
_env = JaamCTRLTrafficEnv(task_id=task_id, mock_sumo=True)
_current_task = task_id
obs, info = _env.reset(seed=seed, options=options)
return jsonify({
"observation": _to_serializable(obs),
"info": _to_serializable(info),
}), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/step", methods=["POST"])
def step():
"""
Execute one step in the environment.
Request body (JSON):
{
"action": [0, 1, 2] or [1] depending on task
}
Response:
{
"observation": {...},
"reward": 12.5,
"terminated": false,
"truncated": false,
"info": {...}
}
"""
global _env
if _env is None:
return jsonify({"error": "Environment not initialized. Call /reset first."}), 400
try:
data = request.get_json() or {}
action = data.get("action")
if action is None:
return jsonify({"error": "Missing 'action' in request body"}), 400
action = np.array(action, dtype=np.int64)
obs, reward, terminated, truncated, info = _env.step(action)
return jsonify({
"observation": _to_serializable(obs),
"reward": float(reward),
"terminated": bool(terminated),
"truncated": bool(truncated),
"info": _to_serializable(info),
}), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/state", methods=["GET"])
def state():
"""
Get the current environment state.
Response:
{
"task_id": 1,
"step": 42,
"sim_time_s": 210.0,
...
}
"""
global _env
if _env is None:
return jsonify({"error": "Environment not initialized. Call /reset first."}), 400
try:
state_dict = _env.state()
return jsonify(_to_serializable(state_dict)), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.errorhandler(404)
def not_found(error):
"""Handle 404 errors."""
return jsonify({"error": "Not found", "path": request.path}), 404
@app.errorhandler(405)
def method_not_allowed(error):
"""Handle 405 errors."""
return jsonify({"error": "Method not allowed", "method": request.method}), 405
@app.errorhandler(500)
def server_error(error):
"""Handle 500 errors."""
return jsonify({"error": "Internal server error", "details": str(error)}), 500
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="OpenEnv HTTP API server")
parser.add_argument("--port", type=int, default=int(os.getenv("FLASK_PORT", 5000)),
help="Port to run server on")
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
args = parser.parse_args()
print(f"Starting OpenEnv API server on {args.host}:{args.port}")
print(f"Health check: GET http://localhost:{args.port}/health")
print(f"Reset env: POST http://localhost:{args.port}/reset")
print(f"Step env: POST http://localhost:{args.port}/step")
print(f"Get state: GET http://localhost:{args.port}/state")
print()
app.run(host=args.host, port=args.port, debug=False)
|