test / inference.py
Jash
fix: use strict stdout log format for inference script
53effed
Raw
History Blame Contribute Delete
53.7 kB
#!/usr/bin/env python3
"""Local agent runner for EmergencyEnv.
This script acts as an agent only:
- reset env
- choose action from observation
- step env
- log trajectory
"""
from __future__ import annotations
import argparse
import json
import math
import os
import random
from pathlib import Path
from datetime import datetime, timezone
from typing import Union, TYPE_CHECKING, Optional, cast
from app.environment.core import EmergencyEnv
from app.models.action import Action
if TYPE_CHECKING:
from openai import OpenAI as OpenAIClient
else:
OpenAIClient = None
try:
from openai import OpenAI
except Exception: # pragma: no cover - fallback for missing optional dependency
OpenAI = None
TASK_ORDER = ["acde_easy", "acde_medium", "acde_hard"]
LEVEL_TO_TASK = {
"low": "acde_easy",
"medium": "acde_medium",
"high": "acde_hard",
}
RANDOM_LEVELS = ("medium", "high")
RANDOM_LEVEL_WEIGHTS = (0.25, 0.75)
BASE_SPEED_KMH = 60.0
TRAFFIC_FACTOR = {"low": 1.0, "medium": 0.6, "high": 0.3}
LEARNING_ARCHIVE_PATH = Path(__file__).resolve().parent / "data" / "learning_archive.json"
LEARNING_ARCHIVE_VERSION = 2
DEFAULT_API_BASE_URL = "https://api-inference.huggingface.co/v1"
DEFAULT_MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct"
REQUIRED_ENV_VARS = ("HF_TOKEN",)
STRICT_SCORE_MIN = 0.001
STRICT_SCORE_MAX = 0.999
def clamp_strict_score(value: float) -> float:
"""Clamp score-like outputs to the strict open interval (0, 1)."""
return max(STRICT_SCORE_MIN, min(STRICT_SCORE_MAX, float(value)))
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="EmergencyEnv agent runner")
parser.add_argument("--mode", choices=["single", "full"], default="full")
parser.add_argument("--task", choices=TASK_ORDER, default=None)
parser.add_argument("--level", choices=["low", "medium", "high"], default=None)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--episodes", type=int, default=1)
parser.add_argument("--train-episodes", type=int, default=0)
parser.add_argument("--train-same-seed", action="store_true")
parser.add_argument(
"--memory-file",
default=str(Path(__file__).resolve().parent / "data" / "learning_memory.json"),
)
return parser.parse_args()
def emit_structured(tag: str, payload: dict) -> None:
print(f"[{tag}] " + json.dumps(payload, ensure_ascii=True, separators=(",", ":")))
def runtime_llm_config() -> dict[str, str]:
return {
"API_BASE_URL": os.getenv("API_BASE_URL", DEFAULT_API_BASE_URL).strip(),
"MODEL_NAME": os.getenv("MODEL_NAME", DEFAULT_MODEL_NAME).strip(),
"HF_TOKEN": os.getenv("HF_TOKEN", "").strip(),
}
def require_llm_config() -> tuple[OpenAIClient, str]:
config = runtime_llm_config()
missing = [name for name, value in config.items() if not value]
if missing:
raise SystemExit(
"Missing required environment variables: "
+ ", ".join(missing)
+ ". Set HF_TOKEN before running inference.py"
)
if OpenAI is None:
raise SystemExit("openai package is required for inference.py LLM rationale generation.")
client = OpenAI(base_url=config["API_BASE_URL"], api_key=config["HF_TOKEN"], timeout=8.0)
return client, config["MODEL_NAME"]
def llm_rationale(
client: Union[OpenAIClient, None],
model_name: str,
observation: dict,
chosen: dict,
strategy: str,
) -> str:
fallback = (
f"Selected {chosen['hospital_id']} by {strategy}; "
f"score={chosen['policy_score']:.3f}, traffic={chosen['traffic']}, icu={chosen['icu']}"
)
if client is None:
return fallback
try:
prompt = (
"You are an emergency routing agent. Return one short sentence rationale "
"for the selected hospital. Keep it under 25 words.\n"
f"task={observation.get('task_id')} difficulty={observation.get('scenario_difficulty')} "
f"step={observation.get('step')} patient={observation.get('patient_condition')} "
f"required={observation.get('required_specialization')} "
f"selected={chosen['hospital_id']} score={chosen['policy_score']:.3f} "
f"distance={chosen['distance_km']:.1f}km traffic={chosen['traffic']} icu={chosen['icu']} "
f"strategy={strategy}"
)
completion = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": "Generate concise emergency triage rationale."},
{"role": "user", "content": prompt},
],
temperature=0.0,
max_tokens=60,
)
text = (completion.choices[0].message.content or "").strip()
if not text:
return fallback
return " ".join(text.split())[:180]
except Exception:
return fallback
def normalize_seed(raw_value: int | str) -> int:
"""Normalize arbitrary numeric/text input into a deterministic positive seed."""
if isinstance(raw_value, int):
value = raw_value
else:
text = str(raw_value).strip()
try:
value = int(text)
except ValueError:
# Deterministic fallback for non-numeric input.
value = sum((idx + 1) * ord(ch) for idx, ch in enumerate(text))
normalized = abs(value) % 1_000_000_000
return normalized if normalized != 0 else 202601
def ask_seed_if_missing(seed: int | None) -> int:
if seed is not None:
return normalize_seed(seed)
# No CLI seed means a fresh randomized run.
return normalize_seed(random.SystemRandom().randint(1, 999_999_999))
def ask_level_if_missing(level: str | None) -> str:
if level in LEVEL_TO_TASK:
return level
# No CLI level means pick a random non-easy difficulty.
return random.choices(
RANDOM_LEVELS,
weights=RANDOM_LEVEL_WEIGHTS,
k=1,
)[0]
def append_trajectory_log(entry: dict) -> None:
path = Path(__file__).resolve().parent / "data" / "trajectory_history.jsonl"
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("a", encoding="utf-8") as fp:
fp.write(json.dumps(entry, ensure_ascii=True) + "\n")
def load_learning_archive() -> dict:
LEARNING_ARCHIVE_PATH.parent.mkdir(parents=True, exist_ok=True)
if not LEARNING_ARCHIVE_PATH.exists():
return {"version": LEARNING_ARCHIVE_VERSION, "profiles": {}, "episodes": []}
try:
payload_text = LEARNING_ARCHIVE_PATH.read_text(encoding="utf-8-sig").strip()
payload = json.loads(payload_text) if payload_text else {}
except json.JSONDecodeError:
return {"version": LEARNING_ARCHIVE_VERSION, "profiles": {}, "episodes": []}
if not isinstance(payload, dict):
return {"version": LEARNING_ARCHIVE_VERSION, "profiles": {}, "episodes": []}
if payload.get("version") != LEARNING_ARCHIVE_VERSION:
return {
"version": LEARNING_ARCHIVE_VERSION,
"profiles": {},
"episodes": payload.get("episodes", [])[-500:] if isinstance(payload.get("episodes", []), list) else [],
}
payload.setdefault("version", LEARNING_ARCHIVE_VERSION)
payload.setdefault("profiles", {})
payload.setdefault("episodes", [])
return payload
def save_learning_archive(archive: dict) -> None:
LEARNING_ARCHIVE_PATH.parent.mkdir(parents=True, exist_ok=True)
LEARNING_ARCHIVE_PATH.write_text(json.dumps(archive, indent=2, ensure_ascii=True), encoding="utf-8")
def profile_key(seed: int, task_id: str) -> str:
return f"{seed}|{task_id}"
def _merge_step_stats(primary: dict, secondary: dict) -> dict:
merged: dict = {}
for step_key in set(primary.keys()) | set(secondary.keys()):
merged[step_key] = {}
step_primary = primary.get(step_key, {})
step_secondary = secondary.get(step_key, {})
for hospital_id in set(step_primary.keys()) | set(step_secondary.keys()):
a = step_primary.get(hospital_id, {})
b = step_secondary.get(hospital_id, {})
count = int(a.get("count", 0)) + int(b.get("count", 0))
accepted = int(a.get("accepted", 0)) + int(b.get("accepted", 0))
partial = int(a.get("partial", 0)) + int(b.get("partial", 0))
rejected = int(a.get("rejected", 0)) + int(b.get("rejected", 0))
total_reward = float(a.get("total_reward", 0.0)) + float(b.get("total_reward", 0.0))
merged[step_key][hospital_id] = {
"count": count,
"success": int(a.get("success", 0)) + int(b.get("success", 0)),
"accepted": accepted,
"partial": partial,
"rejected": rejected,
"total_reward": total_reward,
"avg_reward": (total_reward / max(1, count)),
"success_rate": (accepted / max(1, count)),
"last_status": a.get("last_status") or b.get("last_status"),
"last_reason": a.get("last_reason") or b.get("last_reason"),
}
return merged
def build_learning_profile(
archive: dict,
seed: int,
task_id: str,
required_specialization: str | None = None,
) -> dict | None:
profiles = archive.get("profiles", {})
key = profile_key(seed, task_id)
exact = profiles.get(key)
if not exact:
return None
# Strict scope: learn only from same seed + same level/task.
return {
"attempts": int(exact.get("attempts", 0)),
"best_score": float(exact.get("best_score", 0.0)),
"best_actions": list(exact.get("best_actions", [])),
"step_stats": exact.get("step_stats", {}),
"best_scenario_name": exact.get("best_scenario_name"),
"last_scenario_name": exact.get("last_scenario_name"),
"source": "exact-only",
}
def _difficulty_policy_params(difficulty: str) -> tuple[float, float]:
if difficulty == "easy":
return 0.07, 0.18
if difficulty == "medium":
return 0.16, 0.32
return 0.26, 0.44
def _sample_softmax(candidates: list[dict], key: str, temperature: float, rng: random.Random) -> dict:
logits = [item[key] / max(temperature, 1e-6) for item in candidates]
max_logit = max(logits)
exps = [math.exp(v - max_logit) for v in logits]
total = sum(exps)
probs = [e / total for e in exps]
roll = rng.random()
cdf = 0.0
for item, prob in zip(candidates, probs):
cdf += prob
if roll <= cdf:
return item
return candidates[-1]
def memory_score_for_hospital(
hospital_id: str,
memory_snapshot: dict,
learning_profile: dict | None = None,
step_number: int | None = None,
) -> float:
entry = memory_snapshot.get(hospital_id)
if not entry:
return 0.5
success = int(entry.get("accepted", entry.get("success", 0)))
fail = int(entry.get("rejected", entry.get("fail", 0)))
avg = float(entry.get("avg", 0.0))
total = success + fail
if total <= 0:
return 0.5
success_rate = success / total
# Fix 3: reliability-first memory scoring.
value = (0.6 * success_rate) + (0.4 * avg)
recent_failed = False
if learning_profile and step_number is not None:
step_stats = learning_profile.get("step_stats", {}).get(str(step_number), {})
hospital_stats = step_stats.get(hospital_id)
if hospital_stats:
step_avg = float(hospital_stats.get("avg_reward", 0.0))
step_success = float(hospital_stats.get("success_rate", 0.0))
step_count = int(hospital_stats.get("count", 0))
value += min(0.20, (step_avg * 0.10) + (step_success * 0.08) + min(step_count, 5) * 0.01)
recent_failed = str(hospital_stats.get("last_status", "")).upper() == "REJECTED"
if recent_failed:
value -= 0.3
return max(0.0, min(1.0, value))
def score_hospitals(observation: dict, learning_profile: dict | None = None) -> list[dict]:
failed = set(observation.get("failed_hospitals", []))
recent_failed = set(observation.get("recent_failed_hospitals", []))
visited = set(observation.get("visited_hospitals", []))
memory_snapshot = observation.get("memory_snapshot", {})
previous_action = observation.get("previous_action")
last_arrival = observation.get("last_arrival_outcome") or {}
last_status = str(last_arrival.get("status", "")).lower()
scored: list[dict] = []
initial_limit = float(observation.get("initial_critical_time_limit_minutes", observation["critical_time_limit_minutes"]))
remaining_time = float(observation.get("remaining_time_minutes", observation["critical_time_limit_minutes"]))
urgency = 1.0 - min(1.0, max(0.0, remaining_time / max(initial_limit, 1e-6)))
patient_condition = observation.get("patient_condition", "").lower()
critical_patient = patient_condition in {"critical", "unstable"}
required_specialization = str(observation.get("required_specialization", ""))
scenario_name = str(observation.get("scenario_name", ""))
step_number = int(observation.get("step", 1))
difficulty = str(observation.get("scenario_difficulty", "medium"))
attempts = int(learning_profile.get("attempts", 0)) if learning_profile else 0
preferred_route = []
if learning_profile:
preferred_route = list(learning_profile.get("best_actions", []))
for hospital in observation.get("hospitals", []):
traffic_factor = TRAFFIC_FACTOR[hospital["traffic"]]
speed_kmh = BASE_SPEED_KMH * traffic_factor
travel_time = (hospital["distance_km"] / max(speed_kmh, 1e-6)) * 60.0
distance_score = max(0.0, min(1.0, 1.0 - hospital["distance_km"] / 20.0))
icu_score = 1.0 if hospital["icu"] == "available" else 0.55
mem_score = memory_score_for_hospital(
hospital["hospital_id"],
memory_snapshot,
learning_profile=learning_profile,
step_number=step_number,
)
memory_scenario = ""
if learning_profile:
memory_scenario = str(
learning_profile.get("best_scenario_name")
or learning_profile.get("last_scenario_name")
or ""
)
if memory_scenario and scenario_name and memory_scenario != scenario_name:
mem_score *= 0.5
spec_match = (
hospital["specialization"] == observation["required_specialization"]
or hospital["specialization"] == "general"
or observation["required_specialization"] == "general"
)
exact_spec_match = hospital["specialization"] == observation["required_specialization"]
general_fallback = (
hospital["specialization"] == "general"
and observation["required_specialization"] != "general"
)
rejected_penalty = 0.40 if hospital["hospital_id"] in failed else 0.0
revisit_penalty = 0.14 if hospital["hospital_id"] in visited else 0.0
partial_repeat_penalty = (
0.32
if last_status == "partial" and hospital["hospital_id"] == previous_action
else 0.0
)
critical_unknown_penalty = 0.17 if critical_patient and hospital["icu"] == "unknown" else 0.03
traffic_penalty = 0.10 if hospital["traffic"] == "high" else 0.04 if hospital["traffic"] == "medium" else 0.0
if critical_patient and general_fallback:
spec_penalty = {"easy": 0.08, "medium": 0.16, "hard": 0.26}.get(difficulty, 0.16)
if attempts >= 5:
spec_penalty += 0.06
else:
spec_penalty = 0.0
spec_bonus = 0.16 if exact_spec_match else (0.08 if spec_match else 0.0)
urgency_boost = urgency * (0.18 + max(0.0, 0.25 - travel_time / 100.0))
step_route_bonus = 0.0
if step_number - 1 < len(preferred_route) and preferred_route[step_number - 1] == hospital["hospital_id"]:
step_route_bonus = 0.16
score = (
(icu_score * 0.30)
+ (distance_score * 0.18)
+ (traffic_factor * 0.14)
+ (mem_score * 0.24)
+ spec_bonus
+ urgency_boost
+ step_route_bonus
- rejected_penalty
- revisit_penalty
- partial_repeat_penalty
- spec_penalty
- critical_unknown_penalty
- traffic_penalty
)
if hospital["hospital_id"] == previous_action and last_status == "rejected":
score *= 0.01
if hospital["hospital_id"] in recent_failed:
score *= 0.2
if hospital["specialization"] != required_specialization:
if patient_condition == "critical":
score *= 0.15
else:
score *= 0.4
elif patient_condition == "critical":
score *= 1.5
# Hard realism penalties to align policy scoring with validator outcomes.
if hospital["specialization"] != required_specialization:
score -= 0.6
if critical_patient and hospital["icu"] == "unknown":
score -= 0.5
if critical_patient and hospital["traffic"] == "high":
score -= 0.3
# Confidence-style risk multiplier keeps risky options from looking deceptively good.
risk_factor = 1.0
if hospital["icu"] == "unknown":
risk_factor *= 0.6
if not spec_match:
risk_factor *= 0.5
if critical_patient and hospital["traffic"] == "high":
risk_factor *= 0.7
score *= risk_factor
# Reduce memory dominance in final decision score.
memory_weight = 0.15
current_score_weight = 0.85
if step_number == 1:
memory_weight = 0.1
current_score_weight = 0.9
base_current_score = score
confidence_score = max(0.0, min(1.0, base_current_score))
effective_memory_score = mem_score
in_best_route = hospital["hospital_id"] in preferred_route
if in_best_route and confidence_score < 0.6:
effective_memory_score = 0.0
if confidence_score < 0.2:
effective_memory_score = 0.0
score = (current_score_weight * base_current_score) + (memory_weight * effective_memory_score)
scored.append(
{
"hospital_id": hospital["hospital_id"],
"icu": hospital["icu"],
"distance_km": hospital["distance_km"],
"traffic": hospital["traffic"],
"specialization": hospital["specialization"],
"travel_time": travel_time,
"memory_score": mem_score,
"policy_score": max(0.0, min(1.0, score)),
"specialization_match": spec_match,
"tie_break_score": (
(distance_score * 0.35)
+ (traffic_factor * 0.35)
+ (icu_score * 0.20)
+ (0.10 if spec_match else 0.0)
),
}
)
scored.sort(key=lambda item: item["policy_score"], reverse=True)
if scored:
min_score = min(item["policy_score"] for item in scored)
max_score = max(item["policy_score"] for item in scored)
spread = max_score - min_score
if spread > 1e-9:
for item in scored:
normalized = (item["policy_score"] - min_score) / (spread + 1e-6)
if normalized < 0.2:
jitter_seed = (
int(observation.get("seed", 0))
+ (step_number * 131)
+ sum(ord(ch) for ch in item["hospital_id"])
)
jitter_rng = random.Random(jitter_seed)
normalized *= jitter_rng.uniform(0.3, 0.7)
item["policy_score"] = max(0.0, min(1.0, normalized))
elif max_score > 0:
for item in scored:
normalized = item["policy_score"] / max_score
if normalized < 0.2:
jitter_seed = (
int(observation.get("seed", 0))
+ (step_number * 131)
+ sum(ord(ch) for ch in item["hospital_id"])
)
jitter_rng = random.Random(jitter_seed)
normalized *= jitter_rng.uniform(0.3, 0.7)
item["policy_score"] = max(0.0, min(1.0, normalized))
else:
tie_min = min(item.get("tie_break_score", 0.0) for item in scored)
tie_max = max(item.get("tie_break_score", 0.0) for item in scored)
tie_spread = tie_max - tie_min
if tie_spread > 1e-9:
for item in scored:
normalized = (item.get("tie_break_score", 0.0) - tie_min) / (tie_spread + 1e-6)
if normalized < 0.2:
jitter_seed = (
int(observation.get("seed", 0))
+ (step_number * 131)
+ sum(ord(ch) for ch in item["hospital_id"])
)
jitter_rng = random.Random(jitter_seed)
normalized *= jitter_rng.uniform(0.3, 0.7)
item["policy_score"] = max(0.0, min(1.0, normalized))
else:
for item in scored:
item["policy_score"] = 0.0
# Remove hard-zero scores and normalize to probability-like values.
for item in scored:
if item["policy_score"] <= 0.0:
jitter_seed = (
int(observation.get("seed", 0))
+ (step_number * 173)
+ sum(ord(ch) for ch in item["hospital_id"])
)
jitter_rng = random.Random(jitter_seed)
if critical_patient and required_specialization != "general":
if item.get("specialization") == required_specialization:
item["policy_score"] = jitter_rng.uniform(0.08, 0.18)
else:
item["policy_score"] = jitter_rng.uniform(0.001, 0.01)
else:
item["policy_score"] = jitter_rng.uniform(0.05, 0.15)
total_score = sum(item["policy_score"] for item in scored)
if total_score > 0:
for item in scored:
item["policy_score"] = item["policy_score"] / (total_score + 1e-6)
else:
uniform = 1.0 / len(scored)
for item in scored:
item["policy_score"] = uniform
# Final clinical-priority pass: in critical non-general cases,
# exact specialization should dominate unless unavailable.
if critical_patient and required_specialization != "general":
for item in scored:
if item.get("specialization") == required_specialization:
item["policy_score"] *= 1.5
else:
item["policy_score"] *= 0.15
boosted_total = sum(item["policy_score"] for item in scored)
if boosted_total > 0:
for item in scored:
item["policy_score"] = item["policy_score"] / boosted_total
for item in scored:
raw_score = float(item["policy_score"])
normalized_score = raw_score / (1.0 + abs(raw_score))
# Keep a small floor so no action is fully eliminated from exploration.
if normalized_score < 0.01:
jitter_seed = (
int(observation.get("seed", 0))
+ (step_number * 211)
+ sum(ord(ch) for ch in item["hospital_id"])
)
jitter_rng = random.Random(jitter_seed)
normalized_score = jitter_rng.uniform(0.01, 0.03)
item["policy_score"] = normalized_score
scored.sort(key=lambda item: item["policy_score"], reverse=True)
for item in scored:
item.pop("tie_break_score", None)
return scored
def choose_hospital(
scored: list[dict],
observation: dict,
rng: random.Random,
learning_profile: dict | None = None,
) -> tuple[dict, str]:
difficulty = observation.get("scenario_difficulty", "medium")
epsilon, temperature = _difficulty_policy_params(difficulty)
failed = set(observation.get("failed_hospitals", []))
recent_failed = set(observation.get("recent_failed_hospitals", []))
visited = set(observation.get("visited_hospitals", []))
previous_action = observation.get("previous_action")
selected_hospital_id = observation.get("selected_hospital_id")
visited_sequence = observation.get("visited_hospitals", []) or []
recent_hospital = previous_action or selected_hospital_id or (visited_sequence[-1] if visited_sequence else None)
last_arrival = observation.get("last_arrival_outcome") or {}
last_status = str(last_arrival.get("status", "")).lower()
last_reason = str(last_arrival.get("reason", "")).lower()
is_rerouting_phase = str(observation.get("ambulance_status", "")).lower() == "rerouting"
# Cooldown logic: avoid recently failed hospitals first, then avoid visited when alternatives exist.
candidates = [
item
for item in scored
if item["hospital_id"] not in recent_failed and item["hospital_id"] not in visited
]
if not candidates:
candidates = [item for item in scored if item["hospital_id"] not in recent_failed]
if not candidates:
# Last-resort fallback: if every hospital has failed already, avoid immediate retry.
candidates = list(scored)
if (last_status == "rejected" or is_rerouting_phase) and recent_hospital:
redirected = [item for item in candidates if item["hospital_id"] != recent_hospital]
if redirected:
candidates = redirected
step_number = int(observation.get("step", 1))
attempts = int(learning_profile.get("attempts", 0)) if learning_profile else 0
required_specialization = str(observation.get("required_specialization", ""))
critical_patient = observation.get("patient_condition", "").lower() in {"critical", "unstable"}
# Hard realism rule: never immediately retry the hospital that just rejected the patient.
if (last_status == "rejected" or is_rerouting_phase) and recent_hospital:
immediate_retry_block = [item for item in candidates if item["hospital_id"] != recent_hospital]
if immediate_retry_block:
candidates = immediate_retry_block
elif len(candidates) == 1 and candidates[0]["hospital_id"] == recent_hospital:
fallback_any = [item for item in scored if item["hospital_id"] != recent_hospital]
if fallback_any:
candidates = fallback_any
# In critical non-general cases, prioritize exact specialization when available.
if critical_patient and required_specialization != "general":
exact_spec_candidates = [
item for item in candidates if item["specialization"] == required_specialization
]
if exact_spec_candidates:
candidates = exact_spec_candidates
if step_number == 1:
policy_mode = "safe"
elif last_status == "rejected":
policy_mode = "risk-aware"
else:
policy_mode = "balanced"
safe_weight = 1.0
if policy_mode == "safe":
safe_weight *= 0.8
epsilon *= 0.6
temperature *= 0.8
elif policy_mode == "risk-aware":
epsilon *= 1.1
temperature *= 0.9
# Within-episode learning from concrete failure reasons.
if "wrong hospital specialization" in last_reason:
strict_spec = [
item
for item in candidates
if item["specialization"] == observation.get("required_specialization")
]
if strict_spec:
candidates = strict_spec
if "icu unavailable" in last_reason:
icu_known = [item for item in candidates if item["icu"] == "available"]
if icu_known:
candidates = icu_known
if "specialist" in last_reason:
strict_spec = [
item
for item in candidates
if item["specialization"] == observation.get("required_specialization")
]
if strict_spec:
candidates = strict_spec
if "overloaded" in last_reason:
non_high_traffic = [item for item in candidates if item["traffic"] != "high"]
if non_high_traffic:
candidates = non_high_traffic
if "delay" in last_reason:
candidates = sorted(candidates, key=lambda item: item["distance_km"])
def learned_utility(item: dict) -> float:
base = float(item.get("policy_score", 0.0))
if not learning_profile:
return base
step_stats = learning_profile.get("step_stats", {}).get(str(step_number), {})
stats = step_stats.get(item["hospital_id"], {})
count = int(stats.get("count", 0))
if count <= 0:
exploration_bonus = 0.22 * math.sqrt(max(1.0, math.log(attempts + 2.0)))
return base + exploration_bonus
avg_reward = float(stats.get("avg_reward", 0.0))
success_rate = float(stats.get("success_rate", 0.0))
rejected = int(stats.get("rejected", 0))
rejection_rate = rejected / max(1, count)
exploration_bonus = 0.18 * math.sqrt(max(0.0, math.log(attempts + 2.0) / (count + 1.0)))
# Real-data utility: reward trend + success rate - rejection risk + exploration bonus.
historical_weight = 0.35
historical_weight *= 0.6
historical_bonus = (avg_reward * historical_weight) + (success_rate * 0.30) - (rejection_rate * 0.22)
if item["hospital_id"] in recent_failed:
historical_bonus = 0.0
return base + historical_bonus + exploration_bonus
def pick_improvement_candidate(route_choice_id: str | None) -> dict | None:
if not candidates:
return None
ranked = sorted(candidates, key=learned_utility, reverse=True)
if route_choice_id is None:
return ranked[0]
for item in ranked:
if item["hospital_id"] != route_choice_id:
return item
return ranked[0]
def enforce_score_guard(chosen: dict, strategy: str) -> tuple[dict, str]:
# Absolute next-step guard: never pick the same hospital immediately after a rejection.
if last_status == "rejected" and previous_action and chosen.get("hospital_id") == previous_action:
alternatives = [item for item in scored if item["hospital_id"] != previous_action]
if alternatives:
rerouted = max(alternatives, key=lambda item: float(item.get("policy_score", 0.0)))
return rerouted, strategy + " + immediate-retry block"
# Global guardrail: when a score gap is very large, prefer best option most
# of the time while preserving some exploration.
globally_eligible = [
item
for item in scored
if item["hospital_id"] not in recent_failed
and not (
(last_status == "rejected" or is_rerouting_phase)
and recent_hospital
and item["hospital_id"] == recent_hospital
)
]
if not globally_eligible:
globally_eligible = list(scored)
if globally_eligible:
best_global = max(globally_eligible, key=lambda item: float(item.get("policy_score", 0.0)))
chosen_score = float(chosen.get("policy_score", 0.0))
best_global_score = float(best_global.get("policy_score", 0.0))
# Cooldown hard guard: never immediately retry the just-failed hospital.
if (last_status == "rejected" or is_rerouting_phase) and recent_hospital:
if chosen.get("hospital_id") == recent_hospital:
alternatives = [
item
for item in scored
if item["hospital_id"] != recent_hospital and item["hospital_id"] not in recent_failed
]
if not alternatives:
alternatives = [item for item in scored if item["hospital_id"] != recent_hospital]
if alternatives:
rerouted = max(alternatives, key=lambda item: float(item.get("policy_score", 0.0)))
return rerouted, strategy + " + cooldown reroute"
if chosen_score < (best_global_score * 0.6):
return best_global, strategy + " + anti-stupidity guard"
if (best_global_score - chosen_score) > 0.25 and rng.random() < 0.75:
return best_global, strategy + " + score-gap guard"
return chosen, strategy
# Learning-driven fail guard: avoid hospitals that repeatedly fail at this exact step.
if learning_profile:
step_stats = learning_profile.get("step_stats", {}).get(str(step_number), {})
guard_blocked: set[str] = set()
for hospital_id, stats in step_stats.items():
count = int(stats.get("count", 0))
success_rate = float(stats.get("success_rate", 0.0))
rejected = int(stats.get("rejected", 0))
if count >= 2 and success_rate <= 0.0 and rejected >= 2:
guard_blocked.add(hospital_id)
guarded_candidates = [item for item in candidates if item["hospital_id"] not in guard_blocked]
if guarded_candidates:
candidates = guarded_candidates
# As attempts increase, reduce randomness and rely on learned utility.
if attempts >= 3:
epsilon *= 0.35
temperature *= 0.70
# Same seed + same task policy:
# evaluate route combinations across all steps, not just one-step mutations.
if learning_profile and policy_mode != "risk-aware":
best_route = list(learning_profile.get("best_actions", []))
if step_number - 1 < len(best_route):
baseline_id = best_route[step_number - 1]
ranked = sorted(candidates, key=learned_utility, reverse=True)
baseline_candidate = next((item for item in ranked if item["hospital_id"] == baseline_id), None)
alternatives = [item for item in ranked if item["hospital_id"] != baseline_id]
top_candidate = ranked[0] if ranked else None
if (
step_number == 1
and baseline_candidate is not None
and top_candidate is not None
and float(baseline_candidate.get("policy_score", 0.0)) < float(top_candidate.get("policy_score", 0.0))
):
baseline_candidate = None
alternatives = alternatives[: min(3, len(alternatives))]
if attempts >= 1:
# Mixed-radix route search: each run selects a step-wise digit.
# digit 0 => keep baseline for this step, 1/2 => try alternative ranks.
combo_index = max(0, attempts - 1)
digit = (combo_index // (3 ** max(0, step_number - 1))) % 3
if digit == 0 and baseline_candidate is not None:
return enforce_score_guard(baseline_candidate, "best-route retain")
alt_rank = digit - 1
if alt_rank >= 0 and alt_rank < len(alternatives):
return enforce_score_guard(alternatives[alt_rank], f"combination search step-{step_number} alt-{alt_rank + 1}")
if baseline_candidate is not None:
return enforce_score_guard(baseline_candidate, "best-route retain")
if attempts >= 6:
ranked = sorted(candidates, key=learned_utility, reverse=True)
top_pool = ranked[: min(3, len(ranked))]
return enforce_score_guard(_sample_softmax(top_pool, "policy_score", max(0.08, temperature * 0.85), rng), "learned utility exploit")
if learning_profile and policy_mode == "safe":
preferred_route = list(learning_profile.get("best_actions", []))
if step_number - 1 < len(preferred_route):
preferred_hospital = preferred_route[step_number - 1]
preferred_candidate = next((item for item in candidates if item["hospital_id"] == preferred_hospital), None)
if preferred_candidate is not None:
profile_score = float(learning_profile.get("best_score", 0.0))
if (profile_score * safe_weight) >= 0.85 or len(candidates) == 1:
return enforce_score_guard(preferred_candidate, "learned best path")
# If last outcome was partial, force trying a different hospital when possible.
if last_status == "partial" and previous_action:
redirected = [item for item in candidates if item["hospital_id"] != previous_action]
if redirected:
candidates = redirected
# After partial treatment, reduce random exploration and favor safer follow-up routing.
epsilon = min(epsilon, 0.04)
temperature = min(temperature, 0.24)
critical = observation.get("patient_condition", "").lower() in {"critical", "unstable"}
strategy = f"{policy_mode} policy"
if critical and policy_mode in {"safe", "balanced"}:
confirmed = [item for item in candidates if item["icu"] == "available"]
if confirmed:
candidates = confirmed
strategy = f"{policy_mode} policy + critical triage"
if len(candidates) > 1 and rng.random() < 0.15:
ranked = sorted(candidates, key=learned_utility, reverse=True)
top_k = ranked[: min(3, len(ranked))]
return enforce_score_guard(rng.choice(top_k), strategy + " + guided-exploration")
if len(candidates) > 1:
# Utility-aware candidate ordering for softmax sampling.
ranked = sorted(candidates, key=learned_utility, reverse=True)
chosen = _sample_softmax(ranked, "policy_score", temperature, rng)
return enforce_score_guard(chosen, strategy)
return enforce_score_guard(candidates[0], strategy)
def print_options(scored: list[dict]) -> None:
print(f"Hospital options ({len(scored)} total):")
for idx, item in enumerate(scored, start=1):
print(
f" [{idx}] {item['hospital_id']} | {item['distance_km']:.1f} km | ICU {item['icu']} | "
f"traffic {item['traffic']} | specialty {item['specialization']} | score {item['policy_score']:.3f}"
)
def run_episode(
env: EmergencyEnv,
task_id: str,
seed: int,
archive: dict | None = None,
llm_client: object | None = None,
model_name: str | None = None,
) -> dict:
observation_model = env.reset(seed=seed, task_id=task_id)
observation = observation_model.model_dump()
learning_profile = None
if archive is not None:
learning_profile = build_learning_profile(
archive,
seed,
task_id,
required_specialization=str(observation.get("required_specialization", "")) or None,
)
print("\n" + "=" * 72)
print(f"Scenario: {observation['scenario_name']}")
print(f"Task: {task_id} | Difficulty: {observation['scenario_difficulty']} | Seed: {seed}")
print(f"Patient condition: {observation['patient_condition']}")
print(f"Required specialization: {observation['required_specialization']}")
print("Objective: admit patient successfully (no fixed deadline window)")
print("=" * 72)
print(f"[START] task={task_id} env=acde-openenv model={model_name or 'none'}", flush=True)
if learning_profile:
print(
f"Learning memory: best historical score {float(learning_profile.get('best_score', 0.0)):.3f} "
f"across {int(learning_profile.get('attempts', 0))} attempts"
)
if learning_profile.get("best_actions"):
print(f"Best known route: {' -> '.join(learning_profile['best_actions'])}")
total_reward = 0.0
all_rewards = []
steps = 0
done = False
previous_policy_hospital_id: str | None = None
previous_policy_outcome: str | None = None
attempt_index = int(learning_profile.get("attempts", 0)) if learning_profile else 0
# Keep scenario deterministic by seed, but vary policy exploration across retries.
rng = random.Random(seed + (attempt_index * 7919))
step_records: list[dict] = []
while not done:
steps += 1
print(f"\nStep {observation['step']} | phase={observation['ambulance_status']}")
scored = score_hospitals(observation, learning_profile=learning_profile)
chosen, strategy = choose_hospital(scored, observation, rng, learning_profile=learning_profile)
# Final policy-level guard: no immediate retry of the same hospital after rejection.
if previous_policy_outcome == "REJECTED" and previous_policy_hospital_id and chosen["hospital_id"] == previous_policy_hospital_id:
alternatives = [item for item in scored if item["hospital_id"] != previous_policy_hospital_id]
if alternatives:
chosen = max(alternatives, key=lambda item: float(item.get("policy_score", 0.0)))
strategy = strategy + " + immediate-retry override"
print_options(scored)
rationale = llm_rationale(cast(Optional[OpenAIClient], llm_client), model_name or "", observation, chosen, strategy)
print(f"Decision: {chosen['hospital_id']} ({strategy})")
step_result = env.step(
Action(
step=observation["step"],
hospital_id=chosen["hospital_id"],
rationale=rationale,
)
)
next_obs_model = step_result["observation"]
reward = float(step_result["reward"])
all_rewards.append(reward)
done = bool(step_result["done"])
info = step_result.get("info", {}) or {}
next_observation = next_obs_model.model_dump()
total_reward += reward
outcome = info.get("outcome", {})
status = str(outcome.get("status", "partial")).upper()
reason = str(outcome.get("reason", "No reason provided"))
previous_policy_hospital_id = chosen["hospital_id"]
previous_policy_outcome = status
print(f"Outcome: {status}")
print(f"Reason: {reason}")
print(f"Reward: {reward:.3f}")
error_val = str(info.get("last_action_error")) if info.get("last_action_error") else "null"
print(f"[STEP] step={observation.get('step')} action={chosen['hospital_id']} reward={reward:.2f} done={str(done).lower()} error={error_val}", flush=True)
append_trajectory_log(
{
"seed": seed,
"task": task_id,
"difficulty": observation.get("scenario_difficulty"),
"step": observation.get("step"),
"state": {
"patient_condition": observation.get("patient_condition"),
"remaining_time_minutes": observation.get("remaining_time_minutes"),
"failed_hospitals": observation.get("failed_hospitals", []),
"visited_hospitals": observation.get("visited_hospitals", []),
"ambulance_status": observation.get("ambulance_status"),
},
"action": {
"hospital_id": chosen["hospital_id"],
"policy_score": chosen["policy_score"],
"strategy": strategy,
},
"outcome": {
"status": status,
"reason": reason,
},
"reward": reward,
}
)
step_records.append(
{
"step": observation.get("step"),
"hospital_id": chosen["hospital_id"],
"status": status,
"reason": reason,
"reward": reward,
"policy_score": chosen["policy_score"],
}
)
observation = next_observation
final_state = env.state()
final_result = final_state.final_outcome or "FAILURE"
final_score = clamp_strict_score(final_state.final_score)
print("\nFinal result:")
print(f" Result: {final_result}")
print(f" Total steps: {steps}")
print(f" Final score: {final_score:.3f}")
print(f" Average reward: {total_reward / max(1, steps):.3f}")
rewards_str = ",".join(f"{r:.2f}" for r in all_rewards)
print(f"[END] success={str(final_result == 'SUCCESS').lower()} steps={steps} score={final_score:.2f} rewards={rewards_str}", flush=True)
return {
"success": final_result == "SUCCESS",
"score": final_score,
"steps": steps,
"seed": seed,
"task_id": task_id,
"scenario_name": observation.get("scenario_name"),
"scenario_type": observation.get("scenario_type"),
"difficulty": observation.get("scenario_difficulty"),
"required_specialization": observation.get("required_specialization"),
"actions": [record["hospital_id"] for record in step_records],
"step_records": step_records,
"timestamp": datetime.now(timezone.utc).isoformat(),
}
def update_learning_archive(archive: dict, episode_result: dict) -> None:
key = profile_key(int(episode_result["seed"]), str(episode_result["task_id"]))
profiles = archive.setdefault("profiles", {})
profile = profiles.get(
key,
{
"attempts": 0,
"best_score": 0.0,
"best_actions": [],
"best_steps": 0,
"step_stats": {},
},
)
profile["attempts"] = int(profile.get("attempts", 0)) + 1
profile["last_score"] = float(episode_result["score"])
profile["last_success"] = bool(episode_result["success"])
profile["last_run_at"] = episode_result["timestamp"]
profile["last_actions"] = list(episode_result.get("actions", []))
profile["last_required_specialization"] = episode_result.get("required_specialization")
profile["last_scenario_type"] = episode_result.get("scenario_type")
profile["last_scenario_name"] = episode_result.get("scenario_name")
if float(episode_result["score"]) >= float(profile.get("best_score", 0.0)):
profile["best_score"] = float(episode_result["score"])
profile["best_actions"] = list(episode_result.get("actions", []))
profile["best_steps"] = int(episode_result.get("steps", 0))
profile["best_success"] = bool(episode_result["success"])
profile["best_scenario_name"] = episode_result.get("scenario_name")
profile["best_difficulty"] = episode_result.get("difficulty")
profile["best_required_specialization"] = episode_result.get("required_specialization")
step_stats = profile.setdefault("step_stats", {})
for record in episode_result.get("step_records", []):
step_key = str(record.get("step"))
hospital_id = str(record.get("hospital_id"))
step_bucket = step_stats.setdefault(step_key, {})
hospital_bucket = step_bucket.setdefault(
hospital_id,
{
"count": 0,
"success": 0,
"accepted": 0,
"partial": 0,
"rejected": 0,
"total_reward": 0.0,
"avg_reward": 0.0,
"last_status": None,
"last_reason": None,
},
)
hospital_bucket["count"] += 1
if record["status"] == "ACCEPTED":
hospital_bucket["success"] += 1
hospital_bucket["accepted"] += 1
elif record["status"] == "PARTIAL":
hospital_bucket["partial"] += 1
else:
hospital_bucket["rejected"] += 1
hospital_bucket["total_reward"] = float(hospital_bucket["total_reward"]) + float(record["reward"])
hospital_bucket["avg_reward"] = hospital_bucket["total_reward"] / max(1, hospital_bucket["count"])
hospital_bucket["last_status"] = record["status"]
hospital_bucket["last_reason"] = record["reason"]
hospital_bucket["success_rate"] = hospital_bucket["accepted"] / max(1, hospital_bucket["count"])
profiles[key] = profile
episodes = archive.setdefault("episodes", [])
episodes.append(
{
"seed": episode_result["seed"],
"task_id": episode_result["task_id"],
"difficulty": episode_result["difficulty"],
"required_specialization": episode_result.get("required_specialization"),
"scenario_name": episode_result["scenario_name"],
"score": episode_result["score"],
"success": episode_result["success"],
"actions": episode_result.get("actions", []),
"timestamp": episode_result["timestamp"],
}
)
archive["episodes"] = episodes[-500:]
def print_training_summary(results: list[dict]) -> None:
if not results:
return
scores = [float(item["score"]) for item in results]
successes = sum(1 for item in results if item["success"])
split = max(1, len(scores) // 2)
early_scores = scores[:split]
late_scores = scores[split:]
if not late_scores:
late_scores = scores[-split:]
early_avg = sum(early_scores) / len(early_scores)
late_avg = sum(late_scores) / len(late_scores)
delta = late_avg - early_avg
print("\nTraining summary:")
print(f" Episodes: {len(results)}")
print(f" Success rate: {successes / len(results):.1%}")
print(f" Average score: {sum(scores) / len(scores):.3f}")
print(f" Early avg score ({len(early_scores)} eps): {early_avg:.3f}")
print(f" Late avg score ({len(late_scores)} eps): {late_avg:.3f}")
print(f" Trend delta (late-early): {delta:+.3f}")
def main() -> None:
args = parse_args()
llm_client, model_name = require_llm_config()
seed = ask_seed_if_missing(args.seed)
print(f"Using seed: {seed}")
if args.mode == "full":
tasks = TASK_ORDER
else:
chosen_task = args.task
if chosen_task is None:
chosen_level = ask_level_if_missing(args.level)
chosen_task = LEVEL_TO_TASK[chosen_level]
tasks = [chosen_task]
env = EmergencyEnv(memory_file=args.memory_file)
archive = load_learning_archive()
results = []
run_count = args.train_episodes if args.train_episodes > 0 else args.episodes
training_mode = args.train_episodes > 0
for episode in range(run_count):
for idx, task_id in enumerate(tasks):
if training_mode:
if args.train_same_seed:
task_seed = seed
else:
task_seed = seed + (episode * 100) + idx
else:
task_seed = seed + (episode * 100) + idx
label = f"Training Episode {episode + 1}" if training_mode else f"Episode {episode + 1}"
print(f"\n=== {label} | {task_id} | seed={task_seed} ===")
episode_result = run_episode(
env,
task_id,
task_seed,
archive=archive,
llm_client=llm_client,
model_name=model_name,
)
results.append(episode_result)
update_learning_archive(archive, episode_result)
save_learning_archive(archive)
if training_mode:
print_training_summary(results)
return
if results:
print("\nBatch summary:")
if len(results) == 1:
episode_result = "SUCCESS" if results[0]["success"] else "FAILURE"
print(f" Episode outcome: {episode_result}")
print(f" Episode score: {results[0]['score']:.3f}")
print(f" Episode steps: {results[0]['steps']}")
print(" Note: run 30-50 episodes to estimate difficulty success rate.")
else:
print(f" Success rate: {sum(1 for item in results if item['success']) / len(results):.1%}")
print(f" Average score: {sum(item['score'] for item in results) / len(results):.3f}")
print(f" Average steps: {sum(item['steps'] for item in results) / len(results):.1f}")
if __name__ == "__main__":
main()