jayantaggarwal-sketch commited on
Commit
2f3c64b
·
1 Parent(s): 0194e2e

Fix reset endpoint to honor task_id query parameter.

Browse files
Files changed (1) hide show
  1. server/app.py +28 -1
server/app.py CHANGED
@@ -5,6 +5,7 @@ from __future__ import annotations
5
  import os
6
 
7
  from openenv.core.env_server import create_fastapi_app
 
8
 
9
  from constants import PROJECT_DESCRIPTION, VERSION
10
  from models import CommitmentAction, CommitmentObservation, CommitmentState
@@ -26,10 +27,36 @@ app.version = VERSION
26
 
27
  app.routes[:] = [
28
  r for r in app.routes
29
- if not (hasattr(r, "path") and r.path in ("/state", "/mcp"))
30
  ]
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  @app.get("/state", response_model=CommitmentState)
34
  def get_state() -> CommitmentState:
35
  return _shared_env.state
 
5
  import os
6
 
7
  from openenv.core.env_server import create_fastapi_app
8
+ from fastapi import Query
9
 
10
  from constants import PROJECT_DESCRIPTION, VERSION
11
  from models import CommitmentAction, CommitmentObservation, CommitmentState
 
27
 
28
  app.routes[:] = [
29
  r for r in app.routes
30
+ if not (hasattr(r, "path") and r.path in ("/state", "/mcp", "/reset"))
31
  ]
32
 
33
 
34
+ @app.post("/reset")
35
+ def reset_episode(
36
+ task_id: str | None = Query(default=None),
37
+ difficulty: str | None = Query(default=None),
38
+ seed: int | None = Query(default=None),
39
+ episode_id: str | None = Query(default=None),
40
+ ) -> dict[str, object]:
41
+ """Reset endpoint with explicit query-param support.
42
+
43
+ The default OpenEnv route did not reliably propagate ``task_id`` from
44
+ query params in this deployment setup, which made scenario selection
45
+ non-deterministic for demos/evaluations.
46
+ """
47
+ obs = _shared_env.reset(
48
+ seed=seed,
49
+ episode_id=episode_id,
50
+ task_id=task_id,
51
+ difficulty=difficulty,
52
+ )
53
+ return {
54
+ "observation": obs.model_dump(),
55
+ "reward": float(obs.reward),
56
+ "done": bool(obs.done),
57
+ }
58
+
59
+
60
  @app.get("/state", response_model=CommitmentState)
61
  def get_state() -> CommitmentState:
62
  return _shared_env.state