vvinayakkk commited on
Commit
2e410e4
·
1 Parent(s): f8701a1

Harden API score fields to strict open interval

Browse files
clinical-trial-triage/server/app.py CHANGED
@@ -209,7 +209,16 @@ async def step(
209
  state.task_id,
210
  float(normalized),
211
  )
212
- return result.model_dump()
 
 
 
 
 
 
 
 
 
213
  except RuntimeError as exc:
214
  logger.warning("step runtime error: session_id=%s detail=%s", session_id, str(exc))
215
  raise HTTPException(status_code=400, detail=str(exc))
@@ -220,12 +229,16 @@ async def step(
220
 
221
  @app.get("/state")
222
  async def state(x_session_id: Optional[str] = Header(default="default")) -> Dict[str, Any]:
223
- env = get_or_create_session(_safe_session_id(x_session_id))
224
- try:
225
- s = env.state()
226
- return s.model_dump()
227
- except RuntimeError as exc:
228
- raise HTTPException(status_code=400, detail=str(exc))
 
 
 
 
229
 
230
 
231
  @app.get("/tasks")
@@ -318,11 +331,10 @@ async def grader(x_session_id: Optional[str] = Header(default="default")) -> Dic
318
  "episode_id": s.episode_id,
319
  "task_id": s.task_id,
320
  "done": s.done,
321
- "cumulative_reward": s.cumulative_reward,
322
  "step_count": s.step_count,
323
  "max_steps": s.max_steps,
324
  "normalized_score": normalized_score,
325
- "actions": s.actions_taken,
326
  }
327
  except RuntimeError as exc:
328
  raise HTTPException(status_code=400, detail=str(exc))
 
209
  state.task_id,
210
  float(normalized),
211
  )
212
+ payload = result.model_dump()
213
+ info = payload.get("info")
214
+ if isinstance(info, dict):
215
+ session_state = env.state()
216
+ info["cumulative_reward"] = _clamp_open_score(
217
+ session_state.cumulative_reward / session_state.step_count
218
+ if session_state.step_count > 0
219
+ else _SCORE_EPS
220
+ )
221
+ return payload
222
  except RuntimeError as exc:
223
  logger.warning("step runtime error: session_id=%s detail=%s", session_id, str(exc))
224
  raise HTTPException(status_code=400, detail=str(exc))
 
229
 
230
  @app.get("/state")
231
  async def state(x_session_id: Optional[str] = Header(default="default")) -> Dict[str, Any]:
232
+ env = get_or_create_session(_safe_session_id(x_session_id))
233
+ try:
234
+ s = env.state()
235
+ payload = s.model_dump()
236
+ payload["cumulative_reward"] = _clamp_open_score(
237
+ s.cumulative_reward / s.step_count if s.step_count > 0 else _SCORE_EPS
238
+ )
239
+ return payload
240
+ except RuntimeError as exc:
241
+ raise HTTPException(status_code=400, detail=str(exc))
242
 
243
 
244
  @app.get("/tasks")
 
331
  "episode_id": s.episode_id,
332
  "task_id": s.task_id,
333
  "done": s.done,
334
+ "cumulative_reward": normalized_score,
335
  "step_count": s.step_count,
336
  "max_steps": s.max_steps,
337
  "normalized_score": normalized_score,
 
338
  }
339
  except RuntimeError as exc:
340
  raise HTTPException(status_code=400, detail=str(exc))
clinical-trial-triage/server/openenv_env.py CHANGED
@@ -30,6 +30,13 @@ from models import (
30
  from server.environment import ClinicalTrialEnvironment
31
 
32
 
 
 
 
 
 
 
 
33
  class OpenEnvTriageAction(Action):
34
  """OpenEnv action wrapper for the clinical triage tasks."""
35
 
@@ -177,13 +184,16 @@ class ClinicalTrialOpenEnv(
177
  @property
178
  def state(self) -> OpenEnvTriageState:
179
  state = self._core.state()
 
 
 
180
  return OpenEnvTriageState(
181
  episode_id=state.episode_id,
182
  step_count=state.step_count,
183
  task_id=TaskID(state.task_id),
184
  max_steps=state.max_steps,
185
  done=state.done,
186
- cumulative_reward=state.cumulative_reward,
187
  current_case_id=state.current_case_id,
188
  )
189
 
 
30
  from server.environment import ClinicalTrialEnvironment
31
 
32
 
33
+ _SCORE_EPS = 1e-3
34
+
35
+
36
+ def _clamp_open_score(value: float) -> float:
37
+ return max(_SCORE_EPS, min(1.0 - _SCORE_EPS, float(value)))
38
+
39
+
40
  class OpenEnvTriageAction(Action):
41
  """OpenEnv action wrapper for the clinical triage tasks."""
42
 
 
184
  @property
185
  def state(self) -> OpenEnvTriageState:
186
  state = self._core.state()
187
+ normalized_cumulative = _clamp_open_score(
188
+ state.cumulative_reward / state.step_count if state.step_count > 0 else _SCORE_EPS
189
+ )
190
  return OpenEnvTriageState(
191
  episode_id=state.episode_id,
192
  step_count=state.step_count,
193
  task_id=TaskID(state.task_id),
194
  max_steps=state.max_steps,
195
  done=state.done,
196
+ cumulative_reward=normalized_cumulative,
197
  current_case_id=state.current_case_id,
198
  )
199
 
server/app.py CHANGED
@@ -209,7 +209,16 @@ async def step(
209
  state.task_id,
210
  float(normalized),
211
  )
212
- return result.model_dump()
 
 
 
 
 
 
 
 
 
213
  except RuntimeError as exc:
214
  logger.warning("step runtime error: session_id=%s detail=%s", session_id, str(exc))
215
  raise HTTPException(status_code=400, detail=str(exc))
@@ -220,12 +229,16 @@ async def step(
220
 
221
  @app.get("/state")
222
  async def state(x_session_id: Optional[str] = Header(default="default")) -> Dict[str, Any]:
223
- env = get_or_create_session(_safe_session_id(x_session_id))
224
- try:
225
- s = env.state()
226
- return s.model_dump()
227
- except RuntimeError as exc:
228
- raise HTTPException(status_code=400, detail=str(exc))
 
 
 
 
229
 
230
 
231
  @app.get("/tasks")
@@ -318,11 +331,10 @@ async def grader(x_session_id: Optional[str] = Header(default="default")) -> Dic
318
  "episode_id": s.episode_id,
319
  "task_id": s.task_id,
320
  "done": s.done,
321
- "cumulative_reward": s.cumulative_reward,
322
  "step_count": s.step_count,
323
  "max_steps": s.max_steps,
324
  "normalized_score": normalized_score,
325
- "actions": s.actions_taken,
326
  }
327
  except RuntimeError as exc:
328
  raise HTTPException(status_code=400, detail=str(exc))
 
209
  state.task_id,
210
  float(normalized),
211
  )
212
+ payload = result.model_dump()
213
+ info = payload.get("info")
214
+ if isinstance(info, dict):
215
+ session_state = env.state()
216
+ info["cumulative_reward"] = _clamp_open_score(
217
+ session_state.cumulative_reward / session_state.step_count
218
+ if session_state.step_count > 0
219
+ else _SCORE_EPS
220
+ )
221
+ return payload
222
  except RuntimeError as exc:
223
  logger.warning("step runtime error: session_id=%s detail=%s", session_id, str(exc))
224
  raise HTTPException(status_code=400, detail=str(exc))
 
229
 
230
  @app.get("/state")
231
  async def state(x_session_id: Optional[str] = Header(default="default")) -> Dict[str, Any]:
232
+ env = get_or_create_session(_safe_session_id(x_session_id))
233
+ try:
234
+ s = env.state()
235
+ payload = s.model_dump()
236
+ payload["cumulative_reward"] = _clamp_open_score(
237
+ s.cumulative_reward / s.step_count if s.step_count > 0 else _SCORE_EPS
238
+ )
239
+ return payload
240
+ except RuntimeError as exc:
241
+ raise HTTPException(status_code=400, detail=str(exc))
242
 
243
 
244
  @app.get("/tasks")
 
331
  "episode_id": s.episode_id,
332
  "task_id": s.task_id,
333
  "done": s.done,
334
+ "cumulative_reward": normalized_score,
335
  "step_count": s.step_count,
336
  "max_steps": s.max_steps,
337
  "normalized_score": normalized_score,
 
338
  }
339
  except RuntimeError as exc:
340
  raise HTTPException(status_code=400, detail=str(exc))
server/openenv_env.py CHANGED
@@ -30,6 +30,13 @@ from models import (
30
  from server.environment import ClinicalTrialEnvironment
31
 
32
 
 
 
 
 
 
 
 
33
  class OpenEnvTriageAction(Action):
34
  """OpenEnv action wrapper for the clinical triage tasks."""
35
 
@@ -177,13 +184,16 @@ class ClinicalTrialOpenEnv(
177
  @property
178
  def state(self) -> OpenEnvTriageState:
179
  state = self._core.state()
 
 
 
180
  return OpenEnvTriageState(
181
  episode_id=state.episode_id,
182
  step_count=state.step_count,
183
  task_id=TaskID(state.task_id),
184
  max_steps=state.max_steps,
185
  done=state.done,
186
- cumulative_reward=state.cumulative_reward,
187
  current_case_id=state.current_case_id,
188
  )
189
 
 
30
  from server.environment import ClinicalTrialEnvironment
31
 
32
 
33
+ _SCORE_EPS = 1e-3
34
+
35
+
36
+ def _clamp_open_score(value: float) -> float:
37
+ return max(_SCORE_EPS, min(1.0 - _SCORE_EPS, float(value)))
38
+
39
+
40
  class OpenEnvTriageAction(Action):
41
  """OpenEnv action wrapper for the clinical triage tasks."""
42
 
 
184
  @property
185
  def state(self) -> OpenEnvTriageState:
186
  state = self._core.state()
187
+ normalized_cumulative = _clamp_open_score(
188
+ state.cumulative_reward / state.step_count if state.step_count > 0 else _SCORE_EPS
189
+ )
190
  return OpenEnvTriageState(
191
  episode_id=state.episode_id,
192
  step_count=state.step_count,
193
  task_id=TaskID(state.task_id),
194
  max_steps=state.max_steps,
195
  done=state.done,
196
+ cumulative_reward=normalized_cumulative,
197
  current_case_id=state.current_case_id,
198
  )
199