|
|
| """Local agent runner for EmergencyEnv.
|
|
|
| This script acts as an agent only:
|
| - reset env
|
| - choose action from observation
|
| - step env
|
| - log trajectory
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import argparse
|
| import json
|
| import math
|
| import os
|
| import random
|
| from pathlib import Path
|
| from datetime import datetime, timezone
|
| from typing import Union, TYPE_CHECKING, Optional, cast
|
|
|
| from app.environment.core import EmergencyEnv
|
| from app.models.action import Action
|
|
|
| if TYPE_CHECKING:
|
| from openai import OpenAI as OpenAIClient
|
| else:
|
| OpenAIClient = None
|
|
|
| try:
|
| from openai import OpenAI
|
| except Exception:
|
| OpenAI = None
|
|
|
| TASK_ORDER = ["acde_easy", "acde_medium", "acde_hard"]
|
| LEVEL_TO_TASK = {
|
| "low": "acde_easy",
|
| "medium": "acde_medium",
|
| "high": "acde_hard",
|
| }
|
| RANDOM_LEVELS = ("medium", "high")
|
| RANDOM_LEVEL_WEIGHTS = (0.25, 0.75)
|
| BASE_SPEED_KMH = 60.0
|
| TRAFFIC_FACTOR = {"low": 1.0, "medium": 0.6, "high": 0.3}
|
| LEARNING_ARCHIVE_PATH = Path(__file__).resolve().parent / "data" / "learning_archive.json"
|
| LEARNING_ARCHIVE_VERSION = 2
|
| DEFAULT_API_BASE_URL = "https://api-inference.huggingface.co/v1"
|
| DEFAULT_MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct"
|
| REQUIRED_ENV_VARS = ("HF_TOKEN",)
|
| STRICT_SCORE_MIN = 0.001
|
| STRICT_SCORE_MAX = 0.999
|
|
|
|
|
| def clamp_strict_score(value: float) -> float:
|
| """Clamp score-like outputs to the strict open interval (0, 1)."""
|
| return max(STRICT_SCORE_MIN, min(STRICT_SCORE_MAX, float(value)))
|
|
|
|
|
| def parse_args() -> argparse.Namespace:
|
| parser = argparse.ArgumentParser(description="EmergencyEnv agent runner")
|
| parser.add_argument("--mode", choices=["single", "full"], default="full")
|
| parser.add_argument("--task", choices=TASK_ORDER, default=None)
|
| parser.add_argument("--level", choices=["low", "medium", "high"], default=None)
|
| parser.add_argument("--seed", type=int, default=None)
|
| parser.add_argument("--episodes", type=int, default=1)
|
| parser.add_argument("--train-episodes", type=int, default=0)
|
| parser.add_argument("--train-same-seed", action="store_true")
|
| parser.add_argument(
|
| "--memory-file",
|
| default=str(Path(__file__).resolve().parent / "data" / "learning_memory.json"),
|
| )
|
| return parser.parse_args()
|
|
|
|
|
| def emit_structured(tag: str, payload: dict) -> None:
|
| print(f"[{tag}] " + json.dumps(payload, ensure_ascii=True, separators=(",", ":")))
|
|
|
|
|
| def runtime_llm_config() -> dict[str, str]:
|
| return {
|
| "API_BASE_URL": os.getenv("API_BASE_URL", DEFAULT_API_BASE_URL).strip(),
|
| "MODEL_NAME": os.getenv("MODEL_NAME", DEFAULT_MODEL_NAME).strip(),
|
| "HF_TOKEN": os.getenv("HF_TOKEN", "").strip(),
|
| }
|
|
|
|
|
| def require_llm_config() -> tuple[OpenAIClient, str]:
|
| config = runtime_llm_config()
|
| missing = [name for name, value in config.items() if not value]
|
| if missing:
|
| raise SystemExit(
|
| "Missing required environment variables: "
|
| + ", ".join(missing)
|
| + ". Set HF_TOKEN before running inference.py"
|
| )
|
| if OpenAI is None:
|
| raise SystemExit("openai package is required for inference.py LLM rationale generation.")
|
|
|
| client = OpenAI(base_url=config["API_BASE_URL"], api_key=config["HF_TOKEN"], timeout=8.0)
|
| return client, config["MODEL_NAME"]
|
|
|
|
|
| def llm_rationale(
|
| client: Union[OpenAIClient, None],
|
| model_name: str,
|
| observation: dict,
|
| chosen: dict,
|
| strategy: str,
|
| ) -> str:
|
| fallback = (
|
| f"Selected {chosen['hospital_id']} by {strategy}; "
|
| f"score={chosen['policy_score']:.3f}, traffic={chosen['traffic']}, icu={chosen['icu']}"
|
| )
|
| if client is None:
|
| return fallback
|
| try:
|
| prompt = (
|
| "You are an emergency routing agent. Return one short sentence rationale "
|
| "for the selected hospital. Keep it under 25 words.\n"
|
| f"task={observation.get('task_id')} difficulty={observation.get('scenario_difficulty')} "
|
| f"step={observation.get('step')} patient={observation.get('patient_condition')} "
|
| f"required={observation.get('required_specialization')} "
|
| f"selected={chosen['hospital_id']} score={chosen['policy_score']:.3f} "
|
| f"distance={chosen['distance_km']:.1f}km traffic={chosen['traffic']} icu={chosen['icu']} "
|
| f"strategy={strategy}"
|
| )
|
| completion = client.chat.completions.create(
|
| model=model_name,
|
| messages=[
|
| {"role": "system", "content": "Generate concise emergency triage rationale."},
|
| {"role": "user", "content": prompt},
|
| ],
|
| temperature=0.0,
|
| max_tokens=60,
|
| )
|
| text = (completion.choices[0].message.content or "").strip()
|
| if not text:
|
| return fallback
|
| return " ".join(text.split())[:180]
|
| except Exception:
|
| return fallback
|
|
|
|
|
| def normalize_seed(raw_value: int | str) -> int:
|
| """Normalize arbitrary numeric/text input into a deterministic positive seed."""
|
| if isinstance(raw_value, int):
|
| value = raw_value
|
| else:
|
| text = str(raw_value).strip()
|
| try:
|
| value = int(text)
|
| except ValueError:
|
|
|
| value = sum((idx + 1) * ord(ch) for idx, ch in enumerate(text))
|
|
|
| normalized = abs(value) % 1_000_000_000
|
| return normalized if normalized != 0 else 202601
|
|
|
|
|
| def ask_seed_if_missing(seed: int | None) -> int:
|
| if seed is not None:
|
| return normalize_seed(seed)
|
|
|
| return normalize_seed(random.SystemRandom().randint(1, 999_999_999))
|
|
|
|
|
| def ask_level_if_missing(level: str | None) -> str:
|
| if level in LEVEL_TO_TASK:
|
| return level
|
|
|
| return random.choices(
|
| RANDOM_LEVELS,
|
| weights=RANDOM_LEVEL_WEIGHTS,
|
| k=1,
|
| )[0]
|
|
|
|
|
| def append_trajectory_log(entry: dict) -> None:
|
| path = Path(__file__).resolve().parent / "data" / "trajectory_history.jsonl"
|
| path.parent.mkdir(parents=True, exist_ok=True)
|
| with path.open("a", encoding="utf-8") as fp:
|
| fp.write(json.dumps(entry, ensure_ascii=True) + "\n")
|
|
|
|
|
| def load_learning_archive() -> dict:
|
| LEARNING_ARCHIVE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| if not LEARNING_ARCHIVE_PATH.exists():
|
| return {"version": LEARNING_ARCHIVE_VERSION, "profiles": {}, "episodes": []}
|
|
|
| try:
|
| payload_text = LEARNING_ARCHIVE_PATH.read_text(encoding="utf-8-sig").strip()
|
| payload = json.loads(payload_text) if payload_text else {}
|
| except json.JSONDecodeError:
|
| return {"version": LEARNING_ARCHIVE_VERSION, "profiles": {}, "episodes": []}
|
|
|
| if not isinstance(payload, dict):
|
| return {"version": LEARNING_ARCHIVE_VERSION, "profiles": {}, "episodes": []}
|
|
|
| if payload.get("version") != LEARNING_ARCHIVE_VERSION:
|
| return {
|
| "version": LEARNING_ARCHIVE_VERSION,
|
| "profiles": {},
|
| "episodes": payload.get("episodes", [])[-500:] if isinstance(payload.get("episodes", []), list) else [],
|
| }
|
|
|
| payload.setdefault("version", LEARNING_ARCHIVE_VERSION)
|
| payload.setdefault("profiles", {})
|
| payload.setdefault("episodes", [])
|
| return payload
|
|
|
|
|
| def save_learning_archive(archive: dict) -> None:
|
| LEARNING_ARCHIVE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| LEARNING_ARCHIVE_PATH.write_text(json.dumps(archive, indent=2, ensure_ascii=True), encoding="utf-8")
|
|
|
|
|
| def profile_key(seed: int, task_id: str) -> str:
|
| return f"{seed}|{task_id}"
|
|
|
|
|
| def _merge_step_stats(primary: dict, secondary: dict) -> dict:
|
| merged: dict = {}
|
| for step_key in set(primary.keys()) | set(secondary.keys()):
|
| merged[step_key] = {}
|
| step_primary = primary.get(step_key, {})
|
| step_secondary = secondary.get(step_key, {})
|
| for hospital_id in set(step_primary.keys()) | set(step_secondary.keys()):
|
| a = step_primary.get(hospital_id, {})
|
| b = step_secondary.get(hospital_id, {})
|
| count = int(a.get("count", 0)) + int(b.get("count", 0))
|
| accepted = int(a.get("accepted", 0)) + int(b.get("accepted", 0))
|
| partial = int(a.get("partial", 0)) + int(b.get("partial", 0))
|
| rejected = int(a.get("rejected", 0)) + int(b.get("rejected", 0))
|
| total_reward = float(a.get("total_reward", 0.0)) + float(b.get("total_reward", 0.0))
|
| merged[step_key][hospital_id] = {
|
| "count": count,
|
| "success": int(a.get("success", 0)) + int(b.get("success", 0)),
|
| "accepted": accepted,
|
| "partial": partial,
|
| "rejected": rejected,
|
| "total_reward": total_reward,
|
| "avg_reward": (total_reward / max(1, count)),
|
| "success_rate": (accepted / max(1, count)),
|
| "last_status": a.get("last_status") or b.get("last_status"),
|
| "last_reason": a.get("last_reason") or b.get("last_reason"),
|
| }
|
| return merged
|
|
|
|
|
| def build_learning_profile(
|
| archive: dict,
|
| seed: int,
|
| task_id: str,
|
| required_specialization: str | None = None,
|
| ) -> dict | None:
|
| profiles = archive.get("profiles", {})
|
| key = profile_key(seed, task_id)
|
| exact = profiles.get(key)
|
| if not exact:
|
| return None
|
|
|
|
|
| return {
|
| "attempts": int(exact.get("attempts", 0)),
|
| "best_score": float(exact.get("best_score", 0.0)),
|
| "best_actions": list(exact.get("best_actions", [])),
|
| "step_stats": exact.get("step_stats", {}),
|
| "best_scenario_name": exact.get("best_scenario_name"),
|
| "last_scenario_name": exact.get("last_scenario_name"),
|
| "source": "exact-only",
|
| }
|
|
|
|
|
| def _difficulty_policy_params(difficulty: str) -> tuple[float, float]:
|
| if difficulty == "easy":
|
| return 0.07, 0.18
|
| if difficulty == "medium":
|
| return 0.16, 0.32
|
| return 0.26, 0.44
|
|
|
|
|
| def _sample_softmax(candidates: list[dict], key: str, temperature: float, rng: random.Random) -> dict:
|
| logits = [item[key] / max(temperature, 1e-6) for item in candidates]
|
| max_logit = max(logits)
|
| exps = [math.exp(v - max_logit) for v in logits]
|
| total = sum(exps)
|
| probs = [e / total for e in exps]
|
|
|
| roll = rng.random()
|
| cdf = 0.0
|
| for item, prob in zip(candidates, probs):
|
| cdf += prob
|
| if roll <= cdf:
|
| return item
|
| return candidates[-1]
|
|
|
|
|
| def memory_score_for_hospital(
|
| hospital_id: str,
|
| memory_snapshot: dict,
|
| learning_profile: dict | None = None,
|
| step_number: int | None = None,
|
| ) -> float:
|
| entry = memory_snapshot.get(hospital_id)
|
| if not entry:
|
| return 0.5
|
|
|
| success = int(entry.get("accepted", entry.get("success", 0)))
|
| fail = int(entry.get("rejected", entry.get("fail", 0)))
|
| avg = float(entry.get("avg", 0.0))
|
| total = success + fail
|
| if total <= 0:
|
| return 0.5
|
|
|
| success_rate = success / total
|
|
|
| value = (0.6 * success_rate) + (0.4 * avg)
|
| recent_failed = False
|
|
|
| if learning_profile and step_number is not None:
|
| step_stats = learning_profile.get("step_stats", {}).get(str(step_number), {})
|
| hospital_stats = step_stats.get(hospital_id)
|
| if hospital_stats:
|
| step_avg = float(hospital_stats.get("avg_reward", 0.0))
|
| step_success = float(hospital_stats.get("success_rate", 0.0))
|
| step_count = int(hospital_stats.get("count", 0))
|
| value += min(0.20, (step_avg * 0.10) + (step_success * 0.08) + min(step_count, 5) * 0.01)
|
| recent_failed = str(hospital_stats.get("last_status", "")).upper() == "REJECTED"
|
|
|
| if recent_failed:
|
| value -= 0.3
|
|
|
| return max(0.0, min(1.0, value))
|
|
|
|
|
| def score_hospitals(observation: dict, learning_profile: dict | None = None) -> list[dict]:
|
| failed = set(observation.get("failed_hospitals", []))
|
| recent_failed = set(observation.get("recent_failed_hospitals", []))
|
| visited = set(observation.get("visited_hospitals", []))
|
| memory_snapshot = observation.get("memory_snapshot", {})
|
| previous_action = observation.get("previous_action")
|
| last_arrival = observation.get("last_arrival_outcome") or {}
|
| last_status = str(last_arrival.get("status", "")).lower()
|
|
|
| scored: list[dict] = []
|
| initial_limit = float(observation.get("initial_critical_time_limit_minutes", observation["critical_time_limit_minutes"]))
|
| remaining_time = float(observation.get("remaining_time_minutes", observation["critical_time_limit_minutes"]))
|
| urgency = 1.0 - min(1.0, max(0.0, remaining_time / max(initial_limit, 1e-6)))
|
|
|
| patient_condition = observation.get("patient_condition", "").lower()
|
| critical_patient = patient_condition in {"critical", "unstable"}
|
| required_specialization = str(observation.get("required_specialization", ""))
|
| scenario_name = str(observation.get("scenario_name", ""))
|
| step_number = int(observation.get("step", 1))
|
| difficulty = str(observation.get("scenario_difficulty", "medium"))
|
| attempts = int(learning_profile.get("attempts", 0)) if learning_profile else 0
|
| preferred_route = []
|
| if learning_profile:
|
| preferred_route = list(learning_profile.get("best_actions", []))
|
|
|
| for hospital in observation.get("hospitals", []):
|
| traffic_factor = TRAFFIC_FACTOR[hospital["traffic"]]
|
| speed_kmh = BASE_SPEED_KMH * traffic_factor
|
| travel_time = (hospital["distance_km"] / max(speed_kmh, 1e-6)) * 60.0
|
|
|
| distance_score = max(0.0, min(1.0, 1.0 - hospital["distance_km"] / 20.0))
|
| icu_score = 1.0 if hospital["icu"] == "available" else 0.55
|
| mem_score = memory_score_for_hospital(
|
| hospital["hospital_id"],
|
| memory_snapshot,
|
| learning_profile=learning_profile,
|
| step_number=step_number,
|
| )
|
|
|
| memory_scenario = ""
|
| if learning_profile:
|
| memory_scenario = str(
|
| learning_profile.get("best_scenario_name")
|
| or learning_profile.get("last_scenario_name")
|
| or ""
|
| )
|
| if memory_scenario and scenario_name and memory_scenario != scenario_name:
|
| mem_score *= 0.5
|
|
|
| spec_match = (
|
| hospital["specialization"] == observation["required_specialization"]
|
| or hospital["specialization"] == "general"
|
| or observation["required_specialization"] == "general"
|
| )
|
| exact_spec_match = hospital["specialization"] == observation["required_specialization"]
|
| general_fallback = (
|
| hospital["specialization"] == "general"
|
| and observation["required_specialization"] != "general"
|
| )
|
|
|
| rejected_penalty = 0.40 if hospital["hospital_id"] in failed else 0.0
|
| revisit_penalty = 0.14 if hospital["hospital_id"] in visited else 0.0
|
| partial_repeat_penalty = (
|
| 0.32
|
| if last_status == "partial" and hospital["hospital_id"] == previous_action
|
| else 0.0
|
| )
|
| critical_unknown_penalty = 0.17 if critical_patient and hospital["icu"] == "unknown" else 0.03
|
| traffic_penalty = 0.10 if hospital["traffic"] == "high" else 0.04 if hospital["traffic"] == "medium" else 0.0
|
| if critical_patient and general_fallback:
|
| spec_penalty = {"easy": 0.08, "medium": 0.16, "hard": 0.26}.get(difficulty, 0.16)
|
| if attempts >= 5:
|
| spec_penalty += 0.06
|
| else:
|
| spec_penalty = 0.0
|
| spec_bonus = 0.16 if exact_spec_match else (0.08 if spec_match else 0.0)
|
| urgency_boost = urgency * (0.18 + max(0.0, 0.25 - travel_time / 100.0))
|
| step_route_bonus = 0.0
|
| if step_number - 1 < len(preferred_route) and preferred_route[step_number - 1] == hospital["hospital_id"]:
|
| step_route_bonus = 0.16
|
|
|
| score = (
|
| (icu_score * 0.30)
|
| + (distance_score * 0.18)
|
| + (traffic_factor * 0.14)
|
| + (mem_score * 0.24)
|
| + spec_bonus
|
| + urgency_boost
|
| + step_route_bonus
|
| - rejected_penalty
|
| - revisit_penalty
|
| - partial_repeat_penalty
|
| - spec_penalty
|
| - critical_unknown_penalty
|
| - traffic_penalty
|
| )
|
|
|
| if hospital["hospital_id"] == previous_action and last_status == "rejected":
|
| score *= 0.01
|
|
|
| if hospital["hospital_id"] in recent_failed:
|
| score *= 0.2
|
|
|
| if hospital["specialization"] != required_specialization:
|
| if patient_condition == "critical":
|
| score *= 0.15
|
| else:
|
| score *= 0.4
|
| elif patient_condition == "critical":
|
| score *= 1.5
|
|
|
|
|
| if hospital["specialization"] != required_specialization:
|
| score -= 0.6
|
| if critical_patient and hospital["icu"] == "unknown":
|
| score -= 0.5
|
| if critical_patient and hospital["traffic"] == "high":
|
| score -= 0.3
|
|
|
|
|
| risk_factor = 1.0
|
| if hospital["icu"] == "unknown":
|
| risk_factor *= 0.6
|
| if not spec_match:
|
| risk_factor *= 0.5
|
| if critical_patient and hospital["traffic"] == "high":
|
| risk_factor *= 0.7
|
| score *= risk_factor
|
|
|
|
|
| memory_weight = 0.15
|
| current_score_weight = 0.85
|
| if step_number == 1:
|
| memory_weight = 0.1
|
| current_score_weight = 0.9
|
| base_current_score = score
|
| confidence_score = max(0.0, min(1.0, base_current_score))
|
| effective_memory_score = mem_score
|
| in_best_route = hospital["hospital_id"] in preferred_route
|
| if in_best_route and confidence_score < 0.6:
|
| effective_memory_score = 0.0
|
| if confidence_score < 0.2:
|
| effective_memory_score = 0.0
|
|
|
| score = (current_score_weight * base_current_score) + (memory_weight * effective_memory_score)
|
|
|
| scored.append(
|
| {
|
| "hospital_id": hospital["hospital_id"],
|
| "icu": hospital["icu"],
|
| "distance_km": hospital["distance_km"],
|
| "traffic": hospital["traffic"],
|
| "specialization": hospital["specialization"],
|
| "travel_time": travel_time,
|
| "memory_score": mem_score,
|
| "policy_score": max(0.0, min(1.0, score)),
|
| "specialization_match": spec_match,
|
| "tie_break_score": (
|
| (distance_score * 0.35)
|
| + (traffic_factor * 0.35)
|
| + (icu_score * 0.20)
|
| + (0.10 if spec_match else 0.0)
|
| ),
|
| }
|
| )
|
|
|
| scored.sort(key=lambda item: item["policy_score"], reverse=True)
|
| if scored:
|
| min_score = min(item["policy_score"] for item in scored)
|
| max_score = max(item["policy_score"] for item in scored)
|
| spread = max_score - min_score
|
| if spread > 1e-9:
|
| for item in scored:
|
| normalized = (item["policy_score"] - min_score) / (spread + 1e-6)
|
| if normalized < 0.2:
|
| jitter_seed = (
|
| int(observation.get("seed", 0))
|
| + (step_number * 131)
|
| + sum(ord(ch) for ch in item["hospital_id"])
|
| )
|
| jitter_rng = random.Random(jitter_seed)
|
| normalized *= jitter_rng.uniform(0.3, 0.7)
|
| item["policy_score"] = max(0.0, min(1.0, normalized))
|
| elif max_score > 0:
|
| for item in scored:
|
| normalized = item["policy_score"] / max_score
|
| if normalized < 0.2:
|
| jitter_seed = (
|
| int(observation.get("seed", 0))
|
| + (step_number * 131)
|
| + sum(ord(ch) for ch in item["hospital_id"])
|
| )
|
| jitter_rng = random.Random(jitter_seed)
|
| normalized *= jitter_rng.uniform(0.3, 0.7)
|
| item["policy_score"] = max(0.0, min(1.0, normalized))
|
| else:
|
| tie_min = min(item.get("tie_break_score", 0.0) for item in scored)
|
| tie_max = max(item.get("tie_break_score", 0.0) for item in scored)
|
| tie_spread = tie_max - tie_min
|
| if tie_spread > 1e-9:
|
| for item in scored:
|
| normalized = (item.get("tie_break_score", 0.0) - tie_min) / (tie_spread + 1e-6)
|
| if normalized < 0.2:
|
| jitter_seed = (
|
| int(observation.get("seed", 0))
|
| + (step_number * 131)
|
| + sum(ord(ch) for ch in item["hospital_id"])
|
| )
|
| jitter_rng = random.Random(jitter_seed)
|
| normalized *= jitter_rng.uniform(0.3, 0.7)
|
| item["policy_score"] = max(0.0, min(1.0, normalized))
|
| else:
|
| for item in scored:
|
| item["policy_score"] = 0.0
|
|
|
|
|
| for item in scored:
|
| if item["policy_score"] <= 0.0:
|
| jitter_seed = (
|
| int(observation.get("seed", 0))
|
| + (step_number * 173)
|
| + sum(ord(ch) for ch in item["hospital_id"])
|
| )
|
| jitter_rng = random.Random(jitter_seed)
|
| if critical_patient and required_specialization != "general":
|
| if item.get("specialization") == required_specialization:
|
| item["policy_score"] = jitter_rng.uniform(0.08, 0.18)
|
| else:
|
| item["policy_score"] = jitter_rng.uniform(0.001, 0.01)
|
| else:
|
| item["policy_score"] = jitter_rng.uniform(0.05, 0.15)
|
|
|
| total_score = sum(item["policy_score"] for item in scored)
|
| if total_score > 0:
|
| for item in scored:
|
| item["policy_score"] = item["policy_score"] / (total_score + 1e-6)
|
| else:
|
| uniform = 1.0 / len(scored)
|
| for item in scored:
|
| item["policy_score"] = uniform
|
|
|
|
|
|
|
| if critical_patient and required_specialization != "general":
|
| for item in scored:
|
| if item.get("specialization") == required_specialization:
|
| item["policy_score"] *= 1.5
|
| else:
|
| item["policy_score"] *= 0.15
|
|
|
| boosted_total = sum(item["policy_score"] for item in scored)
|
| if boosted_total > 0:
|
| for item in scored:
|
| item["policy_score"] = item["policy_score"] / boosted_total
|
|
|
| for item in scored:
|
| raw_score = float(item["policy_score"])
|
| normalized_score = raw_score / (1.0 + abs(raw_score))
|
|
|
| if normalized_score < 0.01:
|
| jitter_seed = (
|
| int(observation.get("seed", 0))
|
| + (step_number * 211)
|
| + sum(ord(ch) for ch in item["hospital_id"])
|
| )
|
| jitter_rng = random.Random(jitter_seed)
|
| normalized_score = jitter_rng.uniform(0.01, 0.03)
|
| item["policy_score"] = normalized_score
|
|
|
| scored.sort(key=lambda item: item["policy_score"], reverse=True)
|
|
|
| for item in scored:
|
| item.pop("tie_break_score", None)
|
| return scored
|
|
|
|
|
| def choose_hospital(
|
| scored: list[dict],
|
| observation: dict,
|
| rng: random.Random,
|
| learning_profile: dict | None = None,
|
| ) -> tuple[dict, str]:
|
| difficulty = observation.get("scenario_difficulty", "medium")
|
| epsilon, temperature = _difficulty_policy_params(difficulty)
|
|
|
| failed = set(observation.get("failed_hospitals", []))
|
| recent_failed = set(observation.get("recent_failed_hospitals", []))
|
| visited = set(observation.get("visited_hospitals", []))
|
| previous_action = observation.get("previous_action")
|
| selected_hospital_id = observation.get("selected_hospital_id")
|
| visited_sequence = observation.get("visited_hospitals", []) or []
|
| recent_hospital = previous_action or selected_hospital_id or (visited_sequence[-1] if visited_sequence else None)
|
| last_arrival = observation.get("last_arrival_outcome") or {}
|
| last_status = str(last_arrival.get("status", "")).lower()
|
| last_reason = str(last_arrival.get("reason", "")).lower()
|
| is_rerouting_phase = str(observation.get("ambulance_status", "")).lower() == "rerouting"
|
|
|
|
|
| candidates = [
|
| item
|
| for item in scored
|
| if item["hospital_id"] not in recent_failed and item["hospital_id"] not in visited
|
| ]
|
| if not candidates:
|
| candidates = [item for item in scored if item["hospital_id"] not in recent_failed]
|
| if not candidates:
|
|
|
| candidates = list(scored)
|
| if (last_status == "rejected" or is_rerouting_phase) and recent_hospital:
|
| redirected = [item for item in candidates if item["hospital_id"] != recent_hospital]
|
| if redirected:
|
| candidates = redirected
|
|
|
| step_number = int(observation.get("step", 1))
|
| attempts = int(learning_profile.get("attempts", 0)) if learning_profile else 0
|
| required_specialization = str(observation.get("required_specialization", ""))
|
| critical_patient = observation.get("patient_condition", "").lower() in {"critical", "unstable"}
|
|
|
|
|
| if (last_status == "rejected" or is_rerouting_phase) and recent_hospital:
|
| immediate_retry_block = [item for item in candidates if item["hospital_id"] != recent_hospital]
|
| if immediate_retry_block:
|
| candidates = immediate_retry_block
|
| elif len(candidates) == 1 and candidates[0]["hospital_id"] == recent_hospital:
|
| fallback_any = [item for item in scored if item["hospital_id"] != recent_hospital]
|
| if fallback_any:
|
| candidates = fallback_any
|
|
|
|
|
| if critical_patient and required_specialization != "general":
|
| exact_spec_candidates = [
|
| item for item in candidates if item["specialization"] == required_specialization
|
| ]
|
| if exact_spec_candidates:
|
| candidates = exact_spec_candidates
|
|
|
| if step_number == 1:
|
| policy_mode = "safe"
|
| elif last_status == "rejected":
|
| policy_mode = "risk-aware"
|
| else:
|
| policy_mode = "balanced"
|
|
|
| safe_weight = 1.0
|
| if policy_mode == "safe":
|
| safe_weight *= 0.8
|
| epsilon *= 0.6
|
| temperature *= 0.8
|
| elif policy_mode == "risk-aware":
|
| epsilon *= 1.1
|
| temperature *= 0.9
|
|
|
|
|
| if "wrong hospital specialization" in last_reason:
|
| strict_spec = [
|
| item
|
| for item in candidates
|
| if item["specialization"] == observation.get("required_specialization")
|
| ]
|
| if strict_spec:
|
| candidates = strict_spec
|
| if "icu unavailable" in last_reason:
|
| icu_known = [item for item in candidates if item["icu"] == "available"]
|
| if icu_known:
|
| candidates = icu_known
|
| if "specialist" in last_reason:
|
| strict_spec = [
|
| item
|
| for item in candidates
|
| if item["specialization"] == observation.get("required_specialization")
|
| ]
|
| if strict_spec:
|
| candidates = strict_spec
|
| if "overloaded" in last_reason:
|
| non_high_traffic = [item for item in candidates if item["traffic"] != "high"]
|
| if non_high_traffic:
|
| candidates = non_high_traffic
|
| if "delay" in last_reason:
|
| candidates = sorted(candidates, key=lambda item: item["distance_km"])
|
|
|
| def learned_utility(item: dict) -> float:
|
| base = float(item.get("policy_score", 0.0))
|
| if not learning_profile:
|
| return base
|
| step_stats = learning_profile.get("step_stats", {}).get(str(step_number), {})
|
| stats = step_stats.get(item["hospital_id"], {})
|
| count = int(stats.get("count", 0))
|
| if count <= 0:
|
| exploration_bonus = 0.22 * math.sqrt(max(1.0, math.log(attempts + 2.0)))
|
| return base + exploration_bonus
|
| avg_reward = float(stats.get("avg_reward", 0.0))
|
| success_rate = float(stats.get("success_rate", 0.0))
|
| rejected = int(stats.get("rejected", 0))
|
| rejection_rate = rejected / max(1, count)
|
| exploration_bonus = 0.18 * math.sqrt(max(0.0, math.log(attempts + 2.0) / (count + 1.0)))
|
|
|
| historical_weight = 0.35
|
| historical_weight *= 0.6
|
| historical_bonus = (avg_reward * historical_weight) + (success_rate * 0.30) - (rejection_rate * 0.22)
|
| if item["hospital_id"] in recent_failed:
|
| historical_bonus = 0.0
|
| return base + historical_bonus + exploration_bonus
|
|
|
| def pick_improvement_candidate(route_choice_id: str | None) -> dict | None:
|
| if not candidates:
|
| return None
|
| ranked = sorted(candidates, key=learned_utility, reverse=True)
|
| if route_choice_id is None:
|
| return ranked[0]
|
| for item in ranked:
|
| if item["hospital_id"] != route_choice_id:
|
| return item
|
| return ranked[0]
|
|
|
| def enforce_score_guard(chosen: dict, strategy: str) -> tuple[dict, str]:
|
|
|
| if last_status == "rejected" and previous_action and chosen.get("hospital_id") == previous_action:
|
| alternatives = [item for item in scored if item["hospital_id"] != previous_action]
|
| if alternatives:
|
| rerouted = max(alternatives, key=lambda item: float(item.get("policy_score", 0.0)))
|
| return rerouted, strategy + " + immediate-retry block"
|
|
|
|
|
|
|
| globally_eligible = [
|
| item
|
| for item in scored
|
| if item["hospital_id"] not in recent_failed
|
| and not (
|
| (last_status == "rejected" or is_rerouting_phase)
|
| and recent_hospital
|
| and item["hospital_id"] == recent_hospital
|
| )
|
| ]
|
| if not globally_eligible:
|
| globally_eligible = list(scored)
|
|
|
| if globally_eligible:
|
| best_global = max(globally_eligible, key=lambda item: float(item.get("policy_score", 0.0)))
|
| chosen_score = float(chosen.get("policy_score", 0.0))
|
| best_global_score = float(best_global.get("policy_score", 0.0))
|
|
|
| if (last_status == "rejected" or is_rerouting_phase) and recent_hospital:
|
| if chosen.get("hospital_id") == recent_hospital:
|
| alternatives = [
|
| item
|
| for item in scored
|
| if item["hospital_id"] != recent_hospital and item["hospital_id"] not in recent_failed
|
| ]
|
| if not alternatives:
|
| alternatives = [item for item in scored if item["hospital_id"] != recent_hospital]
|
| if alternatives:
|
| rerouted = max(alternatives, key=lambda item: float(item.get("policy_score", 0.0)))
|
| return rerouted, strategy + " + cooldown reroute"
|
|
|
| if chosen_score < (best_global_score * 0.6):
|
| return best_global, strategy + " + anti-stupidity guard"
|
| if (best_global_score - chosen_score) > 0.25 and rng.random() < 0.75:
|
| return best_global, strategy + " + score-gap guard"
|
|
|
| return chosen, strategy
|
|
|
|
|
| if learning_profile:
|
| step_stats = learning_profile.get("step_stats", {}).get(str(step_number), {})
|
| guard_blocked: set[str] = set()
|
| for hospital_id, stats in step_stats.items():
|
| count = int(stats.get("count", 0))
|
| success_rate = float(stats.get("success_rate", 0.0))
|
| rejected = int(stats.get("rejected", 0))
|
| if count >= 2 and success_rate <= 0.0 and rejected >= 2:
|
| guard_blocked.add(hospital_id)
|
|
|
| guarded_candidates = [item for item in candidates if item["hospital_id"] not in guard_blocked]
|
| if guarded_candidates:
|
| candidates = guarded_candidates
|
|
|
|
|
| if attempts >= 3:
|
| epsilon *= 0.35
|
| temperature *= 0.70
|
|
|
|
|
|
|
| if learning_profile and policy_mode != "risk-aware":
|
| best_route = list(learning_profile.get("best_actions", []))
|
| if step_number - 1 < len(best_route):
|
| baseline_id = best_route[step_number - 1]
|
| ranked = sorted(candidates, key=learned_utility, reverse=True)
|
| baseline_candidate = next((item for item in ranked if item["hospital_id"] == baseline_id), None)
|
| alternatives = [item for item in ranked if item["hospital_id"] != baseline_id]
|
| top_candidate = ranked[0] if ranked else None
|
|
|
| if (
|
| step_number == 1
|
| and baseline_candidate is not None
|
| and top_candidate is not None
|
| and float(baseline_candidate.get("policy_score", 0.0)) < float(top_candidate.get("policy_score", 0.0))
|
| ):
|
| baseline_candidate = None
|
|
|
| alternatives = alternatives[: min(3, len(alternatives))]
|
|
|
| if attempts >= 1:
|
|
|
|
|
| combo_index = max(0, attempts - 1)
|
| digit = (combo_index // (3 ** max(0, step_number - 1))) % 3
|
|
|
| if digit == 0 and baseline_candidate is not None:
|
| return enforce_score_guard(baseline_candidate, "best-route retain")
|
|
|
| alt_rank = digit - 1
|
| if alt_rank >= 0 and alt_rank < len(alternatives):
|
| return enforce_score_guard(alternatives[alt_rank], f"combination search step-{step_number} alt-{alt_rank + 1}")
|
|
|
| if baseline_candidate is not None:
|
| return enforce_score_guard(baseline_candidate, "best-route retain")
|
|
|
| if attempts >= 6:
|
| ranked = sorted(candidates, key=learned_utility, reverse=True)
|
| top_pool = ranked[: min(3, len(ranked))]
|
| return enforce_score_guard(_sample_softmax(top_pool, "policy_score", max(0.08, temperature * 0.85), rng), "learned utility exploit")
|
|
|
| if learning_profile and policy_mode == "safe":
|
| preferred_route = list(learning_profile.get("best_actions", []))
|
| if step_number - 1 < len(preferred_route):
|
| preferred_hospital = preferred_route[step_number - 1]
|
| preferred_candidate = next((item for item in candidates if item["hospital_id"] == preferred_hospital), None)
|
| if preferred_candidate is not None:
|
| profile_score = float(learning_profile.get("best_score", 0.0))
|
| if (profile_score * safe_weight) >= 0.85 or len(candidates) == 1:
|
| return enforce_score_guard(preferred_candidate, "learned best path")
|
|
|
|
|
| if last_status == "partial" and previous_action:
|
| redirected = [item for item in candidates if item["hospital_id"] != previous_action]
|
| if redirected:
|
| candidates = redirected
|
|
|
| epsilon = min(epsilon, 0.04)
|
| temperature = min(temperature, 0.24)
|
|
|
| critical = observation.get("patient_condition", "").lower() in {"critical", "unstable"}
|
| strategy = f"{policy_mode} policy"
|
|
|
| if critical and policy_mode in {"safe", "balanced"}:
|
| confirmed = [item for item in candidates if item["icu"] == "available"]
|
| if confirmed:
|
| candidates = confirmed
|
| strategy = f"{policy_mode} policy + critical triage"
|
|
|
| if len(candidates) > 1 and rng.random() < 0.15:
|
| ranked = sorted(candidates, key=learned_utility, reverse=True)
|
| top_k = ranked[: min(3, len(ranked))]
|
| return enforce_score_guard(rng.choice(top_k), strategy + " + guided-exploration")
|
|
|
| if len(candidates) > 1:
|
|
|
| ranked = sorted(candidates, key=learned_utility, reverse=True)
|
| chosen = _sample_softmax(ranked, "policy_score", temperature, rng)
|
| return enforce_score_guard(chosen, strategy)
|
|
|
| return enforce_score_guard(candidates[0], strategy)
|
|
|
|
|
| def print_options(scored: list[dict]) -> None:
|
| print(f"Hospital options ({len(scored)} total):")
|
| for idx, item in enumerate(scored, start=1):
|
| print(
|
| f" [{idx}] {item['hospital_id']} | {item['distance_km']:.1f} km | ICU {item['icu']} | "
|
| f"traffic {item['traffic']} | specialty {item['specialization']} | score {item['policy_score']:.3f}"
|
| )
|
|
|
|
|
| def run_episode(
|
| env: EmergencyEnv,
|
| task_id: str,
|
| seed: int,
|
| archive: dict | None = None,
|
| llm_client: object | None = None,
|
| model_name: str | None = None,
|
| ) -> dict:
|
| observation_model = env.reset(seed=seed, task_id=task_id)
|
| observation = observation_model.model_dump()
|
| learning_profile = None
|
| if archive is not None:
|
| learning_profile = build_learning_profile(
|
| archive,
|
| seed,
|
| task_id,
|
| required_specialization=str(observation.get("required_specialization", "")) or None,
|
| )
|
|
|
| print("\n" + "=" * 72)
|
| print(f"Scenario: {observation['scenario_name']}")
|
| print(f"Task: {task_id} | Difficulty: {observation['scenario_difficulty']} | Seed: {seed}")
|
| print(f"Patient condition: {observation['patient_condition']}")
|
| print(f"Required specialization: {observation['required_specialization']}")
|
| print("Objective: admit patient successfully (no fixed deadline window)")
|
| print("=" * 72)
|
| print(f"[START] task={task_id} env=acde-openenv model={model_name or 'none'}", flush=True)
|
|
|
| if learning_profile:
|
| print(
|
| f"Learning memory: best historical score {float(learning_profile.get('best_score', 0.0)):.3f} "
|
| f"across {int(learning_profile.get('attempts', 0))} attempts"
|
| )
|
| if learning_profile.get("best_actions"):
|
| print(f"Best known route: {' -> '.join(learning_profile['best_actions'])}")
|
|
|
| total_reward = 0.0
|
| all_rewards = []
|
| steps = 0
|
| done = False
|
| previous_policy_hospital_id: str | None = None
|
| previous_policy_outcome: str | None = None
|
| attempt_index = int(learning_profile.get("attempts", 0)) if learning_profile else 0
|
|
|
| rng = random.Random(seed + (attempt_index * 7919))
|
| step_records: list[dict] = []
|
|
|
| while not done:
|
| steps += 1
|
| print(f"\nStep {observation['step']} | phase={observation['ambulance_status']}")
|
|
|
| scored = score_hospitals(observation, learning_profile=learning_profile)
|
| chosen, strategy = choose_hospital(scored, observation, rng, learning_profile=learning_profile)
|
|
|
|
|
| if previous_policy_outcome == "REJECTED" and previous_policy_hospital_id and chosen["hospital_id"] == previous_policy_hospital_id:
|
| alternatives = [item for item in scored if item["hospital_id"] != previous_policy_hospital_id]
|
| if alternatives:
|
| chosen = max(alternatives, key=lambda item: float(item.get("policy_score", 0.0)))
|
| strategy = strategy + " + immediate-retry override"
|
|
|
| print_options(scored)
|
| rationale = llm_rationale(cast(Optional[OpenAIClient], llm_client), model_name or "", observation, chosen, strategy)
|
| print(f"Decision: {chosen['hospital_id']} ({strategy})")
|
|
|
| step_result = env.step(
|
| Action(
|
| step=observation["step"],
|
| hospital_id=chosen["hospital_id"],
|
| rationale=rationale,
|
| )
|
| )
|
| next_obs_model = step_result["observation"]
|
| reward = float(step_result["reward"])
|
| all_rewards.append(reward)
|
| done = bool(step_result["done"])
|
| info = step_result.get("info", {}) or {}
|
| next_observation = next_obs_model.model_dump()
|
| total_reward += reward
|
|
|
| outcome = info.get("outcome", {})
|
| status = str(outcome.get("status", "partial")).upper()
|
| reason = str(outcome.get("reason", "No reason provided"))
|
| previous_policy_hospital_id = chosen["hospital_id"]
|
| previous_policy_outcome = status
|
|
|
| print(f"Outcome: {status}")
|
| print(f"Reason: {reason}")
|
| print(f"Reward: {reward:.3f}")
|
| error_val = str(info.get("last_action_error")) if info.get("last_action_error") else "null"
|
| print(f"[STEP] step={observation.get('step')} action={chosen['hospital_id']} reward={reward:.2f} done={str(done).lower()} error={error_val}", flush=True)
|
|
|
| append_trajectory_log(
|
| {
|
| "seed": seed,
|
| "task": task_id,
|
| "difficulty": observation.get("scenario_difficulty"),
|
| "step": observation.get("step"),
|
| "state": {
|
| "patient_condition": observation.get("patient_condition"),
|
| "remaining_time_minutes": observation.get("remaining_time_minutes"),
|
| "failed_hospitals": observation.get("failed_hospitals", []),
|
| "visited_hospitals": observation.get("visited_hospitals", []),
|
| "ambulance_status": observation.get("ambulance_status"),
|
| },
|
| "action": {
|
| "hospital_id": chosen["hospital_id"],
|
| "policy_score": chosen["policy_score"],
|
| "strategy": strategy,
|
| },
|
| "outcome": {
|
| "status": status,
|
| "reason": reason,
|
| },
|
| "reward": reward,
|
| }
|
| )
|
|
|
| step_records.append(
|
| {
|
| "step": observation.get("step"),
|
| "hospital_id": chosen["hospital_id"],
|
| "status": status,
|
| "reason": reason,
|
| "reward": reward,
|
| "policy_score": chosen["policy_score"],
|
| }
|
| )
|
|
|
| observation = next_observation
|
|
|
| final_state = env.state()
|
| final_result = final_state.final_outcome or "FAILURE"
|
| final_score = clamp_strict_score(final_state.final_score)
|
|
|
| print("\nFinal result:")
|
| print(f" Result: {final_result}")
|
| print(f" Total steps: {steps}")
|
| print(f" Final score: {final_score:.3f}")
|
| print(f" Average reward: {total_reward / max(1, steps):.3f}")
|
| rewards_str = ",".join(f"{r:.2f}" for r in all_rewards)
|
| print(f"[END] success={str(final_result == 'SUCCESS').lower()} steps={steps} score={final_score:.2f} rewards={rewards_str}", flush=True)
|
|
|
| return {
|
| "success": final_result == "SUCCESS",
|
| "score": final_score,
|
| "steps": steps,
|
| "seed": seed,
|
| "task_id": task_id,
|
| "scenario_name": observation.get("scenario_name"),
|
| "scenario_type": observation.get("scenario_type"),
|
| "difficulty": observation.get("scenario_difficulty"),
|
| "required_specialization": observation.get("required_specialization"),
|
| "actions": [record["hospital_id"] for record in step_records],
|
| "step_records": step_records,
|
| "timestamp": datetime.now(timezone.utc).isoformat(),
|
| }
|
|
|
|
|
| def update_learning_archive(archive: dict, episode_result: dict) -> None:
|
| key = profile_key(int(episode_result["seed"]), str(episode_result["task_id"]))
|
| profiles = archive.setdefault("profiles", {})
|
| profile = profiles.get(
|
| key,
|
| {
|
| "attempts": 0,
|
| "best_score": 0.0,
|
| "best_actions": [],
|
| "best_steps": 0,
|
| "step_stats": {},
|
| },
|
| )
|
|
|
| profile["attempts"] = int(profile.get("attempts", 0)) + 1
|
| profile["last_score"] = float(episode_result["score"])
|
| profile["last_success"] = bool(episode_result["success"])
|
| profile["last_run_at"] = episode_result["timestamp"]
|
| profile["last_actions"] = list(episode_result.get("actions", []))
|
| profile["last_required_specialization"] = episode_result.get("required_specialization")
|
| profile["last_scenario_type"] = episode_result.get("scenario_type")
|
| profile["last_scenario_name"] = episode_result.get("scenario_name")
|
|
|
| if float(episode_result["score"]) >= float(profile.get("best_score", 0.0)):
|
| profile["best_score"] = float(episode_result["score"])
|
| profile["best_actions"] = list(episode_result.get("actions", []))
|
| profile["best_steps"] = int(episode_result.get("steps", 0))
|
| profile["best_success"] = bool(episode_result["success"])
|
| profile["best_scenario_name"] = episode_result.get("scenario_name")
|
| profile["best_difficulty"] = episode_result.get("difficulty")
|
| profile["best_required_specialization"] = episode_result.get("required_specialization")
|
|
|
| step_stats = profile.setdefault("step_stats", {})
|
| for record in episode_result.get("step_records", []):
|
| step_key = str(record.get("step"))
|
| hospital_id = str(record.get("hospital_id"))
|
| step_bucket = step_stats.setdefault(step_key, {})
|
| hospital_bucket = step_bucket.setdefault(
|
| hospital_id,
|
| {
|
| "count": 0,
|
| "success": 0,
|
| "accepted": 0,
|
| "partial": 0,
|
| "rejected": 0,
|
| "total_reward": 0.0,
|
| "avg_reward": 0.0,
|
| "last_status": None,
|
| "last_reason": None,
|
| },
|
| )
|
| hospital_bucket["count"] += 1
|
| if record["status"] == "ACCEPTED":
|
| hospital_bucket["success"] += 1
|
| hospital_bucket["accepted"] += 1
|
| elif record["status"] == "PARTIAL":
|
| hospital_bucket["partial"] += 1
|
| else:
|
| hospital_bucket["rejected"] += 1
|
| hospital_bucket["total_reward"] = float(hospital_bucket["total_reward"]) + float(record["reward"])
|
| hospital_bucket["avg_reward"] = hospital_bucket["total_reward"] / max(1, hospital_bucket["count"])
|
| hospital_bucket["last_status"] = record["status"]
|
| hospital_bucket["last_reason"] = record["reason"]
|
| hospital_bucket["success_rate"] = hospital_bucket["accepted"] / max(1, hospital_bucket["count"])
|
|
|
| profiles[key] = profile
|
| episodes = archive.setdefault("episodes", [])
|
| episodes.append(
|
| {
|
| "seed": episode_result["seed"],
|
| "task_id": episode_result["task_id"],
|
| "difficulty": episode_result["difficulty"],
|
| "required_specialization": episode_result.get("required_specialization"),
|
| "scenario_name": episode_result["scenario_name"],
|
| "score": episode_result["score"],
|
| "success": episode_result["success"],
|
| "actions": episode_result.get("actions", []),
|
| "timestamp": episode_result["timestamp"],
|
| }
|
| )
|
| archive["episodes"] = episodes[-500:]
|
|
|
|
|
| def print_training_summary(results: list[dict]) -> None:
|
| if not results:
|
| return
|
| scores = [float(item["score"]) for item in results]
|
| successes = sum(1 for item in results if item["success"])
|
| split = max(1, len(scores) // 2)
|
| early_scores = scores[:split]
|
| late_scores = scores[split:]
|
| if not late_scores:
|
| late_scores = scores[-split:]
|
| early_avg = sum(early_scores) / len(early_scores)
|
| late_avg = sum(late_scores) / len(late_scores)
|
| delta = late_avg - early_avg
|
|
|
| print("\nTraining summary:")
|
| print(f" Episodes: {len(results)}")
|
| print(f" Success rate: {successes / len(results):.1%}")
|
| print(f" Average score: {sum(scores) / len(scores):.3f}")
|
| print(f" Early avg score ({len(early_scores)} eps): {early_avg:.3f}")
|
| print(f" Late avg score ({len(late_scores)} eps): {late_avg:.3f}")
|
| print(f" Trend delta (late-early): {delta:+.3f}")
|
|
|
|
|
| def main() -> None:
|
| args = parse_args()
|
| llm_client, model_name = require_llm_config()
|
| seed = ask_seed_if_missing(args.seed)
|
| print(f"Using seed: {seed}")
|
| if args.mode == "full":
|
| tasks = TASK_ORDER
|
| else:
|
| chosen_task = args.task
|
| if chosen_task is None:
|
| chosen_level = ask_level_if_missing(args.level)
|
| chosen_task = LEVEL_TO_TASK[chosen_level]
|
| tasks = [chosen_task]
|
|
|
| env = EmergencyEnv(memory_file=args.memory_file)
|
| archive = load_learning_archive()
|
|
|
| results = []
|
| run_count = args.train_episodes if args.train_episodes > 0 else args.episodes
|
| training_mode = args.train_episodes > 0
|
|
|
| for episode in range(run_count):
|
| for idx, task_id in enumerate(tasks):
|
| if training_mode:
|
| if args.train_same_seed:
|
| task_seed = seed
|
| else:
|
| task_seed = seed + (episode * 100) + idx
|
| else:
|
| task_seed = seed + (episode * 100) + idx
|
|
|
| label = f"Training Episode {episode + 1}" if training_mode else f"Episode {episode + 1}"
|
| print(f"\n=== {label} | {task_id} | seed={task_seed} ===")
|
| episode_result = run_episode(
|
| env,
|
| task_id,
|
| task_seed,
|
| archive=archive,
|
| llm_client=llm_client,
|
| model_name=model_name,
|
| )
|
| results.append(episode_result)
|
| update_learning_archive(archive, episode_result)
|
|
|
| save_learning_archive(archive)
|
|
|
| if training_mode:
|
| print_training_summary(results)
|
| return
|
|
|
| if results:
|
| print("\nBatch summary:")
|
| if len(results) == 1:
|
| episode_result = "SUCCESS" if results[0]["success"] else "FAILURE"
|
| print(f" Episode outcome: {episode_result}")
|
| print(f" Episode score: {results[0]['score']:.3f}")
|
| print(f" Episode steps: {results[0]['steps']}")
|
| print(" Note: run 30-50 episodes to estimate difficulty success rate.")
|
| else:
|
| print(f" Success rate: {sum(1 for item in results if item['success']) / len(results):.1%}")
|
| print(f" Average score: {sum(item['score'] for item in results) / len(results):.3f}")
|
| print(f" Average steps: {sum(item['steps'] for item in results) / len(results):.1f}")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|