JaamCTRL-OpenEnv / openenv_api.py
Akshara
Add OpenEnv HTTP API server for submission checker
5c3a0f6
"""
openenv_api.py
──────────────────────────────────────────────────────────────────────────────
HTTP API wrapper for JaamCTRL environment.
Exposes Gymnasium reset/step/state endpoints as POST routes.
Supports OpenEnv hackathon submission checker.
Usage:
python openenv_api.py # Start server on 0.0.0.0:5000
python openenv_api.py --port 8000 # Custom port
FLASK_DEBUG=1 python openenv_api.py # Debug mode
Environment variables:
FLASK_PORT Server port (default: 5000)
MOCK_SUMO Force mock mode (default: "1")
──────────────────────────────────────────────────────────────────────────────
"""
from __future__ import annotations
import json
import os
import sys
import argparse
from typing import Any, Dict, Tuple
import numpy as np
from flask import Flask, request, jsonify
# ── Setup paths ─────────────────────────────────────────────────────────────
ROOT = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, ROOT)
# Force mock SUMO for API mode (unless explicitly disabled)
if "MOCK_SUMO" not in os.environ:
os.environ["MOCK_SUMO"] = "1"
# ── Imports ─────────────────────────────────────────────────────────────────
from env import JaamCTRLTrafficEnv, register_envs
register_envs()
# ── Flask app ───────────────────────────────────────────────────────────────
app = Flask(__name__)
app.config["JSON_SORT_KEYS"] = False
# Global environment instance (reset on each request)
_env = None
_current_task = 1
def _to_serializable(obj: Any) -> Any:
"""Convert numpy arrays and other non-JSON types to native Python types."""
if isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, (np.integer, np.floating)):
return obj.item()
elif isinstance(obj, dict):
return {k: _to_serializable(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return [_to_serializable(v) for v in obj]
return obj
@app.route("/health", methods=["GET"])
def health():
"""Health check endpoint."""
return jsonify({"status": "ok", "version": "1.0.0"}), 200
@app.route("/reset", methods=["POST"])
def reset():
"""
Reset the environment.
Request body (JSON):
{
"task_id": 1,
"seed": 42,
"options": {}
}
Response:
{
"observation": {...},
"info": {...}
}
"""
global _env, _current_task
try:
data = request.get_json() or {}
task_id = data.get("task_id", 1)
seed = data.get("seed")
options = data.get("options", {})
# Create/recreate environment if task changed
if _env is None or task_id != _current_task:
if _env:
_env.close()
_env = JaamCTRLTrafficEnv(task_id=task_id, mock_sumo=True)
_current_task = task_id
obs, info = _env.reset(seed=seed, options=options)
return jsonify({
"observation": _to_serializable(obs),
"info": _to_serializable(info),
}), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/step", methods=["POST"])
def step():
"""
Execute one step in the environment.
Request body (JSON):
{
"action": [0, 1, 2] or [1] depending on task
}
Response:
{
"observation": {...},
"reward": 12.5,
"terminated": false,
"truncated": false,
"info": {...}
}
"""
global _env
if _env is None:
return jsonify({"error": "Environment not initialized. Call /reset first."}), 400
try:
data = request.get_json() or {}
action = data.get("action")
if action is None:
return jsonify({"error": "Missing 'action' in request body"}), 400
action = np.array(action, dtype=np.int64)
obs, reward, terminated, truncated, info = _env.step(action)
return jsonify({
"observation": _to_serializable(obs),
"reward": float(reward),
"terminated": bool(terminated),
"truncated": bool(truncated),
"info": _to_serializable(info),
}), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/state", methods=["GET"])
def state():
"""
Get the current environment state.
Response:
{
"task_id": 1,
"step": 42,
"sim_time_s": 210.0,
...
}
"""
global _env
if _env is None:
return jsonify({"error": "Environment not initialized. Call /reset first."}), 400
try:
state_dict = _env.state()
return jsonify(_to_serializable(state_dict)), 200
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.errorhandler(404)
def not_found(error):
"""Handle 404 errors."""
return jsonify({"error": "Not found", "path": request.path}), 404
@app.errorhandler(405)
def method_not_allowed(error):
"""Handle 405 errors."""
return jsonify({"error": "Method not allowed", "method": request.method}), 405
@app.errorhandler(500)
def server_error(error):
"""Handle 500 errors."""
return jsonify({"error": "Internal server error", "details": str(error)}), 500
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="OpenEnv HTTP API server")
parser.add_argument("--port", type=int, default=int(os.getenv("FLASK_PORT", 5000)),
help="Port to run server on")
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
args = parser.parse_args()
print(f"Starting OpenEnv API server on {args.host}:{args.port}")
print(f"Health check: GET http://localhost:{args.port}/health")
print(f"Reset env: POST http://localhost:{args.port}/reset")
print(f"Step env: POST http://localhost:{args.port}/step")
print(f"Get state: GET http://localhost:{args.port}/state")
print()
app.run(host=args.host, port=args.port, debug=False)