ziadbc's picture
feat: add friendly root endpoint with API info
df9de4c
# Copyright (c) 2024 TeeUnit Project
# SPDX-License-Identifier: MIT
"""
FastAPI application for the TeeUnit Environment.
This module creates an HTTP server that exposes the TeeEnvironment
over HTTP and WebSocket endpoints, compatible with MCPToolClient.
Usage:
# Development (with auto-reload):
uvicorn teeunit_env.server.app:app --reload --host 0.0.0.0 --port 8000
# Production:
uvicorn teeunit_env.server.app:app --host 0.0.0.0 --port 8000 --workers 4
# Or run directly:
python -m teeunit_env.server.app
"""
# Support both in-repo and standalone imports
try:
from openenv.core.env_server.http_server import create_app
from openenv.core.env_server.mcp_types import CallToolAction, CallToolObservation
from .tee_environment import TeeEnvironment
# Create the app with web interface
# Pass the class (factory) instead of an instance for WebSocket session support
# Use MCP types for action/observation since this is a pure MCP environment
app = create_app(
TeeEnvironment, CallToolAction, CallToolObservation, env_name="teeunit_env"
)
# Add a friendly root endpoint
@app.get("/")
async def root():
"""Welcome page with API information."""
return {
"name": "TeeUnit Environment",
"description": "OpenEnv-compatible Teeworlds arena for LLM-based RL training",
"version": "0.1.0",
"status": "running",
"docs": "/docs",
"endpoints": {
"health": "GET /health - Health check",
"reset": "POST /reset - Reset environment",
"step": "POST /step - Execute action",
"state": "GET /state - Get current state",
"metadata": "GET /metadata - Environment info",
"schema": "GET /schema - Action/Observation schema",
"websocket": "WS /ws - Stateful WebSocket session (recommended for RL)",
},
"tools": ["move", "jump", "aim", "shoot", "hook", "get_status"],
"example": {
"websocket": "Connect to wss://ziadbc-teeunit-env.hf.space/ws for stateful sessions",
"reset": {"type": "reset", "data": {}},
"step": {
"type": "step",
"data": {"type": "call_tool", "tool_name": "get_status", "arguments": {}}
}
},
"github": "https://github.com/ziadgit/teeunit",
}
except ImportError:
# Fallback: Create a simple FastAPI app for development/testing
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, Dict, Any
import json
# Import our environment
try:
from .tee_environment import TeeEnvironment
except ImportError:
from tee_environment import TeeEnvironment
app = FastAPI(
title="TeeUnit OpenEnv",
description="OpenEnv-compatible Teeworlds arena environment for LLM RL training",
version="0.1.0",
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Store environment instances per session
_environments: Dict[str, TeeEnvironment] = {}
class ResetRequest(BaseModel):
seed: Optional[int] = None
episode_id: Optional[str] = None
class ResetResponse(BaseModel):
status: str
episode_id: str
message: str
class ToolCallRequest(BaseModel):
tool_name: str
arguments: Dict[str, Any] = {}
class ToolCallResponse(BaseModel):
result: Any
reward: float
done: bool
metadata: Dict[str, Any] = {}
class ToolInfo(BaseModel):
name: str
description: str
parameters: Dict[str, Any] = {}
@app.get("/")
async def root():
"""Root endpoint with environment info."""
return {
"name": "TeeUnit OpenEnv",
"version": "0.1.0",
"description": "OpenEnv-compatible Teeworlds arena for LLM training",
"endpoints": {
"reset": "POST /reset",
"tools": "GET /tools",
"call_tool": "POST /call_tool",
"websocket": "WS /ws",
}
}
@app.get("/health")
async def health():
"""Health check endpoint."""
return {"status": "healthy"}
@app.post("/reset", response_model=ResetResponse)
async def reset(request: ResetRequest):
"""Reset the environment for a new episode."""
env = TeeEnvironment()
obs = env.reset(seed=request.seed, episode_id=request.episode_id)
session_id = obs.metadata.get("episode_id", "default")
_environments[session_id] = env
return ResetResponse(
status="ready",
episode_id=session_id,
message=obs.metadata.get("message", "Environment ready"),
)
@app.get("/tools")
async def list_tools():
"""List available MCP tools."""
return {
"tools": [
{
"name": "move",
"description": "Move the tee left, right, or none",
"parameters": {"direction": {"type": "string", "enum": ["left", "right", "none"]}},
},
{
"name": "jump",
"description": "Make the tee jump",
"parameters": {},
},
{
"name": "aim",
"description": "Aim at target coordinates",
"parameters": {
"x": {"type": "integer", "description": "Target X coordinate"},
"y": {"type": "integer", "description": "Target Y coordinate"},
},
},
{
"name": "shoot",
"description": "Fire the specified weapon",
"parameters": {
"weapon": {"type": "integer", "description": "Weapon ID (0-5) or -1 for current", "default": -1},
},
},
{
"name": "hook",
"description": "Use the grappling hook",
"parameters": {},
},
{
"name": "get_status",
"description": "Get current game state as text",
"parameters": {},
},
]
}
@app.post("/call_tool", response_model=ToolCallResponse)
async def call_tool(request: ToolCallRequest, session_id: str = "default"):
"""Call an MCP tool."""
env = _environments.get(session_id)
if env is None:
env = TeeEnvironment()
env.reset()
_environments[session_id] = env
# Get the MCP server from environment
mcp = env._mcp
# Call the tool
tool_name = request.tool_name
arguments = request.arguments
try:
# Use FastMCP's async call_tool method
tool_result = await mcp.call_tool(tool_name, arguments)
# Extract text result from ToolResult
if tool_result and tool_result.content:
result = tool_result.content[0].text if hasattr(tool_result.content[0], 'text') else str(tool_result.content[0])
else:
result = str(tool_result)
# Simulate tick and get reward
env._simulate_tick()
reward = env._calculate_reward()
# Check done
done = env._state.step_count >= env._max_steps
player = env._agents.get(env._current_agent_id)
if player and not player.is_alive:
done = True
return ToolCallResponse(
result=result,
reward=reward,
done=done,
metadata={
"step": env._state.step_count,
"tick": env._tick,
},
)
except Exception as e:
return ToolCallResponse(
result=f"Error: {str(e)}",
reward=0.0,
done=False,
metadata={"error": str(e)},
)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket endpoint for real-time interaction."""
await websocket.accept()
env = TeeEnvironment()
env.reset()
try:
while True:
data = await websocket.receive_text()
message = json.loads(data)
action_type = message.get("type", "call_tool")
if action_type == "reset":
obs = env.reset(
seed=message.get("seed"),
episode_id=message.get("episode_id"),
)
await websocket.send_json({
"type": "reset",
"status": "ready",
"episode_id": obs.metadata.get("episode_id"),
"message": obs.metadata.get("message"),
})
elif action_type == "list_tools":
await websocket.send_json({
"type": "tools",
"tools": [
{"name": "move", "description": "Move left/right/none"},
{"name": "jump", "description": "Jump"},
{"name": "aim", "description": "Aim at x,y"},
{"name": "shoot", "description": "Fire weapon"},
{"name": "hook", "description": "Use hook"},
{"name": "get_status", "description": "Get game state"},
],
})
elif action_type == "call_tool":
tool_name = message.get("tool_name")
arguments = message.get("arguments", {})
# Call tool using FastMCP's async call_tool method
mcp = env._mcp
try:
tool_result = await mcp.call_tool(tool_name, arguments)
if tool_result and tool_result.content:
result = tool_result.content[0].text if hasattr(tool_result.content[0], 'text') else str(tool_result.content[0])
else:
result = str(tool_result)
except Exception as e:
result = f"Error: {str(e)}"
# Simulate and get reward
env._simulate_tick()
reward = env._calculate_reward()
done = env._state.step_count >= env._max_steps
await websocket.send_json({
"type": "tool_result",
"tool_name": tool_name,
"result": result,
"reward": reward,
"done": done,
"step": env._state.step_count,
"tick": env._tick,
})
except WebSocketDisconnect:
pass
def main():
"""
Entry point for direct execution.
Usage:
python -m teeunit_env.server.app
uvicorn teeunit_env.server.app:app --host 0.0.0.0 --port 8000
"""
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
main()