File size: 4,092 Bytes
a2896bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
import asyncio
import json
from typing import Optional

from models import Observation, Action, Reward, State, StepResult, TaskDifficulty
from server.contextflow_environment import ContextFlowEnvironment


app = FastAPI(title="ContextFlow OpenEnv")

connections: dict[str, WebSocket] = {}
environments: dict[str, ContextFlowEnvironment] = {}


@app.get("/")
async def root():
    return {"message": "ContextFlow OpenEnv Environment", "version": "1.0.0"}


@app.get("/health")
async def health():
    return {"status": "healthy"}


@app.post("/reset")
async def reset(difficulty: Optional[str] = "medium"):
    try:
        difficulty_enum = TaskDifficulty(difficulty.lower())
    except ValueError:
        difficulty_enum = TaskDifficulty.MEDIUM
    
    env = ContextFlowEnvironment(task_difficulty=difficulty_enum)
    observation = env.reset()
    
    env_id = observation.episode_id
    environments[env_id] = env
    
    return {
        "observation": observation.model_dump(),
        "episode_id": env_id,
    }


@app.post("/step")
async def step(action: Action):
    if not action.episode_id or action.episode_id not in environments:
        return JSONResponse(
            status_code=400,
            content={"error": "Invalid or missing episode_id"}
        )
    
    env = environments[action.episode_id]
    result = env.step(action)
    
    if result.done:
        del environments[action.episode_id]
    
    return result.model_dump()


@app.get("/state/{episode_id}")
async def get_state(episode_id: str):
    if episode_id not in environments:
        return JSONResponse(
            status_code=404,
            content={"error": "Episode not found"}
        )
    
    env = environments[episode_id]
    return env.state().model_dump()


@app.websocket("/ws/{episode_id}")
async def websocket_endpoint(websocket: WebSocket, episode_id: str):
    await websocket.accept()
    connections[episode_id] = websocket
    
    if episode_id not in environments:
        await websocket.send_json({"error": "Episode not found"})
        await websocket.close()
        return
    
    try:
        while True:
            data = await websocket.receive_text()
            message = json.loads(data)
            
            if message["type"] == "reset":
                difficulty = message.get("difficulty", "medium")
                env = ContextFlowEnvironment(task_difficulty=TaskDifficulty(difficulty))
                observation = env.reset()
                environments[episode_id] = env
                await websocket.send_json({
                    "type": "reset",
                    "observation": observation.model_dump()
                })
            
            elif message["type"] == "step":
                if episode_id not in environments:
                    await websocket.send_json({"error": "Episode not found"})
                    continue
                
                env = environments[episode_id]
                action = Action(**message["action"])
                result = env.step(action)
                
                if result.done:
                    del environments[episode_id]
                
                await websocket.send_json({
                    "type": "step",
                    "result": result.model_dump()
                })
            
            elif message["type"] == "state":
                if episode_id not in environments:
                    await websocket.send_json({"error": "Episode not found"})
                    continue
                
                env = environments[episode_id]
                await websocket.send_json({
                    "type": "state",
                    "state": env.state().model_dump()
                })
    
    except WebSocketDisconnect:
        pass
    finally:
        if episode_id in connections:
            del connections[episode_id]


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)