File size: 2,500 Bytes
788411f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""
ContextFlow OpenEnv - Simple API Server
"""

from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional, List, Dict, Any
import uvicorn

from models import Observation, Action, Reward, State, StepResult, TaskDifficulty, ActionType
from server.contextflow_environment import ContextFlowEnvironment

app = FastAPI(title="ContextFlow OpenEnv")

environments: Dict[str, ContextFlowEnvironment] = {}


class ResetResponse(BaseModel):
    observation: dict
    episode_id: str


class StepRequest(BaseModel):
    action_type: str
    predicted_confusion: Optional[float] = None
    intervention_type: Optional[str] = None
    intervention_intensity: Optional[float] = None
    episode_id: Optional[str] = None


@app.get("/")
async def root():
    return {"message": "ContextFlow OpenEnv Environment", "version": "1.0.0"}


@app.get("/health")
async def health():
    return {"status": "healthy"}


@app.post("/reset", response_model=ResetResponse)
async def reset(difficulty: str = "medium"):
    try:
        difficulty_enum = TaskDifficulty(difficulty.lower())
    except ValueError:
        difficulty_enum = TaskDifficulty.MEDIUM
    
    env = ContextFlowEnvironment(task_difficulty=difficulty_enum)
    observation = env.reset()
    
    env_id = observation.episode_id
    environments[env_id] = env
    
    return ResetResponse(
        observation=observation.model_dump(),
        episode_id=env_id,
    )


@app.post("/step")
async def step(request: StepRequest):
    if not request.episode_id or request.episode_id not in environments:
        return {"error": "Invalid or missing episode_id"}
    
    env = environments[request.episode_id]
    
    action = Action(
        action_type=ActionType(request.action_type),
        predicted_confusion=request.predicted_confusion,
        intervention_type=request.intervention_type,
        intervention_intensity=request.intervention_intensity,
    )
    
    result = env.step(action)
    
    if result.done:
        del environments[request.episode_id]
    
    return result.model_dump()


@app.get("/state/{episode_id}")
async def get_state(episode_id: str):
    if episode_id not in environments:
        return {"error": "Episode not found"}
    
    env = environments[episode_id]
    return env.get_state().model_dump()


@app.get("/")
async def read_root():
    return {"message": "ContextFlow OpenEnv", "version": "1.0.0"}


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)