SayedZahur786 commited on
Commit
3c855d7
·
1 Parent(s): 15a5d1d

fix: revert inference format and make /reset payload optional

Browse files
Files changed (3) hide show
  1. inference.py +7 -21
  2. src/server/app.py +3 -1
  3. tests/test_inference.py +13 -19
inference.py CHANGED
@@ -36,33 +36,19 @@ TASK_MAX_STEPS: dict[str, int] = {
36
  # ---------------------------------------------------------------------------
37
 
38
  def log_start(task: str, env: str, model: str):
39
- print(json.dumps({
40
- "type": "START",
41
- "task": task,
42
- "env": env,
43
- "model": model
44
- }), flush=True)
45
 
46
 
47
  def log_step(step: int, action, reward: float, done: bool, error=None):
48
- print(json.dumps({
49
- "type": "STEP",
50
- "step": step,
51
- "action": str(action),
52
- "reward": reward,
53
- "done": done,
54
- "error": str(error) if error else None
55
- }), flush=True)
56
 
57
 
58
  def log_end(success: bool, steps: int, score: float, rewards: list):
59
- print(json.dumps({
60
- "type": "END",
61
- "success": success,
62
- "steps": steps,
63
- "score": score,
64
- "rewards": rewards
65
- }), flush=True)
66
 
67
 
68
  # ---------------------------------------------------------------------------
 
36
  # ---------------------------------------------------------------------------
37
 
38
  def log_start(task: str, env: str, model: str):
39
+ print(f"[START] task={task} env={env} model={model}", flush=True)
 
 
 
 
 
40
 
41
 
42
  def log_step(step: int, action, reward: float, done: bool, error=None):
43
+ done_str = "true" if done else "false"
44
+ err_str = "null" if error is None else str(error)
45
+ print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_str} error={err_str}", flush=True)
 
 
 
 
 
46
 
47
 
48
  def log_end(success: bool, steps: int, score: float, rewards: list):
49
+ success_str = "true" if success else "false"
50
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
51
+ print(f"[END] success={success_str} steps={steps} score={score:.2f} rewards={rewards_str}", flush=True)
 
 
 
 
52
 
53
 
54
  # ---------------------------------------------------------------------------
src/server/app.py CHANGED
@@ -157,7 +157,9 @@ async def list_tasks() -> list[dict[str, str]]:
157
 
158
 
159
  @app.post("/reset")
160
- async def reset(request: ResetRequest) -> dict[str, Any]:
 
 
161
  global _env
162
  _env = OpenEnvEnvironment(task_id=request.task_id, seed=request.seed)
163
  obs = await _env.reset()
 
157
 
158
 
159
  @app.post("/reset")
160
+ async def reset(request: ResetRequest | None = None) -> dict[str, Any]:
161
+ if request is None:
162
+ request = ResetRequest()
163
  global _env
164
  _env = OpenEnvEnvironment(task_id=request.task_id, seed=request.seed)
165
  obs = await _env.reset()
tests/test_inference.py CHANGED
@@ -36,13 +36,10 @@ class TestInferenceFormatCompliance:
36
  assert returncode == 0, f"inference.py failed: {stderr}"
37
  tasks_run = []
38
  for line in stdout.split("\n"):
39
- if '"type": "START"' in line:
40
- try:
41
- import json
42
- d = json.loads(line)
43
- tasks_run.append(d.get("task"))
44
- except:
45
- pass
46
  assert tasks_run == self.TASK_IDS
47
 
48
  def test_start_line_format(self) -> None:
@@ -53,13 +50,10 @@ class TestInferenceFormatCompliance:
53
  "USE_RANDOM": "true",
54
  }
55
  _, stdout, _ = self._run_inference_capture(env)
 
56
  for line in stdout.split("\n"):
57
- if '"type": "START"' in line:
58
- import json
59
- d = json.loads(line)
60
- assert d.get("task") in self.TASK_IDS
61
- assert d.get("env") == "citywide-dispatch-supervisor"
62
- assert d.get("model") == "test-model"
63
 
64
  def test_step_line_error_format(self) -> None:
65
  env = {
@@ -69,12 +63,13 @@ class TestInferenceFormatCompliance:
69
  "USE_RANDOM": "true",
70
  }
71
  _, stdout, _ = self._run_inference_capture(env)
72
- valid_errors = {None, "max_steps_exceeded", "illegal_transition", "step_error"}
73
  for line in stdout.split("\n"):
74
- if '"type": "STEP"' in line:
75
- import json
76
- d = json.loads(line)
77
- assert d.get("error") in valid_errors or isinstance(d.get("error"), str)
 
78
 
79
 
80
  class TestEnvVarValidation:
@@ -84,7 +79,6 @@ class TestEnvVarValidation:
84
  merged_env.update(env)
85
 
86
  # Ensure tests are not affected by host environment variables.
87
- # If the test doesn't provide a required var, explicitly remove it.
88
  if "API_BASE_URL" not in env:
89
  merged_env.pop("API_BASE_URL", None)
90
  if "MODEL_NAME" not in env:
 
36
  assert returncode == 0, f"inference.py failed: {stderr}"
37
  tasks_run = []
38
  for line in stdout.split("\n"):
39
+ if line.startswith("[START]"):
40
+ match = re.match(r"\[START\] task=(\S+) env=(\S+) model=(\S+)", line)
41
+ assert match
42
+ tasks_run.append(match.group(1))
 
 
 
43
  assert tasks_run == self.TASK_IDS
44
 
45
  def test_start_line_format(self) -> None:
 
50
  "USE_RANDOM": "true",
51
  }
52
  _, stdout, _ = self._run_inference_capture(env)
53
+ pattern = r"\[START\] task=\S+ env=citywide-dispatch-supervisor model=\S+"
54
  for line in stdout.split("\n"):
55
+ if line.startswith("[START]"):
56
+ assert re.match(pattern, line)
 
 
 
 
57
 
58
  def test_step_line_error_format(self) -> None:
59
  env = {
 
63
  "USE_RANDOM": "true",
64
  }
65
  _, stdout, _ = self._run_inference_capture(env)
66
+ valid_errors = {"null", "max_steps_exceeded", "illegal_transition", "step_error"}
67
  for line in stdout.split("\n"):
68
+ if not line.startswith("[STEP]"):
69
+ continue
70
+ match = re.match(r"\[STEP\].+ error=(.+)", line)
71
+ assert match
72
+ assert match.group(1) in valid_errors
73
 
74
 
75
  class TestEnvVarValidation:
 
79
  merged_env.update(env)
80
 
81
  # Ensure tests are not affected by host environment variables.
 
82
  if "API_BASE_URL" not in env:
83
  merged_env.pop("API_BASE_URL", None)
84
  if "MODEL_NAME" not in env: