Spaces:
Sleeping
Sleeping
Commit ·
2e410e4
1
Parent(s): f8701a1
Harden API score fields to strict open interval
Browse files- clinical-trial-triage/server/app.py +21 -9
- clinical-trial-triage/server/openenv_env.py +11 -1
- server/app.py +21 -9
- server/openenv_env.py +11 -1
clinical-trial-triage/server/app.py
CHANGED
|
@@ -209,7 +209,16 @@ async def step(
|
|
| 209 |
state.task_id,
|
| 210 |
float(normalized),
|
| 211 |
)
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
except RuntimeError as exc:
|
| 214 |
logger.warning("step runtime error: session_id=%s detail=%s", session_id, str(exc))
|
| 215 |
raise HTTPException(status_code=400, detail=str(exc))
|
|
@@ -220,12 +229,16 @@ async def step(
|
|
| 220 |
|
| 221 |
@app.get("/state")
|
| 222 |
async def state(x_session_id: Optional[str] = Header(default="default")) -> Dict[str, Any]:
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
|
| 231 |
@app.get("/tasks")
|
|
@@ -318,11 +331,10 @@ async def grader(x_session_id: Optional[str] = Header(default="default")) -> Dic
|
|
| 318 |
"episode_id": s.episode_id,
|
| 319 |
"task_id": s.task_id,
|
| 320 |
"done": s.done,
|
| 321 |
-
|
| 322 |
"step_count": s.step_count,
|
| 323 |
"max_steps": s.max_steps,
|
| 324 |
"normalized_score": normalized_score,
|
| 325 |
-
"actions": s.actions_taken,
|
| 326 |
}
|
| 327 |
except RuntimeError as exc:
|
| 328 |
raise HTTPException(status_code=400, detail=str(exc))
|
|
|
|
| 209 |
state.task_id,
|
| 210 |
float(normalized),
|
| 211 |
)
|
| 212 |
+
payload = result.model_dump()
|
| 213 |
+
info = payload.get("info")
|
| 214 |
+
if isinstance(info, dict):
|
| 215 |
+
session_state = env.state()
|
| 216 |
+
info["cumulative_reward"] = _clamp_open_score(
|
| 217 |
+
session_state.cumulative_reward / session_state.step_count
|
| 218 |
+
if session_state.step_count > 0
|
| 219 |
+
else _SCORE_EPS
|
| 220 |
+
)
|
| 221 |
+
return payload
|
| 222 |
except RuntimeError as exc:
|
| 223 |
logger.warning("step runtime error: session_id=%s detail=%s", session_id, str(exc))
|
| 224 |
raise HTTPException(status_code=400, detail=str(exc))
|
|
|
|
| 229 |
|
| 230 |
@app.get("/state")
|
| 231 |
async def state(x_session_id: Optional[str] = Header(default="default")) -> Dict[str, Any]:
|
| 232 |
+
env = get_or_create_session(_safe_session_id(x_session_id))
|
| 233 |
+
try:
|
| 234 |
+
s = env.state()
|
| 235 |
+
payload = s.model_dump()
|
| 236 |
+
payload["cumulative_reward"] = _clamp_open_score(
|
| 237 |
+
s.cumulative_reward / s.step_count if s.step_count > 0 else _SCORE_EPS
|
| 238 |
+
)
|
| 239 |
+
return payload
|
| 240 |
+
except RuntimeError as exc:
|
| 241 |
+
raise HTTPException(status_code=400, detail=str(exc))
|
| 242 |
|
| 243 |
|
| 244 |
@app.get("/tasks")
|
|
|
|
| 331 |
"episode_id": s.episode_id,
|
| 332 |
"task_id": s.task_id,
|
| 333 |
"done": s.done,
|
| 334 |
+
"cumulative_reward": normalized_score,
|
| 335 |
"step_count": s.step_count,
|
| 336 |
"max_steps": s.max_steps,
|
| 337 |
"normalized_score": normalized_score,
|
|
|
|
| 338 |
}
|
| 339 |
except RuntimeError as exc:
|
| 340 |
raise HTTPException(status_code=400, detail=str(exc))
|
clinical-trial-triage/server/openenv_env.py
CHANGED
|
@@ -30,6 +30,13 @@ from models import (
|
|
| 30 |
from server.environment import ClinicalTrialEnvironment
|
| 31 |
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
class OpenEnvTriageAction(Action):
|
| 34 |
"""OpenEnv action wrapper for the clinical triage tasks."""
|
| 35 |
|
|
@@ -177,13 +184,16 @@ class ClinicalTrialOpenEnv(
|
|
| 177 |
@property
|
| 178 |
def state(self) -> OpenEnvTriageState:
|
| 179 |
state = self._core.state()
|
|
|
|
|
|
|
|
|
|
| 180 |
return OpenEnvTriageState(
|
| 181 |
episode_id=state.episode_id,
|
| 182 |
step_count=state.step_count,
|
| 183 |
task_id=TaskID(state.task_id),
|
| 184 |
max_steps=state.max_steps,
|
| 185 |
done=state.done,
|
| 186 |
-
cumulative_reward=
|
| 187 |
current_case_id=state.current_case_id,
|
| 188 |
)
|
| 189 |
|
|
|
|
| 30 |
from server.environment import ClinicalTrialEnvironment
|
| 31 |
|
| 32 |
|
| 33 |
+
_SCORE_EPS = 1e-3
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _clamp_open_score(value: float) -> float:
|
| 37 |
+
return max(_SCORE_EPS, min(1.0 - _SCORE_EPS, float(value)))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
class OpenEnvTriageAction(Action):
|
| 41 |
"""OpenEnv action wrapper for the clinical triage tasks."""
|
| 42 |
|
|
|
|
| 184 |
@property
|
| 185 |
def state(self) -> OpenEnvTriageState:
|
| 186 |
state = self._core.state()
|
| 187 |
+
normalized_cumulative = _clamp_open_score(
|
| 188 |
+
state.cumulative_reward / state.step_count if state.step_count > 0 else _SCORE_EPS
|
| 189 |
+
)
|
| 190 |
return OpenEnvTriageState(
|
| 191 |
episode_id=state.episode_id,
|
| 192 |
step_count=state.step_count,
|
| 193 |
task_id=TaskID(state.task_id),
|
| 194 |
max_steps=state.max_steps,
|
| 195 |
done=state.done,
|
| 196 |
+
cumulative_reward=normalized_cumulative,
|
| 197 |
current_case_id=state.current_case_id,
|
| 198 |
)
|
| 199 |
|
server/app.py
CHANGED
|
@@ -209,7 +209,16 @@ async def step(
|
|
| 209 |
state.task_id,
|
| 210 |
float(normalized),
|
| 211 |
)
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
except RuntimeError as exc:
|
| 214 |
logger.warning("step runtime error: session_id=%s detail=%s", session_id, str(exc))
|
| 215 |
raise HTTPException(status_code=400, detail=str(exc))
|
|
@@ -220,12 +229,16 @@ async def step(
|
|
| 220 |
|
| 221 |
@app.get("/state")
|
| 222 |
async def state(x_session_id: Optional[str] = Header(default="default")) -> Dict[str, Any]:
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
|
| 231 |
@app.get("/tasks")
|
|
@@ -318,11 +331,10 @@ async def grader(x_session_id: Optional[str] = Header(default="default")) -> Dic
|
|
| 318 |
"episode_id": s.episode_id,
|
| 319 |
"task_id": s.task_id,
|
| 320 |
"done": s.done,
|
| 321 |
-
|
| 322 |
"step_count": s.step_count,
|
| 323 |
"max_steps": s.max_steps,
|
| 324 |
"normalized_score": normalized_score,
|
| 325 |
-
"actions": s.actions_taken,
|
| 326 |
}
|
| 327 |
except RuntimeError as exc:
|
| 328 |
raise HTTPException(status_code=400, detail=str(exc))
|
|
|
|
| 209 |
state.task_id,
|
| 210 |
float(normalized),
|
| 211 |
)
|
| 212 |
+
payload = result.model_dump()
|
| 213 |
+
info = payload.get("info")
|
| 214 |
+
if isinstance(info, dict):
|
| 215 |
+
session_state = env.state()
|
| 216 |
+
info["cumulative_reward"] = _clamp_open_score(
|
| 217 |
+
session_state.cumulative_reward / session_state.step_count
|
| 218 |
+
if session_state.step_count > 0
|
| 219 |
+
else _SCORE_EPS
|
| 220 |
+
)
|
| 221 |
+
return payload
|
| 222 |
except RuntimeError as exc:
|
| 223 |
logger.warning("step runtime error: session_id=%s detail=%s", session_id, str(exc))
|
| 224 |
raise HTTPException(status_code=400, detail=str(exc))
|
|
|
|
| 229 |
|
| 230 |
@app.get("/state")
|
| 231 |
async def state(x_session_id: Optional[str] = Header(default="default")) -> Dict[str, Any]:
|
| 232 |
+
env = get_or_create_session(_safe_session_id(x_session_id))
|
| 233 |
+
try:
|
| 234 |
+
s = env.state()
|
| 235 |
+
payload = s.model_dump()
|
| 236 |
+
payload["cumulative_reward"] = _clamp_open_score(
|
| 237 |
+
s.cumulative_reward / s.step_count if s.step_count > 0 else _SCORE_EPS
|
| 238 |
+
)
|
| 239 |
+
return payload
|
| 240 |
+
except RuntimeError as exc:
|
| 241 |
+
raise HTTPException(status_code=400, detail=str(exc))
|
| 242 |
|
| 243 |
|
| 244 |
@app.get("/tasks")
|
|
|
|
| 331 |
"episode_id": s.episode_id,
|
| 332 |
"task_id": s.task_id,
|
| 333 |
"done": s.done,
|
| 334 |
+
"cumulative_reward": normalized_score,
|
| 335 |
"step_count": s.step_count,
|
| 336 |
"max_steps": s.max_steps,
|
| 337 |
"normalized_score": normalized_score,
|
|
|
|
| 338 |
}
|
| 339 |
except RuntimeError as exc:
|
| 340 |
raise HTTPException(status_code=400, detail=str(exc))
|
server/openenv_env.py
CHANGED
|
@@ -30,6 +30,13 @@ from models import (
|
|
| 30 |
from server.environment import ClinicalTrialEnvironment
|
| 31 |
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
class OpenEnvTriageAction(Action):
|
| 34 |
"""OpenEnv action wrapper for the clinical triage tasks."""
|
| 35 |
|
|
@@ -177,13 +184,16 @@ class ClinicalTrialOpenEnv(
|
|
| 177 |
@property
|
| 178 |
def state(self) -> OpenEnvTriageState:
|
| 179 |
state = self._core.state()
|
|
|
|
|
|
|
|
|
|
| 180 |
return OpenEnvTriageState(
|
| 181 |
episode_id=state.episode_id,
|
| 182 |
step_count=state.step_count,
|
| 183 |
task_id=TaskID(state.task_id),
|
| 184 |
max_steps=state.max_steps,
|
| 185 |
done=state.done,
|
| 186 |
-
cumulative_reward=
|
| 187 |
current_case_id=state.current_case_id,
|
| 188 |
)
|
| 189 |
|
|
|
|
| 30 |
from server.environment import ClinicalTrialEnvironment
|
| 31 |
|
| 32 |
|
| 33 |
+
_SCORE_EPS = 1e-3
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _clamp_open_score(value: float) -> float:
|
| 37 |
+
return max(_SCORE_EPS, min(1.0 - _SCORE_EPS, float(value)))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
class OpenEnvTriageAction(Action):
|
| 41 |
"""OpenEnv action wrapper for the clinical triage tasks."""
|
| 42 |
|
|
|
|
| 184 |
@property
|
| 185 |
def state(self) -> OpenEnvTriageState:
|
| 186 |
state = self._core.state()
|
| 187 |
+
normalized_cumulative = _clamp_open_score(
|
| 188 |
+
state.cumulative_reward / state.step_count if state.step_count > 0 else _SCORE_EPS
|
| 189 |
+
)
|
| 190 |
return OpenEnvTriageState(
|
| 191 |
episode_id=state.episode_id,
|
| 192 |
step_count=state.step_count,
|
| 193 |
task_id=TaskID(state.task_id),
|
| 194 |
max_steps=state.max_steps,
|
| 195 |
done=state.done,
|
| 196 |
+
cumulative_reward=normalized_cumulative,
|
| 197 |
current_case_id=state.current_case_id,
|
| 198 |
)
|
| 199 |
|