| from fastapi import FastAPI
|
| from env.models import Action
|
| from env.environment import PropagationShieldEnvironment
|
| import asyncio
|
|
|
| app = FastAPI(title="PropagationShield Environment Server")
|
| env = PropagationShieldEnvironment()
|
|
|
| @app.post("/reset")
|
| async def reset_endpoint(config: dict = {}):
|
| obs, state = env.reset(config)
|
|
|
| return {"observation": obs.model_dump(), "state": state.model_dump()}
|
|
|
| @app.post("/step")
|
| async def step_endpoint(action: Action):
|
| try:
|
|
|
| obs, rewards, done, info = await asyncio.wait_for(
|
| asyncio.to_thread(env.step, action),
|
| timeout=30.0
|
| )
|
| return {
|
| "observation": obs.model_dump(),
|
| "reward": rewards,
|
| "done": done,
|
| "info": info
|
| }
|
| except asyncio.TimeoutError:
|
| return {"reward": {"total": 0.0}, "done": True, "info": {"error": "timeout"}} |