File size: 1,052 Bytes
dada368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from fastapi import FastAPI
from env.models import Action
from env.environment import PropagationShieldEnvironment
import asyncio

app = FastAPI(title="PropagationShield Environment Server")
env = PropagationShieldEnvironment()

@app.post("/reset")
async def reset_endpoint(config: dict = {}):
    obs, state = env.reset(config)
    # We return the dict version of the Pydantic models for the JSON response
    return {"observation": obs.model_dump(), "state": state.model_dump()}

@app.post("/step")
async def step_endpoint(action: Action):
    try:
        # Wrapping in timeout as required by Section 8.5
        obs, rewards, done, info = await asyncio.wait_for(
            asyncio.to_thread(env.step, action),
            timeout=30.0
        )
        return {
            "observation": obs.model_dump(),
            "reward": rewards,
            "done": done,
            "info": info
        }
    except asyncio.TimeoutError:
        return {"reward": {"total": 0.0}, "done": True, "info": {"error": "timeout"}}