Spaces:
Sleeping
Sleeping
File size: 4,092 Bytes
a2896bf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
import asyncio
import json
from typing import Optional
from models import Observation, Action, Reward, State, StepResult, TaskDifficulty
from server.contextflow_environment import ContextFlowEnvironment
app = FastAPI(title="ContextFlow OpenEnv")
connections: dict[str, WebSocket] = {}
environments: dict[str, ContextFlowEnvironment] = {}
@app.get("/")
async def root():
return {"message": "ContextFlow OpenEnv Environment", "version": "1.0.0"}
@app.get("/health")
async def health():
return {"status": "healthy"}
@app.post("/reset")
async def reset(difficulty: Optional[str] = "medium"):
try:
difficulty_enum = TaskDifficulty(difficulty.lower())
except ValueError:
difficulty_enum = TaskDifficulty.MEDIUM
env = ContextFlowEnvironment(task_difficulty=difficulty_enum)
observation = env.reset()
env_id = observation.episode_id
environments[env_id] = env
return {
"observation": observation.model_dump(),
"episode_id": env_id,
}
@app.post("/step")
async def step(action: Action):
if not action.episode_id or action.episode_id not in environments:
return JSONResponse(
status_code=400,
content={"error": "Invalid or missing episode_id"}
)
env = environments[action.episode_id]
result = env.step(action)
if result.done:
del environments[action.episode_id]
return result.model_dump()
@app.get("/state/{episode_id}")
async def get_state(episode_id: str):
if episode_id not in environments:
return JSONResponse(
status_code=404,
content={"error": "Episode not found"}
)
env = environments[episode_id]
return env.state().model_dump()
@app.websocket("/ws/{episode_id}")
async def websocket_endpoint(websocket: WebSocket, episode_id: str):
await websocket.accept()
connections[episode_id] = websocket
if episode_id not in environments:
await websocket.send_json({"error": "Episode not found"})
await websocket.close()
return
try:
while True:
data = await websocket.receive_text()
message = json.loads(data)
if message["type"] == "reset":
difficulty = message.get("difficulty", "medium")
env = ContextFlowEnvironment(task_difficulty=TaskDifficulty(difficulty))
observation = env.reset()
environments[episode_id] = env
await websocket.send_json({
"type": "reset",
"observation": observation.model_dump()
})
elif message["type"] == "step":
if episode_id not in environments:
await websocket.send_json({"error": "Episode not found"})
continue
env = environments[episode_id]
action = Action(**message["action"])
result = env.step(action)
if result.done:
del environments[episode_id]
await websocket.send_json({
"type": "step",
"result": result.model_dump()
})
elif message["type"] == "state":
if episode_id not in environments:
await websocket.send_json({"error": "Episode not found"})
continue
env = environments[episode_id]
await websocket.send_json({
"type": "state",
"state": env.state().model_dump()
})
except WebSocketDisconnect:
pass
finally:
if episode_id in connections:
del connections[episode_id]
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|