""" ContextFlow OpenEnv - Simple API Server """ from fastapi import FastAPI from pydantic import BaseModel from typing import Optional, List, Dict, Any import uvicorn from models import Observation, Action, Reward, State, StepResult, TaskDifficulty, ActionType from server.contextflow_environment import ContextFlowEnvironment app = FastAPI(title="ContextFlow OpenEnv") environments: Dict[str, ContextFlowEnvironment] = {} class ResetResponse(BaseModel): observation: dict episode_id: str class StepRequest(BaseModel): action_type: str predicted_confusion: Optional[float] = None intervention_type: Optional[str] = None intervention_intensity: Optional[float] = None episode_id: Optional[str] = None @app.get("/") async def root(): return {"message": "ContextFlow OpenEnv Environment", "version": "1.0.0"} @app.get("/health") async def health(): return {"status": "healthy"} @app.post("/reset", response_model=ResetResponse) async def reset(difficulty: str = "medium"): try: difficulty_enum = TaskDifficulty(difficulty.lower()) except ValueError: difficulty_enum = TaskDifficulty.MEDIUM env = ContextFlowEnvironment(task_difficulty=difficulty_enum) observation = env.reset() env_id = observation.episode_id environments[env_id] = env return ResetResponse( observation=observation.model_dump(), episode_id=env_id, ) @app.post("/step") async def step(request: StepRequest): if not request.episode_id or request.episode_id not in environments: return {"error": "Invalid or missing episode_id"} env = environments[request.episode_id] action = Action( action_type=ActionType(request.action_type), predicted_confusion=request.predicted_confusion, intervention_type=request.intervention_type, intervention_intensity=request.intervention_intensity, ) result = env.step(action) if result.done: del environments[request.episode_id] return result.model_dump() @app.get("/state/{episode_id}") async def get_state(episode_id: str): if episode_id not in environments: return {"error": "Episode not found"} env = environments[episode_id] return env.get_state().model_dump() @app.get("/") async def read_root(): return {"message": "ContextFlow OpenEnv", "version": "1.0.0"} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)