Sepsis-OpenEnv / graders.py
BAIBHAV1234's picture
Upload folder using huggingface_hub
8017cc7 verified
from __future__ import annotations
from typing import Any
from tasks import TaskConfig
SCORE_EPS = 1e-3
SCORE_MARGIN = 1e-6
def _clamp(value: float, low: float = SCORE_EPS, high: float = 1.0 - SCORE_EPS) -> float:
try:
numeric_value = float(value)
except (TypeError, ValueError):
numeric_value = low
if numeric_value <= low:
return low + SCORE_MARGIN
if numeric_value >= high:
return high - SCORE_MARGIN
return numeric_value
def _strict_score(value: float) -> float:
return _clamp(value, SCORE_EPS, 1.0 - SCORE_EPS)
def _normalize_weights(weights: dict[str, float]) -> dict[str, float]:
total = sum(float(weight) for weight in weights.values())
if total <= 0:
return weights
return {metric_name: float(weight) / total for metric_name, weight in weights.items()}
def _format_metric(value: float) -> float:
return float(f"{_clamp(value):.6f}")
def grade_episode(task: TaskConfig, metrics: dict[str, Any]) -> float:
weights = _normalize_weights(task.score_weights)
score = sum(weight * _clamp(metrics.get(metric_name, 0.0)) for metric_name, weight in weights.items())
return float(f"{_strict_score(score):.6f}")
def summarize_episode(total_reward: float, state_history: list[dict[str, Any]], terminal_outcome: str) -> dict[str, Any]:
step_count = max(len(state_history), 1)
safety_violations = sum(1 for item in state_history if item.get("unsafe", False))
lab_steps = [item for item in state_history if item.get("action_type") == "request_lab"]
treatment_steps = [item for item in state_history if item.get("action_type") == "request_treatment"]
early_window = state_history[: min(3, len(state_history))] or state_history
detection = max((item.get("detection_credit", 0.0) for item in early_window), default=0.0)
lab_workup = (
sum(item.get("lab_score", 0.0) for item in lab_steps) / len(lab_steps)
if lab_steps
else 0.0
)
treatment = (
sum(item.get("treatment_score", 0.0) for item in treatment_steps) / len(treatment_steps)
if treatment_steps
else 0.0
)
first_meaningful_step = next(
(
idx
for idx, item in enumerate(state_history)
if item.get("detection_credit", 0.0) > 0.0 or item.get("treatment_score", 0.0) > 0.0
),
step_count,
)
timeliness = _clamp(1.0 - (first_meaningful_step / step_count))
stability = _clamp(sum(item.get("stability_score", 0.0) for item in state_history) / step_count)
safety = _clamp(1.0 - (safety_violations / step_count))
outcome = _format_metric(1.0 if terminal_outcome == "survived" else 0.0)
return {
"steps": step_count,
"avg_reward": _format_metric(total_reward / step_count),
"detection": _format_metric(detection),
"lab_workup": _format_metric(lab_workup),
"treatment": _format_metric(treatment),
"timeliness": _format_metric(timeliness),
"stability": _format_metric(stability),
"safety": _format_metric(safety),
"safety_violation_rate": _format_metric(safety_violations / step_count),
"outcome": outcome,
}