M ShreeRaj commited on
Commit
fa903fe
Β·
unverified Β·
1 Parent(s): 55309da

Refactor main.py for OpenEnv integration

Browse files

Refactor main.py to integrate OpenEnv compatibility and enhance API structure.

Files changed (1) hide show
  1. backend/main.py +164 -101
backend/main.py CHANGED
@@ -1,108 +1,171 @@
1
- import uuid
2
- from typing import Dict, Any, Optional
3
-
4
- from fastapi import FastAPI, HTTPException
5
- from fastapi.middleware.cors import CORSMiddleware
6
- from pydantic import BaseModel
7
-
8
- import sys
9
  import os
 
 
 
10
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
11
 
12
- from models import Action, Observation, generate_tasks, deterministic_grader, CLMEnvironment
 
 
13
 
14
- app = FastAPI(title="Cognitive Load Manager (CLM) Environment API")
 
 
 
 
 
 
 
15
 
16
- app.add_middleware(
17
- CORSMiddleware,
18
- allow_origins=["*"],
19
- allow_credentials=True,
20
- allow_methods=["*"],
21
- allow_headers=["*"],
22
  )
23
 
24
- # In-memory session store
25
- sessions: Dict[str, CLMEnvironment] = {}
26
-
27
- class ResetRequest(BaseModel):
28
- task_id: str = "easy" # easy, medium, hard
29
- session_id: Optional[str] = None
30
-
31
- class ResetResponse(BaseModel):
32
- session_id: str
33
- observation: Observation
34
-
35
- class StepRequest(BaseModel):
36
- session_id: str = "default"
37
- action: Optional[Action] = None
38
-
39
- class StepResponse(BaseModel):
40
- observation: Observation
41
- reward: float
42
- done: bool
43
- info: Dict[str, Any]
44
-
45
- @app.get("/")
46
- def read_root():
47
- routes = []
48
- for route in app.routes:
49
- route_info = {"path": route.path, "name": getattr(route, "name", "")}
50
- if hasattr(route, "methods"):
51
- route_info["methods"] = list(route.methods)
52
- routes.append(route_info)
53
- return {
54
- "message": "Cognitive Load Manager is running πŸš€",
55
- "routes": routes
56
- }
57
-
58
- @app.post("/reset", response_model=ResetResponse)
59
- def reset_env(req: Optional[ResetRequest] = None):
60
- if req is None:
61
- req = ResetRequest()
62
-
63
- if req.task_id not in ["easy", "medium", "hard"]:
64
- raise HTTPException(status_code=400, detail="Invalid task_id")
65
-
66
- tasks = generate_tasks(req.task_id)
67
- env = CLMEnvironment(tasks=tasks, max_steps=50) # Max 50 steps
68
- obs = env.reset()
69
-
70
- sess_id = req.session_id or str(uuid.uuid4())
71
- sessions[sess_id] = env
72
-
73
- return ResetResponse(session_id=sess_id, observation=obs)
74
-
75
- @app.post("/step", response_model=StepResponse)
76
- def step_env(req: Optional[StepRequest] = None):
77
- if req is None:
78
- req = StepRequest()
79
- if req.action is None:
80
- req.action = Action(type="work")
81
-
82
- if req.session_id not in sessions:
83
- tasks = generate_tasks("easy")
84
- env = CLMEnvironment(tasks=tasks, max_steps=50)
85
- env.reset()
86
- sessions[req.session_id] = env
87
-
88
- env = sessions[req.session_id]
89
- obs, reward, done, info = env.step(req.action)
90
-
91
- if done:
92
- score = deterministic_grader(env.state.tasks, env.state.time_step, env.state.energy)
93
- info["final_score"] = score
94
-
95
- return StepResponse(observation=obs, reward=reward, done=done, info=info)
96
-
97
- @app.get("/state")
98
- def get_state(session_id: Optional[str] = "default"):
99
- if session_id is None:
100
- session_id = "default"
101
-
102
- if session_id not in sessions:
103
- tasks = generate_tasks("easy")
104
- env = CLMEnvironment(tasks=tasks, max_steps=50)
105
- env.reset()
106
- sessions[session_id] = env
107
-
108
- return sessions[session_id].state_dict()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import sys
3
+ from typing import Any, Dict, List, Optional
4
+
5
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6
 
7
+ from fastapi import FastAPI
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from pydantic import Field
10
 
11
+ from openenv.core.env_server.interfaces import Environment
12
+ from openenv.core.env_server.types import (
13
+ Action as OEAction,
14
+ Observation as OEObservation,
15
+ State as OEState,
16
+ EnvironmentMetadata,
17
+ )
18
+ from openenv.core.env_server.http_server import HTTPEnvServer
19
 
20
+ from models import (
21
+ Action as ModelAction,
22
+ Observation as ModelObservation,
23
+ generate_tasks,
24
+ deterministic_grader,
25
+ CLMEnvironment,
26
  )
27
 
28
+
29
+ # ── OpenEnv-compatible Action / Observation / State models ──────────────────
30
+
31
+ class CLMAction(OEAction):
32
+ """Action for the Cognitive Load Manager environment."""
33
+ type: str = Field(description="Action type: work, break, switch, or delay")
34
+ task_id: Optional[str] = Field(default=None, description="Task ID to act on")
35
+
36
+ model_config = {"extra": "allow"}
37
+
38
+
39
+ class CLMObservation(OEObservation):
40
+ """Observation from the Cognitive Load Manager environment."""
41
+ tasks: List[Dict[str, Any]] = Field(default_factory=list)
42
+ visible_state: Dict[str, Any] = Field(default_factory=dict)
43
+ time_step: int = Field(default=0)
44
+
45
+ model_config = {"extra": "allow"}
46
+
47
+
48
+ class CLMState(OEState):
49
+ """State for the Cognitive Load Manager environment."""
50
+ energy: float = Field(default=1.0)
51
+ stress: float = Field(default=0.0)
52
+ fatigue: float = Field(default=0.0)
53
+ current_task_id: Optional[str] = Field(default=None)
54
+ tasks: List[Dict[str, Any]] = Field(default_factory=list)
55
+
56
+ model_config = {"extra": "allow"}
57
+
58
+
59
+ # ── OpenEnv Environment wrapper ─────────────────────────��───────────────────
60
+
61
+ class CLMEnvWrapper(Environment):
62
+ """
63
+ Cognitive Load Manager wrapped as an OpenEnv-compliant environment.
64
+
65
+ Three difficulty levels via the task_id reset parameter:
66
+ - easy: 2 tasks, no deadlines
67
+ - medium: 5 tasks with deadlines
68
+ - hard: 8 tasks with tight deadlines
69
+ """
70
+
71
+ SUPPORTS_CONCURRENT_SESSIONS = True
72
+
73
+ def __init__(self):
74
+ super().__init__()
75
+ level = os.getenv("CLM_LEVEL", "easy")
76
+ tasks = generate_tasks(level)
77
+ self._env = CLMEnvironment(tasks=tasks, max_steps=50)
78
+ self._final_score: float = 0.0
79
+
80
+ def _to_oe_obs(self, obs: ModelObservation, done: bool = False, reward: Optional[float] = None, info: Optional[dict] = None) -> CLMObservation:
81
+ return CLMObservation(
82
+ tasks=[t.model_dump() for t in obs.tasks],
83
+ visible_state=obs.visible_state.model_dump(),
84
+ time_step=obs.time_step,
85
+ done=done,
86
+ reward=reward,
87
+ metadata=info or {},
88
+ )
89
+
90
+ def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, task_id: str = "easy", **kwargs) -> CLMObservation:
91
+ if task_id not in ("easy", "medium", "hard"):
92
+ task_id = "easy"
93
+ tasks = generate_tasks(task_id)
94
+ self._env = CLMEnvironment(tasks=tasks, max_steps=50)
95
+ self._final_score = 0.0
96
+ obs = self._env.reset()
97
+ return self._to_oe_obs(obs)
98
+
99
+ def step(self, action: CLMAction, timeout_s: Optional[float] = None, **kwargs) -> CLMObservation:
100
+ model_action = ModelAction(type=action.type, task_id=action.task_id)
101
+ obs, reward, done, info = self._env.step(model_action)
102
+ if done:
103
+ self._final_score = deterministic_grader(
104
+ self._env.state.tasks,
105
+ self._env.state.time_step,
106
+ self._env.state.energy,
107
+ )
108
+ info["final_score"] = self._final_score
109
+ return self._to_oe_obs(obs, done=done, reward=float(reward), info=info)
110
+
111
+ @property
112
+ def state(self) -> CLMState:
113
+ raw = self._env.state_dict()
114
+ return CLMState(
115
+ energy=raw.get("energy", 1.0),
116
+ stress=raw.get("stress", 0.0),
117
+ fatigue=raw.get("fatigue", 0.0),
118
+ current_task_id=raw.get("current_task_id"),
119
+ tasks=raw.get("tasks", []),
120
+ step_count=raw.get("time_step", 0),
121
+ )
122
+
123
+ def get_metadata(self) -> EnvironmentMetadata:
124
+ return EnvironmentMetadata(
125
+ name="cognitive-load-manager",
126
+ description=(
127
+ "Cognitive Load Manager (CLM) simulates human cognitive load "
128
+ "(energy, stress, fatigue) while managing tasks with deadlines. "
129
+ "Three difficulty levels: easy (2 tasks, no deadlines), "
130
+ "medium (5 tasks with deadlines), hard (8 tasks with tight deadlines)."
131
+ ),
132
+ version="1.0.0",
133
+ author="Team Innovators",
134
+ )
135
+
136
+ def close(self) -> None:
137
+ pass
138
+
139
+
140
+ # ── Build FastAPI app via OpenEnv HTTPEnvServer ──────────────────────────────
141
+
142
+ def build_app() -> FastAPI:
143
+ server = HTTPEnvServer(
144
+ env=CLMEnvWrapper,
145
+ action_cls=CLMAction,
146
+ observation_cls=CLMObservation,
147
+ max_concurrent_envs=10,
148
+ )
149
+
150
+ _app = FastAPI(
151
+ title="Cognitive Load Manager (CLM) Environment API",
152
+ version="1.0.0",
153
+ description=(
154
+ "OpenEnv-compliant environment for the Meta PyTorch Hackathon. "
155
+ "Simulates cognitive load management with three difficulty levels."
156
+ ),
157
+ )
158
+
159
+ _app.add_middleware(
160
+ CORSMiddleware,
161
+ allow_origins=["*"],
162
+ allow_credentials=True,
163
+ allow_methods=["*"],
164
+ allow_headers=["*"],
165
+ )
166
+
167
+ server.register_routes(_app)
168
+ return _app
169
+
170
+
171
+ app = build_app()