| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import random |
| import re |
| from collections import Counter |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| from openai import OpenAI |
|
|
| from client import SepsisTreatmentEnv |
| from models import SepsisAction, SepsisObservation |
|
|
|
|
| OUTPUT_DIR = Path("outputs") |
| MAX_STEPS_PER_TASK = {"easy": 8, "medium": 12, "hard": 16} |
| EPSILON = 0.1 |
| RNG = random.Random(7) |
| ENV_NAME = "sepsis-openenv" |
| LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") |
| DEFAULT_API_BASE_URL = "https://router.huggingface.co/v1" |
| DEFAULT_MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct" |
| TASK_IDS = ["easy", "medium", "hard"] |
| LAB_OPTIONS = ["lactate", "wbc", "creatinine", "bicarbonate", "platelets", "bilirubin"] |
| TREATMENT_OPTIONS = ["monitor", "fluids", "vasopressors", "combination"] |
| LAB_ALIASES = { |
| "lactate": ["lactate", "lactic acid", "serum lactate", "blood lactate"], |
| "wbc": ["wbc", "white blood cell", "white blood cell count", "complete blood count", "cbc"], |
| "creatinine": [ |
| "creatinine", |
| "renal panel", |
| "kidney function", |
| "bmp", |
| "basic metabolic panel", |
| "cmp", |
| "comprehensive metabolic panel", |
| "bun", |
| ], |
| "bicarbonate": ["bicarbonate", "hco3", "co2", "carbon dioxide", "blood gas", "abg", "vbg"], |
| "platelets": ["platelets", "platelet count"], |
| "bilirubin": ["bilirubin", "total bili", "total bilirubin", "liver function", "lft"], |
| } |
| TREATMENT_ALIASES = { |
| "monitor": ["monitor", "observe", "observation", "watch", "watchful waiting", "reassess"], |
| "fluids": ["fluids", "iv fluids", "fluid resuscitation", "crystalloid", "bolus"], |
| "vasopressors": ["vasopressors", "pressor", "pressors", "norepinephrine", "levophed"], |
| "combination": ["combination", "fluids and vasopressors", "dual therapy", "both fluids and pressors"], |
| } |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Run sepsis environment inference and aggregate metrics.") |
| parser.add_argument( |
| "--episodes", |
| type=int, |
| default=1, |
| help="Number of evaluation cycles to run. Each cycle runs easy, medium, and hard once.", |
| ) |
| parser.add_argument( |
| "--model", |
| choices=["auto", "heuristic", "llm", "id3qne"], |
| default="auto", |
| help="Policy mode. auto uses llm if credentials are present, otherwise heuristic.", |
| ) |
| parser.add_argument( |
| "--output", |
| default=str(OUTPUT_DIR / "baseline_scores.json"), |
| help="Path for the JSON summary output.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def format_action(action: SepsisAction) -> str: |
| if action.action_type == "request_lab": |
| return f"request_lab({action.lab_type}, suspect_sepsis={str(action.suspect_sepsis).lower()})" |
| if action.action_type == "request_treatment": |
| return f"request_treatment({action.treatment_type}, suspect_sepsis={str(action.suspect_sepsis).lower()})" |
| return f"monitor(suspect_sepsis={str(action.suspect_sepsis).lower()})" |
|
|
|
|
| def curriculum_action(observation: SepsisObservation) -> SepsisAction | None: |
| task_id = observation.task_id |
| step_index = observation.step_index |
|
|
| if task_id == "easy": |
| if step_index == 0: |
| return SepsisAction( |
| action_type="request_lab", |
| suspect_sepsis=True, |
| lab_type="lactate", |
| rationale="Curriculum schedule: start with lactate for early workup.", |
| ) |
| if step_index == 2: |
| return SepsisAction( |
| action_type="request_lab", |
| suspect_sepsis=True, |
| lab_type="creatinine", |
| rationale="Curriculum schedule: add renal assessment when the workup broadens.", |
| ) |
| return SepsisAction( |
| action_type="monitor", |
| suspect_sepsis=True, |
| rationale="Curriculum schedule: maintain observation after priority labs.", |
| ) |
|
|
| if task_id == "medium": |
| lab_plan = {0: "lactate", 1: "wbc", 2: "creatinine", 3: "bicarbonate"} |
| if step_index in lab_plan: |
| return SepsisAction( |
| action_type="request_lab", |
| suspect_sepsis=True, |
| lab_type=lab_plan[step_index], |
| rationale=f"Curriculum schedule: collect {lab_plan[step_index]} before treatment.", |
| ) |
| treatment_type = "fluids" if step_index <= 6 else "monitor" |
| return SepsisAction( |
| action_type="request_treatment", |
| suspect_sepsis=True, |
| treatment_type=treatment_type, |
| rationale="Curriculum schedule: shift from early support to monitoring.", |
| ) |
|
|
| if task_id == "hard": |
| if step_index == 0: |
| return SepsisAction( |
| action_type="request_lab", |
| suspect_sepsis=True, |
| lab_type="lactate", |
| rationale="Curriculum schedule: start unstable trajectory with lactate.", |
| ) |
| if step_index == 1: |
| return SepsisAction( |
| action_type="request_lab", |
| suspect_sepsis=True, |
| lab_type="creatinine", |
| rationale="Curriculum schedule: check renal strain before extended management.", |
| ) |
| treatment_type = "monitor" if step_index in {3, 4} else "fluids" |
| return SepsisAction( |
| action_type="request_treatment", |
| suspect_sepsis=True, |
| treatment_type=treatment_type, |
| rationale="Curriculum schedule: alternate stabilization and support across the harder case.", |
| ) |
|
|
| return None |
|
|
|
|
| def heuristic_action(observation: SepsisObservation) -> SepsisAction: |
| scheduled_action = curriculum_action(observation) |
| if scheduled_action is not None: |
| return scheduled_action |
|
|
| severity = observation.severity_proxy |
| shock = observation.vitals.get("Shock_Index", 0.0) |
| mean_bp = observation.vitals.get("MeanBP", 0.0) |
| visible_labs = observation.visible_labs |
| requested_labs = set(observation.requested_labs) |
|
|
| if RNG.random() < EPSILON: |
| unseen_labs = [lab for lab in ["lactate", "wbc", "creatinine", "bicarbonate"] if lab not in requested_labs] |
| if unseen_labs: |
| lab_choice = unseen_labs[0] |
| return SepsisAction( |
| action_type="request_lab", |
| suspect_sepsis=severity >= 1.0 or shock > 0.1 or mean_bp < 0.0, |
| lab_type=lab_choice, |
| rationale="Exploration step", |
| ) |
|
|
| lab_priority_order = ["lactate", "wbc", "creatinine", "bicarbonate"] |
| for lab in lab_priority_order: |
| should_request = False |
| if lab == "lactate": |
| should_request = lab not in requested_labs |
| elif lab == "wbc": |
| should_request = lab not in requested_labs and (severity >= 0.75 or shock > 0.08) |
| elif lab == "creatinine": |
| should_request = lab not in requested_labs and severity >= 1.2 |
| elif lab == "bicarbonate": |
| should_request = lab not in requested_labs and (severity >= 1.5 or mean_bp < -0.1) |
|
|
| if should_request: |
| return SepsisAction( |
| action_type="request_lab", |
| suspect_sepsis=severity >= 1.0 or shock > 0.1 or mean_bp < 0.0, |
| lab_type=lab, |
| rationale=f"Exploring informative lab: {lab}", |
| ) |
|
|
| lactate = visible_labs.get("lactate", 0.0) or 0.0 |
| bicarbonate = visible_labs.get("bicarbonate", 0.0) or 0.0 |
| if severity < 0.8 and mean_bp >= -0.1: |
| treatment_type = "monitor" |
| elif severity >= 2.0 or mean_bp < -0.2: |
| treatment_type = "combination" if lactate > 0.25 else "vasopressors" |
| elif shock > 0.15 or severity >= 1.1 or bicarbonate < -0.15: |
| treatment_type = "fluids" |
| else: |
| treatment_type = "monitor" |
|
|
| return SepsisAction( |
| action_type="request_treatment", |
| suspect_sepsis=severity >= 1.0 or lactate > 0.25, |
| treatment_type=treatment_type, |
| rationale="Improved staged policy with exploration and severity awareness.", |
| ) |
|
|
|
|
| def iter_text_fragments(value: Any) -> list[str]: |
| if value is None: |
| return [] |
| if isinstance(value, str): |
| return [value] |
| if isinstance(value, (list, tuple, set)): |
| fragments: list[str] = [] |
| for item in value: |
| fragments.extend(iter_text_fragments(item)) |
| return fragments |
| if isinstance(value, dict): |
| fragments: list[str] = [] |
| for item in value.values(): |
| fragments.extend(iter_text_fragments(item)) |
| return fragments |
| return [str(value)] |
|
|
|
|
| def normalize_text(value: Any) -> str: |
| fragments = iter_text_fragments(value) |
| raw = " ".join(fragment.strip().lower() for fragment in fragments if fragment) |
| return re.sub(r"[^a-z0-9]+", " ", raw).strip() |
|
|
|
|
| def match_alias(value: Any, alias_map: dict[str, list[str]]) -> str | None: |
| fragments = iter_text_fragments(value) |
| matches: list[str] = [] |
| for fragment in fragments: |
| normalized = normalize_text(fragment) |
| if not normalized: |
| continue |
| for canonical, aliases in alias_map.items(): |
| if normalized == canonical: |
| matches.append(canonical) |
| continue |
| if any(alias in normalized for alias in aliases): |
| matches.append(canonical) |
| continue |
|
|
| if not matches: |
| combined = normalize_text(value) |
| for canonical, aliases in alias_map.items(): |
| if combined == canonical or any(alias in combined for alias in aliases): |
| matches.append(canonical) |
|
|
| unique_matches = list(dict.fromkeys(matches)) |
| if len(unique_matches) == 1: |
| return unique_matches[0] |
| return None |
|
|
|
|
| def parse_boolish(value: Any, default: bool = False) -> bool: |
| if isinstance(value, bool): |
| return value |
| if value is None: |
| return default |
| normalized = normalize_text(value) |
| if normalized in {"true", "yes", "y", "1"}: |
| return True |
| if normalized in {"false", "no", "n", "0"}: |
| return False |
| return default |
|
|
|
|
| def normalize_lab_choice(value: Any) -> str | None: |
| return match_alias(value, LAB_ALIASES) |
|
|
|
|
| def normalize_treatment_choice(value: Any) -> str | None: |
| return match_alias(value, TREATMENT_ALIASES) |
|
|
|
|
| def normalize_action_type(value: Any, lab_choice: str | None, treatment_choice: str | None) -> str | None: |
| normalized = normalize_text(value) |
| if normalized in {"request lab", "lab", "labs", "test", "request tests", "request lab test"}: |
| return "request_lab" |
| if normalized in {"request treatment", "treatment", "treat", "therapy", "intervene"}: |
| return "request_treatment" |
| if normalized in {"monitor", "observe", "observation", "watch", "reassess"}: |
| return "monitor" |
| if lab_choice: |
| return "request_lab" |
| if treatment_choice: |
| return "request_treatment" |
| return None |
|
|
|
|
| def should_use_heuristic_guardrail( |
| candidate: SepsisAction, |
| heuristic: SepsisAction, |
| observation: SepsisObservation, |
| ) -> bool: |
| del observation |
| return ( |
| candidate.action_type != heuristic.action_type |
| or candidate.lab_type != heuristic.lab_type |
| or candidate.treatment_type != heuristic.treatment_type |
| ) |
|
|
|
|
| def repair_model_action(payload: dict[str, Any], observation: SepsisObservation) -> tuple[SepsisAction, str, str | None]: |
| heuristic = heuristic_action(observation) |
| normalized_lab = normalize_lab_choice(payload.get("lab_type")) |
| normalized_treatment = normalize_treatment_choice(payload.get("treatment_type")) |
| normalized_action_type = normalize_action_type( |
| payload.get("action_type"), |
| normalized_lab, |
| normalized_treatment, |
| ) |
| suspect_sepsis = parse_boolish(payload.get("suspect_sepsis"), default=heuristic.suspect_sepsis) |
| suspect_sepsis = suspect_sepsis or heuristic.suspect_sepsis |
| rationale = str(payload.get("rationale", "")).strip() |
|
|
| if normalized_action_type == "request_lab": |
| if normalized_lab is None: |
| return heuristic, "heuristic_guardrail", "LLM selected an unsupported lab; using heuristic action." |
| candidate = SepsisAction( |
| action_type="request_lab", |
| suspect_sepsis=suspect_sepsis, |
| lab_type=normalized_lab, |
| rationale=rationale or "LLM lab choice normalized into environment action space.", |
| ) |
| elif normalized_action_type == "request_treatment": |
| if normalized_treatment is None: |
| return heuristic, "heuristic_guardrail", "LLM selected an unsupported treatment; using heuristic action." |
| candidate = SepsisAction( |
| action_type="request_treatment", |
| suspect_sepsis=suspect_sepsis, |
| treatment_type=normalized_treatment, |
| rationale=rationale or "LLM treatment choice normalized into environment action space.", |
| ) |
| elif normalized_action_type == "monitor": |
| candidate = SepsisAction( |
| action_type="monitor", |
| suspect_sepsis=suspect_sepsis, |
| rationale=rationale or "LLM monitor choice normalized into environment action space.", |
| ) |
| else: |
| return heuristic, "heuristic_guardrail", "LLM action could not be normalized; using heuristic action." |
|
|
| if should_use_heuristic_guardrail(candidate, heuristic, observation): |
| return heuristic, "heuristic_guardrail", "LLM action was valid but low-value for this step; using heuristic." |
|
|
| if candidate.model_dump(exclude={"rationale"}) == heuristic.model_dump(exclude={"rationale"}): |
| return candidate, "llm_aligned", None |
| return candidate, "llm_repaired", None |
|
|
|
|
| def id3qne_action(observation: SepsisObservation) -> SepsisAction: |
| task_id = observation.task_id |
| step_index = observation.step_index |
| severity = observation.severity_proxy |
| mean_bp = observation.vitals.get("MeanBP", 0.0) |
| shock = observation.vitals.get("Shock_Index", 0.0) |
| requested_labs = set(observation.requested_labs) |
| visible_labs = observation.visible_labs |
| suspect_sepsis = severity >= 1.0 or shock > 0.1 or mean_bp < 0.0 |
|
|
| if task_id == "easy": |
| if "lactate" not in requested_labs: |
| return SepsisAction( |
| action_type="request_lab", |
| suspect_sepsis=True, |
| lab_type="lactate", |
| rationale="ID3QNE tree: always reveal lactate first in the easy workup branch.", |
| ) |
| if step_index >= 2 and "creatinine" not in requested_labs: |
| return SepsisAction( |
| action_type="request_lab", |
| suspect_sepsis=True, |
| lab_type="creatinine", |
| rationale="ID3QNE tree: second split requests creatinine for renal assessment.", |
| ) |
| return SepsisAction( |
| action_type="monitor", |
| suspect_sepsis=True, |
| rationale="ID3QNE tree: monitor after the high-yield easy branch labs are collected.", |
| ) |
|
|
| if task_id == "medium": |
| for lab_name in ["lactate", "wbc", "creatinine", "bicarbonate"]: |
| if lab_name not in requested_labs: |
| return SepsisAction( |
| action_type="request_lab", |
| suspect_sepsis=True, |
| lab_type=lab_name, |
| rationale=f"ID3QNE tree: continue the medium-depth lab branch with {lab_name}.", |
| ) |
| treatment_type = "fluids" if step_index <= 6 else "monitor" |
| return SepsisAction( |
| action_type="request_treatment", |
| suspect_sepsis=True, |
| treatment_type=treatment_type, |
| rationale="ID3QNE tree: treat early, then monitor after stabilization.", |
| ) |
|
|
| if task_id == "hard": |
| if "lactate" not in requested_labs: |
| return SepsisAction( |
| action_type="request_lab", |
| suspect_sepsis=True, |
| lab_type="lactate", |
| rationale="ID3QNE tree: unstable branch starts with lactate.", |
| ) |
| if "creatinine" not in requested_labs: |
| return SepsisAction( |
| action_type="request_lab", |
| suspect_sepsis=True, |
| lab_type="creatinine", |
| rationale="ID3QNE tree: follow with creatinine when the trajectory turns unstable.", |
| ) |
| creatinine = visible_labs.get("creatinine", 0.0) or 0.0 |
| if step_index in {3, 4} and severity < 1.5 and mean_bp >= -0.2: |
| treatment_type = "monitor" |
| elif severity >= 2.0 and mean_bp < -0.2: |
| treatment_type = "combination" |
| elif severity >= 1.0 or creatinine > 0.15 or step_index >= 5: |
| treatment_type = "fluids" |
| else: |
| treatment_type = "monitor" |
| return SepsisAction( |
| action_type="request_treatment", |
| suspect_sepsis=suspect_sepsis or creatinine > 0.15 or step_index >= 1, |
| treatment_type=treatment_type, |
| rationale="ID3QNE tree: treatment branch uses severity, renal strain, and step progression.", |
| ) |
|
|
| return heuristic_action(observation) |
|
|
|
|
| def build_prompt(observation: SepsisObservation) -> str: |
| return ( |
| "You are controlling a sequential sepsis management simulator.\n" |
| f"Task: {observation.task_description}\n" |
| f"Step: {observation.step_index + 1}/{observation.max_steps}\n" |
| f"Severity proxy: {observation.severity_proxy:.2f}\n" |
| f"Mortality flag in logged trajectory: {observation.mortality_risk_flag}\n" |
| f"Demographics: {json.dumps(observation.demographics)}\n" |
| f"Vitals: {json.dumps(observation.vitals)}\n" |
| f"Context features: {json.dumps(observation.context_features)}\n" |
| f"Visible labs: {json.dumps(observation.visible_labs)}\n" |
| f"Requested labs so far: {json.dumps(observation.requested_labs)}\n" |
| "You must choose exactly one environment action.\n" |
| f"Allowed lab_type values: {json.dumps(LAB_OPTIONS)}.\n" |
| f"Allowed treatment_type values: {json.dumps(TREATMENT_OPTIONS)}.\n" |
| "If action_type is request_lab, lab_type must be one of the allowed values and treatment_type must be null.\n" |
| "If action_type is request_treatment, treatment_type must be one of the allowed values and lab_type must be null.\n" |
| "If action_type is monitor, both lab_type and treatment_type must be null.\n" |
| "Do not return lists, synonyms, antibiotics, blood cultures, or free-text clinical plans.\n" |
| "Return JSON only with keys action_type, suspect_sepsis, lab_type, treatment_type, rationale." |
| ) |
|
|
|
|
| def parse_model_json(content: str) -> dict[str, Any]: |
| candidate = content.strip() |
| if candidate.startswith("```"): |
| lines = candidate.splitlines() |
| if lines and lines[0].startswith("```"): |
| lines = lines[1:] |
| if lines and lines[-1].startswith("```"): |
| lines = lines[:-1] |
| candidate = "\n".join(lines).strip() |
|
|
| match = re.search(r"\{.*\}", candidate, re.DOTALL) |
| if match: |
| candidate = match.group(0) |
| return json.loads(candidate) |
|
|
|
|
| def model_action( |
| client: OpenAI | None, |
| model_name: str | None, |
| observation: SepsisObservation, |
| ) -> tuple[SepsisAction, str, str | None]: |
| if client is None or not model_name: |
| return heuristic_action(observation), "heuristic", None |
|
|
| messages = [ |
| {"role": "system", "content": "Return only valid JSON for a sepsis management action."}, |
| {"role": "user", "content": build_prompt(observation)}, |
| ] |
| try: |
| completion = client.chat.completions.create( |
| model=model_name, |
| messages=messages, |
| temperature=0.0, |
| max_tokens=200, |
| response_format={"type": "json_object"}, |
| ) |
| content = completion.choices[0].message.content or "" |
| payload = parse_model_json(content) |
| return repair_model_action(payload, observation) |
| except Exception as exc: |
| return heuristic_action(observation), "heuristic_fallback", str(exc) |
|
|
|
|
| def choose_action( |
| policy_mode: str, |
| client: OpenAI | None, |
| model_name: str | None, |
| observation: SepsisObservation, |
| ) -> tuple[SepsisAction, str, str | None]: |
| if policy_mode == "heuristic": |
| return heuristic_action(observation), "heuristic", None |
| if policy_mode == "id3qne": |
| return id3qne_action(observation), "id3qne", None |
| if policy_mode == "llm": |
| return model_action(client, model_name, observation) |
| raise ValueError(f"Unsupported policy mode: {policy_mode}") |
|
|
|
|
| def compute_action_entropy(action_history: list[str]) -> float: |
| if not action_history: |
| return 0.0 |
|
|
| action_lengths = [len(action.split()) for action in action_history] |
| counts = np.bincount(action_lengths) |
| nonzero_counts = counts[counts > 0] |
| probabilities = nonzero_counts / len(action_history) |
| entropy = float(-np.sum(probabilities * np.log2(probabilities))) |
| if abs(entropy) < 1e-12: |
| return 0.0 |
| return entropy if entropy > 0 else 0.0 |
|
|
|
|
| def compute_dense_reward_metrics( |
| reward_trace: list[float], |
| step_count: int, |
| max_steps: int, |
| action_history: list[str], |
| ) -> dict[str, float | int]: |
| nonzero_rewards = [reward for reward in reward_trace if reward != 0] |
|
|
| return { |
| "steps_taken": step_count, |
| "total_reward": float(sum(reward_trace)), |
| "reward_count": len(reward_trace), |
| "positive_rewards_count": sum(1 for reward in reward_trace if reward > 0), |
| "reward_density": float(sum(1 for reward in reward_trace if reward > 0) / len(reward_trace)) |
| if reward_trace |
| else 0.0, |
| "avg_reward_per_step": float(np.mean(reward_trace)) if reward_trace else 0.0, |
| "reward_variance": float(np.var(reward_trace)) if reward_trace else 0.0, |
| "max_single_reward": float(max(reward_trace)) if reward_trace else 0.0, |
| "episode_length_efficiency": float(step_count / max_steps) if max_steps else 0.0, |
| "positive_reward_ratio": float( |
| sum(1 for reward in reward_trace if reward > 0) / max(1, len(nonzero_rewards)) |
| ), |
| "unique_actions": len(set(action_history)), |
| "action_entropy": compute_action_entropy(action_history), |
| } |
|
|
|
|
| def run_task( |
| task_id: str, |
| policy_mode: str, |
| client: OpenAI | None, |
| model_name: str | None, |
| episode_index: int, |
| ) -> dict[str, Any]: |
| global EPSILON |
| if task_id == "easy": |
| EPSILON = 0.05 |
| elif task_id == "medium": |
| EPSILON = 0.1 |
| else: |
| EPSILON = 0.15 |
|
|
| env = SepsisTreatmentEnv(base_url=os.getenv("ENV_BASE_URL"), task_id=task_id) |
| result = env.reset() |
| observation = result.observation |
| final_info = result.info |
| reward_trace: list[float] = [] |
| action_history: list[str] = [] |
| policy_sources: Counter[str] = Counter() |
| policy_errors: list[str] = [] |
| success = False |
| step_count = 0 |
|
|
| print( |
| f"[START] task={task_id} env={ENV_NAME} policy={policy_mode} " |
| f"model={model_name or policy_mode}" |
| ) |
|
|
| try: |
| for step_number in range(1, MAX_STEPS_PER_TASK[task_id] + 1): |
| action, source, error_message = choose_action(policy_mode, client, model_name, observation) |
| formatted_action = format_action(action) |
| result = env.step(action) |
| observation = result.observation |
| final_info = result.info |
| reward = float(result.reward or 0.0) |
| reward_trace.append(reward) |
| action_history.append(formatted_action) |
| policy_sources[source] += 1 |
| if error_message: |
| policy_errors.append(error_message) |
| step_count = step_number |
| print( |
| f"[STEP] step={step_number} source={source} action={formatted_action} " |
| f"reward={reward:.2f} done={str(result.done).lower()} error=null" |
| ) |
| if result.done: |
| success = True |
| break |
| except Exception as exc: |
| policy_errors.append(str(exc)) |
| success = False |
| finally: |
| state = env.state() |
| env.close() |
| rewards_repr = ",".join(f"{reward:.2f}" for reward in reward_trace) |
| print(f"[END] success={str(success).lower()} steps={step_count} rewards={rewards_repr}") |
|
|
| metrics = final_info.get("metrics", {}) |
| dense_metrics = compute_dense_reward_metrics( |
| reward_trace=reward_trace, |
| step_count=step_count, |
| max_steps=MAX_STEPS_PER_TASK[task_id], |
| action_history=action_history, |
| ) |
| return { |
| "task_id": task_id, |
| "episode_id": state.episode_id, |
| "score": metrics.get("score", 0.0), |
| "avg_reward": metrics.get("avg_reward", 0.0), |
| "detection": metrics.get("detection", 0.0), |
| "lab_workup": metrics.get("lab_workup", 0.0), |
| "treatment": metrics.get("treatment", 0.0), |
| "timeliness": metrics.get("timeliness", 0.0), |
| "stability": metrics.get("stability", 0.0), |
| "safety": metrics.get("safety", 0.0), |
| "safety_violation_rate": metrics.get("safety_violation_rate", 0.0), |
| "safety_violations": metrics.get("safety_violations", 0), |
| "outcome": metrics.get("outcome", 0.0), |
| "steps": metrics.get("steps", state.step_count), |
| "episode_index": episode_index, |
| "policy_mode": policy_mode, |
| "policy_sources": dict(policy_sources), |
| "policy_error_count": len(policy_errors), |
| "policy_last_error": policy_errors[-1] if policy_errors else None, |
| **dense_metrics, |
| } |
|
|
|
|
| def summarize_runs( |
| all_results: list[dict[str, Any]], |
| per_episode_results: list[dict[str, Any]], |
| requested_policy: str, |
| active_policy: str, |
| model_name: str, |
| ) -> dict[str, Any]: |
| if not all_results: |
| raise ValueError("No results were generated.") |
|
|
| policy_source_totals: Counter[str] = Counter() |
| for result in all_results: |
| policy_source_totals.update(result.get("policy_sources", {})) |
|
|
| total_reward_count = sum(result["reward_count"] for result in all_results) |
| total_positive_rewards = sum(result["positive_rewards_count"] for result in all_results) |
| total_steps = sum(result["steps_taken"] for result in all_results) |
| total_safety_violations = sum(result["safety_violations"] for result in all_results) |
|
|
| return { |
| "results": all_results, |
| "episode_summaries": per_episode_results, |
| "mean_score": round(float(np.mean([item["score"] for item in all_results])), 4), |
| "score_std": round(float(np.std([item["score"] for item in all_results])), 4), |
| "mean_score_std": round(float(np.std([item["mean_score"] for item in per_episode_results])), 4) |
| if per_episode_results |
| else 0.0, |
| "mean_reward_density": round(float(np.mean([item["reward_density"] for item in all_results])), 4), |
| "global_reward_density": round(float(total_positive_rewards / total_reward_count), 4) |
| if total_reward_count |
| else 0.0, |
| "mean_avg_reward_per_step": round(float(np.mean([item["avg_reward_per_step"] for item in all_results])), 4), |
| "mean_reward_variance": round(float(np.mean([item["reward_variance"] for item in all_results])), 4), |
| "mean_positive_reward_ratio": round(float(np.mean([item["positive_reward_ratio"] for item in all_results])), 4), |
| "mean_action_entropy": round(float(np.mean([item["action_entropy"] for item in all_results])), 4), |
| "safety_violation_rate": round(float(total_safety_violations / total_steps), 4) if total_steps else 0.0, |
| "total_runs": len(all_results), |
| "episodes": len(per_episode_results), |
| "requested_policy": requested_policy, |
| "active_policy": active_policy, |
| "model_name": model_name, |
| "policy_source_totals": dict(policy_source_totals), |
| } |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| OUTPUT_DIR.mkdir(exist_ok=True) |
| api_base_url = os.getenv("API_BASE_URL", DEFAULT_API_BASE_URL) |
| model_name = os.getenv("MODEL_NAME", DEFAULT_MODEL_NAME) |
| api_key = os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN") |
|
|
| llm_client = None |
| if api_base_url and model_name and api_key: |
| llm_client = OpenAI(base_url=api_base_url, api_key=api_key) |
|
|
| if args.episodes < 1: |
| raise SystemExit("--episodes must be at least 1.") |
|
|
| if args.model == "llm" and llm_client is None: |
| raise SystemExit("LLM mode requires OPENAI_API_KEY or HF_TOKEN plus API_BASE_URL and MODEL_NAME.") |
|
|
| active_policy = args.model |
| if args.model == "auto": |
| active_policy = "llm" if llm_client is not None else "heuristic" |
|
|
| all_results: list[dict[str, Any]] = [] |
| episode_summaries: list[dict[str, Any]] = [] |
| for episode_index in range(args.episodes): |
| episode_results = [ |
| run_task(task_id, active_policy, llm_client, model_name, episode_index) for task_id in TASK_IDS |
| ] |
| all_results.extend(episode_results) |
| episode_steps = sum(item["steps_taken"] for item in episode_results) |
| episode_safety_violations = sum(item["safety_violations"] for item in episode_results) |
| episode_summaries.append( |
| { |
| "episode_index": episode_index, |
| "mean_score": round(float(np.mean([item["score"] for item in episode_results])), 4), |
| "mean_reward_density": round(float(np.mean([item["reward_density"] for item in episode_results])), 4), |
| "safety_violation_rate": round(float(episode_safety_violations / episode_steps), 4) |
| if episode_steps |
| else 0.0, |
| } |
| ) |
|
|
| summary = summarize_runs( |
| all_results=all_results, |
| per_episode_results=episode_summaries, |
| requested_policy=args.model, |
| active_policy=active_policy, |
| model_name=model_name if active_policy == "llm" else active_policy, |
| ) |
| output_path = Path(args.output) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| output_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|