| from __future__ import annotations | |
| from typing import Any | |
| from tasks import TaskConfig | |
| def _clamp(value: float, low: float = 0.0, high: float = 1.0) -> float: | |
| return max(low, min(high, value)) | |
| def grade_episode(task: TaskConfig, metrics: dict[str, Any]) -> float: | |
| weights = task.score_weights | |
| score = sum(weights.get(metric_name, 0.0) * _clamp(metrics.get(metric_name, 0.0)) for metric_name in weights) | |
| return round(_clamp(score), 4) | |
| def summarize_episode(total_reward: float, state_history: list[dict[str, Any]], terminal_outcome: str) -> dict[str, Any]: | |
| step_count = max(len(state_history), 1) | |
| safety_violations = sum(1 for item in state_history if item.get("unsafe", False)) | |
| lab_steps = [item for item in state_history if item.get("action_type") == "request_lab"] | |
| treatment_steps = [item for item in state_history if item.get("action_type") == "request_treatment"] | |
| early_window = state_history[: min(3, len(state_history))] or state_history | |
| detection = max((item.get("detection_credit", 0.0) for item in early_window), default=0.0) | |
| lab_workup = ( | |
| sum(item.get("lab_score", 0.0) for item in lab_steps) / len(lab_steps) | |
| if lab_steps | |
| else 0.0 | |
| ) | |
| treatment = ( | |
| sum(item.get("treatment_score", 0.0) for item in treatment_steps) / len(treatment_steps) | |
| if treatment_steps | |
| else 0.0 | |
| ) | |
| first_meaningful_step = next( | |
| ( | |
| idx | |
| for idx, item in enumerate(state_history) | |
| if item.get("detection_credit", 0.0) > 0.0 or item.get("treatment_score", 0.0) > 0.0 | |
| ), | |
| step_count, | |
| ) | |
| timeliness = _clamp(1.0 - (first_meaningful_step / step_count)) | |
| stability = sum(item.get("stability_score", 0.0) for item in state_history) / step_count | |
| safety = _clamp(1.0 - (safety_violations / step_count)) | |
| outcome = 1.0 if terminal_outcome == "survived" else 0.0 | |
| return { | |
| "steps": step_count, | |
| "avg_reward": total_reward / step_count, | |
| "detection": round(_clamp(detection), 4), | |
| "lab_workup": round(_clamp(lab_workup), 4), | |
| "treatment": round(_clamp(treatment), 4), | |
| "timeliness": round(_clamp(timeliness), 4), | |
| "stability": round(_clamp(stability), 4), | |
| "safety": round(_clamp(safety), 4), | |
| "safety_violation_rate": safety_violations / step_count, | |
| "outcome": outcome, | |
| } | |