| """ |
| Wildfire Containment Simulator β Inference Script |
| =================================================== |
| Runs an LLM agent (via OpenAI-compatible client) against all three task tiers |
| and emits structured [START] / [STEP] / [END] logs for automated evaluation. |
| |
| Required environment variables: |
| API_BASE_URL LLM endpoint (default: https://router.huggingface.co/v1) |
| MODEL_NAME Model identifier (default: Qwen/Qwen2.5-72B-Instruct) |
| HF_TOKEN HuggingFace / API key |
| |
| Optional: |
| TASK_NAME Run a single task: easy | medium | hard (default: all three) |
| """ |
|
|
| import json |
| import os |
| import textwrap |
| from typing import List, Optional |
|
|
| from openai import OpenAI |
|
|
| from env import WildfireEnv, Action, ActionType |
| from env.models import Observation |
|
|
| |
| API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") |
| MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") |
|
|
| TASKS = ["easy", "medium", "hard"] |
| SEED = 42 |
| SUCCESS_THRESHOLD = 0.5 |
| TEMPERATURE = 0.2 |
| MAX_TOKENS = 120 |
|
|
| |
|
|
| def log_start(task: str, model: str) -> None: |
| print(f"[START] task={task} env=wildfire-containment-simulator model={model}", flush=True) |
|
|
|
|
| def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: |
| err = error if error else "null" |
| print( |
| f"[STEP] step={step} action={action} reward={reward:.2f} " |
| f"done={str(done).lower()} error={err}", |
| flush=True, |
| ) |
|
|
|
|
| def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) |
| print( |
| f"[END] success={str(success).lower()} steps={steps} " |
| f"score={score:.2f} rewards={rewards_str}", |
| flush=True, |
| ) |
|
|
|
|
| |
|
|
| SYSTEM_PROMPT = textwrap.dedent(""" |
| You are an AI wildfire incident commander. Each step issue exactly ONE action as JSON. |
| |
| Action types and required fields: |
| deploy_crew : {"action_type":"deploy_crew","crew_id":"crew_N","target_row":R,"target_col":C} |
| move_crew : {"action_type":"move_crew","crew_id":"crew_N","direction":"N|S|E|W|NE|NW|SE|SW"} |
| drop_retardant : {"action_type":"drop_retardant","tanker_id":"tanker_N","target_row":R,"target_col":C} |
| build_firebreak: {"action_type":"build_firebreak","crew_id":"crew_N","direction":"N|S|E|W|NE|NW|SE|SW"} |
| recon_flight : {"action_type":"recon_flight","target_row":R,"target_col":C} |
| idle : {"action_type":"idle","reason":"..."} |
| |
| Strategy: |
| - DEPLOY undeployed crews first (deploy_crew) before any other crew action. |
| - MOVE crews toward fire to suppress it. |
| - BUILD firebreaks between fire and populated zones. |
| - DROP retardant on high-intensity clusters near populated cells. |
| - Output ONLY raw JSON. No explanation, no markdown, no code fences. |
| """).strip() |
|
|
|
|
| def build_user_prompt(obs: Observation, step: int, history: List[str]) -> str: |
| stats = obs.stats |
| weather = obs.weather |
| res = obs.resources |
|
|
| burning = [ |
| f"({cell.row},{cell.col},{cell.intensity_bin.value})" |
| for row in obs.grid for cell in row |
| if cell.fire_state.value in ("burning", "ember") |
| ][:12] |
|
|
| populated_safe = [ |
| f"({cell.row},{cell.col})" |
| for row in obs.grid for cell in row |
| if cell.is_populated and cell.fire_state.value not in ("burned_out", "burning") |
| ][:8] |
|
|
| crews = [f"{c.crew_id}@({c.row},{c.col}) deployed={c.is_deployed} active={c.is_active}" |
| for c in res.crews] |
| tankers = [f"{t.tanker_id} cooldown={t.cooldown_remaining} active={t.is_active}" |
| for t in res.tankers] |
|
|
| history_block = "\n".join(history[-4:]) if history else "none" |
|
|
| return textwrap.dedent(f""" |
| Step {step} / {stats.max_steps} |
| Fire: {stats.cells_burning} burning, {stats.cells_burned} burned out |
| Population lost: {stats.population_lost} | Containment: {stats.containment_pct:.1f}% |
| Weather: {weather.wind_speed_kmh:.0f} km/h @ {weather.wind_direction_deg:.0f}Β° | humidity {weather.humidity_pct:.0f}% | rain={weather.rain_active} |
| |
| Burning cells (row,col,intensity): {burning} |
| Safe populated cells: {populated_safe} |
| |
| Crews: {crews} |
| Tankers: {tankers} |
| Firebreak budget: {res.firebreak_budget} | Recon budget: {res.recon_budget} |
| |
| Recent events: {obs.recent_events} |
| Last actions: |
| {history_block} |
| |
| Output your next action as JSON: |
| """).strip() |
|
|
|
|
| |
|
|
| def _compact_action(action: Action) -> str: |
| """Short human-readable string for [STEP] log.""" |
| at = action.action_type.value |
| if at == "deploy_crew": |
| return f"deploy_crew({action.crew_id},{action.target_row},{action.target_col})" |
| if at == "move_crew": |
| return f"move_crew({action.crew_id},{action.direction.value})" |
| if at == "drop_retardant": |
| return f"drop_retardant({action.tanker_id},{action.target_row},{action.target_col})" |
| if at == "build_firebreak": |
| return f"build_firebreak({action.crew_id},{action.direction.value})" |
| if at == "recon_flight": |
| return f"recon_flight({action.target_row},{action.target_col})" |
| return f"idle({action.reason or ''})" |
|
|
|
|
| def get_llm_action( |
| client: OpenAI, |
| obs: Observation, |
| step: int, |
| history: List[str], |
| ) -> tuple[Action, str, Optional[str]]: |
| """Call LLM, parse JSON action. Falls back to IDLE on any failure.""" |
| user_prompt = build_user_prompt(obs, step, history) |
| error: Optional[str] = None |
|
|
| try: |
| completion = client.chat.completions.create( |
| model=MODEL_NAME, |
| messages=[ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": user_prompt}, |
| ], |
| temperature=TEMPERATURE, |
| max_tokens=MAX_TOKENS, |
| stream=False, |
| ) |
| raw = (completion.choices[0].message.content or "").strip() |
|
|
| |
| if "```" in raw: |
| parts = raw.split("```") |
| raw = parts[1] if len(parts) > 1 else raw |
| if raw.lower().startswith("json"): |
| raw = raw[4:].strip() |
|
|
| data = json.loads(raw) |
| action = Action(**data) |
| return action, _compact_action(action), None |
|
|
| except Exception as exc: |
| error = str(exc)[:80] |
| idle = Action(action_type=ActionType.IDLE, reason="llm_parse_error") |
| return idle, "idle(llm_parse_error)", error |
|
|
|
|
| |
|
|
| def run_task(client: OpenAI, task_id: str, seed: int) -> float: |
| """Run one full episode and return the final score in [0, 1].""" |
| env = WildfireEnv() |
| obs = env.reset(task_id=task_id, seed=seed) |
|
|
| rewards: List[float] = [] |
| history: List[str] = [] |
| steps_taken: int = 0 |
| score: float = 0.0 |
| success: bool = False |
|
|
| log_start(task=task_id, model=MODEL_NAME) |
|
|
| try: |
| step = 0 |
| while not env.done: |
| step += 1 |
| action, action_str, error = get_llm_action(client, obs, step, history) |
|
|
| result = env.step(action) |
| obs = result.observation |
| reward = result.reward |
| done = result.done |
| steps_taken = step |
|
|
| rewards.append(reward) |
| log_step(step=step, action=action_str, reward=reward, done=done, error=error) |
| history.append(f"Step {step}: {action_str} -> reward {reward:.2f}") |
|
|
| |
| score = rewards[-1] if rewards else 0.0 |
| score = min(max(score, 0.0), 1.0) |
| success = score >= SUCCESS_THRESHOLD |
|
|
| except Exception as exc: |
| error_msg = str(exc)[:120] |
| print(f"[DEBUG] Episode error: {error_msg}", flush=True) |
|
|
| finally: |
| log_end(success=success, steps=steps_taken, score=score, rewards=rewards) |
|
|
| return score |
|
|
|
|
| |
|
|
| def main() -> None: |
| client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) |
|
|
| task_override = os.getenv("TASK_NAME") |
| tasks = [task_override] if task_override else TASKS |
|
|
| results = {} |
| for task_id in tasks: |
| results[task_id] = run_task(client, task_id, seed=SEED) |
|
|
| |
| summary = " | ".join(f"{t}={s:.3f}" for t, s in results.items()) |
| print(f"\n[SUMMARY] {summary}", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|