soumi guria commited on
Commit
1b81250
Β·
1 Parent(s): 68a1f40

updated the main.py and inference file such that it is compatible to both the fast api approach and the http openenv compatible

Browse files
Files changed (2) hide show
  1. backend/main.py +341 -80
  2. inference.py +3 -4
backend/main.py CHANGED
@@ -1,13 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import sys
3
- from typing import Any, Dict, List, Optional
4
 
 
 
 
 
 
5
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6
 
7
- from fastapi import FastAPI
8
- from fastapi.middleware.cors import CORSMiddleware
9
- from pydantic import Field
 
 
 
 
 
10
 
 
11
  from openenv.core.env_server.interfaces import Environment
12
  from openenv.core.env_server.types import (
13
  Action as OEAction,
@@ -17,67 +203,163 @@ from openenv.core.env_server.types import (
17
  )
18
  from openenv.core.env_server.http_server import HTTPEnvServer
19
 
20
- from models import (
21
- Action as ModelAction,
22
- Observation as ModelObservation,
23
- generate_tasks,
24
- deterministic_grader,
25
- CLMEnvironment,
 
 
 
 
 
 
 
 
 
 
26
  )
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- # ── OpenEnv-compatible Action / Observation / State models ──────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  class CLMAction(OEAction):
32
- """Action for the Cognitive Load Manager environment."""
33
- type: str = Field(description="Action type: work, break, switch, or delay")
34
- task_id: Optional[str] = Field(default=None, description="Task ID to act on")
35
 
36
  model_config = {"extra": "allow"}
37
 
38
 
39
  class CLMObservation(OEObservation):
40
- """Observation from the Cognitive Load Manager environment."""
41
  tasks: List[Dict[str, Any]] = Field(default_factory=list)
42
  visible_state: Dict[str, Any] = Field(default_factory=dict)
43
- time_step: int = Field(default=0)
44
 
45
  model_config = {"extra": "allow"}
46
 
47
 
48
  class CLMState(OEState):
49
- """State for the Cognitive Load Manager environment."""
50
- energy: float = Field(default=1.0)
51
- stress: float = Field(default=0.0)
52
- fatigue: float = Field(default=0.0)
53
- current_task_id: Optional[str] = Field(default=None)
54
  tasks: List[Dict[str, Any]] = Field(default_factory=list)
55
 
56
  model_config = {"extra": "allow"}
57
 
58
 
59
- # ── OpenEnv Environment wrapper ─────────────────────────────────────────────
60
-
61
  class CLMEnvWrapper(Environment):
62
- """
63
- Cognitive Load Manager wrapped as an OpenEnv-compliant environment.
64
-
65
- Three difficulty levels via the task_id reset parameter:
66
- - easy: 2 tasks, no deadlines
67
- - medium: 5 tasks with deadlines
68
- - hard: 8 tasks with tight deadlines
69
- """
70
 
71
  SUPPORTS_CONCURRENT_SESSIONS = True
72
 
73
  def __init__(self):
74
  super().__init__()
75
- level = os.getenv("CLM_LEVEL", "easy")
76
- tasks = generate_tasks(level)
77
  self._env = CLMEnvironment(tasks=tasks, max_steps=50)
78
- self._final_score: float = 0.0
79
 
80
- def _to_oe_obs(self, obs: ModelObservation, done: bool = False, reward: Optional[float] = None, info: Optional[dict] = None) -> CLMObservation:
81
  return CLMObservation(
82
  tasks=[t.model_dump() for t in obs.tasks],
83
  visible_state=obs.visible_state.model_dump(),
@@ -87,18 +369,21 @@ class CLMEnvWrapper(Environment):
87
  metadata=info or {},
88
  )
89
 
90
- def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, task_id: str = "easy", **kwargs) -> CLMObservation:
91
  if task_id not in ("easy", "medium", "hard"):
92
  task_id = "easy"
 
93
  tasks = generate_tasks(task_id)
94
  self._env = CLMEnvironment(tasks=tasks, max_steps=50)
95
- self._final_score = 0.0
96
  obs = self._env.reset()
97
- return self._to_oe_obs(obs)
 
 
 
98
 
99
- def step(self, action: CLMAction, timeout_s: Optional[float] = None, **kwargs) -> CLMObservation:
100
- model_action = ModelAction(type=action.type, task_id=action.task_id)
101
  obs, reward, done, info = self._env.step(model_action)
 
102
  if done:
103
  self._final_score = deterministic_grader(
104
  self._env.state.tasks,
@@ -106,10 +391,11 @@ class CLMEnvWrapper(Environment):
106
  self._env.state.energy,
107
  )
108
  info["final_score"] = self._final_score
109
- return self._to_oe_obs(obs, done=done, reward=float(reward), info=info)
 
110
 
111
  @property
112
- def state(self) -> CLMState:
113
  raw = self._env.state_dict()
114
  return CLMState(
115
  energy=raw.get("energy", 1.0),
@@ -120,52 +406,27 @@ class CLMEnvWrapper(Environment):
120
  step_count=raw.get("time_step", 0),
121
  )
122
 
123
- def get_metadata(self) -> EnvironmentMetadata:
124
  return EnvironmentMetadata(
125
  name="cognitive-load-manager",
126
- description=(
127
- "Cognitive Load Manager (CLM) simulates human cognitive load "
128
- "(energy, stress, fatigue) while managing tasks with deadlines. "
129
- "Three difficulty levels: easy (2 tasks, no deadlines), "
130
- "medium (5 tasks with deadlines), hard (8 tasks with tight deadlines)."
131
- ),
132
  version="1.0.0",
133
  author="Team Innovators",
134
  )
135
 
136
- def close(self) -> None:
137
  pass
138
 
139
 
140
- # ── Build FastAPI app via OpenEnv HTTPEnvServer ──────────────────────────────
141
-
142
- def build_app() -> FastAPI:
143
- server = HTTPEnvServer(
144
- env=CLMEnvWrapper,
145
- action_cls=CLMAction,
146
- observation_cls=CLMObservation,
147
- max_concurrent_envs=10,
148
- )
149
-
150
- _app = FastAPI(
151
- title="Cognitive Load Manager (CLM) Environment API",
152
- version="1.0.0",
153
- description=(
154
- "OpenEnv-compliant environment for the Meta PyTorch Hackathon. "
155
- "Simulates cognitive load management with three difficulty levels."
156
- ),
157
- )
158
-
159
- _app.add_middleware(
160
- CORSMiddleware,
161
- allow_origins=["*"],
162
- allow_credentials=True,
163
- allow_methods=["*"],
164
- allow_headers=["*"],
165
- )
166
-
167
- server.register_routes(_app)
168
- return _app
169
 
 
 
 
 
 
 
170
 
171
- app = build_app()
 
1
+ # import os
2
+ # import sys
3
+ # from typing import Any, Dict, List, Optional
4
+
5
+ # sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6
+
7
+ # from fastapi import FastAPI
8
+ # from fastapi.middleware.cors import CORSMiddleware
9
+ # from pydantic import Field
10
+
11
+ # from openenv.core.env_server.interfaces import Environment
12
+ # from openenv.core.env_server.types import (
13
+ # Action as OEAction,
14
+ # Observation as OEObservation,
15
+ # State as OEState,
16
+ # EnvironmentMetadata,
17
+ # )
18
+ # from openenv.core.env_server.http_server import HTTPEnvServer
19
+
20
+ # from models import (
21
+ # Action as ModelAction,
22
+ # Observation as ModelObservation,
23
+ # generate_tasks,
24
+ # deterministic_grader,
25
+ # CLMEnvironment,
26
+ # )
27
+
28
+
29
+ # # ── OpenEnv-compatible Action / Observation / State models ──────────────────
30
+
31
+ # class CLMAction(OEAction):
32
+ # """Action for the Cognitive Load Manager environment."""
33
+ # type: str = Field(description="Action type: work, break, switch, or delay")
34
+ # task_id: Optional[str] = Field(default=None, description="Task ID to act on")
35
+
36
+ # model_config = {"extra": "allow"}
37
+
38
+
39
+ # class CLMObservation(OEObservation):
40
+ # """Observation from the Cognitive Load Manager environment."""
41
+ # tasks: List[Dict[str, Any]] = Field(default_factory=list)
42
+ # visible_state: Dict[str, Any] = Field(default_factory=dict)
43
+ # time_step: int = Field(default=0)
44
+
45
+ # model_config = {"extra": "allow"}
46
+
47
+
48
+ # class CLMState(OEState):
49
+ # """State for the Cognitive Load Manager environment."""
50
+ # energy: float = Field(default=1.0)
51
+ # stress: float = Field(default=0.0)
52
+ # fatigue: float = Field(default=0.0)
53
+ # current_task_id: Optional[str] = Field(default=None)
54
+ # tasks: List[Dict[str, Any]] = Field(default_factory=list)
55
+
56
+ # model_config = {"extra": "allow"}
57
+
58
+
59
+ # # ── OpenEnv Environment wrapper ─────────────────────────────────────────────
60
+
61
+ # class CLMEnvWrapper(Environment):
62
+ # """
63
+ # Cognitive Load Manager wrapped as an OpenEnv-compliant environment.
64
+
65
+ # Three difficulty levels via the task_id reset parameter:
66
+ # - easy: 2 tasks, no deadlines
67
+ # - medium: 5 tasks with deadlines
68
+ # - hard: 8 tasks with tight deadlines
69
+ # """
70
+
71
+ # SUPPORTS_CONCURRENT_SESSIONS = True
72
+
73
+ # def __init__(self):
74
+ # super().__init__()
75
+ # level = os.getenv("CLM_LEVEL", "easy")
76
+ # tasks = generate_tasks(level)
77
+ # self._env = CLMEnvironment(tasks=tasks, max_steps=50)
78
+ # self._final_score: float = 0.0
79
+
80
+ # def _to_oe_obs(self, obs: ModelObservation, done: bool = False, reward: Optional[float] = None, info: Optional[dict] = None) -> CLMObservation:
81
+ # return CLMObservation(
82
+ # tasks=[t.model_dump() for t in obs.tasks],
83
+ # visible_state=obs.visible_state.model_dump(),
84
+ # time_step=obs.time_step,
85
+ # done=done,
86
+ # reward=reward,
87
+ # metadata=info or {},
88
+ # )
89
+
90
+ # def reset(self, seed: Optional[int] = None, episode_id: Optional[str] = None, task_id: str = "easy", **kwargs) -> CLMObservation:
91
+ # if task_id not in ("easy", "medium", "hard"):
92
+ # task_id = "easy"
93
+ # tasks = generate_tasks(task_id)
94
+ # self._env = CLMEnvironment(tasks=tasks, max_steps=50)
95
+ # self._final_score = 0.0
96
+ # obs = self._env.reset()
97
+ # return self._to_oe_obs(obs)
98
+
99
+ # def step(self, action: CLMAction, timeout_s: Optional[float] = None, **kwargs) -> CLMObservation:
100
+ # model_action = ModelAction(type=action.type, task_id=action.task_id)
101
+ # obs, reward, done, info = self._env.step(model_action)
102
+ # if done:
103
+ # self._final_score = deterministic_grader(
104
+ # self._env.state.tasks,
105
+ # self._env.state.time_step,
106
+ # self._env.state.energy,
107
+ # )
108
+ # info["final_score"] = self._final_score
109
+ # return self._to_oe_obs(obs, done=done, reward=float(reward), info=info)
110
+
111
+ # @property
112
+ # def state(self) -> CLMState:
113
+ # raw = self._env.state_dict()
114
+ # return CLMState(
115
+ # energy=raw.get("energy", 1.0),
116
+ # stress=raw.get("stress", 0.0),
117
+ # fatigue=raw.get("fatigue", 0.0),
118
+ # current_task_id=raw.get("current_task_id"),
119
+ # tasks=raw.get("tasks", []),
120
+ # step_count=raw.get("time_step", 0),
121
+ # )
122
+
123
+ # def get_metadata(self) -> EnvironmentMetadata:
124
+ # return EnvironmentMetadata(
125
+ # name="cognitive-load-manager",
126
+ # description=(
127
+ # "Cognitive Load Manager (CLM) simulates human cognitive load "
128
+ # "(energy, stress, fatigue) while managing tasks with deadlines. "
129
+ # "Three difficulty levels: easy (2 tasks, no deadlines), "
130
+ # "medium (5 tasks with deadlines), hard (8 tasks with tight deadlines)."
131
+ # ),
132
+ # version="1.0.0",
133
+ # author="Team Innovators",
134
+ # )
135
+
136
+ # def close(self) -> None:
137
+ # pass
138
+
139
+
140
+ # # ── Build FastAPI app via OpenEnv HTTPEnvServer ──────────────────────────────
141
+
142
+ # def build_app() -> FastAPI:
143
+ # server = HTTPEnvServer(
144
+ # env=CLMEnvWrapper,
145
+ # action_cls=CLMAction,
146
+ # observation_cls=CLMObservation,
147
+ # max_concurrent_envs=10,
148
+ # )
149
+
150
+ # _app = FastAPI(
151
+ # title="Cognitive Load Manager (CLM) Environment API",
152
+ # version="1.0.0",
153
+ # description=(
154
+ # "OpenEnv-compliant environment for the Meta PyTorch Hackathon. "
155
+ # "Simulates cognitive load management with three difficulty levels."
156
+ # ),
157
+ # )
158
+
159
+ # _app.add_middleware(
160
+ # CORSMiddleware,
161
+ # allow_origins=["*"],
162
+ # allow_credentials=True,
163
+ # allow_methods=["*"],
164
+ # allow_headers=["*"],
165
+ # )
166
+
167
+ # server.register_routes(_app)
168
+ # return _app
169
+
170
+
171
+ # app = build_app()
172
+
173
+
174
+
175
+ import uuid
176
  import os
177
  import sys
178
+ from typing import Dict, Any, Optional, List
179
 
180
+ from fastapi import FastAPI, HTTPException
181
+ from fastapi.middleware.cors import CORSMiddleware
182
+ from pydantic import BaseModel, Field
183
+
184
+ # Fix imports
185
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
186
 
187
+ # ── Your core models ─────────────────────────────────────────────────────────
188
+ from models import (
189
+ Action,
190
+ Observation,
191
+ generate_tasks,
192
+ deterministic_grader,
193
+ CLMEnvironment,
194
+ )
195
 
196
+ # ── OpenEnv imports ──────────────────────────────────────────────────────────
197
  from openenv.core.env_server.interfaces import Environment
198
  from openenv.core.env_server.types import (
199
  Action as OEAction,
 
203
  )
204
  from openenv.core.env_server.http_server import HTTPEnvServer
205
 
206
+
207
+ # =============================================================================
208
+ # ── PART 1: SIMPLE FASTAPI API (Your Original API) ────────────────────────────
209
+ # =============================================================================
210
+
211
+ app = FastAPI(
212
+ title="Cognitive Load Manager (CLM) Environment API",
213
+ version="1.0.0"
214
+ )
215
+
216
+ app.add_middleware(
217
+ CORSMiddleware,
218
+ allow_origins=["*"],
219
+ allow_credentials=True,
220
+ allow_methods=["*"],
221
+ allow_headers=["*"],
222
  )
223
 
224
+ # In-memory session store
225
+ sessions: Dict[str, CLMEnvironment] = {}
226
+
227
+
228
+ # ── Request / Response Models ────────────────────────────────────────────────
229
+
230
+ class ResetRequest(BaseModel):
231
+ level: str = "easy"
232
+ task_id: str = "easy"
233
+ session_id: Optional[str] = None
234
+
235
+
236
+ class ResetResponse(BaseModel):
237
+ session_id: str
238
+ observation: Any
239
+
240
+
241
+ class StepRequest(BaseModel):
242
+ session_id: str = "default"
243
+ action: Optional[Action] = None
244
+
245
+
246
+ class StepResponse(BaseModel):
247
+ observation: Any
248
+ reward: float
249
+ done: bool
250
+ info: Dict[str, Any]
251
+
252
+
253
+ # ── Routes ──────────────────────────────────────────────────────────────────
254
+
255
+ @app.post("/reset", response_model=ResetResponse)
256
+ def reset_env(req: Optional[ResetRequest] = None):
257
+ if req is None:
258
+ req = ResetRequest()
259
+
260
+ if req.level not in ["easy", "medium", "hard"]:
261
+ raise HTTPException(status_code=400, detail="Invalid level")
262
+
263
+ if req.task_id not in ["easy", "medium", "hard"]:
264
+ raise HTTPException(status_code=400, detail="Invalid task_id")
265
+
266
+ # FIX: choose ONE (task_id is better)
267
+ tasks = generate_tasks(req.task_id)
268
+
269
+ env = CLMEnvironment(tasks=tasks, max_steps=50)
270
+ obs = env.reset()
271
+
272
+ sess_id = req.session_id or str(uuid.uuid4())
273
+ sessions[sess_id] = env
274
 
275
+ return ResetResponse(session_id=sess_id, observation=obs)
276
+
277
+
278
+ @app.post("/step", response_model=StepResponse)
279
+ def step_env(req: Optional[StepRequest] = None):
280
+ if req is None:
281
+ req = StepRequest()
282
+
283
+ if req.action is None:
284
+ req.action = Action(type="work")
285
+
286
+ if req.session_id not in sessions:
287
+ tasks = generate_tasks("easy")
288
+ env = CLMEnvironment(tasks=tasks, max_steps=50)
289
+ env.reset()
290
+ sessions[req.session_id] = env
291
+
292
+ env = sessions[req.session_id]
293
+
294
+ obs, reward, done, info = env.step(req.action)
295
+
296
+ if done:
297
+ score = deterministic_grader(
298
+ env.state.tasks,
299
+ env.state.time_step,
300
+ env.state.energy
301
+ )
302
+ info["final_score"] = score
303
+
304
+ return StepResponse(
305
+ observation=obs,
306
+ reward=reward,
307
+ done=done,
308
+ info=info
309
+ )
310
+
311
+
312
+ @app.get("/state")
313
+ def get_state(session_id: Optional[str] = "default"):
314
+ if session_id not in sessions:
315
+ tasks = generate_tasks("easy")
316
+ env = CLMEnvironment(tasks=tasks, max_steps=50)
317
+ env.reset()
318
+ sessions[session_id] = env
319
+
320
+ return sessions[session_id].state_dict()
321
+
322
+
323
+ # =============================================================================
324
+ # ── PART 2: OPENENV COMPATIBLE WRAPPER ───────────────────────────────────────
325
+ # =============================================================================
326
 
327
  class CLMAction(OEAction):
328
+ type: str = Field(description="work, break, switch, delay")
329
+ task_id: Optional[str] = None
 
330
 
331
  model_config = {"extra": "allow"}
332
 
333
 
334
  class CLMObservation(OEObservation):
 
335
  tasks: List[Dict[str, Any]] = Field(default_factory=list)
336
  visible_state: Dict[str, Any] = Field(default_factory=dict)
337
+ time_step: int = 0
338
 
339
  model_config = {"extra": "allow"}
340
 
341
 
342
  class CLMState(OEState):
343
+ energy: float = 1.0
344
+ stress: float = 0.0
345
+ fatigue: float = 0.0
346
+ current_task_id: Optional[str] = None
 
347
  tasks: List[Dict[str, Any]] = Field(default_factory=list)
348
 
349
  model_config = {"extra": "allow"}
350
 
351
 
 
 
352
  class CLMEnvWrapper(Environment):
 
 
 
 
 
 
 
 
353
 
354
  SUPPORTS_CONCURRENT_SESSIONS = True
355
 
356
  def __init__(self):
357
  super().__init__()
358
+ tasks = generate_tasks("easy")
 
359
  self._env = CLMEnvironment(tasks=tasks, max_steps=50)
360
+ self._final_score = 0.0
361
 
362
+ def _to_obs(self, obs: Observation, done=False, reward=None, info=None):
363
  return CLMObservation(
364
  tasks=[t.model_dump() for t in obs.tasks],
365
  visible_state=obs.visible_state.model_dump(),
 
369
  metadata=info or {},
370
  )
371
 
372
+ def reset(self, task_id: str = "easy", **kwargs):
373
  if task_id not in ("easy", "medium", "hard"):
374
  task_id = "easy"
375
+
376
  tasks = generate_tasks(task_id)
377
  self._env = CLMEnvironment(tasks=tasks, max_steps=50)
378
+
379
  obs = self._env.reset()
380
+ return self._to_obs(obs)
381
+
382
+ def step(self, action: CLMAction, **kwargs):
383
+ model_action = Action(type=action.type, task_id=action.task_id)
384
 
 
 
385
  obs, reward, done, info = self._env.step(model_action)
386
+
387
  if done:
388
  self._final_score = deterministic_grader(
389
  self._env.state.tasks,
 
391
  self._env.state.energy,
392
  )
393
  info["final_score"] = self._final_score
394
+
395
+ return self._to_obs(obs, done=done, reward=float(reward), info=info)
396
 
397
  @property
398
+ def state(self):
399
  raw = self._env.state_dict()
400
  return CLMState(
401
  energy=raw.get("energy", 1.0),
 
406
  step_count=raw.get("time_step", 0),
407
  )
408
 
409
+ def get_metadata(self):
410
  return EnvironmentMetadata(
411
  name="cognitive-load-manager",
412
+ description="CLM environment with cognitive load simulation",
 
 
 
 
 
413
  version="1.0.0",
414
  author="Team Innovators",
415
  )
416
 
417
+ def close(self):
418
  pass
419
 
420
 
421
+ # =============================================================================
422
+ # ── PART 3: REGISTER OPENENV ROUTES ──────────────────────────────────────────
423
+ # =============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
 
425
+ server = HTTPEnvServer(
426
+ env=CLMEnvWrapper,
427
+ action_cls=CLMAction,
428
+ observation_cls=CLMObservation,
429
+ max_concurrent_envs=10,
430
+ )
431
 
432
+ server.register_routes(app)
inference.py CHANGED
@@ -29,11 +29,10 @@ def post_json(url: str, payload: dict) -> dict:
29
  raise Exception(f"HTTP Error {e.code}: {e.read().decode('utf-8')}")
30
 
31
  # ── Environment variables ────────────────────────────────────────────────────
32
- API_BASE_URL = os.getenv("API_BASE_URL")
33
- if not API_BASE_URL:
34
- raise ValueError("API_BASE_URL environment variable is required")
35
 
36
- API_KEY = os.getenv("API_KEY")
37
  if not API_KEY:
38
  raise ValueError("API_KEY environment variable is required")
39
 
 
29
  raise Exception(f"HTTP Error {e.code}: {e.read().decode('utf-8')}")
30
 
31
  # ── Environment variables ────────────────────────────────────────────────────
32
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
33
+ HF_TOKEN = os.getenv("HF_TOKEN")
 
34
 
35
+ API_KEY = HF_TOKEN or os.getenv("API_KEY")
36
  if not API_KEY:
37
  raise ValueError("API_KEY environment variable is required")
38