#!/usr/bin/env python3 """Local agent runner for EmergencyEnv. This script acts as an agent only: - reset env - choose action from observation - step env - log trajectory """ from __future__ import annotations import argparse import json import math import os import random from pathlib import Path from datetime import datetime, timezone from typing import Union, TYPE_CHECKING, Optional, cast from app.environment.core import EmergencyEnv from app.models.action import Action if TYPE_CHECKING: from openai import OpenAI as OpenAIClient else: OpenAIClient = None try: from openai import OpenAI except Exception: # pragma: no cover - fallback for missing optional dependency OpenAI = None TASK_ORDER = ["acde_easy", "acde_medium", "acde_hard"] LEVEL_TO_TASK = { "low": "acde_easy", "medium": "acde_medium", "high": "acde_hard", } RANDOM_LEVELS = ("medium", "high") RANDOM_LEVEL_WEIGHTS = (0.25, 0.75) BASE_SPEED_KMH = 60.0 TRAFFIC_FACTOR = {"low": 1.0, "medium": 0.6, "high": 0.3} LEARNING_ARCHIVE_PATH = Path(__file__).resolve().parent / "data" / "learning_archive.json" LEARNING_ARCHIVE_VERSION = 2 DEFAULT_API_BASE_URL = "https://api-inference.huggingface.co/v1" DEFAULT_MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct" REQUIRED_ENV_VARS = ("HF_TOKEN",) STRICT_SCORE_MIN = 0.001 STRICT_SCORE_MAX = 0.999 def clamp_strict_score(value: float) -> float: """Clamp score-like outputs to the strict open interval (0, 1).""" return max(STRICT_SCORE_MIN, min(STRICT_SCORE_MAX, float(value))) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="EmergencyEnv agent runner") parser.add_argument("--mode", choices=["single", "full"], default="full") parser.add_argument("--task", choices=TASK_ORDER, default=None) parser.add_argument("--level", choices=["low", "medium", "high"], default=None) parser.add_argument("--seed", type=int, default=None) parser.add_argument("--episodes", type=int, default=1) parser.add_argument("--train-episodes", type=int, default=0) parser.add_argument("--train-same-seed", action="store_true") parser.add_argument( "--memory-file", default=str(Path(__file__).resolve().parent / "data" / "learning_memory.json"), ) return parser.parse_args() def emit_structured(tag: str, payload: dict) -> None: print(f"[{tag}] " + json.dumps(payload, ensure_ascii=True, separators=(",", ":"))) def runtime_llm_config() -> dict[str, str]: return { "API_BASE_URL": os.getenv("API_BASE_URL", DEFAULT_API_BASE_URL).strip(), "MODEL_NAME": os.getenv("MODEL_NAME", DEFAULT_MODEL_NAME).strip(), "HF_TOKEN": os.getenv("HF_TOKEN", "").strip(), } def require_llm_config() -> tuple[OpenAIClient, str]: config = runtime_llm_config() missing = [name for name, value in config.items() if not value] if missing: raise SystemExit( "Missing required environment variables: " + ", ".join(missing) + ". Set HF_TOKEN before running inference.py" ) if OpenAI is None: raise SystemExit("openai package is required for inference.py LLM rationale generation.") client = OpenAI(base_url=config["API_BASE_URL"], api_key=config["HF_TOKEN"], timeout=8.0) return client, config["MODEL_NAME"] def llm_rationale( client: Union[OpenAIClient, None], model_name: str, observation: dict, chosen: dict, strategy: str, ) -> str: fallback = ( f"Selected {chosen['hospital_id']} by {strategy}; " f"score={chosen['policy_score']:.3f}, traffic={chosen['traffic']}, icu={chosen['icu']}" ) if client is None: return fallback try: prompt = ( "You are an emergency routing agent. Return one short sentence rationale " "for the selected hospital. Keep it under 25 words.\n" f"task={observation.get('task_id')} difficulty={observation.get('scenario_difficulty')} " f"step={observation.get('step')} patient={observation.get('patient_condition')} " f"required={observation.get('required_specialization')} " f"selected={chosen['hospital_id']} score={chosen['policy_score']:.3f} " f"distance={chosen['distance_km']:.1f}km traffic={chosen['traffic']} icu={chosen['icu']} " f"strategy={strategy}" ) completion = client.chat.completions.create( model=model_name, messages=[ {"role": "system", "content": "Generate concise emergency triage rationale."}, {"role": "user", "content": prompt}, ], temperature=0.0, max_tokens=60, ) text = (completion.choices[0].message.content or "").strip() if not text: return fallback return " ".join(text.split())[:180] except Exception: return fallback def normalize_seed(raw_value: int | str) -> int: """Normalize arbitrary numeric/text input into a deterministic positive seed.""" if isinstance(raw_value, int): value = raw_value else: text = str(raw_value).strip() try: value = int(text) except ValueError: # Deterministic fallback for non-numeric input. value = sum((idx + 1) * ord(ch) for idx, ch in enumerate(text)) normalized = abs(value) % 1_000_000_000 return normalized if normalized != 0 else 202601 def ask_seed_if_missing(seed: int | None) -> int: if seed is not None: return normalize_seed(seed) # No CLI seed means a fresh randomized run. return normalize_seed(random.SystemRandom().randint(1, 999_999_999)) def ask_level_if_missing(level: str | None) -> str: if level in LEVEL_TO_TASK: return level # No CLI level means pick a random non-easy difficulty. return random.choices( RANDOM_LEVELS, weights=RANDOM_LEVEL_WEIGHTS, k=1, )[0] def append_trajectory_log(entry: dict) -> None: path = Path(__file__).resolve().parent / "data" / "trajectory_history.jsonl" path.parent.mkdir(parents=True, exist_ok=True) with path.open("a", encoding="utf-8") as fp: fp.write(json.dumps(entry, ensure_ascii=True) + "\n") def load_learning_archive() -> dict: LEARNING_ARCHIVE_PATH.parent.mkdir(parents=True, exist_ok=True) if not LEARNING_ARCHIVE_PATH.exists(): return {"version": LEARNING_ARCHIVE_VERSION, "profiles": {}, "episodes": []} try: payload_text = LEARNING_ARCHIVE_PATH.read_text(encoding="utf-8-sig").strip() payload = json.loads(payload_text) if payload_text else {} except json.JSONDecodeError: return {"version": LEARNING_ARCHIVE_VERSION, "profiles": {}, "episodes": []} if not isinstance(payload, dict): return {"version": LEARNING_ARCHIVE_VERSION, "profiles": {}, "episodes": []} if payload.get("version") != LEARNING_ARCHIVE_VERSION: return { "version": LEARNING_ARCHIVE_VERSION, "profiles": {}, "episodes": payload.get("episodes", [])[-500:] if isinstance(payload.get("episodes", []), list) else [], } payload.setdefault("version", LEARNING_ARCHIVE_VERSION) payload.setdefault("profiles", {}) payload.setdefault("episodes", []) return payload def save_learning_archive(archive: dict) -> None: LEARNING_ARCHIVE_PATH.parent.mkdir(parents=True, exist_ok=True) LEARNING_ARCHIVE_PATH.write_text(json.dumps(archive, indent=2, ensure_ascii=True), encoding="utf-8") def profile_key(seed: int, task_id: str) -> str: return f"{seed}|{task_id}" def _merge_step_stats(primary: dict, secondary: dict) -> dict: merged: dict = {} for step_key in set(primary.keys()) | set(secondary.keys()): merged[step_key] = {} step_primary = primary.get(step_key, {}) step_secondary = secondary.get(step_key, {}) for hospital_id in set(step_primary.keys()) | set(step_secondary.keys()): a = step_primary.get(hospital_id, {}) b = step_secondary.get(hospital_id, {}) count = int(a.get("count", 0)) + int(b.get("count", 0)) accepted = int(a.get("accepted", 0)) + int(b.get("accepted", 0)) partial = int(a.get("partial", 0)) + int(b.get("partial", 0)) rejected = int(a.get("rejected", 0)) + int(b.get("rejected", 0)) total_reward = float(a.get("total_reward", 0.0)) + float(b.get("total_reward", 0.0)) merged[step_key][hospital_id] = { "count": count, "success": int(a.get("success", 0)) + int(b.get("success", 0)), "accepted": accepted, "partial": partial, "rejected": rejected, "total_reward": total_reward, "avg_reward": (total_reward / max(1, count)), "success_rate": (accepted / max(1, count)), "last_status": a.get("last_status") or b.get("last_status"), "last_reason": a.get("last_reason") or b.get("last_reason"), } return merged def build_learning_profile( archive: dict, seed: int, task_id: str, required_specialization: str | None = None, ) -> dict | None: profiles = archive.get("profiles", {}) key = profile_key(seed, task_id) exact = profiles.get(key) if not exact: return None # Strict scope: learn only from same seed + same level/task. return { "attempts": int(exact.get("attempts", 0)), "best_score": float(exact.get("best_score", 0.0)), "best_actions": list(exact.get("best_actions", [])), "step_stats": exact.get("step_stats", {}), "best_scenario_name": exact.get("best_scenario_name"), "last_scenario_name": exact.get("last_scenario_name"), "source": "exact-only", } def _difficulty_policy_params(difficulty: str) -> tuple[float, float]: if difficulty == "easy": return 0.07, 0.18 if difficulty == "medium": return 0.16, 0.32 return 0.26, 0.44 def _sample_softmax(candidates: list[dict], key: str, temperature: float, rng: random.Random) -> dict: logits = [item[key] / max(temperature, 1e-6) for item in candidates] max_logit = max(logits) exps = [math.exp(v - max_logit) for v in logits] total = sum(exps) probs = [e / total for e in exps] roll = rng.random() cdf = 0.0 for item, prob in zip(candidates, probs): cdf += prob if roll <= cdf: return item return candidates[-1] def memory_score_for_hospital( hospital_id: str, memory_snapshot: dict, learning_profile: dict | None = None, step_number: int | None = None, ) -> float: entry = memory_snapshot.get(hospital_id) if not entry: return 0.5 success = int(entry.get("accepted", entry.get("success", 0))) fail = int(entry.get("rejected", entry.get("fail", 0))) avg = float(entry.get("avg", 0.0)) total = success + fail if total <= 0: return 0.5 success_rate = success / total # Fix 3: reliability-first memory scoring. value = (0.6 * success_rate) + (0.4 * avg) recent_failed = False if learning_profile and step_number is not None: step_stats = learning_profile.get("step_stats", {}).get(str(step_number), {}) hospital_stats = step_stats.get(hospital_id) if hospital_stats: step_avg = float(hospital_stats.get("avg_reward", 0.0)) step_success = float(hospital_stats.get("success_rate", 0.0)) step_count = int(hospital_stats.get("count", 0)) value += min(0.20, (step_avg * 0.10) + (step_success * 0.08) + min(step_count, 5) * 0.01) recent_failed = str(hospital_stats.get("last_status", "")).upper() == "REJECTED" if recent_failed: value -= 0.3 return max(0.0, min(1.0, value)) def score_hospitals(observation: dict, learning_profile: dict | None = None) -> list[dict]: failed = set(observation.get("failed_hospitals", [])) recent_failed = set(observation.get("recent_failed_hospitals", [])) visited = set(observation.get("visited_hospitals", [])) memory_snapshot = observation.get("memory_snapshot", {}) previous_action = observation.get("previous_action") last_arrival = observation.get("last_arrival_outcome") or {} last_status = str(last_arrival.get("status", "")).lower() scored: list[dict] = [] initial_limit = float(observation.get("initial_critical_time_limit_minutes", observation["critical_time_limit_minutes"])) remaining_time = float(observation.get("remaining_time_minutes", observation["critical_time_limit_minutes"])) urgency = 1.0 - min(1.0, max(0.0, remaining_time / max(initial_limit, 1e-6))) patient_condition = observation.get("patient_condition", "").lower() critical_patient = patient_condition in {"critical", "unstable"} required_specialization = str(observation.get("required_specialization", "")) scenario_name = str(observation.get("scenario_name", "")) step_number = int(observation.get("step", 1)) difficulty = str(observation.get("scenario_difficulty", "medium")) attempts = int(learning_profile.get("attempts", 0)) if learning_profile else 0 preferred_route = [] if learning_profile: preferred_route = list(learning_profile.get("best_actions", [])) for hospital in observation.get("hospitals", []): traffic_factor = TRAFFIC_FACTOR[hospital["traffic"]] speed_kmh = BASE_SPEED_KMH * traffic_factor travel_time = (hospital["distance_km"] / max(speed_kmh, 1e-6)) * 60.0 distance_score = max(0.0, min(1.0, 1.0 - hospital["distance_km"] / 20.0)) icu_score = 1.0 if hospital["icu"] == "available" else 0.55 mem_score = memory_score_for_hospital( hospital["hospital_id"], memory_snapshot, learning_profile=learning_profile, step_number=step_number, ) memory_scenario = "" if learning_profile: memory_scenario = str( learning_profile.get("best_scenario_name") or learning_profile.get("last_scenario_name") or "" ) if memory_scenario and scenario_name and memory_scenario != scenario_name: mem_score *= 0.5 spec_match = ( hospital["specialization"] == observation["required_specialization"] or hospital["specialization"] == "general" or observation["required_specialization"] == "general" ) exact_spec_match = hospital["specialization"] == observation["required_specialization"] general_fallback = ( hospital["specialization"] == "general" and observation["required_specialization"] != "general" ) rejected_penalty = 0.40 if hospital["hospital_id"] in failed else 0.0 revisit_penalty = 0.14 if hospital["hospital_id"] in visited else 0.0 partial_repeat_penalty = ( 0.32 if last_status == "partial" and hospital["hospital_id"] == previous_action else 0.0 ) critical_unknown_penalty = 0.17 if critical_patient and hospital["icu"] == "unknown" else 0.03 traffic_penalty = 0.10 if hospital["traffic"] == "high" else 0.04 if hospital["traffic"] == "medium" else 0.0 if critical_patient and general_fallback: spec_penalty = {"easy": 0.08, "medium": 0.16, "hard": 0.26}.get(difficulty, 0.16) if attempts >= 5: spec_penalty += 0.06 else: spec_penalty = 0.0 spec_bonus = 0.16 if exact_spec_match else (0.08 if spec_match else 0.0) urgency_boost = urgency * (0.18 + max(0.0, 0.25 - travel_time / 100.0)) step_route_bonus = 0.0 if step_number - 1 < len(preferred_route) and preferred_route[step_number - 1] == hospital["hospital_id"]: step_route_bonus = 0.16 score = ( (icu_score * 0.30) + (distance_score * 0.18) + (traffic_factor * 0.14) + (mem_score * 0.24) + spec_bonus + urgency_boost + step_route_bonus - rejected_penalty - revisit_penalty - partial_repeat_penalty - spec_penalty - critical_unknown_penalty - traffic_penalty ) if hospital["hospital_id"] == previous_action and last_status == "rejected": score *= 0.01 if hospital["hospital_id"] in recent_failed: score *= 0.2 if hospital["specialization"] != required_specialization: if patient_condition == "critical": score *= 0.15 else: score *= 0.4 elif patient_condition == "critical": score *= 1.5 # Hard realism penalties to align policy scoring with validator outcomes. if hospital["specialization"] != required_specialization: score -= 0.6 if critical_patient and hospital["icu"] == "unknown": score -= 0.5 if critical_patient and hospital["traffic"] == "high": score -= 0.3 # Confidence-style risk multiplier keeps risky options from looking deceptively good. risk_factor = 1.0 if hospital["icu"] == "unknown": risk_factor *= 0.6 if not spec_match: risk_factor *= 0.5 if critical_patient and hospital["traffic"] == "high": risk_factor *= 0.7 score *= risk_factor # Reduce memory dominance in final decision score. memory_weight = 0.15 current_score_weight = 0.85 if step_number == 1: memory_weight = 0.1 current_score_weight = 0.9 base_current_score = score confidence_score = max(0.0, min(1.0, base_current_score)) effective_memory_score = mem_score in_best_route = hospital["hospital_id"] in preferred_route if in_best_route and confidence_score < 0.6: effective_memory_score = 0.0 if confidence_score < 0.2: effective_memory_score = 0.0 score = (current_score_weight * base_current_score) + (memory_weight * effective_memory_score) scored.append( { "hospital_id": hospital["hospital_id"], "icu": hospital["icu"], "distance_km": hospital["distance_km"], "traffic": hospital["traffic"], "specialization": hospital["specialization"], "travel_time": travel_time, "memory_score": mem_score, "policy_score": max(0.0, min(1.0, score)), "specialization_match": spec_match, "tie_break_score": ( (distance_score * 0.35) + (traffic_factor * 0.35) + (icu_score * 0.20) + (0.10 if spec_match else 0.0) ), } ) scored.sort(key=lambda item: item["policy_score"], reverse=True) if scored: min_score = min(item["policy_score"] for item in scored) max_score = max(item["policy_score"] for item in scored) spread = max_score - min_score if spread > 1e-9: for item in scored: normalized = (item["policy_score"] - min_score) / (spread + 1e-6) if normalized < 0.2: jitter_seed = ( int(observation.get("seed", 0)) + (step_number * 131) + sum(ord(ch) for ch in item["hospital_id"]) ) jitter_rng = random.Random(jitter_seed) normalized *= jitter_rng.uniform(0.3, 0.7) item["policy_score"] = max(0.0, min(1.0, normalized)) elif max_score > 0: for item in scored: normalized = item["policy_score"] / max_score if normalized < 0.2: jitter_seed = ( int(observation.get("seed", 0)) + (step_number * 131) + sum(ord(ch) for ch in item["hospital_id"]) ) jitter_rng = random.Random(jitter_seed) normalized *= jitter_rng.uniform(0.3, 0.7) item["policy_score"] = max(0.0, min(1.0, normalized)) else: tie_min = min(item.get("tie_break_score", 0.0) for item in scored) tie_max = max(item.get("tie_break_score", 0.0) for item in scored) tie_spread = tie_max - tie_min if tie_spread > 1e-9: for item in scored: normalized = (item.get("tie_break_score", 0.0) - tie_min) / (tie_spread + 1e-6) if normalized < 0.2: jitter_seed = ( int(observation.get("seed", 0)) + (step_number * 131) + sum(ord(ch) for ch in item["hospital_id"]) ) jitter_rng = random.Random(jitter_seed) normalized *= jitter_rng.uniform(0.3, 0.7) item["policy_score"] = max(0.0, min(1.0, normalized)) else: for item in scored: item["policy_score"] = 0.0 # Remove hard-zero scores and normalize to probability-like values. for item in scored: if item["policy_score"] <= 0.0: jitter_seed = ( int(observation.get("seed", 0)) + (step_number * 173) + sum(ord(ch) for ch in item["hospital_id"]) ) jitter_rng = random.Random(jitter_seed) if critical_patient and required_specialization != "general": if item.get("specialization") == required_specialization: item["policy_score"] = jitter_rng.uniform(0.08, 0.18) else: item["policy_score"] = jitter_rng.uniform(0.001, 0.01) else: item["policy_score"] = jitter_rng.uniform(0.05, 0.15) total_score = sum(item["policy_score"] for item in scored) if total_score > 0: for item in scored: item["policy_score"] = item["policy_score"] / (total_score + 1e-6) else: uniform = 1.0 / len(scored) for item in scored: item["policy_score"] = uniform # Final clinical-priority pass: in critical non-general cases, # exact specialization should dominate unless unavailable. if critical_patient and required_specialization != "general": for item in scored: if item.get("specialization") == required_specialization: item["policy_score"] *= 1.5 else: item["policy_score"] *= 0.15 boosted_total = sum(item["policy_score"] for item in scored) if boosted_total > 0: for item in scored: item["policy_score"] = item["policy_score"] / boosted_total for item in scored: raw_score = float(item["policy_score"]) normalized_score = raw_score / (1.0 + abs(raw_score)) # Keep a small floor so no action is fully eliminated from exploration. if normalized_score < 0.01: jitter_seed = ( int(observation.get("seed", 0)) + (step_number * 211) + sum(ord(ch) for ch in item["hospital_id"]) ) jitter_rng = random.Random(jitter_seed) normalized_score = jitter_rng.uniform(0.01, 0.03) item["policy_score"] = normalized_score scored.sort(key=lambda item: item["policy_score"], reverse=True) for item in scored: item.pop("tie_break_score", None) return scored def choose_hospital( scored: list[dict], observation: dict, rng: random.Random, learning_profile: dict | None = None, ) -> tuple[dict, str]: difficulty = observation.get("scenario_difficulty", "medium") epsilon, temperature = _difficulty_policy_params(difficulty) failed = set(observation.get("failed_hospitals", [])) recent_failed = set(observation.get("recent_failed_hospitals", [])) visited = set(observation.get("visited_hospitals", [])) previous_action = observation.get("previous_action") selected_hospital_id = observation.get("selected_hospital_id") visited_sequence = observation.get("visited_hospitals", []) or [] recent_hospital = previous_action or selected_hospital_id or (visited_sequence[-1] if visited_sequence else None) last_arrival = observation.get("last_arrival_outcome") or {} last_status = str(last_arrival.get("status", "")).lower() last_reason = str(last_arrival.get("reason", "")).lower() is_rerouting_phase = str(observation.get("ambulance_status", "")).lower() == "rerouting" # Cooldown logic: avoid recently failed hospitals first, then avoid visited when alternatives exist. candidates = [ item for item in scored if item["hospital_id"] not in recent_failed and item["hospital_id"] not in visited ] if not candidates: candidates = [item for item in scored if item["hospital_id"] not in recent_failed] if not candidates: # Last-resort fallback: if every hospital has failed already, avoid immediate retry. candidates = list(scored) if (last_status == "rejected" or is_rerouting_phase) and recent_hospital: redirected = [item for item in candidates if item["hospital_id"] != recent_hospital] if redirected: candidates = redirected step_number = int(observation.get("step", 1)) attempts = int(learning_profile.get("attempts", 0)) if learning_profile else 0 required_specialization = str(observation.get("required_specialization", "")) critical_patient = observation.get("patient_condition", "").lower() in {"critical", "unstable"} # Hard realism rule: never immediately retry the hospital that just rejected the patient. if (last_status == "rejected" or is_rerouting_phase) and recent_hospital: immediate_retry_block = [item for item in candidates if item["hospital_id"] != recent_hospital] if immediate_retry_block: candidates = immediate_retry_block elif len(candidates) == 1 and candidates[0]["hospital_id"] == recent_hospital: fallback_any = [item for item in scored if item["hospital_id"] != recent_hospital] if fallback_any: candidates = fallback_any # In critical non-general cases, prioritize exact specialization when available. if critical_patient and required_specialization != "general": exact_spec_candidates = [ item for item in candidates if item["specialization"] == required_specialization ] if exact_spec_candidates: candidates = exact_spec_candidates if step_number == 1: policy_mode = "safe" elif last_status == "rejected": policy_mode = "risk-aware" else: policy_mode = "balanced" safe_weight = 1.0 if policy_mode == "safe": safe_weight *= 0.8 epsilon *= 0.6 temperature *= 0.8 elif policy_mode == "risk-aware": epsilon *= 1.1 temperature *= 0.9 # Within-episode learning from concrete failure reasons. if "wrong hospital specialization" in last_reason: strict_spec = [ item for item in candidates if item["specialization"] == observation.get("required_specialization") ] if strict_spec: candidates = strict_spec if "icu unavailable" in last_reason: icu_known = [item for item in candidates if item["icu"] == "available"] if icu_known: candidates = icu_known if "specialist" in last_reason: strict_spec = [ item for item in candidates if item["specialization"] == observation.get("required_specialization") ] if strict_spec: candidates = strict_spec if "overloaded" in last_reason: non_high_traffic = [item for item in candidates if item["traffic"] != "high"] if non_high_traffic: candidates = non_high_traffic if "delay" in last_reason: candidates = sorted(candidates, key=lambda item: item["distance_km"]) def learned_utility(item: dict) -> float: base = float(item.get("policy_score", 0.0)) if not learning_profile: return base step_stats = learning_profile.get("step_stats", {}).get(str(step_number), {}) stats = step_stats.get(item["hospital_id"], {}) count = int(stats.get("count", 0)) if count <= 0: exploration_bonus = 0.22 * math.sqrt(max(1.0, math.log(attempts + 2.0))) return base + exploration_bonus avg_reward = float(stats.get("avg_reward", 0.0)) success_rate = float(stats.get("success_rate", 0.0)) rejected = int(stats.get("rejected", 0)) rejection_rate = rejected / max(1, count) exploration_bonus = 0.18 * math.sqrt(max(0.0, math.log(attempts + 2.0) / (count + 1.0))) # Real-data utility: reward trend + success rate - rejection risk + exploration bonus. historical_weight = 0.35 historical_weight *= 0.6 historical_bonus = (avg_reward * historical_weight) + (success_rate * 0.30) - (rejection_rate * 0.22) if item["hospital_id"] in recent_failed: historical_bonus = 0.0 return base + historical_bonus + exploration_bonus def pick_improvement_candidate(route_choice_id: str | None) -> dict | None: if not candidates: return None ranked = sorted(candidates, key=learned_utility, reverse=True) if route_choice_id is None: return ranked[0] for item in ranked: if item["hospital_id"] != route_choice_id: return item return ranked[0] def enforce_score_guard(chosen: dict, strategy: str) -> tuple[dict, str]: # Absolute next-step guard: never pick the same hospital immediately after a rejection. if last_status == "rejected" and previous_action and chosen.get("hospital_id") == previous_action: alternatives = [item for item in scored if item["hospital_id"] != previous_action] if alternatives: rerouted = max(alternatives, key=lambda item: float(item.get("policy_score", 0.0))) return rerouted, strategy + " + immediate-retry block" # Global guardrail: when a score gap is very large, prefer best option most # of the time while preserving some exploration. globally_eligible = [ item for item in scored if item["hospital_id"] not in recent_failed and not ( (last_status == "rejected" or is_rerouting_phase) and recent_hospital and item["hospital_id"] == recent_hospital ) ] if not globally_eligible: globally_eligible = list(scored) if globally_eligible: best_global = max(globally_eligible, key=lambda item: float(item.get("policy_score", 0.0))) chosen_score = float(chosen.get("policy_score", 0.0)) best_global_score = float(best_global.get("policy_score", 0.0)) # Cooldown hard guard: never immediately retry the just-failed hospital. if (last_status == "rejected" or is_rerouting_phase) and recent_hospital: if chosen.get("hospital_id") == recent_hospital: alternatives = [ item for item in scored if item["hospital_id"] != recent_hospital and item["hospital_id"] not in recent_failed ] if not alternatives: alternatives = [item for item in scored if item["hospital_id"] != recent_hospital] if alternatives: rerouted = max(alternatives, key=lambda item: float(item.get("policy_score", 0.0))) return rerouted, strategy + " + cooldown reroute" if chosen_score < (best_global_score * 0.6): return best_global, strategy + " + anti-stupidity guard" if (best_global_score - chosen_score) > 0.25 and rng.random() < 0.75: return best_global, strategy + " + score-gap guard" return chosen, strategy # Learning-driven fail guard: avoid hospitals that repeatedly fail at this exact step. if learning_profile: step_stats = learning_profile.get("step_stats", {}).get(str(step_number), {}) guard_blocked: set[str] = set() for hospital_id, stats in step_stats.items(): count = int(stats.get("count", 0)) success_rate = float(stats.get("success_rate", 0.0)) rejected = int(stats.get("rejected", 0)) if count >= 2 and success_rate <= 0.0 and rejected >= 2: guard_blocked.add(hospital_id) guarded_candidates = [item for item in candidates if item["hospital_id"] not in guard_blocked] if guarded_candidates: candidates = guarded_candidates # As attempts increase, reduce randomness and rely on learned utility. if attempts >= 3: epsilon *= 0.35 temperature *= 0.70 # Same seed + same task policy: # evaluate route combinations across all steps, not just one-step mutations. if learning_profile and policy_mode != "risk-aware": best_route = list(learning_profile.get("best_actions", [])) if step_number - 1 < len(best_route): baseline_id = best_route[step_number - 1] ranked = sorted(candidates, key=learned_utility, reverse=True) baseline_candidate = next((item for item in ranked if item["hospital_id"] == baseline_id), None) alternatives = [item for item in ranked if item["hospital_id"] != baseline_id] top_candidate = ranked[0] if ranked else None if ( step_number == 1 and baseline_candidate is not None and top_candidate is not None and float(baseline_candidate.get("policy_score", 0.0)) < float(top_candidate.get("policy_score", 0.0)) ): baseline_candidate = None alternatives = alternatives[: min(3, len(alternatives))] if attempts >= 1: # Mixed-radix route search: each run selects a step-wise digit. # digit 0 => keep baseline for this step, 1/2 => try alternative ranks. combo_index = max(0, attempts - 1) digit = (combo_index // (3 ** max(0, step_number - 1))) % 3 if digit == 0 and baseline_candidate is not None: return enforce_score_guard(baseline_candidate, "best-route retain") alt_rank = digit - 1 if alt_rank >= 0 and alt_rank < len(alternatives): return enforce_score_guard(alternatives[alt_rank], f"combination search step-{step_number} alt-{alt_rank + 1}") if baseline_candidate is not None: return enforce_score_guard(baseline_candidate, "best-route retain") if attempts >= 6: ranked = sorted(candidates, key=learned_utility, reverse=True) top_pool = ranked[: min(3, len(ranked))] return enforce_score_guard(_sample_softmax(top_pool, "policy_score", max(0.08, temperature * 0.85), rng), "learned utility exploit") if learning_profile and policy_mode == "safe": preferred_route = list(learning_profile.get("best_actions", [])) if step_number - 1 < len(preferred_route): preferred_hospital = preferred_route[step_number - 1] preferred_candidate = next((item for item in candidates if item["hospital_id"] == preferred_hospital), None) if preferred_candidate is not None: profile_score = float(learning_profile.get("best_score", 0.0)) if (profile_score * safe_weight) >= 0.85 or len(candidates) == 1: return enforce_score_guard(preferred_candidate, "learned best path") # If last outcome was partial, force trying a different hospital when possible. if last_status == "partial" and previous_action: redirected = [item for item in candidates if item["hospital_id"] != previous_action] if redirected: candidates = redirected # After partial treatment, reduce random exploration and favor safer follow-up routing. epsilon = min(epsilon, 0.04) temperature = min(temperature, 0.24) critical = observation.get("patient_condition", "").lower() in {"critical", "unstable"} strategy = f"{policy_mode} policy" if critical and policy_mode in {"safe", "balanced"}: confirmed = [item for item in candidates if item["icu"] == "available"] if confirmed: candidates = confirmed strategy = f"{policy_mode} policy + critical triage" if len(candidates) > 1 and rng.random() < 0.15: ranked = sorted(candidates, key=learned_utility, reverse=True) top_k = ranked[: min(3, len(ranked))] return enforce_score_guard(rng.choice(top_k), strategy + " + guided-exploration") if len(candidates) > 1: # Utility-aware candidate ordering for softmax sampling. ranked = sorted(candidates, key=learned_utility, reverse=True) chosen = _sample_softmax(ranked, "policy_score", temperature, rng) return enforce_score_guard(chosen, strategy) return enforce_score_guard(candidates[0], strategy) def print_options(scored: list[dict]) -> None: print(f"Hospital options ({len(scored)} total):") for idx, item in enumerate(scored, start=1): print( f" [{idx}] {item['hospital_id']} | {item['distance_km']:.1f} km | ICU {item['icu']} | " f"traffic {item['traffic']} | specialty {item['specialization']} | score {item['policy_score']:.3f}" ) def run_episode( env: EmergencyEnv, task_id: str, seed: int, archive: dict | None = None, llm_client: object | None = None, model_name: str | None = None, ) -> dict: observation_model = env.reset(seed=seed, task_id=task_id) observation = observation_model.model_dump() learning_profile = None if archive is not None: learning_profile = build_learning_profile( archive, seed, task_id, required_specialization=str(observation.get("required_specialization", "")) or None, ) print("\n" + "=" * 72) print(f"Scenario: {observation['scenario_name']}") print(f"Task: {task_id} | Difficulty: {observation['scenario_difficulty']} | Seed: {seed}") print(f"Patient condition: {observation['patient_condition']}") print(f"Required specialization: {observation['required_specialization']}") print("Objective: admit patient successfully (no fixed deadline window)") print("=" * 72) print(f"[START] task={task_id} env=acde-openenv model={model_name or 'none'}", flush=True) if learning_profile: print( f"Learning memory: best historical score {float(learning_profile.get('best_score', 0.0)):.3f} " f"across {int(learning_profile.get('attempts', 0))} attempts" ) if learning_profile.get("best_actions"): print(f"Best known route: {' -> '.join(learning_profile['best_actions'])}") total_reward = 0.0 all_rewards = [] steps = 0 done = False previous_policy_hospital_id: str | None = None previous_policy_outcome: str | None = None attempt_index = int(learning_profile.get("attempts", 0)) if learning_profile else 0 # Keep scenario deterministic by seed, but vary policy exploration across retries. rng = random.Random(seed + (attempt_index * 7919)) step_records: list[dict] = [] while not done: steps += 1 print(f"\nStep {observation['step']} | phase={observation['ambulance_status']}") scored = score_hospitals(observation, learning_profile=learning_profile) chosen, strategy = choose_hospital(scored, observation, rng, learning_profile=learning_profile) # Final policy-level guard: no immediate retry of the same hospital after rejection. if previous_policy_outcome == "REJECTED" and previous_policy_hospital_id and chosen["hospital_id"] == previous_policy_hospital_id: alternatives = [item for item in scored if item["hospital_id"] != previous_policy_hospital_id] if alternatives: chosen = max(alternatives, key=lambda item: float(item.get("policy_score", 0.0))) strategy = strategy + " + immediate-retry override" print_options(scored) rationale = llm_rationale(cast(Optional[OpenAIClient], llm_client), model_name or "", observation, chosen, strategy) print(f"Decision: {chosen['hospital_id']} ({strategy})") step_result = env.step( Action( step=observation["step"], hospital_id=chosen["hospital_id"], rationale=rationale, ) ) next_obs_model = step_result["observation"] reward = float(step_result["reward"]) all_rewards.append(reward) done = bool(step_result["done"]) info = step_result.get("info", {}) or {} next_observation = next_obs_model.model_dump() total_reward += reward outcome = info.get("outcome", {}) status = str(outcome.get("status", "partial")).upper() reason = str(outcome.get("reason", "No reason provided")) previous_policy_hospital_id = chosen["hospital_id"] previous_policy_outcome = status print(f"Outcome: {status}") print(f"Reason: {reason}") print(f"Reward: {reward:.3f}") error_val = str(info.get("last_action_error")) if info.get("last_action_error") else "null" print(f"[STEP] step={observation.get('step')} action={chosen['hospital_id']} reward={reward:.2f} done={str(done).lower()} error={error_val}", flush=True) append_trajectory_log( { "seed": seed, "task": task_id, "difficulty": observation.get("scenario_difficulty"), "step": observation.get("step"), "state": { "patient_condition": observation.get("patient_condition"), "remaining_time_minutes": observation.get("remaining_time_minutes"), "failed_hospitals": observation.get("failed_hospitals", []), "visited_hospitals": observation.get("visited_hospitals", []), "ambulance_status": observation.get("ambulance_status"), }, "action": { "hospital_id": chosen["hospital_id"], "policy_score": chosen["policy_score"], "strategy": strategy, }, "outcome": { "status": status, "reason": reason, }, "reward": reward, } ) step_records.append( { "step": observation.get("step"), "hospital_id": chosen["hospital_id"], "status": status, "reason": reason, "reward": reward, "policy_score": chosen["policy_score"], } ) observation = next_observation final_state = env.state() final_result = final_state.final_outcome or "FAILURE" final_score = clamp_strict_score(final_state.final_score) print("\nFinal result:") print(f" Result: {final_result}") print(f" Total steps: {steps}") print(f" Final score: {final_score:.3f}") print(f" Average reward: {total_reward / max(1, steps):.3f}") rewards_str = ",".join(f"{r:.2f}" for r in all_rewards) print(f"[END] success={str(final_result == 'SUCCESS').lower()} steps={steps} score={final_score:.2f} rewards={rewards_str}", flush=True) return { "success": final_result == "SUCCESS", "score": final_score, "steps": steps, "seed": seed, "task_id": task_id, "scenario_name": observation.get("scenario_name"), "scenario_type": observation.get("scenario_type"), "difficulty": observation.get("scenario_difficulty"), "required_specialization": observation.get("required_specialization"), "actions": [record["hospital_id"] for record in step_records], "step_records": step_records, "timestamp": datetime.now(timezone.utc).isoformat(), } def update_learning_archive(archive: dict, episode_result: dict) -> None: key = profile_key(int(episode_result["seed"]), str(episode_result["task_id"])) profiles = archive.setdefault("profiles", {}) profile = profiles.get( key, { "attempts": 0, "best_score": 0.0, "best_actions": [], "best_steps": 0, "step_stats": {}, }, ) profile["attempts"] = int(profile.get("attempts", 0)) + 1 profile["last_score"] = float(episode_result["score"]) profile["last_success"] = bool(episode_result["success"]) profile["last_run_at"] = episode_result["timestamp"] profile["last_actions"] = list(episode_result.get("actions", [])) profile["last_required_specialization"] = episode_result.get("required_specialization") profile["last_scenario_type"] = episode_result.get("scenario_type") profile["last_scenario_name"] = episode_result.get("scenario_name") if float(episode_result["score"]) >= float(profile.get("best_score", 0.0)): profile["best_score"] = float(episode_result["score"]) profile["best_actions"] = list(episode_result.get("actions", [])) profile["best_steps"] = int(episode_result.get("steps", 0)) profile["best_success"] = bool(episode_result["success"]) profile["best_scenario_name"] = episode_result.get("scenario_name") profile["best_difficulty"] = episode_result.get("difficulty") profile["best_required_specialization"] = episode_result.get("required_specialization") step_stats = profile.setdefault("step_stats", {}) for record in episode_result.get("step_records", []): step_key = str(record.get("step")) hospital_id = str(record.get("hospital_id")) step_bucket = step_stats.setdefault(step_key, {}) hospital_bucket = step_bucket.setdefault( hospital_id, { "count": 0, "success": 0, "accepted": 0, "partial": 0, "rejected": 0, "total_reward": 0.0, "avg_reward": 0.0, "last_status": None, "last_reason": None, }, ) hospital_bucket["count"] += 1 if record["status"] == "ACCEPTED": hospital_bucket["success"] += 1 hospital_bucket["accepted"] += 1 elif record["status"] == "PARTIAL": hospital_bucket["partial"] += 1 else: hospital_bucket["rejected"] += 1 hospital_bucket["total_reward"] = float(hospital_bucket["total_reward"]) + float(record["reward"]) hospital_bucket["avg_reward"] = hospital_bucket["total_reward"] / max(1, hospital_bucket["count"]) hospital_bucket["last_status"] = record["status"] hospital_bucket["last_reason"] = record["reason"] hospital_bucket["success_rate"] = hospital_bucket["accepted"] / max(1, hospital_bucket["count"]) profiles[key] = profile episodes = archive.setdefault("episodes", []) episodes.append( { "seed": episode_result["seed"], "task_id": episode_result["task_id"], "difficulty": episode_result["difficulty"], "required_specialization": episode_result.get("required_specialization"), "scenario_name": episode_result["scenario_name"], "score": episode_result["score"], "success": episode_result["success"], "actions": episode_result.get("actions", []), "timestamp": episode_result["timestamp"], } ) archive["episodes"] = episodes[-500:] def print_training_summary(results: list[dict]) -> None: if not results: return scores = [float(item["score"]) for item in results] successes = sum(1 for item in results if item["success"]) split = max(1, len(scores) // 2) early_scores = scores[:split] late_scores = scores[split:] if not late_scores: late_scores = scores[-split:] early_avg = sum(early_scores) / len(early_scores) late_avg = sum(late_scores) / len(late_scores) delta = late_avg - early_avg print("\nTraining summary:") print(f" Episodes: {len(results)}") print(f" Success rate: {successes / len(results):.1%}") print(f" Average score: {sum(scores) / len(scores):.3f}") print(f" Early avg score ({len(early_scores)} eps): {early_avg:.3f}") print(f" Late avg score ({len(late_scores)} eps): {late_avg:.3f}") print(f" Trend delta (late-early): {delta:+.3f}") def main() -> None: args = parse_args() llm_client, model_name = require_llm_config() seed = ask_seed_if_missing(args.seed) print(f"Using seed: {seed}") if args.mode == "full": tasks = TASK_ORDER else: chosen_task = args.task if chosen_task is None: chosen_level = ask_level_if_missing(args.level) chosen_task = LEVEL_TO_TASK[chosen_level] tasks = [chosen_task] env = EmergencyEnv(memory_file=args.memory_file) archive = load_learning_archive() results = [] run_count = args.train_episodes if args.train_episodes > 0 else args.episodes training_mode = args.train_episodes > 0 for episode in range(run_count): for idx, task_id in enumerate(tasks): if training_mode: if args.train_same_seed: task_seed = seed else: task_seed = seed + (episode * 100) + idx else: task_seed = seed + (episode * 100) + idx label = f"Training Episode {episode + 1}" if training_mode else f"Episode {episode + 1}" print(f"\n=== {label} | {task_id} | seed={task_seed} ===") episode_result = run_episode( env, task_id, task_seed, archive=archive, llm_client=llm_client, model_name=model_name, ) results.append(episode_result) update_learning_archive(archive, episode_result) save_learning_archive(archive) if training_mode: print_training_summary(results) return if results: print("\nBatch summary:") if len(results) == 1: episode_result = "SUCCESS" if results[0]["success"] else "FAILURE" print(f" Episode outcome: {episode_result}") print(f" Episode score: {results[0]['score']:.3f}") print(f" Episode steps: {results[0]['steps']}") print(" Note: run 30-50 episodes to estimate difficulty success rate.") else: print(f" Success rate: {sum(1 for item in results if item['success']) / len(results):.1%}") print(f" Average score: {sum(item['score'] for item in results) / len(results):.3f}") print(f" Average steps: {sum(item['steps'] for item in results) / len(results):.1f}") if __name__ == "__main__": main()