pragunk's picture
Upload 14 files
dada368 verified
from fastapi import FastAPI
from env.models import Action
from env.environment import PropagationShieldEnvironment
import asyncio
app = FastAPI(title="PropagationShield Environment Server")
env = PropagationShieldEnvironment()
@app.post("/reset")
async def reset_endpoint(config: dict = {}):
obs, state = env.reset(config)
# We return the dict version of the Pydantic models for the JSON response
return {"observation": obs.model_dump(), "state": state.model_dump()}
@app.post("/step")
async def step_endpoint(action: Action):
try:
# Wrapping in timeout as required by Section 8.5
obs, rewards, done, info = await asyncio.wait_for(
asyncio.to_thread(env.step, action),
timeout=30.0
)
return {
"observation": obs.model_dump(),
"reward": rewards,
"done": done,
"info": info
}
except asyncio.TimeoutError:
return {"reward": {"total": 0.0}, "done": True, "info": {"error": "timeout"}}