Ayush-Singh's picture
server fixes
9d69b2d
"""
FastAPI server for StructuralDesignEnv (OpenEnv HTTP interface).
Endpoints:
GET /health
GET /tasks
POST /reset
POST /step
GET /state?session_id=...
GET /action_schema
"""
from __future__ import annotations
import uuid
from typing import Any, Dict, Optional
from fastapi import Body, FastAPI, HTTPException
from fastapi.responses import HTMLResponse
import os
from pydantic import BaseModel
from structural_design_env.env import StructuralDesignEnv
from structural_design_env.tasks import TASK_REGISTRY
# ---------------------------------------------------------------------------
# Session manager
# ---------------------------------------------------------------------------
class SessionManager:
def __init__(self):
self._sessions: Dict[str, StructuralDesignEnv] = {}
def create(self, session_id: Optional[str] = None) -> tuple[str, StructuralDesignEnv]:
sid = session_id or str(uuid.uuid4())
env = StructuralDesignEnv()
self._sessions[sid] = env
return sid, env
def get(self, session_id: str) -> Optional[StructuralDesignEnv]:
return self._sessions.get(session_id)
def get_or_create(self, session_id: Optional[str]) -> tuple[str, StructuralDesignEnv]:
if session_id and session_id in self._sessions:
return session_id, self._sessions[session_id]
return self.create(session_id)
session_manager = SessionManager()
# ---------------------------------------------------------------------------
# FastAPI app
# ---------------------------------------------------------------------------
app = FastAPI(
title="StructuralDesignEnv",
description="OpenEnv API for steel frame structural engineering RL environment.",
version="1.0.0",
)
# ---------------------------------------------------------------------------
# Request / response models
# ---------------------------------------------------------------------------
class ResetRequest(BaseModel):
task_id: str = "task1_warehouse"
session_id: Optional[str] = None
seed: Optional[int] = None
class StepRequest(BaseModel):
session_id: Optional[str] = None
message: str # JSON-encoded StructuralAction
class ResetResponse(BaseModel):
session_id: str
observation: dict
class StepResponse(BaseModel):
session_id: str
observation: dict
reward: float
done: bool
info: dict
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/health")
def health():
return {"status": "ok", "env": "StructuralDesignEnv", "version": "1.0.0"}
@app.get("/", response_class=HTMLResponse)
@app.get("/demo", response_class=HTMLResponse)
def serve_demo():
demo_path = os.path.join(os.path.dirname(__file__), "demo.html")
if os.path.exists(demo_path):
with open(demo_path, "r", encoding="utf-8") as f:
return f.read()
return "demo.html not found."
@app.get("/tasks")
def list_tasks():
tasks = []
for tid, (cfg, _) in TASK_REGISTRY.items():
tasks.append({
"id": tid,
"name": cfg.name,
"difficulty": cfg.difficulty,
"max_steps": cfg.max_steps,
"n_floors": cfg.n_floors,
"site_width_m": cfg.site_width_m,
"site_depth_m": cfg.site_depth_m,
"description": _task_description(tid),
})
return {"tasks": tasks}
@app.post("/reset", response_model=ResetResponse)
def reset_env(body: Dict[str, Any] | None = Body(default=None)):
req = ResetRequest(**(body or {}))
if req.task_id not in TASK_REGISTRY:
raise HTTPException(
status_code=400,
detail=f"Unknown task_id '{req.task_id}'. Valid: {list(TASK_REGISTRY)}",
)
sid, env = session_manager.get_or_create(req.session_id)
obs = env.reset(task_id=req.task_id, seed=req.seed)
return ResetResponse(session_id=sid, observation=obs)
@app.post("/step", response_model=StepResponse)
def step_env(req: StepRequest):
sid, env = session_manager.get_or_create(req.session_id)
obs, reward, done, info = env.step(req.message)
return StepResponse(
session_id=sid,
observation=obs,
reward=round(reward, 6),
done=done,
info=info,
)
@app.get("/state")
def get_state(session_id: Optional[str] = None):
if not session_id:
raise HTTPException(status_code=400, detail="session_id is required")
env = session_manager.get(session_id)
if env is None:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
return env.state()
@app.get("/action_schema")
def action_schema():
return {
"description": "Actions for StructuralDesignEnv. Send as JSON string in the 'message' field of /step.",
"actions": [
{
"action_type": "place_column",
"description": "Place a steel column at a grid position on a floor.",
"fields": {
"grid_x": "int [0-19] β€” horizontal grid position (1 cell = 1m)",
"grid_y": "int [0-19] β€” depth grid position",
"floor": "int [0..n_floors-1] β€” floor index (0=ground)",
"section": "str β€” one of HEB140, HEB160, HEB200, HEB240, HEB300, HEB360, HEB400",
},
"example": {
"action_type": "place_column",
"grid_x": 5,
"grid_y": 0,
"floor": 0,
"section": "HEB200",
},
},
{
"action_type": "place_beam",
"description": "Place a steel beam connecting two column nodes at the same floor.",
"fields": {
"from_node_x": "int β€” x of start column",
"from_node_y": "int β€” y of start column",
"to_node_x": "int β€” x of end column (must share x OR y with start)",
"to_node_y": "int β€” y of end column",
"floor": "int β€” floor where both columns sit",
"section": "str β€” one of IPE200, IPE240, IPE300, IPE360, IPE400, IPE450, IPE500",
"orientation": "'x' (horizontal) or 'y' (depth direction)",
},
"example": {
"action_type": "place_beam",
"from_node_x": 0,
"from_node_y": 0,
"to_node_x": 5,
"to_node_y": 0,
"floor": 0,
"section": "IPE300",
"orientation": "x",
},
},
{
"action_type": "upgrade_section",
"description": "Upgrade an element to the next larger standard section.",
"fields": {
"element_id": "str β€” element ID from placed_elements list (e.g. 'col_5_0_0')",
},
"example": {"action_type": "upgrade_section", "element_id": "col_5_0_0"},
},
{
"action_type": "downgrade_section",
"description": "Downgrade an element to the next smaller standard section.",
"fields": {
"element_id": "str β€” element ID",
},
"example": {"action_type": "downgrade_section", "element_id": "beam_0_0_5_0_0"},
},
{
"action_type": "remove_element",
"description": "Remove an element from the structure.",
"fields": {"element_id": "str β€” element ID"},
"example": {"action_type": "remove_element", "element_id": "col_5_0_0"},
},
{
"action_type": "add_wall",
"description": "Add a concrete shear wall between two column nodes.",
"fields": {
"from_node_x": "int",
"from_node_y": "int",
"to_node_x": "int",
"to_node_y": "int",
"floor": "int",
"thickness_m": "float β€” 0.2 or 0.3",
"orientation": "'x' or 'y'",
},
"example": {
"action_type": "add_wall",
"from_node_x": 0,
"from_node_y": 0,
"to_node_x": 5,
"to_node_y": 0,
"floor": 0,
"thickness_m": 0.2,
"orientation": "x",
},
},
{
"action_type": "done",
"description": "Signal that the design is complete. Triggers final grading.",
"example": {"action_type": "done"},
},
],
"sections": {
"columns": ["HEB140", "HEB160", "HEB200", "HEB240", "HEB300", "HEB360", "HEB400"],
"beams": ["IPE200", "IPE240", "IPE300", "IPE360", "IPE400", "IPE450", "IPE500"],
},
}
def _task_description(tid: str) -> str:
descriptions = {
"task1_warehouse": "Single-story 20Γ—10m warehouse. No lateral loads. Score by validity + material efficiency.",
"task2_office": "3-story 20Γ—20m office with wind and light seismic. Score by drift + efficiency + torsional balance.",
"task3_hospital": "3-story hospital in seismic Zone 3. Score by seismic drift + budget + redundancy + utilisation.",
}
return descriptions.get(tid, "")