{
"cells": [
{
"cell_type": "markdown",
"id": "a1b2c3d4",
"metadata": {},
"source": [
"# Pyre — HuggingFace Baseline Evaluation\n",
"\n",
"Evaluates an **un-fine-tuned** (or fine-tuned) HuggingFace model across all 3 Pyre difficulty levels\n",
"using the same episode-running logic as `evals.py`.\n",
"\n",
"**Workflow:**\n",
"1. Edit **Cell 2 (Config)** with your model path, server URL, and eval settings\n",
"2. Run all cells top-to-bottom (`Run All`)\n",
"3. To change any parameter — update Cell 2 and re-run from Cell 7 onward"
]
},
{
"cell_type": "markdown",
"id": "b2c3d4e5",
"metadata": {},
"source": [
"## Cell 1 — Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c3d4e5f6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Imports OK\n"
]
}
],
"source": [
"import csv\n",
"import json\n",
"import logging\n",
"import re\n",
"import textwrap\n",
"from collections import defaultdict\n",
"from datetime import datetime\n",
"from pathlib import Path\n",
"from typing import Any, Dict, List, Optional, Tuple\n",
"\n",
"import requests\n",
"from langchain_core.language_models.chat_models import SimpleChatModel\n",
"from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage\n",
"from pydantic import PrivateAttr\n",
"\n",
"from pyre_env import PyreEnv, PyreAction\n",
"\n",
"try:\n",
" from dotenv import load_dotenv\n",
" load_dotenv()\n",
"except ImportError:\n",
" pass\n",
"\n",
"logging.basicConfig(\n",
" level=logging.INFO,\n",
" format=\"%(asctime)s [%(levelname)s] %(message)s\",\n",
" datefmt=\"%H:%M:%S\",\n",
")\n",
"log = logging.getLogger(\"evals_hf\")\n",
"print(\"Imports OK\")"
]
},
{
"cell_type": "markdown",
"id": "1f1d9271",
"metadata": {},
"source": [
"## Cell 2 — Config ✏️ Edit this cell to change eval settings"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "e5f6a7b8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model : Qwen/Qwen3-1.7B\n",
"4-bit : False\n",
"Temperature: 0.3\n",
"Levels : all 3 (easy, medium, hard)\n",
"Seeds : [1, 2, 3]\n",
"Output dir : ./outputs/hf_evals\n"
]
}
],
"source": [
"# ── Model ─────────────────────────────────────────────────────────────────────\n",
"MODEL_ID = \"Qwen/Qwen3-1.7B\" # HF model ID or local path / adapter dir\n",
"LOAD_4BIT = False # True → bitsandbytes 4-bit quant (low VRAM)\n",
"TEMPERATURE = 0.3 # 0 = greedy decoding\n",
"MAX_NEW_TOKENS = 512 # max tokens per model response\n",
"\n",
"# ── Environment ───────────────────────────────────────────────────────────────\n",
"ENV_URL = \"http://localhost:8000\" # running Pyre server\n",
"\n",
"# ── Difficulties ──────────────────────────────────────────────────────────────\n",
"# Set to None to run all 3 levels, or list specific ones:\n",
"# [\"easy\", \"medium\", \"hard\"]\n",
"DIFFICULTIES_TO_RUN = None\n",
"\n",
"# ── Seeds ─────────────────────────────────────────────────────────────────────\n",
"NUM_SEEDS = 3 # episodes per difficulty level\n",
"SEED_START = 1 # first seed value\n",
"\n",
"# ── Output ────────────────────────────────────────────────────────────────────\n",
"OUTPUT_DIR = \"./outputs/hf_evals\" # CSV + debug traces saved here\n",
"VERBOSE = False # True → per-step DEBUG logs\n",
"\n",
"# ── Derived (do not edit) ─────────────────────────────────────────────────────\n",
"if VERBOSE:\n",
" log.setLevel(logging.DEBUG)\n",
"\n",
"SEEDS = list(range(SEED_START, SEED_START + NUM_SEEDS))\n",
"print(f\"Model : {MODEL_ID}\")\n",
"print(f\"4-bit : {LOAD_4BIT}\")\n",
"print(f\"Temperature: {TEMPERATURE}\")\n",
"print(f\"Levels : {DIFFICULTIES_TO_RUN or 'all 3 (easy, medium, hard)'}\")\n",
"print(f\"Seeds : {SEEDS}\")\n",
"print(f\"Output dir : {OUTPUT_DIR}\")"
]
},
{
"cell_type": "markdown",
"id": "f6a7b8c9",
"metadata": {},
"source": [
"## Cell 3 — Difficulty Registry & System Prompt"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a7b8c9d0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Difficulties to evaluate : ['easy', 'medium', 'hard']\n",
"Total episodes : 9\n"
]
}
],
"source": [
"ALL_DIFFICULTIES: List[Dict[str, Any]] = [\n",
" {\"difficulty\": \"easy\", \"max_steps\": 200, \"description\": \"1 fire source · slow spread · calm wind · high humidity\"},\n",
" {\"difficulty\": \"medium\", \"max_steps\": 150, \"description\": \"2–4 fire sources · moderate spread · any wind · moderate humidity\"},\n",
" {\"difficulty\": \"hard\", \"max_steps\": 100, \"description\": \"3–5 fire sources · fast spread · always windy · low humidity\"},\n",
"]\n",
"\n",
"# Filter down to DIFFICULTIES_TO_RUN if set in Config cell\n",
"difficulties_to_run = (\n",
" [d for d in ALL_DIFFICULTIES if d[\"difficulty\"] in DIFFICULTIES_TO_RUN]\n",
" if DIFFICULTIES_TO_RUN else ALL_DIFFICULTIES\n",
")\n",
"if not difficulties_to_run:\n",
" raise ValueError(f\"No matching difficulties. Available: {[d['difficulty'] for d in ALL_DIFFICULTIES]}\")\n",
"\n",
"SYSTEM_PROMPT = textwrap.dedent(\"\"\"\n",
"You are an agent trapped inside a burning building.\n",
"Your goal is to navigate to an EXIT before your health reaches zero or time runs out.\n",
"\n",
"ENVIRONMENT RULES\n",
"- You have partial vision: you cannot see through walls or dense smoke.\n",
"- Fire and smoke spread each step — do NOT linger in hazardous areas.\n",
"- Closing a door adjacent to active fire slows its spread (strategic move).\n",
"- Your health drains faster in moderate/heavy smoke and on fire cells.\n",
"- Exits may be BLOCKED if fire burns directly on them — check available hints.\n",
"\n",
"OUTPUT FORMAT (STRICT)\n",
"You MUST reason inside ... tags first, then emit EXACTLY ONE JSON object.\n",
"Output NOTHING else — no extra text, no markdown fences, no second JSON block.\n",
"\n",
"\n",
"Brief reasoning: what can I see, where is danger, what is the safest next move?\n",
"\n",
"{\"action\": \"move\", \"direction\": \"north\"}\n",
"\n",
"AVAILABLE ACTIONS\n",
"- move : {\"action\": \"move\", \"direction\": \"north|south|east|west\"}\n",
"- look : {\"action\": \"look\", \"direction\": \"north|south|east|west\"} ← scan 5 cells ahead\n",
"- door : {\"action\": \"door\", \"target_id\": \"door_X\", \"door_state\": \"open|close\"}\n",
"- wait : {\"action\": \"wait\"}\n",
"\n",
"REWARD SIGNAL (shown in history after each step)\n",
"- Positive reward → you moved closer to an exit or played a smart move.\n",
"- Negative reward → you moved into danger, stalled, or wasted a step.\n",
"- Use the reward trend to judge if your current direction is working.\n",
"\n",
"STRATEGY TIPS\n",
"- Use `look` to scout a direction before entering an unknown corridor.\n",
"- Closing a door between you and fire buys time; re-open when clear.\n",
"- Prefer moves that increase reward — progress toward the exit is rewarded.\n",
"- If smoke is heavy, back away; your health drains fast in thick smoke.\n",
"- Door IDs (e.g. door_3) appear in the Visible objects list — use them with the door action.\n",
"\"\"\").strip()\n",
"\n",
"print(f\"Difficulties to evaluate : {[d['difficulty'] for d in difficulties_to_run]}\")\n",
"print(f\"Total episodes : {len(difficulties_to_run) * len(SEEDS)}\")"
]
},
{
"cell_type": "markdown",
"id": "b8c9d0e1",
"metadata": {},
"source": [
"## Cell 4 — HuggingFace Chat Model Wrapper"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c9d0e1f2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"HFChatModel class defined.\n"
]
}
],
"source": [
"class HFChatModel(SimpleChatModel):\n",
" \"\"\"\n",
" Wraps a locally-loaded HuggingFace causal-LM as a LangChain chat model.\n",
" Uses the model loading + chat-template logic compatible with Qwen3 and similar models.\n",
" \"\"\"\n",
"\n",
" model_id: str\n",
" temperature: float = 0.3\n",
" max_new_tokens: int = 512\n",
"\n",
" _model: Any = PrivateAttr(default=None)\n",
" _tokenizer: Any = PrivateAttr(default=None)\n",
"\n",
" def load(self, load_4bit: bool = False) -> \"HFChatModel\":\n",
" \"\"\"Load model + tokenizer. Call once after construction.\"\"\"\n",
" import torch\n",
" from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
" log.info(\"Loading tokenizer from %s …\", self.model_id)\n",
" tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)\n",
" if tokenizer.pad_token is None:\n",
" tokenizer.pad_token = tokenizer.eos_token\n",
" self._tokenizer = tokenizer\n",
"\n",
" log.info(\"Loading model from %s (4bit=%s) …\", self.model_id, load_4bit)\n",
" model = AutoModelForCausalLM.from_pretrained(\n",
" self.model_id,\n",
" torch_dtype=torch.bfloat16 if not load_4bit else None,\n",
" device_map=\"auto\",\n",
" load_in_4bit=load_4bit,\n",
" trust_remote_code=True,\n",
" )\n",
" model.eval()\n",
" self._model = model\n",
" device_info = getattr(self._model, \"hf_device_map\", \"auto\")\n",
" log.info(\"Model loaded. Device map: %s\", device_info)\n",
" return self\n",
"\n",
" def _call(\n",
" self,\n",
" messages: List[BaseMessage],\n",
" stop: Optional[List[str]] = None,\n",
" run_manager=None,\n",
" **kwargs,\n",
" ) -> str:\n",
" import torch\n",
"\n",
" # Convert LangChain messages → HF chat-template dicts\n",
" conversation: List[Dict[str, str]] = []\n",
" for msg in messages:\n",
" if isinstance(msg, SystemMessage):\n",
" conversation.append({\"role\": \"system\", \"content\": msg.content})\n",
" elif isinstance(msg, HumanMessage):\n",
" conversation.append({\"role\": \"user\", \"content\": msg.content})\n",
" elif isinstance(msg, AIMessage):\n",
" conversation.append({\"role\": \"assistant\", \"content\": msg.content})\n",
"\n",
" # Apply chat template — try with enable_thinking first (Qwen3), fall back for others\n",
" try:\n",
" prompt_text = self._tokenizer.apply_chat_template(\n",
" conversation,\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
" enable_thinking=True,\n",
" )\n",
" except TypeError:\n",
" prompt_text = self._tokenizer.apply_chat_template(\n",
" conversation,\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
" )\n",
"\n",
" inputs = self._tokenizer(prompt_text, return_tensors=\"pt\").to(self._model.device)\n",
" in_len = inputs[\"input_ids\"].shape[1]\n",
"\n",
" gen_kwargs: Dict[str, Any] = dict(\n",
" max_new_tokens=self.max_new_tokens,\n",
" pad_token_id=self._tokenizer.eos_token_id,\n",
" eos_token_id=self._tokenizer.eos_token_id,\n",
" )\n",
" if self.temperature > 0:\n",
" gen_kwargs.update(do_sample=True, temperature=self.temperature)\n",
" else:\n",
" gen_kwargs[\"do_sample\"] = False\n",
"\n",
" with torch.inference_mode():\n",
" output_ids = self._model.generate(**inputs, **gen_kwargs)\n",
"\n",
" new_tokens = output_ids[0][in_len:]\n",
" return self._tokenizer.decode(new_tokens, skip_special_tokens=True)\n",
"\n",
" @property\n",
" def _llm_type(self) -> str:\n",
" return \"hf-chat-model\"\n",
"\n",
"print(\"HFChatModel class defined.\")"
]
},
{
"cell_type": "markdown",
"id": "d0e1f2a3",
"metadata": {},
"source": [
"## Cell 5 — Prompt Builder & Action Parser"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e1f2a3b4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Prompt builder and action parser defined.\n"
]
}
],
"source": [
"def _build_user_message(obs: Dict[str, Any], history: List[str]) -> str:\n",
" \"\"\"Convert a raw observation dict + history into the LLM user message.\"\"\"\n",
" narrative = obs.get(\"narrative\", \"(no narrative)\")\n",
" # Strip the \"Available actions:\" line — the system prompt already covers this.\n",
" narrative = re.sub(r\"\\nAvailable actions:.*$\", \"\", narrative, flags=re.MULTILINE)\n",
"\n",
" health = obs.get(\"agent_health\", \"?\")\n",
" health_st = obs.get(\"health_status\", \"?\")\n",
" location = obs.get(\"location_label\", \"?\")\n",
" smoke = obs.get(\"smoke_level\", \"none\")\n",
" fire_vis = obs.get(\"fire_visible\", False)\n",
" fire_dir = obs.get(\"fire_direction\") or \"none\"\n",
" wind = obs.get(\"wind_dir\", \"CALM\")\n",
" elapsed = obs.get(\"elapsed_steps\", 0)\n",
" blocked_exits = obs.get(\"blocked_exit_ids\", [])\n",
" visible_objs = obs.get(\"visible_objects\", [])\n",
" audible = obs.get(\"audible_signals\", [])\n",
"\n",
" status_line = (\n",
" f\"Health: {health:.1f} | Status: {health_st} | Location: {location}\\n\"\n",
" f\"Smoke: {smoke} | Fire visible: {fire_vis}\"\n",
" + (f\" (direction: {fire_dir})\" if fire_vis else \"\")\n",
" + f\"\\nWind: {wind} | Steps elapsed: {elapsed}\"\n",
" )\n",
" if blocked_exits:\n",
" status_line += f\"\\nBLOCKED exits (fire on them): {', '.join(blocked_exits)}\"\n",
" if visible_objs:\n",
" obj_strs = [\n",
" f\"{o.get('type','?')} '{o.get('id','?')}' {o.get('relative_pos','?')}\"\n",
" + (f\" [{o.get('state','')}]\" if o.get(\"state\") else \"\")\n",
" for o in visible_objs\n",
" ]\n",
" status_line += f\"\\nVisible objects: {'; '.join(obj_strs)}\"\n",
" if audible:\n",
" status_line += f\"\\nSounds: {'; '.join(audible)}\"\n",
"\n",
" history_str = \"\"\n",
" if history:\n",
" recent = history[-8:] # last 8 steps to stay within context limits\n",
" history_str = (\n",
" \"=== RECENT ACTION HISTORY (action → feedback → reward → health) ===\\n\"\n",
" + \"\\n\".join(recent) + \"\\n\\n\"\n",
" )\n",
"\n",
" return (\n",
" f\"=== CURRENT OBSERVATION ===\\n{narrative}\\n\\n\"\n",
" f\"=== STATUS ===\\n{status_line}\\n\\n\"\n",
" + history_str\n",
" + \"What is your next action? Respond with ... then a single JSON action.\"\n",
" )\n",
"\n",
"\n",
"_VALID_ACTIONS = {\"move\", \"door\", \"look\", \"wait\"}\n",
"_VALID_DIRECTIONS = {\"north\", \"south\", \"east\", \"west\"}\n",
"_VALID_DOOR_STATES = {\"open\", \"close\"}\n",
"_FALLBACK_ACTION = {\"action\": \"wait\"}\n",
"\n",
"\n",
"def _validate_pyre_action(blob: Dict[str, Any]) -> Optional[Dict[str, Any]]:\n",
" \"\"\"Return a sanitised action dict or None if the blob is unusable.\"\"\"\n",
" action = blob.get(\"action\", \"\").strip().lower()\n",
" if action not in _VALID_ACTIONS:\n",
" return None\n",
"\n",
" out: Dict[str, Any] = {\"action\": action}\n",
"\n",
" if action in (\"move\", \"look\"):\n",
" direction = str(blob.get(\"direction\", \"\")).strip().lower()\n",
" if direction not in _VALID_DIRECTIONS:\n",
" return None\n",
" out[\"direction\"] = direction\n",
"\n",
" elif action == \"door\":\n",
" tid = blob.get(\"target_id\", \"\") or blob.get(\"target\", \"\")\n",
" ds = str(blob.get(\"door_state\", \"\")).strip().lower()\n",
" if not tid or ds not in _VALID_DOOR_STATES:\n",
" return None\n",
" out[\"target_id\"] = str(tid)\n",
" out[\"door_state\"] = ds\n",
"\n",
" return out\n",
"\n",
"\n",
"def _parse_pyre_action(text: str) -> Tuple[Dict[str, Any], float]:\n",
" \"\"\"\n",
" Extract a Pyre action from raw LLM text.\n",
"\n",
" Returns (action_dict, format_score) where format_score reflects output quality:\n",
" 1.0 — valid JSON + tags\n",
" 0.7 — valid JSON, no \n",
" 0.4 — partial JSON rescued via regex\n",
" 0.1 — action keyword found in raw text (last resort)\n",
" 0.0 — completely unparseable → {\"action\": \"wait\"} fallback\n",
" \"\"\"\n",
" has_think = \"\" in text and \"\" in text\n",
"\n",
" # Level 1: well-formed JSON\n",
" start = text.find(\"{\")\n",
" end = text.rfind(\"}\")\n",
" if start != -1 and end > start:\n",
" try:\n",
" blob = json.loads(text[start:end + 1])\n",
" if isinstance(blob, dict):\n",
" action = _validate_pyre_action(blob)\n",
" if action is not None:\n",
" return action, (1.0 if has_think else 0.7)\n",
" except json.JSONDecodeError:\n",
" pass\n",
"\n",
" # Level 2: regex — find the innermost {...} with \"action\" key\n",
" for m in re.finditer(r'\\{[^{}]+\\}', text):\n",
" try:\n",
" blob = json.loads(m.group())\n",
" if isinstance(blob, dict) and \"action\" in blob:\n",
" action = _validate_pyre_action(blob)\n",
" if action is not None:\n",
" return action, 0.4\n",
" except json.JSONDecodeError:\n",
" continue\n",
"\n",
" # Level 3: bare keyword extraction\n",
" lower = text.lower()\n",
" for d in _VALID_DIRECTIONS:\n",
" if f\"move {d}\" in lower:\n",
" return {\"action\": \"move\", \"direction\": d}, 0.1\n",
" for d in _VALID_DIRECTIONS:\n",
" if f\"look {d}\" in lower:\n",
" return {\"action\": \"look\", \"direction\": d}, 0.1\n",
" door_m = re.search(r'door[_\\s]*([\\w]+)', lower)\n",
" if door_m:\n",
" tid = door_m.group(1)\n",
" ds = \"close\" if \"clos\" in lower else \"open\"\n",
" return {\"action\": \"door\", \"target_id\": tid, \"door_state\": ds}, 0.1\n",
" if \"wait\" in lower:\n",
" return {\"action\": \"wait\"}, 0.1\n",
"\n",
" # Level 4: total parse failure\n",
" return dict(_FALLBACK_ACTION), 0.0\n",
"\n",
"\n",
"print(\"Prompt builder and action parser defined.\")"
]
},
{
"cell_type": "markdown",
"id": "f2a3b4c5",
"metadata": {},
"source": [
"## Cell 6 — Episode Runner"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a3b4c5d6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Episode runner defined.\n"
]
}
],
"source": [
"def run_episode(\n",
" llm: HFChatModel,\n",
" difficulty: str,\n",
" seed: int,\n",
" max_steps: int,\n",
" debug_dir: Optional[Path] = None,\n",
") -> Dict[str, Any]:\n",
" \"\"\"Run one full episode using the PyreEnv sync client (WebSocket-based, stateful).\n",
"\n",
" The episode is self-contained: reset → loop(observe → LLM → act) → done.\n",
" Uses PyreEnv.sync() so the WebSocket session persists across all steps.\n",
" \"\"\"\n",
" step_rewards: List[float] = []\n",
" history: List[str] = []\n",
" llm_responses_log: List[str] = []\n",
" steps_taken = 0\n",
" think_steps = 0\n",
" parsed_steps = 0\n",
" fmt_scores: List[float] = []\n",
" cause_of_end = \"timeout\"\n",
" final_health = 0.0\n",
" agent_evacuated = False\n",
"\n",
" try:\n",
" with PyreEnv(base_url=ENV_URL).sync() as env:\n",
" # ── Reset ─────────────────────────────────────────────────────────\n",
" result = env.reset(difficulty=difficulty, seed=seed)\n",
" obs = result.observation # PyreObservation\n",
" done = result.done\n",
"\n",
" if debug_dir is not None:\n",
" try:\n",
" state_resp = requests.get(f\"{ENV_URL}/state\", timeout=10)\n",
" if state_resp.ok:\n",
" debug_dir.mkdir(parents=True, exist_ok=True)\n",
" (debug_dir / f\"{difficulty}_seed{seed}_init_state.json\").write_text(\n",
" json.dumps(state_resp.json(), indent=2)\n",
" )\n",
" except Exception as exc:\n",
" log.warning(\"Could not save initial state: %s\", exc)\n",
"\n",
" # ── Episode loop ─────────────────────────────────────────────────\n",
" for _step in range(max_steps):\n",
" if done:\n",
" break\n",
"\n",
" obs_dict = obs.model_dump()\n",
" user_msg = _build_user_message(obs_dict, history)\n",
" messages = [\n",
" SystemMessage(content=SYSTEM_PROMPT),\n",
" HumanMessage(content=user_msg),\n",
" ]\n",
"\n",
" # ── LLM call ─────────────────────────────────────────────────\n",
" try:\n",
" response = llm.invoke(messages)\n",
" completion_text = response.content\n",
" except Exception as exc:\n",
" log.warning(\"LLM call failed at step %d: %s\", _step + 1, exc)\n",
" step_rewards.append(-0.20)\n",
" break\n",
"\n",
" llm_responses_log.append(f\"## Step {_step + 1}\\n{completion_text}\\n\")\n",
"\n",
" has_think = \"\" in completion_text and \"\" in completion_text\n",
" if has_think:\n",
" think_steps += 1\n",
" steps_taken += 1\n",
"\n",
" # ── Parse action ─────────────────────────────────────────────\n",
" action_dict, fmt_score = _parse_pyre_action(completion_text)\n",
" fmt_scores.append(fmt_score)\n",
"\n",
" if fmt_score > 0.0:\n",
" parsed_steps += 1\n",
" else:\n",
" log.debug(\" step=%d UNPARSEABLE — using wait fallback\", _step + 1)\n",
"\n",
" log.debug(\n",
" \" step=%d fmt=%.1f action=%s\",\n",
" _step + 1, fmt_score, json.dumps(action_dict),\n",
" )\n",
"\n",
" # ── Step the environment ──────────────────────────────────────\n",
" try:\n",
" result = env.step(PyreAction(**action_dict))\n",
" obs = result.observation\n",
" step_reward = float(result.reward or 0.0)\n",
" done = result.done\n",
" except Exception as exc:\n",
" log.warning(\"Step failed at step %d: %s\", _step + 1, exc)\n",
" step_rewards.append(-0.20)\n",
" break\n",
"\n",
" # Format penalty: imperfect parse loses up to -0.10\n",
" fmt_penalty = (1.0 - fmt_score) * -0.10\n",
" step_rewards.append(step_reward + fmt_penalty)\n",
"\n",
" feedback = obs.last_action_feedback or \"\"\n",
" history.append(\n",
" f\"Step {_step + 1}: {json.dumps(action_dict)}\"\n",
" + (f\"\\n → {feedback}\" if feedback else \"\")\n",
" + f\"\\n reward: {step_reward:+.3f} health: {obs.agent_health:.1f}\"\n",
" )\n",
"\n",
" # ── Final state ───────────────────────────────────────────────────\n",
" agent_evacuated = obs.agent_evacuated\n",
" final_health = float(obs.agent_health)\n",
"\n",
" except Exception as exc:\n",
" log.error(\"Episode failed (difficulty=%s seed=%d): %s\", difficulty, seed, exc)\n",
" return {\"difficulty\": difficulty, \"seed\": seed, \"error\": str(exc), \"evacuated\": 0}\n",
"\n",
" # ── Determine cause of end ────────────────────────────────────────────────\n",
" if agent_evacuated:\n",
" cause_of_end = \"evacuated\"\n",
" elif final_health <= 0.0:\n",
" cause_of_end = \"death\"\n",
" else:\n",
" cause_of_end = \"timeout\"\n",
"\n",
" if debug_dir is not None and llm_responses_log:\n",
" try:\n",
" debug_dir.mkdir(parents=True, exist_ok=True)\n",
" (debug_dir / f\"{difficulty}_seed{seed}_llm_trace.md\").write_text(\n",
" \"\\n\".join(llm_responses_log)\n",
" )\n",
" except Exception as exc:\n",
" log.warning(\"Could not save LLM trace: %s\", exc)\n",
"\n",
" total_reward = sum(step_rewards)\n",
" mean_step_reward = total_reward / max(len(step_rewards), 1)\n",
" think_rate = think_steps / max(steps_taken, 1)\n",
" parse_rate = parsed_steps / max(steps_taken, 1)\n",
" fmt_score_avg = sum(fmt_scores) / max(len(fmt_scores), 1)\n",
"\n",
" return {\n",
" \"difficulty\": difficulty,\n",
" \"seed\": seed,\n",
" \"evacuated\": int(agent_evacuated),\n",
" \"cause_of_end\": cause_of_end,\n",
" \"final_health\": round(final_health, 2),\n",
" \"total_reward\": round(total_reward, 4),\n",
" \"mean_step_reward\": round(mean_step_reward, 4),\n",
" \"steps_taken\": steps_taken,\n",
" \"max_steps\": max_steps,\n",
" \"think_rate\": round(think_rate, 4),\n",
" \"parse_rate\": round(parse_rate, 4),\n",
" \"format_score_avg\": round(fmt_score_avg, 4),\n",
" \"error\": None,\n",
" }\n",
"\n",
"\n",
"print(\"Episode runner defined.\")"
]
},
{
"cell_type": "markdown",
"id": "b4c5d6e7",
"metadata": {},
"source": [
"## Cell 7 — Server Health Check & Model Load"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "c5d6e7f8",
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "Server not reachable at http://localhost:8000: HTTPConnectionPool(host='localhost', port=8000): Max retries exceeded with url: /health (Caused by NewConnectionError(\"HTTPConnection(host='localhost', port=8000): Failed to establish a new connection: [Errno 111] Connection refused\"))",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mConnectionRefusedError\u001b[39m Traceback (most recent call last)",
"\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/urllib3/connection.py:204\u001b[39m, in \u001b[36mHTTPConnection._new_conn\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 203\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m204\u001b[39m sock = \u001b[30;43mconnection\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mcreate_connection\u001b[39;49m\u001b[30;43m(\u001b[39;49m\n\u001b[32m 205\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43m_dns_host\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mport\u001b[39;49m\u001b[30;43m)\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 206\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mtimeout\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 207\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43msource_address\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43msource_address\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 208\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43msocket_options\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43msocket_options\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 209\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 210\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m socket.gaierror \u001b[38;5;28;01mas\u001b[39;00m e:\n",
"\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/urllib3/util/connection.py:85\u001b[39m, in \u001b[36mcreate_connection\u001b[39m\u001b[34m(address, timeout, source_address, socket_options)\u001b[39m\n\u001b[32m 84\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m85\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m err\n\u001b[32m 86\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 87\u001b[39m \u001b[38;5;66;03m# Break explicitly a reference cycle\u001b[39;00m\n",
"\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/urllib3/util/connection.py:73\u001b[39m, in \u001b[36mcreate_connection\u001b[39m\u001b[34m(address, timeout, source_address, socket_options)\u001b[39m\n\u001b[32m 72\u001b[39m sock.bind(source_address)\n\u001b[32m---> \u001b[39m\u001b[32m73\u001b[39m \u001b[30;43msock\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mconnect\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43msa\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 74\u001b[39m \u001b[38;5;66;03m# Break explicitly a reference cycle\u001b[39;00m\n",
"\u001b[31mConnectionRefusedError\u001b[39m: [Errno 111] Connection refused",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[31mNewConnectionError\u001b[39m Traceback (most recent call last)",
"\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/urllib3/connectionpool.py:787\u001b[39m, in \u001b[36mHTTPConnectionPool.urlopen\u001b[39m\u001b[34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, preload_content, decode_content, **response_kw)\u001b[39m\n\u001b[32m 786\u001b[39m \u001b[38;5;66;03m# Make the request on the HTTPConnection object\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m787\u001b[39m response = \u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43m_make_request\u001b[39;49m\u001b[30;43m(\u001b[39;49m\n\u001b[32m 788\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mconn\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 789\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mmethod\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 790\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43murl\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 791\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mtimeout\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mtimeout_obj\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 792\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mbody\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mbody\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 793\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mheaders\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mheaders\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 794\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mchunked\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mchunked\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 795\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mretries\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mretries\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 796\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mresponse_conn\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mresponse_conn\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 797\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mpreload_content\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mpreload_content\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 798\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mdecode_content\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mdecode_content\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 799\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43mresponse_kw\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 800\u001b[39m \u001b[30;43m\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 802\u001b[39m \u001b[38;5;66;03m# Everything went great!\u001b[39;00m\n",
"\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/urllib3/connectionpool.py:493\u001b[39m, in \u001b[36mHTTPConnectionPool._make_request\u001b[39m\u001b[34m(self, conn, method, url, body, headers, retries, timeout, chunked, response_conn, preload_content, decode_content, enforce_content_length)\u001b[39m\n\u001b[32m 492\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m493\u001b[39m \u001b[30;43mconn\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mrequest\u001b[39;49m\u001b[30;43m(\u001b[39;49m\n\u001b[32m 494\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mmethod\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 495\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43murl\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 496\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mbody\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mbody\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 497\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mheaders\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mheaders\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 498\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mchunked\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mchunked\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 499\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mpreload_content\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mpreload_content\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 500\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mdecode_content\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mdecode_content\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 501\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43menforce_content_length\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43menforce_content_length\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 502\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 504\u001b[39m \u001b[38;5;66;03m# We are swallowing BrokenPipeError (errno.EPIPE) since the server is\u001b[39;00m\n\u001b[32m 505\u001b[39m \u001b[38;5;66;03m# legitimately able to close the connection after sending a valid response.\u001b[39;00m\n\u001b[32m 506\u001b[39m \u001b[38;5;66;03m# With this behaviour, the received response is still readable.\u001b[39;00m\n",
"\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/urllib3/connection.py:500\u001b[39m, in \u001b[36mHTTPConnection.request\u001b[39m\u001b[34m(self, method, url, body, headers, chunked, preload_content, decode_content, enforce_content_length)\u001b[39m\n\u001b[32m 499\u001b[39m \u001b[38;5;28mself\u001b[39m.putheader(header, value)\n\u001b[32m--> \u001b[39m\u001b[32m500\u001b[39m \u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mendheaders\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 502\u001b[39m \u001b[38;5;66;03m# If we're given a body we start sending that in chunks.\u001b[39;00m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/http/client.py:1333\u001b[39m, in \u001b[36mHTTPConnection.endheaders\u001b[39m\u001b[34m(self, message_body, encode_chunked)\u001b[39m\n\u001b[32m 1332\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m CannotSendHeader()\n\u001b[32m-> \u001b[39m\u001b[32m1333\u001b[39m \u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43m_send_output\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mmessage_body\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43mencode_chunked\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mencode_chunked\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/http/client.py:1093\u001b[39m, in \u001b[36mHTTPConnection._send_output\u001b[39m\u001b[34m(self, message_body, encode_chunked)\u001b[39m\n\u001b[32m 1092\u001b[39m \u001b[38;5;28;01mdel\u001b[39;00m \u001b[38;5;28mself\u001b[39m._buffer[:]\n\u001b[32m-> \u001b[39m\u001b[32m1093\u001b[39m \u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43msend\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mmsg\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 1095\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m message_body \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 1096\u001b[39m \n\u001b[32m 1097\u001b[39m \u001b[38;5;66;03m# create a consistent interface to message_body\u001b[39;00m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/http/client.py:1037\u001b[39m, in \u001b[36mHTTPConnection.send\u001b[39m\u001b[34m(self, data)\u001b[39m\n\u001b[32m 1036\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.auto_open:\n\u001b[32m-> \u001b[39m\u001b[32m1037\u001b[39m \u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mconnect\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 1038\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n",
"\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/urllib3/connection.py:331\u001b[39m, in \u001b[36mHTTPConnection.connect\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 330\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mconnect\u001b[39m(\u001b[38;5;28mself\u001b[39m) -> \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m331\u001b[39m \u001b[38;5;28mself\u001b[39m.sock = \u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43m_new_conn\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 332\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._tunnel_host:\n\u001b[32m 333\u001b[39m \u001b[38;5;66;03m# If we're tunneling it means we're connected to our proxy.\u001b[39;00m\n",
"\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/urllib3/connection.py:219\u001b[39m, in \u001b[36mHTTPConnection._new_conn\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 218\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mOSError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[32m--> \u001b[39m\u001b[32m219\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m NewConnectionError(\n\u001b[32m 220\u001b[39m \u001b[38;5;28mself\u001b[39m, \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mFailed to establish a new connection: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m 221\u001b[39m ) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01me\u001b[39;00m\n\u001b[32m 223\u001b[39m sys.audit(\u001b[33m\"\u001b[39m\u001b[33mhttp.client.connect\u001b[39m\u001b[33m\"\u001b[39m, \u001b[38;5;28mself\u001b[39m, \u001b[38;5;28mself\u001b[39m.host, \u001b[38;5;28mself\u001b[39m.port)\n",
"\u001b[31mNewConnectionError\u001b[39m: HTTPConnection(host='localhost', port=8000): Failed to establish a new connection: [Errno 111] Connection refused",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[31mMaxRetryError\u001b[39m Traceback (most recent call last)",
"\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/requests/adapters.py:645\u001b[39m, in \u001b[36mHTTPAdapter.send\u001b[39m\u001b[34m(self, request, stream, timeout, verify, cert, proxies)\u001b[39m\n\u001b[32m 644\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m645\u001b[39m resp = \u001b[30;43mconn\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43murlopen\u001b[39;49m\u001b[30;43m(\u001b[39;49m\n\u001b[32m 646\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mmethod\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mrequest\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mmethod\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 647\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43murl\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43murl\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 648\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mbody\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mrequest\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mbody\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 649\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mheaders\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mrequest\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mheaders\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 650\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mredirect\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43;01mFalse\u001b[39;49;00m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 651\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43massert_same_host\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43;01mFalse\u001b[39;49;00m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 652\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mpreload_content\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43;01mFalse\u001b[39;49;00m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 653\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mdecode_content\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43;01mFalse\u001b[39;49;00m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 654\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mretries\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mmax_retries\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 655\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mtimeout\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mtimeout\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 656\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mchunked\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mchunked\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 657\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 659\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m (ProtocolError, \u001b[38;5;167;01mOSError\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m err:\n",
"\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/urllib3/connectionpool.py:841\u001b[39m, in \u001b[36mHTTPConnectionPool.urlopen\u001b[39m\u001b[34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, preload_content, decode_content, **response_kw)\u001b[39m\n\u001b[32m 839\u001b[39m new_e = ProtocolError(\u001b[33m\"\u001b[39m\u001b[33mConnection aborted.\u001b[39m\u001b[33m\"\u001b[39m, new_e)\n\u001b[32m--> \u001b[39m\u001b[32m841\u001b[39m retries = \u001b[30;43mretries\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mincrement\u001b[39;49m\u001b[30;43m(\u001b[39;49m\n\u001b[32m 842\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mmethod\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43murl\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43merror\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mnew_e\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m_pool\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mself\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m_stacktrace\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43msys\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mexc_info\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43m)\u001b[39;49m\u001b[30;43m[\u001b[39;49m\u001b[30;43m2\u001b[39;49m\u001b[30;43m]\u001b[39;49m\n\u001b[32m 843\u001b[39m \u001b[30;43m\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 844\u001b[39m retries.sleep()\n",
"\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/urllib3/util/retry.py:535\u001b[39m, in \u001b[36mRetry.increment\u001b[39m\u001b[34m(self, method, url, response, error, _pool, _stacktrace)\u001b[39m\n\u001b[32m 534\u001b[39m reason = error \u001b[38;5;129;01mor\u001b[39;00m ResponseError(cause)\n\u001b[32m--> \u001b[39m\u001b[32m535\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m MaxRetryError(_pool, url, reason) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mreason\u001b[39;00m \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n\u001b[32m 537\u001b[39m log.debug(\u001b[33m\"\u001b[39m\u001b[33mIncremented Retry for (url=\u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m): \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[33m\"\u001b[39m, url, new_retry)\n",
"\u001b[31mMaxRetryError\u001b[39m: HTTPConnectionPool(host='localhost', port=8000): Max retries exceeded with url: /health (Caused by NewConnectionError(\"HTTPConnection(host='localhost', port=8000): Failed to establish a new connection: [Errno 111] Connection refused\"))",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[31mConnectionError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[7]\u001b[39m\u001b[32m, line 7\u001b[39m\n\u001b[32m 5\u001b[39m print(f\"Server health check PASSED ({ENV_URL})\")\n\u001b[32m 6\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m Exception \u001b[38;5;28;01mas\u001b[39;00m exc:\n\u001b[32m----> \u001b[39m\u001b[32m7\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m RuntimeError(f\"Server not reachable at {ENV_URL}: {exc}\")\n\u001b[32m 8\u001b[39m \n",
"\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/requests/api.py:73\u001b[39m, in \u001b[36mget\u001b[39m\u001b[34m(url, params, **kwargs)\u001b[39m\n\u001b[32m 63\u001b[39m \u001b[38;5;250m\u001b[39m\u001b[33mr\u001b[39m\u001b[33;03m\"\"\"Sends a GET request.\u001b[39;00m\n\u001b[32m 64\u001b[39m \n\u001b[32m 65\u001b[39m \u001b[33;03m:param url: URL for the new :class:`Request` object.\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 70\u001b[39m \u001b[33;03m:rtype: requests.Response\u001b[39;00m\n\u001b[32m 71\u001b[39m \u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m73\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[30;43mrequest\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43m\"\u001b[39;49m\u001b[30;43mget\u001b[39;49m\u001b[30;43m\"\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43murl\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43mparams\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mparams\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43mkwargs\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/requests/api.py:59\u001b[39m, in \u001b[36mrequest\u001b[39m\u001b[34m(method, url, **kwargs)\u001b[39m\n\u001b[32m 58\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m sessions.Session() \u001b[38;5;28;01mas\u001b[39;00m session:\n\u001b[32m---> \u001b[39m\u001b[32m59\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[30;43msession\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mrequest\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mmethod\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mmethod\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43murl\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43murl\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43mkwargs\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/requests/sessions.py:592\u001b[39m, in \u001b[36mSession.request\u001b[39m\u001b[34m(self, method, url, params, data, headers, cookies, files, auth, timeout, allow_redirects, proxies, hooks, stream, verify, cert, json)\u001b[39m\n\u001b[32m 591\u001b[39m send_kwargs.update(settings)\n\u001b[32m--> \u001b[39m\u001b[32m592\u001b[39m resp = \u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43msend\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mprep\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43msend_kwargs\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 594\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m resp\n",
"\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/requests/sessions.py:706\u001b[39m, in \u001b[36mSession.send\u001b[39m\u001b[34m(self, request, **kwargs)\u001b[39m\n\u001b[32m 705\u001b[39m \u001b[38;5;66;03m# Send the request\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m706\u001b[39m r = \u001b[30;43madapter\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43msend\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mrequest\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43mkwargs\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 708\u001b[39m \u001b[38;5;66;03m# Total elapsed time of the request (approximately)\u001b[39;00m\n",
"\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/requests/adapters.py:678\u001b[39m, in \u001b[36mHTTPAdapter.send\u001b[39m\u001b[34m(self, request, stream, timeout, verify, cert, proxies)\u001b[39m\n\u001b[32m 676\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m SSLError(e, request=request)\n\u001b[32m--> \u001b[39m\u001b[32m678\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mConnectionError\u001b[39;00m(e, request=request)\n\u001b[32m 680\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m ClosedPoolError \u001b[38;5;28;01mas\u001b[39;00m e:\n",
"\u001b[31mConnectionError\u001b[39m: HTTPConnectionPool(host='localhost', port=8000): Max retries exceeded with url: /health (Caused by NewConnectionError(\"HTTPConnection(host='localhost', port=8000): Failed to establish a new connection: [Errno 111] Connection refused\"))",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[7]\u001b[39m\u001b[32m, line 7\u001b[39m\n\u001b[32m 3\u001b[39m health = requests.get(f\"{ENV_URL}/health\", timeout=\u001b[32m5\u001b[39m)\n\u001b[32m 4\u001b[39m health.raise_for_status()\n\u001b[32m 5\u001b[39m print(f\"Server health check PASSED ({ENV_URL})\")\n\u001b[32m 6\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m Exception \u001b[38;5;28;01mas\u001b[39;00m exc:\n\u001b[32m----> \u001b[39m\u001b[32m7\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m RuntimeError(f\"Server not reachable at {ENV_URL}: {exc}\")\n\u001b[32m 8\u001b[39m \n\u001b[32m 9\u001b[39m \u001b[38;5;66;03m# Build and load the HF model — this is the expensive step\u001b[39;00m\n\u001b[32m 10\u001b[39m llm = HFChatModel(\n",
"\u001b[31mRuntimeError\u001b[39m: Server not reachable at http://localhost:8000: HTTPConnectionPool(host='localhost', port=8000): Max retries exceeded with url: /health (Caused by NewConnectionError(\"HTTPConnection(host='localhost', port=8000): Failed to establish a new connection: [Errno 111] Connection refused\"))"
]
}
],
"source": [
"# Health check first — fail fast before waiting for the large model to load\n",
"try:\n",
" health = requests.get(f\"{ENV_URL}/health\", timeout=5)\n",
" health.raise_for_status()\n",
" print(f\"Server health check PASSED ({ENV_URL})\")\n",
"except Exception as exc:\n",
" raise RuntimeError(f\"Server not reachable at {ENV_URL}: {exc}\")\n",
"\n",
"# Build and load the HF model — this is the expensive step\n",
"llm = HFChatModel(\n",
" model_id=MODEL_ID,\n",
" temperature=TEMPERATURE,\n",
" max_new_tokens=MAX_NEW_TOKENS,\n",
").load(load_4bit=LOAD_4BIT)\n",
"\n",
"print(f\"\\nModel ready: {MODEL_ID}\")"
]
},
{
"cell_type": "markdown",
"id": "d6e7f8a9",
"metadata": {},
"source": [
"## Cell 8 — Run Evaluation Loop"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e7f8a9b0",
"metadata": {},
"outputs": [],
"source": [
"results: List[Dict[str, Any]] = []\n",
"debug_dir = Path(OUTPUT_DIR) / \"debug\"\n",
"total = len(difficulties_to_run) * len(SEEDS)\n",
"\n",
"for done_count, (diff_cfg, seed) in enumerate(\n",
" [(d, s) for d in difficulties_to_run for s in SEEDS], start=1\n",
"):\n",
" difficulty = diff_cfg[\"difficulty\"]\n",
" max_steps = diff_cfg[\"max_steps\"]\n",
"\n",
" log.info(\"[%d/%d] difficulty=%-8s seed=%d\", done_count, total, difficulty, seed)\n",
"\n",
" result = run_episode(\n",
" llm=llm,\n",
" difficulty=difficulty,\n",
" seed=seed,\n",
" max_steps=max_steps,\n",
" debug_dir=debug_dir,\n",
" )\n",
" results.append(result)\n",
"\n",
" if result.get(\"error\"):\n",
" log.warning(\" ERROR: %s\", result[\"error\"])\n",
" else:\n",
" evac_sym = \"✓ EVACUATED\" if result[\"evacuated\"] else f\"✗ {result['cause_of_end'].upper()}\"\n",
" log.info(\n",
" \" %-14s health=%5.1f reward=%+.2f steps=%d/%d think=%d%% parse=%d%%\",\n",
" evac_sym,\n",
" result[\"final_health\"],\n",
" result[\"total_reward\"],\n",
" result[\"steps_taken\"],\n",
" max_steps,\n",
" result[\"think_rate\"] * 100,\n",
" result[\"parse_rate\"] * 100,\n",
" )\n",
"\n",
"print(f\"\\nAll {total} episodes complete.\")"
]
},
{
"cell_type": "markdown",
"id": "f8a9b0c1",
"metadata": {},
"source": [
"## Cell 9 — Summary Table & CSV Export"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a9b0c1d2",
"metadata": {},
"outputs": [],
"source": [
"def _avg(vals: List[float]) -> float:\n",
" return round(sum(vals) / len(vals), 4) if vals else 0.0\n",
"\n",
"\n",
"by_diff: Dict[str, List[Dict]] = defaultdict(list)\n",
"for r in results:\n",
" if not r.get(\"error\"):\n",
" by_diff[r[\"difficulty\"]].append(r)\n",
"\n",
"col = \"{:<10} {:>8} {:>10} {:>10} {:>11} {:>9} {:>9} {:>9}\"\n",
"header = col.format(\n",
" \"Difficulty\", \"Evac%\", \"AvgHealth\", \"TotalRew\",\n",
" \"MeanStpRew\", \"Steps/Max\", \"Think%\", \"Parse%\",\n",
")\n",
"sep = \"=\" * len(header)\n",
"\n",
"print(f\"\\n{sep}\")\n",
"print(f\" PYRE EVAL — model: {MODEL_ID}\")\n",
"print(sep)\n",
"print(header)\n",
"print(\"-\" * len(header))\n",
"\n",
"diagnosis: List[str] = []\n",
"\n",
"for cfg in ALL_DIFFICULTIES:\n",
" diff = cfg[\"difficulty\"]\n",
" max_steps = cfg[\"max_steps\"]\n",
" rows = by_diff.get(diff, [])\n",
" if not rows:\n",
" print(col.format(diff, \"n/a\", \"n/a\", \"n/a\", \"n/a\", \"n/a\", \"n/a\", \"n/a\"))\n",
" continue\n",
"\n",
" evac_rate = _avg([r[\"evacuated\"] for r in rows])\n",
" avg_health = _avg([r[\"final_health\"] for r in rows])\n",
" avg_total_rew = _avg([r[\"total_reward\"] for r in rows])\n",
" avg_step_rew = _avg([r[\"mean_step_reward\"] for r in rows])\n",
" avg_steps = _avg([r[\"steps_taken\"] for r in rows])\n",
" avg_think = _avg([r[\"think_rate\"] for r in rows])\n",
" avg_parse = _avg([r[\"parse_rate\"] for r in rows])\n",
"\n",
" cause_counts: Dict[str, int] = defaultdict(int)\n",
" for r in rows:\n",
" cause_counts[r[\"cause_of_end\"]] += 1\n",
"\n",
" print(col.format(\n",
" diff,\n",
" f\"{evac_rate * 100:.0f}%\",\n",
" f\"{avg_health:.1f}\",\n",
" f\"{avg_total_rew:+.2f}\",\n",
" f\"{avg_step_rew:+.3f}\",\n",
" f\"{avg_steps:.0f}/{max_steps}\",\n",
" f\"{avg_think * 100:.0f}%\",\n",
" f\"{avg_parse * 100:.0f}%\",\n",
" ))\n",
"\n",
" # Accumulate diagnosis hints\n",
" n = len(rows)\n",
" death_n = cause_counts.get(\"death\", 0)\n",
" timeout_n = cause_counts.get(\"timeout\", 0)\n",
" evac_n = cause_counts.get(\"evacuated\", 0)\n",
"\n",
" if avg_parse < 0.80:\n",
" diagnosis.append(\n",
" f\"[{diff}] Low parse rate ({avg_parse*100:.0f}%) — \"\n",
" \"action schema may be unclear; consider simplifying JSON keys.\"\n",
" )\n",
" if avg_think < 0.50:\n",
" diagnosis.append(\n",
" f\"[{diff}] Low think rate ({avg_think*100:.0f}%) — \"\n",
" \"model is skipping tags; prompt may need stronger CoT instruction.\"\n",
" )\n",
" if evac_n == 0 and diff == \"easy\":\n",
" diagnosis.append(\n",
" \"[easy] Zero evacuations on easy difficulty — \"\n",
" \"check exit reachability, narrative clarity, or BFS distance from spawn.\"\n",
" )\n",
" if n > 0 and death_n / n > 0.80:\n",
" diagnosis.append(\n",
" f\"[{diff}] {death_n}/{n} episodes end in death — \"\n",
" \"damage rates may be too high or smoke/fire proximity warnings are not clear.\"\n",
" )\n",
" if n > 0 and timeout_n / n > 0.80:\n",
" diagnosis.append(\n",
" f\"[{diff}] {timeout_n}/{n} episodes time out — \"\n",
" \"model may be looping/waiting; exits might be hard to locate from narratives.\"\n",
" )\n",
" if avg_step_rew < -0.3 and evac_rate < 0.1:\n",
" diagnosis.append(\n",
" f\"[{diff}] Very negative mean step reward ({avg_step_rew:+.3f}) with near-zero \"\n",
" \"success — model may be actively moving into fire; check DangerPenalty trigger conditions.\"\n",
" )\n",
"\n",
"print(sep + \"\\n\")\n",
"\n",
"if diagnosis:\n",
" print(\"DIAGNOSTICS (environment design signals)\")\n",
" print(\"-\" * 50)\n",
" for d in diagnosis:\n",
" print(f\" • {d}\")\n",
" print()\n",
"else:\n",
" print(\" No red-flag patterns detected — environment appears legible to this model.\\n\")\n",
"\n",
"# ── Save CSV ───────────────────────────────────────────────────────────────────\n",
"timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
"model_slug = MODEL_ID.replace(\"/\", \"_\").replace(\".\", \"-\")\n",
"csv_path = Path(OUTPUT_DIR) / f\"hf_baseline_{model_slug}_{timestamp}.csv\"\n",
"csv_path.parent.mkdir(parents=True, exist_ok=True)\n",
"\n",
"fieldnames = [\n",
" \"difficulty\", \"seed\", \"evacuated\", \"cause_of_end\",\n",
" \"final_health\", \"total_reward\", \"mean_step_reward\",\n",
" \"steps_taken\", \"max_steps\",\n",
" \"think_rate\", \"parse_rate\", \"format_score_avg\", \"error\",\n",
"]\n",
"with open(csv_path, \"w\", newline=\"\") as f:\n",
" writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction=\"ignore\")\n",
" writer.writeheader()\n",
" writer.writerows(results)\n",
"\n",
"print(f\"CSV saved → {csv_path}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}