diff --git a/Dockerfile b/Dockerfile index e83ae24d040673da9e0c4e8021aa19d2824fb43e..e0130af1c8e0fb06708f6eaf0b2fbcffa680468c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -77,5 +77,4 @@ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ # Run the FastAPI server # The module path is constructed to work with the /app/env structure -ENV ENABLE_WEB_INTERFACE=true CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"] diff --git a/README.md b/README.md index b4adaf9f6948ff8813c8fe493ad120abab00bad4..a530ea4736e06129e533b358dce5eaf24bd02479 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,6 @@ colorTo: yellow sdk: docker pinned: false app_port: 8000 -base_path: /web tags: - openenv --- diff --git a/evals_hf.ipynb b/evals_hf.ipynb index 484847995750a4af72b775dcc0d66572c17ae7ad..3c500281fa0e927989444716a646bd5f81592d5e 100644 --- a/evals_hf.ipynb +++ b/evals_hf.ipynb @@ -26,10 +26,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "c3d4e5f6", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Imports OK\n" + ] + } + ], "source": [ "import csv\n", "import json\n", @@ -65,7 +73,7 @@ }, { "cell_type": "markdown", - "id": "d4e5f6a7", + "id": "1f1d9271", "metadata": {}, "source": [ "## Cell 2 — Config ✏️ Edit this cell to change eval settings" @@ -73,10 +81,23 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "e5f6a7b8", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model : Qwen/Qwen3-1.7B\n", + "4-bit : False\n", + "Temperature: 0.3\n", + "Levels : all 3 (easy, medium, hard)\n", + "Seeds : [1, 2, 3]\n", + "Output dir : ./outputs/hf_evals\n" + ] + } + ], "source": [ "# ── Model ─────────────────────────────────────────────────────────────────────\n", "MODEL_ID = \"Qwen/Qwen3-1.7B\" # HF model ID or local path / adapter dir\n", @@ -123,10 +144,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "a7b8c9d0", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Difficulties to evaluate : ['easy', 'medium', 'hard']\n", + "Total episodes : 9\n" + ] + } + ], "source": [ "ALL_DIFFICULTIES: List[Dict[str, Any]] = [\n", " {\"difficulty\": \"easy\", \"max_steps\": 200, \"description\": \"1 fire source · slow spread · calm wind · high humidity\"},\n", @@ -195,10 +225,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "c9d0e1f2", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "HFChatModel class defined.\n" + ] + } + ], "source": [ "class HFChatModel(SimpleChatModel):\n", " \"\"\"\n", @@ -308,10 +346,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "e1f2a3b4", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prompt builder and action parser defined.\n" + ] + } + ], "source": [ "def _build_user_message(obs: Dict[str, Any], history: List[str]) -> str:\n", " \"\"\"Convert a raw observation dict + history into the LLM user message.\"\"\"\n", @@ -466,10 +512,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "a3b4c5d6", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Episode runner defined.\n" + ] + } + ], "source": [ "def run_episode(\n", " llm: HFChatModel,\n", @@ -637,10 +691,54 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "c5d6e7f8", "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "RuntimeError", + "evalue": "Server not reachable at http://localhost:8000: HTTPConnectionPool(host='localhost', port=8000): Max retries exceeded with url: /health (Caused by NewConnectionError(\"HTTPConnection(host='localhost', port=8000): Failed to establish a new connection: [Errno 111] Connection refused\"))", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mConnectionRefusedError\u001b[39m Traceback (most recent call last)", + "\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/urllib3/connection.py:204\u001b[39m, in \u001b[36mHTTPConnection._new_conn\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 203\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m204\u001b[39m sock = \u001b[30;43mconnection\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mcreate_connection\u001b[39;49m\u001b[30;43m(\u001b[39;49m\n\u001b[32m 205\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43m_dns_host\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mport\u001b[39;49m\u001b[30;43m)\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 206\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mtimeout\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 207\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43msource_address\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43msource_address\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 208\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43msocket_options\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43msocket_options\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 209\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 210\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m socket.gaierror \u001b[38;5;28;01mas\u001b[39;00m e:\n", + "\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/urllib3/util/connection.py:85\u001b[39m, in \u001b[36mcreate_connection\u001b[39m\u001b[34m(address, timeout, source_address, socket_options)\u001b[39m\n\u001b[32m 84\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m85\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m err\n\u001b[32m 86\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 87\u001b[39m \u001b[38;5;66;03m# Break explicitly a reference cycle\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/urllib3/util/connection.py:73\u001b[39m, in \u001b[36mcreate_connection\u001b[39m\u001b[34m(address, timeout, source_address, socket_options)\u001b[39m\n\u001b[32m 72\u001b[39m sock.bind(source_address)\n\u001b[32m---> \u001b[39m\u001b[32m73\u001b[39m \u001b[30;43msock\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mconnect\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43msa\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 74\u001b[39m \u001b[38;5;66;03m# Break explicitly a reference cycle\u001b[39;00m\n", + "\u001b[31mConnectionRefusedError\u001b[39m: [Errno 111] Connection refused", + "\nThe above exception was the direct cause of the following exception:\n", + "\u001b[31mNewConnectionError\u001b[39m Traceback (most recent call last)", + "\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/urllib3/connectionpool.py:787\u001b[39m, in \u001b[36mHTTPConnectionPool.urlopen\u001b[39m\u001b[34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, preload_content, decode_content, **response_kw)\u001b[39m\n\u001b[32m 786\u001b[39m \u001b[38;5;66;03m# Make the request on the HTTPConnection object\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m787\u001b[39m response = \u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43m_make_request\u001b[39;49m\u001b[30;43m(\u001b[39;49m\n\u001b[32m 788\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mconn\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 789\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mmethod\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 790\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43murl\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 791\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mtimeout\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mtimeout_obj\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 792\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mbody\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mbody\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 793\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mheaders\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mheaders\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 794\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mchunked\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mchunked\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 795\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mretries\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mretries\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 796\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mresponse_conn\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mresponse_conn\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 797\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mpreload_content\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mpreload_content\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 798\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mdecode_content\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mdecode_content\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 799\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43mresponse_kw\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 800\u001b[39m \u001b[30;43m\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 802\u001b[39m \u001b[38;5;66;03m# Everything went great!\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/urllib3/connectionpool.py:493\u001b[39m, in \u001b[36mHTTPConnectionPool._make_request\u001b[39m\u001b[34m(self, conn, method, url, body, headers, retries, timeout, chunked, response_conn, preload_content, decode_content, enforce_content_length)\u001b[39m\n\u001b[32m 492\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m493\u001b[39m \u001b[30;43mconn\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mrequest\u001b[39;49m\u001b[30;43m(\u001b[39;49m\n\u001b[32m 494\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mmethod\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 495\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43murl\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 496\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mbody\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mbody\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 497\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mheaders\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mheaders\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 498\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mchunked\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mchunked\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 499\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mpreload_content\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mpreload_content\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 500\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mdecode_content\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mdecode_content\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 501\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43menforce_content_length\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43menforce_content_length\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 502\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 504\u001b[39m \u001b[38;5;66;03m# We are swallowing BrokenPipeError (errno.EPIPE) since the server is\u001b[39;00m\n\u001b[32m 505\u001b[39m \u001b[38;5;66;03m# legitimately able to close the connection after sending a valid response.\u001b[39;00m\n\u001b[32m 506\u001b[39m \u001b[38;5;66;03m# With this behaviour, the received response is still readable.\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/urllib3/connection.py:500\u001b[39m, in \u001b[36mHTTPConnection.request\u001b[39m\u001b[34m(self, method, url, body, headers, chunked, preload_content, decode_content, enforce_content_length)\u001b[39m\n\u001b[32m 499\u001b[39m \u001b[38;5;28mself\u001b[39m.putheader(header, value)\n\u001b[32m--> \u001b[39m\u001b[32m500\u001b[39m \u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mendheaders\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 502\u001b[39m \u001b[38;5;66;03m# If we're given a body we start sending that in chunks.\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/http/client.py:1333\u001b[39m, in \u001b[36mHTTPConnection.endheaders\u001b[39m\u001b[34m(self, message_body, encode_chunked)\u001b[39m\n\u001b[32m 1332\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m CannotSendHeader()\n\u001b[32m-> \u001b[39m\u001b[32m1333\u001b[39m \u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43m_send_output\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mmessage_body\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43mencode_chunked\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mencode_chunked\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/http/client.py:1093\u001b[39m, in \u001b[36mHTTPConnection._send_output\u001b[39m\u001b[34m(self, message_body, encode_chunked)\u001b[39m\n\u001b[32m 1092\u001b[39m \u001b[38;5;28;01mdel\u001b[39;00m \u001b[38;5;28mself\u001b[39m._buffer[:]\n\u001b[32m-> \u001b[39m\u001b[32m1093\u001b[39m \u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43msend\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mmsg\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 1095\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m message_body \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 1096\u001b[39m \n\u001b[32m 1097\u001b[39m \u001b[38;5;66;03m# create a consistent interface to message_body\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/http/client.py:1037\u001b[39m, in \u001b[36mHTTPConnection.send\u001b[39m\u001b[34m(self, data)\u001b[39m\n\u001b[32m 1036\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.auto_open:\n\u001b[32m-> \u001b[39m\u001b[32m1037\u001b[39m \u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mconnect\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 1038\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n", + "\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/urllib3/connection.py:331\u001b[39m, in \u001b[36mHTTPConnection.connect\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 330\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mconnect\u001b[39m(\u001b[38;5;28mself\u001b[39m) -> \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m331\u001b[39m \u001b[38;5;28mself\u001b[39m.sock = \u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43m_new_conn\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 332\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._tunnel_host:\n\u001b[32m 333\u001b[39m \u001b[38;5;66;03m# If we're tunneling it means we're connected to our proxy.\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/urllib3/connection.py:219\u001b[39m, in \u001b[36mHTTPConnection._new_conn\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 218\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mOSError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[32m--> \u001b[39m\u001b[32m219\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m NewConnectionError(\n\u001b[32m 220\u001b[39m \u001b[38;5;28mself\u001b[39m, \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mFailed to establish a new connection: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m 221\u001b[39m ) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01me\u001b[39;00m\n\u001b[32m 223\u001b[39m sys.audit(\u001b[33m\"\u001b[39m\u001b[33mhttp.client.connect\u001b[39m\u001b[33m\"\u001b[39m, \u001b[38;5;28mself\u001b[39m, \u001b[38;5;28mself\u001b[39m.host, \u001b[38;5;28mself\u001b[39m.port)\n", + "\u001b[31mNewConnectionError\u001b[39m: HTTPConnection(host='localhost', port=8000): Failed to establish a new connection: [Errno 111] Connection refused", + "\nThe above exception was the direct cause of the following exception:\n", + "\u001b[31mMaxRetryError\u001b[39m Traceback (most recent call last)", + "\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/requests/adapters.py:645\u001b[39m, in \u001b[36mHTTPAdapter.send\u001b[39m\u001b[34m(self, request, stream, timeout, verify, cert, proxies)\u001b[39m\n\u001b[32m 644\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m645\u001b[39m resp = \u001b[30;43mconn\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43murlopen\u001b[39;49m\u001b[30;43m(\u001b[39;49m\n\u001b[32m 646\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mmethod\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mrequest\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mmethod\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 647\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43murl\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43murl\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 648\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mbody\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mrequest\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mbody\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 649\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mheaders\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mrequest\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mheaders\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 650\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mredirect\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43;01mFalse\u001b[39;49;00m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 651\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43massert_same_host\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43;01mFalse\u001b[39;49;00m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 652\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mpreload_content\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43;01mFalse\u001b[39;49;00m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 653\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mdecode_content\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43;01mFalse\u001b[39;49;00m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 654\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mretries\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mmax_retries\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 655\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mtimeout\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mtimeout\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 656\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mchunked\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mchunked\u001b[39;49m\u001b[30;43m,\u001b[39;49m\n\u001b[32m 657\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 659\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m (ProtocolError, \u001b[38;5;167;01mOSError\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m err:\n", + "\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/urllib3/connectionpool.py:841\u001b[39m, in \u001b[36mHTTPConnectionPool.urlopen\u001b[39m\u001b[34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, preload_content, decode_content, **response_kw)\u001b[39m\n\u001b[32m 839\u001b[39m new_e = ProtocolError(\u001b[33m\"\u001b[39m\u001b[33mConnection aborted.\u001b[39m\u001b[33m\"\u001b[39m, new_e)\n\u001b[32m--> \u001b[39m\u001b[32m841\u001b[39m retries = \u001b[30;43mretries\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mincrement\u001b[39;49m\u001b[30;43m(\u001b[39;49m\n\u001b[32m 842\u001b[39m \u001b[30;43m \u001b[39;49m\u001b[30;43mmethod\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43murl\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43merror\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mnew_e\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m_pool\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mself\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m_stacktrace\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43msys\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mexc_info\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43m)\u001b[39;49m\u001b[30;43m[\u001b[39;49m\u001b[30;43m2\u001b[39;49m\u001b[30;43m]\u001b[39;49m\n\u001b[32m 843\u001b[39m \u001b[30;43m\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 844\u001b[39m retries.sleep()\n", + "\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/urllib3/util/retry.py:535\u001b[39m, in \u001b[36mRetry.increment\u001b[39m\u001b[34m(self, method, url, response, error, _pool, _stacktrace)\u001b[39m\n\u001b[32m 534\u001b[39m reason = error \u001b[38;5;129;01mor\u001b[39;00m ResponseError(cause)\n\u001b[32m--> \u001b[39m\u001b[32m535\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m MaxRetryError(_pool, url, reason) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mreason\u001b[39;00m \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n\u001b[32m 537\u001b[39m log.debug(\u001b[33m\"\u001b[39m\u001b[33mIncremented Retry for (url=\u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m): \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[33m\"\u001b[39m, url, new_retry)\n", + "\u001b[31mMaxRetryError\u001b[39m: HTTPConnectionPool(host='localhost', port=8000): Max retries exceeded with url: /health (Caused by NewConnectionError(\"HTTPConnection(host='localhost', port=8000): Failed to establish a new connection: [Errno 111] Connection refused\"))", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[31mConnectionError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[7]\u001b[39m\u001b[32m, line 7\u001b[39m\n\u001b[32m 5\u001b[39m print(f\"Server health check PASSED ({ENV_URL})\")\n\u001b[32m 6\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m Exception \u001b[38;5;28;01mas\u001b[39;00m exc:\n\u001b[32m----> \u001b[39m\u001b[32m7\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m RuntimeError(f\"Server not reachable at {ENV_URL}: {exc}\")\n\u001b[32m 8\u001b[39m \n", + "\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/requests/api.py:73\u001b[39m, in \u001b[36mget\u001b[39m\u001b[34m(url, params, **kwargs)\u001b[39m\n\u001b[32m 63\u001b[39m \u001b[38;5;250m\u001b[39m\u001b[33mr\u001b[39m\u001b[33;03m\"\"\"Sends a GET request.\u001b[39;00m\n\u001b[32m 64\u001b[39m \n\u001b[32m 65\u001b[39m \u001b[33;03m:param url: URL for the new :class:`Request` object.\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 70\u001b[39m \u001b[33;03m:rtype: requests.Response\u001b[39;00m\n\u001b[32m 71\u001b[39m \u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m73\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[30;43mrequest\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43m\"\u001b[39;49m\u001b[30;43mget\u001b[39;49m\u001b[30;43m\"\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43murl\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43mparams\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mparams\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43mkwargs\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/requests/api.py:59\u001b[39m, in \u001b[36mrequest\u001b[39m\u001b[34m(method, url, **kwargs)\u001b[39m\n\u001b[32m 58\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m sessions.Session() \u001b[38;5;28;01mas\u001b[39;00m session:\n\u001b[32m---> \u001b[39m\u001b[32m59\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[30;43msession\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43mrequest\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mmethod\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43mmethod\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43murl\u001b[39;49m\u001b[30;43m=\u001b[39;49m\u001b[30;43murl\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43mkwargs\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/requests/sessions.py:592\u001b[39m, in \u001b[36mSession.request\u001b[39m\u001b[34m(self, method, url, params, data, headers, cookies, files, auth, timeout, allow_redirects, proxies, hooks, stream, verify, cert, json)\u001b[39m\n\u001b[32m 591\u001b[39m send_kwargs.update(settings)\n\u001b[32m--> \u001b[39m\u001b[32m592\u001b[39m resp = \u001b[30;43mself\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43msend\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mprep\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43msend_kwargs\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 594\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m resp\n", + "\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/requests/sessions.py:706\u001b[39m, in \u001b[36mSession.send\u001b[39m\u001b[34m(self, request, **kwargs)\u001b[39m\n\u001b[32m 705\u001b[39m \u001b[38;5;66;03m# Send the request\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m706\u001b[39m r = \u001b[30;43madapter\u001b[39;49m\u001b[30;43m.\u001b[39;49m\u001b[30;43msend\u001b[39;49m\u001b[30;43m(\u001b[39;49m\u001b[30;43mrequest\u001b[39;49m\u001b[30;43m,\u001b[39;49m\u001b[30;43m \u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43m*\u001b[39;49m\u001b[30;43mkwargs\u001b[39;49m\u001b[30;43m)\u001b[39;49m\n\u001b[32m 708\u001b[39m \u001b[38;5;66;03m# Total elapsed time of the request (approximately)\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/dccstor/kirushikesh/personal-projects/openenv-pyre/.venv/lib/python3.12/site-packages/requests/adapters.py:678\u001b[39m, in \u001b[36mHTTPAdapter.send\u001b[39m\u001b[34m(self, request, stream, timeout, verify, cert, proxies)\u001b[39m\n\u001b[32m 676\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m SSLError(e, request=request)\n\u001b[32m--> \u001b[39m\u001b[32m678\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mConnectionError\u001b[39;00m(e, request=request)\n\u001b[32m 680\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m ClosedPoolError \u001b[38;5;28;01mas\u001b[39;00m e:\n", + "\u001b[31mConnectionError\u001b[39m: HTTPConnectionPool(host='localhost', port=8000): Max retries exceeded with url: /health (Caused by NewConnectionError(\"HTTPConnection(host='localhost', port=8000): Failed to establish a new connection: [Errno 111] Connection refused\"))", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[7]\u001b[39m\u001b[32m, line 7\u001b[39m\n\u001b[32m 3\u001b[39m health = requests.get(f\"{ENV_URL}/health\", timeout=\u001b[32m5\u001b[39m)\n\u001b[32m 4\u001b[39m health.raise_for_status()\n\u001b[32m 5\u001b[39m print(f\"Server health check PASSED ({ENV_URL})\")\n\u001b[32m 6\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m Exception \u001b[38;5;28;01mas\u001b[39;00m exc:\n\u001b[32m----> \u001b[39m\u001b[32m7\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m RuntimeError(f\"Server not reachable at {ENV_URL}: {exc}\")\n\u001b[32m 8\u001b[39m \n\u001b[32m 9\u001b[39m \u001b[38;5;66;03m# Build and load the HF model — this is the expensive step\u001b[39;00m\n\u001b[32m 10\u001b[39m llm = HFChatModel(\n", + "\u001b[31mRuntimeError\u001b[39m: Server not reachable at http://localhost:8000: HTTPConnectionPool(host='localhost', port=8000): Max retries exceeded with url: /health (Caused by NewConnectionError(\"HTTPConnection(host='localhost', port=8000): Failed to establish a new connection: [Errno 111] Connection refused\"))" + ] + } + ], "source": [ "# Health check first — fail fast before waiting for the large model to load\n", "try:\n", @@ -855,13 +953,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.10.0" } }, "nbformat": 4, diff --git a/examples/train_rl_agent.py b/examples/train_rl_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..f923c3fc60a65667d8f45ab944c95515578f4f11 --- /dev/null +++ b/examples/train_rl_agent.py @@ -0,0 +1,984 @@ +"""Train a deep RL baseline directly against the local Pyre environment. + +This script makes the environment contract explicit: + - Observation: encoded from `PyreObservation.map_state` into a fixed-length vector + - Action: fixed discrete action table with a runtime validity mask from `available_actions_hint` + - Reward: the environment's composite reward returned by `PyreEnvironment.step()` + +It uses a self-contained NumPy actor-critic implementation so it can run in +this repository without external ML dependencies. + +Examples: + python examples/train_rl_agent.py --episodes 150 --difficulty easy + python examples/train_rl_agent.py --episodes 300 --difficulty-schedule easy,medium + python examples/train_rl_agent.py --episodes 200 --difficulty easy,medium,hard --observation-mode full + python examples/train_rl_agent.py --describe-only +""" + +from __future__ import annotations + +import argparse +import csv +import json +import math +import re +from collections import deque +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Sequence + +import numpy as np + +from pyre_env.models import PyreAction, PyreObservation +from pyre_env.server.pyre_env_environment import PyreEnvironment + + +MAX_GRID_W = 24 +MAX_GRID_H = 24 +MAX_DOORS = 16 +DIRECTIONS = ("north", "south", "west", "east") +WINDS = ("CALM", "NORTH", "SOUTH", "WEST", "EAST") +DIFFICULTIES = ("easy", "medium", "hard") + +MOVE_KEYS = [f"move(direction='{d}')" for d in DIRECTIONS] +LOOK_KEYS = [f"look(direction='{d}')" for d in DIRECTIONS] +WAIT_KEY = "wait()" +OPEN_KEYS = [f"door(target_id='door_{i}', door_state='open')" for i in range(1, MAX_DOORS + 1)] +CLOSE_KEYS = [f"door(target_id='door_{i}', door_state='close')" for i in range(1, MAX_DOORS + 1)] +ACTION_KEYS = MOVE_KEYS + LOOK_KEYS + [WAIT_KEY] + OPEN_KEYS + CLOSE_KEYS +ACTION_DIM = len(ACTION_KEYS) +ACTION_TO_INDEX = {key: idx for idx, key in enumerate(ACTION_KEYS)} + +_MOVE_RE = re.compile(r"move\(direction='(north|south|west|east)'\)") +_LOOK_RE = re.compile(r"look\(direction='(north|south|west|east)'\)") +_DOOR_RE = re.compile(r"door\(target_id='(door_(\d+))', door_state='(open|close)'\)") + + +def _one_hot(index: int, size: int) -> np.ndarray: + arr = np.zeros(size, dtype=np.float32) + if 0 <= index < size: + arr[index] = 1.0 + return arr + + +def action_index_to_env_action(index: int) -> PyreAction: + if 0 <= index < 4: + return PyreAction(action="move", direction=DIRECTIONS[index]) + if 4 <= index < 8: + return PyreAction(action="look", direction=DIRECTIONS[index - 4]) + if index == 8: + return PyreAction(action="wait") + if 9 <= index < 9 + MAX_DOORS: + door_id = f"door_{index - 8}" + return PyreAction(action="door", target_id=door_id, door_state="open") + door_slot = index - (9 + MAX_DOORS) + door_id = f"door_{door_slot + 1}" + return PyreAction(action="door", target_id=door_id, door_state="close") + + +def build_action_mask(observation: PyreObservation) -> np.ndarray: + mask = np.zeros(ACTION_DIM, dtype=np.float32) + for hint in observation.available_actions_hint: + idx = ACTION_TO_INDEX.get(hint) + if idx is not None: + mask[idx] = 1.0 + continue + + match = _MOVE_RE.fullmatch(hint) + if match: + mask[ACTION_TO_INDEX[f"move(direction='{match.group(1)}')"]] = 1.0 + continue + + match = _LOOK_RE.fullmatch(hint) + if match: + mask[ACTION_TO_INDEX[f"look(direction='{match.group(1)}')"]] = 1.0 + continue + + match = _DOOR_RE.fullmatch(hint) + if match: + door_id = match.group(1) + door_num = int(match.group(2)) + state = match.group(3) + if 1 <= door_num <= MAX_DOORS: + mask[ACTION_TO_INDEX[f"door(target_id='{door_id}', door_state='{state}')"]] = 1.0 + + if mask.sum() == 0: + mask[ACTION_TO_INDEX[WAIT_KEY]] = 1.0 + return mask + + +class ObservationEncoder: + """Encode Pyre observations into a fixed-size float vector.""" + + def __init__(self, mode: str = "visible"): + if mode not in {"visible", "full"}: + raise ValueError(f"Unsupported observation mode: {mode}") + self.mode = mode + self.base_dim = MAX_GRID_W * MAX_GRID_H * 10 + 22 + + def encode(self, observation: PyreObservation) -> np.ndarray: + map_state = observation.map_state + if map_state is None: + raise ValueError("PyreObservation.map_state is required for RL training.") + + cell_one_hot = np.zeros((MAX_GRID_H, MAX_GRID_W, 6), dtype=np.float32) + fire_channel = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32) + smoke_channel = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32) + visible_channel = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32) + agent_channel = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32) + + visible = {(x, y) for x, y in map_state.visible_cells} + for y in range(map_state.grid_h): + for x in range(map_state.grid_w): + if self.mode == "visible" and (x, y) not in visible and (x, y) != (map_state.agent_x, map_state.agent_y): + continue + i = y * map_state.grid_w + x + cell_type = int(map_state.cell_grid[i]) + if 0 <= cell_type <= 5: + cell_one_hot[y, x, cell_type] = 1.0 + fire_channel[y, x] = float(map_state.fire_grid[i]) + smoke_channel[y, x] = float(map_state.smoke_grid[i]) + visible_channel[y, x] = 1.0 if (x, y) in visible else 0.0 + + if 0 <= map_state.agent_x < MAX_GRID_W and 0 <= map_state.agent_y < MAX_GRID_H: + agent_channel[map_state.agent_y, map_state.agent_x] = 1.0 + + grid_features = np.concatenate( + [ + cell_one_hot.reshape(-1), + fire_channel.reshape(-1), + smoke_channel.reshape(-1), + visible_channel.reshape(-1), + agent_channel.reshape(-1), + ] + ) + + metadata = observation.metadata or {} + wind_dir = str(metadata.get("wind_dir", map_state.wind_dir or "CALM")).upper() + difficulty = str(metadata.get("difficulty", "medium")).lower() + wind_index = WINDS.index(wind_dir) if wind_dir in WINDS else 0 + difficulty_index = DIFFICULTIES.index(difficulty) if difficulty in DIFFICULTIES else 1 + + global_features = np.concatenate( + [ + np.array( + [ + float(observation.agent_health) / 100.0, + float(map_state.agent_health) / 100.0, + float(map_state.step_count) / max(1, map_state.max_steps), + float(map_state.fire_spread_rate), + float(map_state.humidity), + float(map_state.agent_x) / max(1, map_state.grid_w - 1), + float(map_state.agent_y) / max(1, map_state.grid_h - 1), + float(metadata.get("nearest_exit_distance", MAX_GRID_W + MAX_GRID_H) or 0.0) / float(MAX_GRID_W + MAX_GRID_H), + float(metadata.get("reachable_exit_count", 0.0)) / 4.0, + float(metadata.get("visible_cell_count", 0.0)) / float(MAX_GRID_W * MAX_GRID_H), + float(metadata.get("fire_sources", 0.0)) / 5.0, + {"none": 0.0, "light": 0.33, "moderate": 0.66, "heavy": 1.0}.get(observation.smoke_level, 0.0), + 1.0 if map_state.agent_alive else 0.0, + 1.0 if map_state.agent_evacuated else 0.0, + ], + dtype=np.float32, + ), + _one_hot(wind_index, len(WINDS)), + _one_hot(difficulty_index, len(DIFFICULTIES)), + ] + ) + + return np.concatenate([grid_features, global_features]).astype(np.float32) + + def describe(self, history_length: int) -> str: + grid_text = ( + f"Observation mode `{self.mode}` encodes a {MAX_GRID_W}x{MAX_GRID_H} padded map with " + "10 channels per cell: 6-way cell type one-hot, fire intensity, smoke intensity, visible mask, and agent mask." + ) + if self.mode == "visible": + visibility_text = "Only currently visible cells are populated; unseen cells stay zeroed." + else: + visibility_text = "The full ground-truth map is exposed for curriculum/debug use." + return ( + f"{grid_text} {visibility_text} " + f"Global features add health, step progress, fire parameters, position, exit-distance metadata, smoke severity, wind, and difficulty. " + f"{history_length} encoded frames are stacked, so the network input dimension is {self.base_dim * history_length}." + ) + + +def softmax_with_mask(logits: np.ndarray, mask: np.ndarray) -> np.ndarray: + masked_logits = np.where(mask > 0.0, logits, -1e9) + max_logits = np.max(masked_logits, axis=1, keepdims=True) + exps = np.exp(masked_logits - max_logits) * mask + denom = np.sum(exps, axis=1, keepdims=True) + denom = np.where(denom <= 0.0, 1.0, denom) + return exps / denom + + +class AdamOptimizer: + def __init__(self, params: Dict[str, np.ndarray], lr: float = 3e-4, beta1: float = 0.9, beta2: float = 0.999): + self.lr = lr + self.beta1 = beta1 + self.beta2 = beta2 + self.eps = 1e-8 + self.t = 0 + self.m = {k: np.zeros_like(v) for k, v in params.items()} + self.v = {k: np.zeros_like(v) for k, v in params.items()} + + def step(self, params: Dict[str, np.ndarray], grads: Dict[str, np.ndarray], clip_norm: float = 1.0) -> None: + total_norm_sq = 0.0 + for grad in grads.values(): + total_norm_sq += float(np.sum(grad * grad)) + total_norm = math.sqrt(total_norm_sq) + scale = 1.0 + if total_norm > clip_norm: + scale = clip_norm / (total_norm + 1e-8) + + self.t += 1 + for name, param in params.items(): + grad = grads[name] * scale + self.m[name] = self.beta1 * self.m[name] + (1.0 - self.beta1) * grad + self.v[name] = self.beta2 * self.v[name] + (1.0 - self.beta2) * (grad * grad) + m_hat = self.m[name] / (1.0 - self.beta1 ** self.t) + v_hat = self.v[name] / (1.0 - self.beta2 ** self.t) + params[name] -= self.lr * m_hat / (np.sqrt(v_hat) + self.eps) + + +class PolicyValueNetwork: + def __init__(self, input_dim: int, action_dim: int, rng: np.random.Generator, hidden_sizes: Sequence[int] = (256, 128)): + h1, h2 = hidden_sizes + self.params: Dict[str, np.ndarray] = { + "w1": self._init_weight(rng, input_dim, h1), + "b1": np.zeros(h1, dtype=np.float32), + "w2": self._init_weight(rng, h1, h2), + "b2": np.zeros(h2, dtype=np.float32), + "wp": self._init_weight(rng, h2, action_dim), + "bp": np.zeros(action_dim, dtype=np.float32), + "wv": self._init_weight(rng, h2, 1), + "bv": np.zeros(1, dtype=np.float32), + } + self.optimizer = AdamOptimizer(self.params) + + @staticmethod + def _init_weight(rng: np.random.Generator, in_dim: int, out_dim: int) -> np.ndarray: + scale = math.sqrt(2.0 / max(1, in_dim + out_dim)) + return (rng.standard_normal((in_dim, out_dim)) * scale).astype(np.float32) + + def forward(self, x: np.ndarray) -> tuple[np.ndarray, np.ndarray, Dict[str, np.ndarray]]: + z1 = x @ self.params["w1"] + self.params["b1"] + h1 = np.tanh(z1) + z2 = h1 @ self.params["w2"] + self.params["b2"] + h2 = np.tanh(z2) + logits = h2 @ self.params["wp"] + self.params["bp"] + values = (h2 @ self.params["wv"] + self.params["bv"]).reshape(-1) + cache = {"x": x, "h1": h1, "h2": h2} + return logits, values, cache + + def predict(self, x: np.ndarray, mask: np.ndarray) -> tuple[np.ndarray, float]: + logits, values, _ = self.forward(x[None, :]) + probs = softmax_with_mask(logits, mask[None, :])[0] + return probs, float(values[0]) + + def update( + self, + states: np.ndarray, + masks: np.ndarray, + actions: np.ndarray, + returns: np.ndarray, + advantages: np.ndarray, + value_coef: float = 0.5, + ) -> Dict[str, float]: + logits, values, cache = self.forward(states) + probs = softmax_with_mask(logits, masks) + + batch_size = max(1, states.shape[0]) + grad_logits = probs.copy() + grad_logits[np.arange(batch_size), actions] -= 1.0 + grad_logits *= advantages[:, None] / batch_size + grad_logits *= masks + + grad_values = ((values - returns)[:, None] * value_coef) / batch_size + + grads: Dict[str, np.ndarray] = {} + grads["wp"] = cache["h2"].T @ grad_logits + grads["bp"] = np.sum(grad_logits, axis=0) + grads["wv"] = cache["h2"].T @ grad_values + grads["bv"] = np.sum(grad_values, axis=0) + + dh2 = grad_logits @ self.params["wp"].T + grad_values @ self.params["wv"].T + dz2 = dh2 * (1.0 - cache["h2"] ** 2) + grads["w2"] = cache["h1"].T @ dz2 + grads["b2"] = np.sum(dz2, axis=0) + + dh1 = dz2 @ self.params["w2"].T + dz1 = dh1 * (1.0 - cache["h1"] ** 2) + grads["w1"] = cache["x"].T @ dz1 + grads["b1"] = np.sum(dz1, axis=0) + + self.optimizer.step(self.params, grads, clip_norm=1.0) + + chosen_probs = np.clip(probs[np.arange(batch_size), actions], 1e-8, 1.0) + policy_loss = float(-np.mean(advantages * np.log(chosen_probs))) + value_loss = float(0.5 * np.mean((values - returns) ** 2)) + entropy = float(-np.mean(np.sum(np.where(probs > 0.0, probs * np.log(np.clip(probs, 1e-8, 1.0)), 0.0), axis=1))) + return { + "policy_loss": policy_loss, + "value_loss": value_loss, + "entropy": entropy, + "mean_value": float(np.mean(values)), + } + + def save(self, path: Path, metadata: Dict[str, object]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + arrays = {name: value for name, value in self.params.items()} + arrays["metadata_json"] = np.array(json.dumps(metadata)) + np.savez(path, **arrays) + + +@dataclass +class Trajectory: + states: List[np.ndarray] + masks: List[np.ndarray] + actions: List[int] + rewards: List[float] + values: List[float] + evacuated: bool + final_health: float + steps: int + total_reward: float + + +def compute_gae( + rewards: Sequence[float], + values: Sequence[float], + gamma: float, + gae_lambda: float, +) -> tuple[np.ndarray, np.ndarray]: + rewards_arr = np.asarray(rewards, dtype=np.float32) + values_arr = np.asarray(values, dtype=np.float32) + advantages = np.zeros(len(rewards_arr), dtype=np.float32) + gae = 0.0 + next_value = 0.0 + for i in range(len(rewards_arr) - 1, -1, -1): + delta = rewards_arr[i] + gamma * next_value - values_arr[i] + gae = delta + gamma * gae_lambda * gae + advantages[i] = gae + next_value = values_arr[i] + returns = advantages + values_arr + return returns.astype(np.float32), advantages.astype(np.float32) + + +def select_action( + network: PolicyValueNetwork, + state_vec: np.ndarray, + mask: np.ndarray, + rng: np.random.Generator, + greedy: bool = False, +) -> tuple[int, float]: + probs, value = network.predict(state_vec, mask) + valid_indices = np.flatnonzero(mask > 0.0) + if len(valid_indices) == 0: + return ACTION_TO_INDEX[WAIT_KEY], value + if greedy: + best_local = int(np.argmax(probs[valid_indices])) + return int(valid_indices[best_local]), value + return int(rng.choice(np.arange(len(probs)), p=probs)), value + + +def build_stacked_state(frames: deque[np.ndarray]) -> np.ndarray: + return np.concatenate(list(frames), dtype=np.float32) + + +def run_episode( + env: PyreEnvironment, + network: PolicyValueNetwork, + encoder: ObservationEncoder, + rng: np.random.Generator, + difficulty: str, + history_length: int, + greedy: bool = False, +) -> Trajectory: + observation = env.reset(difficulty=difficulty) + zero_frame = np.zeros(encoder.base_dim, dtype=np.float32) + frames: deque[np.ndarray] = deque([zero_frame.copy() for _ in range(history_length)], maxlen=history_length) + frames.append(encoder.encode(observation)) + + states: List[np.ndarray] = [] + masks: List[np.ndarray] = [] + actions: List[int] = [] + rewards: List[float] = [] + values: List[float] = [] + + total_reward = 0.0 + final_health = observation.agent_health + evacuated = False + steps = 0 + + while True: + state_vec = build_stacked_state(frames) + mask = build_action_mask(observation) + action_idx, value = select_action(network, state_vec, mask, rng, greedy=greedy) + action = action_index_to_env_action(action_idx) + + next_obs = env.step(action) + reward = float(next_obs.reward or 0.0) + + states.append(state_vec) + masks.append(mask) + actions.append(action_idx) + rewards.append(reward) + values.append(value) + + total_reward += reward + steps += 1 + final_health = next_obs.agent_health + evacuated = next_obs.agent_evacuated + + frames.append(encoder.encode(next_obs)) + observation = next_obs + if next_obs.done: + break + + return Trajectory( + states=states, + masks=masks, + actions=actions, + rewards=rewards, + values=values, + evacuated=evacuated, + final_health=final_health, + steps=steps, + total_reward=total_reward, + ) + + +def evaluate_policy( + env: PyreEnvironment, + network: PolicyValueNetwork, + encoder: ObservationEncoder, + rng: np.random.Generator, + difficulty: str, + history_length: int, + episodes: int, +) -> Dict[str, float]: + rewards = [] + evacuations = 0 + lengths = [] + for _ in range(episodes): + traj = run_episode(env, network, encoder, rng, difficulty, history_length, greedy=True) + rewards.append(traj.total_reward) + lengths.append(traj.steps) + evacuations += int(traj.evacuated) + return { + "eval_reward_mean": float(np.mean(rewards)) if rewards else 0.0, + "eval_reward_max": float(np.max(rewards)) if rewards else 0.0, + "eval_success_rate": float(evacuations / max(1, episodes)), + "eval_steps_mean": float(np.mean(lengths)) if lengths else 0.0, + } + + +def expand_difficulty_schedule(schedule_text: str, episodes: int) -> List[str]: + stages = [part.strip().lower() for part in schedule_text.split(",") if part.strip()] + if not stages: + stages = ["medium"] + for stage in stages: + if stage not in DIFFICULTIES: + raise ValueError(f"Invalid difficulty in schedule: {stage}") + segment = max(1, episodes // len(stages)) + expanded: List[str] = [] + for stage in stages: + expanded.extend([stage] * segment) + while len(expanded) < episodes: + expanded.append(stages[-1]) + return expanded[:episodes] + + +def describe_environment_contract(encoder: ObservationEncoder, history_length: int) -> str: + action_text = ( + f"Action space has {ACTION_DIM} fixed discrete actions: 4 moves, 4 looks, wait, " + f"{MAX_DOORS} door-open slots, and {MAX_DOORS} door-close slots. " + "A per-step mask from `available_actions_hint` prevents invalid actions." + ) + reward_text = ( + "Reward comes directly from the environment's composite rubric: time penalty, exit progress, " + "progress regression penalty, safe-progress bonus, danger penalty, health-drain penalty, " + "strategic door bonus, exploration bonus, plus terminal evacuation/death/timeout/near-miss/time bonuses." + ) + return "\n".join( + [ + "Pyre RL contract", + encoder.describe(history_length), + action_text, + reward_text, + ] + ) + + +def _moving_average(values: Sequence[float], window: int) -> List[float]: + if not values: + return [] + out: List[float] = [] + run = 0.0 + q: deque[float] = deque() + for value in values: + q.append(float(value)) + run += float(value) + if len(q) > window: + run -= q.popleft() + out.append(run / len(q)) + return out + + +def save_metrics_csv(path: Path, rows: List[Dict[str, float | int | str]]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + if not rows: + return + with path.open("w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) + writer.writeheader() + writer.writerows(rows) + + +def save_training_graph(path: Path, episode_rows: List[Dict[str, float | int | str]], eval_rows: List[Dict[str, float | int | str]]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + if not episode_rows: + return + + width = 1260 + height = 780 + margin_left = 100 # extra room for rotated Y-axis label + tick values + margin_right = 110 # extra room for right axis label + tick values + margin_top = 70 # room for title + margin_bottom = 90 # room for X-axis label + tick values + legend + plot_w = width - margin_left - margin_right + plot_h = height - margin_top - margin_bottom + + # X: plot_left=100, plot_right=1150 Y: plot_top=70, plot_bottom=690 + + episodes = [int(r["episode"]) for r in episode_rows] + rewards = [float(r["reward"]) for r in episode_rows] + reward_ma = _moving_average(rewards, 20) + success_ma = _moving_average([float(r["evacuated"]) for r in episode_rows], 20) + + all_reward_values = rewards + reward_ma + [float(r["reward_mean"]) for r in eval_rows] + [float(r["reward_max"]) for r in eval_rows] + y_min = min(all_reward_values) if all_reward_values else -1.0 + y_max = max(all_reward_values) if all_reward_values else 1.0 + if abs(y_max - y_min) < 1e-6: + y_min -= 1.0 + y_max += 1.0 + y_pad = 0.1 * (y_max - y_min) + y_min -= y_pad + y_max += y_pad + + max_episode = max(episodes) if episodes else 1 + + plot_left = margin_left + plot_right = margin_left + plot_w + plot_top = margin_top + plot_bottom = margin_top + plot_h + + def x_pos(ep: float) -> float: + return plot_left + (float(ep) - 1.0) / max(1.0, max_episode - 1.0) * plot_w + + def y_pos_reward(value: float) -> float: + return plot_top + (y_max - float(value)) / max(1e-6, (y_max - y_min)) * plot_h + + def y_pos_success(value: float) -> float: + return plot_top + (1.0 - float(value)) * plot_h + + def polyline(points: List[tuple[float, float]]) -> str: + return " ".join(f"{x:.1f},{y:.1f}" for x, y in points) + + reward_points = [(x_pos(ep), y_pos_reward(val)) for ep, val in zip(episodes, rewards)] + reward_ma_points = [(x_pos(ep), y_pos_reward(val)) for ep, val in zip(episodes, reward_ma)] + success_points = [(x_pos(ep), y_pos_success(val)) for ep, val in zip(episodes, success_ma)] + eval_points = [(x_pos(float(r["episode"])), y_pos_success(float(r["success_rate"]))) for r in eval_rows] + + n_x_ticks = 8 + episode_ticks = sorted(set( + max(1, round(1 + i * (max_episode - 1) / n_x_ticks)) + for i in range(n_x_ticks + 1) + )) + n_y_ticks = 6 + reward_ticks = [y_min + (y_max - y_min) * i / n_y_ticks for i in range(n_y_ticks + 1)] + success_ticks = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] + + svg = [] + svg.append(f'") + path.write_text("\n".join(svg), encoding="utf-8") + + +def train(args: argparse.Namespace) -> None: + rng = np.random.default_rng(args.seed) + encoder = ObservationEncoder(mode=args.observation_mode) + difficulty_schedule = expand_difficulty_schedule(args.difficulty_schedule, args.episodes) + input_dim = encoder.base_dim * args.history_length + network = PolicyValueNetwork(input_dim=input_dim, action_dim=ACTION_DIM, rng=rng) + env = PyreEnvironment(max_steps=args.max_steps) + + print(describe_environment_contract(encoder, args.history_length)) + print("") + + batch_states: List[np.ndarray] = [] + batch_masks: List[np.ndarray] = [] + batch_actions: List[int] = [] + batch_returns: List[np.ndarray] = [] + batch_advantages: List[np.ndarray] = [] + + reward_window: deque[float] = deque(maxlen=20) + success_window: deque[float] = deque(maxlen=20) + episode_metrics: List[Dict[str, float | int | str]] = [] + eval_metrics_rows: List[Dict[str, float | int | str]] = [] + + for episode_idx in range(args.episodes): + difficulty = difficulty_schedule[episode_idx] if args.difficulty_schedule else args.difficulty + traj = run_episode( + env=env, + network=network, + encoder=encoder, + rng=rng, + difficulty=difficulty, + history_length=args.history_length, + greedy=False, + ) + + returns, advantages = compute_gae(traj.rewards, traj.values, args.gamma, args.gae_lambda) + batch_states.extend(traj.states) + batch_masks.extend(traj.masks) + batch_actions.extend(traj.actions) + batch_returns.append(returns) + batch_advantages.append(advantages) + + reward_window.append(traj.total_reward) + success_window.append(float(traj.evacuated)) + episode_metrics.append( + { + "episode": episode_idx + 1, + "difficulty": difficulty, + "reward": round(traj.total_reward, 4), + "evacuated": int(traj.evacuated), + "steps": traj.steps, + "final_health": round(traj.final_health, 2), + "reward_mean_20": round(float(np.mean(reward_window)), 4), + "success_rate_20": round(float(np.mean(success_window)), 4), + } + ) + + print( + f"episode={episode_idx + 1:04d} difficulty={difficulty:<6} " + f"steps={traj.steps:03d} reward={traj.total_reward:+8.3f} " + f"evacuated={int(traj.evacuated)} health={traj.final_health:6.1f}" + ) + + should_update = (episode_idx + 1) % args.update_every == 0 or (episode_idx + 1) == args.episodes + if should_update and batch_states: + states_arr = np.asarray(batch_states, dtype=np.float32) + masks_arr = np.asarray(batch_masks, dtype=np.float32) + actions_arr = np.asarray(batch_actions, dtype=np.int64) + returns_arr = np.concatenate(batch_returns).astype(np.float32) + advantages_arr = np.concatenate(batch_advantages).astype(np.float32) + advantages_arr = (advantages_arr - advantages_arr.mean()) / (advantages_arr.std() + 1e-8) + + network.optimizer.lr = args.learning_rate + metrics = {} + for _ in range(args.update_epochs): + order = rng.permutation(len(states_arr)) + for start in range(0, len(states_arr), args.minibatch_size): + idx = order[start:start + args.minibatch_size] + metrics = network.update( + states=states_arr[idx], + masks=masks_arr[idx], + actions=actions_arr[idx], + returns=returns_arr[idx], + advantages=advantages_arr[idx], + value_coef=args.value_coef, + ) + + print( + f"update episodes={episode_idx + 1:04d} samples={len(states_arr):05d} " + f"reward_mean20={np.mean(reward_window):+8.3f} success20={np.mean(success_window):.2f} " + f"policy_loss={metrics['policy_loss']:+.4f} value_loss={metrics['value_loss']:.4f} " + f"entropy={metrics['entropy']:.4f}" + ) + + batch_states.clear() + batch_masks.clear() + batch_actions.clear() + batch_returns.clear() + batch_advantages.clear() + + should_eval = args.eval_every > 0 and ((episode_idx + 1) % args.eval_every == 0 or (episode_idx + 1) == args.episodes) + if should_eval: + eval_metrics = evaluate_policy( + env=env, + network=network, + encoder=encoder, + rng=rng, + difficulty=args.eval_difficulty, + history_length=args.history_length, + episodes=args.eval_episodes, + ) + print( + f"eval episodes={episode_idx + 1:04d} difficulty={args.eval_difficulty:<6} " + f"reward_mean={eval_metrics['eval_reward_mean']:+8.3f} " + f"reward_max={eval_metrics['eval_reward_max']:+8.3f} " + f"success={eval_metrics['eval_success_rate']:.2f} " + f"steps={eval_metrics['eval_steps_mean']:.1f}" + ) + eval_metrics_rows.append( + { + "episode": episode_idx + 1, + "difficulty": args.eval_difficulty, + "reward_mean": round(eval_metrics["eval_reward_mean"], 4), + "reward_max": round(eval_metrics["eval_reward_max"], 4), + "success_rate": round(eval_metrics["eval_success_rate"], 4), + "steps_mean": round(eval_metrics["eval_steps_mean"], 4), + } + ) + + if args.output: + output_path = Path(args.output) + network.save( + output_path, + metadata={ + "observation_mode": args.observation_mode, + "history_length": args.history_length, + "episodes": args.episodes, + "difficulty": args.difficulty, + "difficulty_schedule": args.difficulty_schedule, + "gamma": args.gamma, + "gae_lambda": args.gae_lambda, + "learning_rate": args.learning_rate, + "update_epochs": args.update_epochs, + "minibatch_size": args.minibatch_size, + "action_dim": ACTION_DIM, + "input_dim": input_dim, + }, + ) + print(f"saved model={output_path}") + if args.save_metrics: + metrics_path = output_path.with_suffix(".csv") + save_metrics_csv(metrics_path, episode_metrics) + print(f"saved metrics={metrics_path}") + if args.save_graph: + graph_path = output_path.with_suffix(".svg") + save_training_graph(graph_path, episode_metrics, eval_metrics_rows) + print(f"saved graph={graph_path}") + # Also save PNG + try: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + import matplotlib.ticker as mticker + import matplotlib.patches as mpatches + + episodes_list = [int(r["episode"]) for r in episode_metrics] + rewards_list = [float(r["reward"]) for r in episode_metrics] + evacuated_list = [float(r["evacuated"]) for r in episode_metrics] + diff_list = [str(r["difficulty"]) for r in episode_metrics] + + def _ma(vals, w=20): + out, run, q = [], 0.0, [] + for v in vals: + q.append(v); run += v + if len(q) > w: run -= q.pop(0) + out.append(run / len(q)) + return out + + reward_ma = _ma(rewards_list) + success_ma = _ma(evacuated_list) + eval_eps = [int(r["episode"]) for r in eval_metrics_rows] + eval_succ = [float(r["success_rate"]) for r in eval_metrics_rows] + + diff_colors = {"easy": "#d4edda", "medium": "#fff3cd", "hard": "#f8d7da"} + regions = [] + if diff_list: + cur, start = diff_list[0], episodes_list[0] + for ep, d in zip(episodes_list[1:], diff_list[1:]): + if d != cur: + regions.append((start, ep, cur)); cur, start = d, ep + regions.append((start, episodes_list[-1], cur)) + + fig, ax1 = plt.subplots(figsize=(14, 6)) + ax2 = ax1.twinx() + for x0, x1, diff in regions: + ax1.axvspan(x0, x1, color=diff_colors.get(diff, "#eeeeee"), alpha=0.35, zorder=0) + ax1.axhline(0, color="#aaaaaa", linewidth=0.8, linestyle="--", zorder=1) + ax1.plot(episodes_list, rewards_list, color="#d1c7bc", linewidth=0.8, alpha=0.6, label="Episode reward", zorder=2) + ax1.plot(episodes_list, reward_ma, color="#c1661c", linewidth=2.5, label="Reward (MA-20)", zorder=3) + ax2.plot(episodes_list, success_ma, color="#1a7a8a", linewidth=2.5, label="Success rate (MA-20)", zorder=3) + if eval_eps: + ax2.scatter(eval_eps, eval_succ, color="#0d5b6b", s=60, zorder=5, marker="D", edgecolors="white", linewidths=1.2, label="Eval success") + ax1.set_xlabel("Episode", fontsize=13, fontweight="bold", labelpad=8) + ax1.set_ylabel("Reward", fontsize=13, fontweight="bold", color="#c1661c", labelpad=8) + ax2.set_ylabel("Success Rate", fontsize=13, fontweight="bold", color="#1a7a8a", labelpad=8) + ax1.tick_params(axis="y", labelcolor="#c1661c") + ax2.tick_params(axis="y", labelcolor="#1a7a8a") + ax2.set_ylim(-0.05, 1.05) + ax2.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0, decimals=0)) + ax1.grid(True, linestyle="--", linewidth=0.6, color="#dddddd", alpha=0.8) + ax1.set_xlim(episodes_list[0], episodes_list[-1]) + diff_patches = [mpatches.Patch(color=diff_colors[d], alpha=0.6, label=d.capitalize()) + for d in ["easy", "medium", "hard"] if d in diff_list] + h1, l1 = ax1.get_legend_handles_labels() + h2, l2 = ax2.get_legend_handles_labels() + ax1.legend(h1 + h2 + diff_patches, l1 + l2 + [p.get_label() for p in diff_patches], + loc="upper left", fontsize=9, framealpha=0.85) + final_sr = success_ma[-1] if success_ma else 0.0 + fig.suptitle(f"Pyre NumPy A2C Training — {episodes_list[-1]} episodes | final success: {final_sr:.0%}", + fontsize=14, fontweight="bold", y=1.01) + fig.tight_layout() + png_path = output_path.with_suffix(".png") + fig.savefig(png_path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"saved graph_png={png_path}") + except ImportError: + pass + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Train a NumPy actor-critic baseline for Pyre.") + parser.add_argument("--episodes", type=int, default=120, help="Training episodes.") + parser.add_argument("--difficulty", type=str, default="easy", choices=DIFFICULTIES) + parser.add_argument( + "--difficulty-schedule", + type=str, + default="easy,medium", + help="Comma-separated curriculum, expanded evenly across episodes.", + ) + parser.add_argument("--eval-difficulty", type=str, default="medium", choices=DIFFICULTIES) + parser.add_argument("--eval-episodes", type=int, default=5) + parser.add_argument("--eval-every", type=int, default=20) + parser.add_argument("--update-every", type=int, default=5, help="Episodes per policy update.") + parser.add_argument("--update-epochs", type=int, default=3, help="Gradient passes over each on-policy batch.") + parser.add_argument("--minibatch-size", type=int, default=256, help="Samples per gradient step.") + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--gae-lambda", type=float, default=0.95) + parser.add_argument("--learning-rate", type=float, default=3e-4) + parser.add_argument("--value-coef", type=float, default=0.5) + parser.add_argument("--history-length", type=int, default=4) + parser.add_argument("--max-steps", type=int, default=150) + parser.add_argument("--seed", type=int, default=7) + parser.add_argument("--observation-mode", type=str, default="visible", choices=("visible", "full")) + parser.add_argument("--output", type=str, default="artifacts/pyre_actor_critic.npz") + parser.add_argument("--save-metrics", action="store_true", help="Save per-episode metrics as CSV beside the model.") + parser.add_argument("--save-graph", action="store_true", help="Save an SVG training graph beside the model.") + parser.add_argument("--describe-only", action="store_true", help="Print observation/action/reward definitions and exit.") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + encoder = ObservationEncoder(mode=args.observation_mode) + if args.describe_only: + print(describe_environment_contract(encoder, args.history_length)) + return + train(args) + + +if __name__ == "__main__": + main() diff --git a/examples/train_sb3_agent.py b/examples/train_sb3_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..6082b7f11a587019fd22068e57a75dd192399dd5 --- /dev/null +++ b/examples/train_sb3_agent.py @@ -0,0 +1,285 @@ +import sys +import os +import types + +# Windows AppControl can block matplotlib's compiled C extensions. +# Stub the minimal surface that stable_baselines3.common.logger imports +# at module level so SB3 loads cleanly even without a working matplotlib. +def _stub_matplotlib(): + if "matplotlib" in sys.modules: + return + _mpl = types.ModuleType("matplotlib") + _mpl.figure = types.ModuleType("matplotlib.figure") + _mpl.figure.Figure = object + _mpl.use = lambda *a, **kw: None + _mpl.__version__ = "0.0.0" + sys.modules["matplotlib"] = _mpl + sys.modules["matplotlib.figure"] = _mpl.figure + for sub in ("matplotlib.pyplot", "matplotlib.ticker", "matplotlib.patches", + "matplotlib.gridspec", "matplotlib.colors", "matplotlib.cm", + "matplotlib.backend_bases", "matplotlib.backends", + "matplotlib.backends.backend_agg"): + m = types.ModuleType(sub) + sys.modules[sub] = m + +_stub_matplotlib() + +import gymnasium as gym +import numpy as np +from gymnasium import spaces +from pyre_env.models import PyreAction, PyreObservation +from pyre_env.server.pyre_env_environment import PyreEnvironment +import torch as th +sys.path.append(os.getcwd()) + +class PyreGymEnv(gym.Env): + """Gymnasium wrapper for PyreEnvironment.""" + + def __init__(self, difficulty="easy", max_steps=150, observation_mode="visible"): + super().__init__() + self.env = PyreEnvironment(max_steps=max_steps) + self.difficulty = difficulty + self.observation_mode = observation_mode + + # Action space: + # 0-3: Move (N, S, W, E) + # 4-7: Look (N, S, W, E) + # 8: Wait + # 9-24: Open Door 1-16 + # 25-40: Close Door 1-16 + self.action_space = spaces.Discrete(41) + + # Observation space: Multi-input + # 1. Grid: 24x24x7 (Floor, Wall, Door_Open, Door_Closed, Exit, Obstacle, Fire, Smoke) + # 2. Global: [health, oxygen, step_progress, fire_spread, humidity, agent_x, agent_y, nearest_exit_dist, is_coughing] + # 3. Heat Sensor: 3x3 + self.observation_space = spaces.Dict({ + "grid": spaces.Box(low=0, high=1, shape=(7, 24, 24), dtype=np.float32), + "global": spaces.Box(low=0, high=1, shape=(9,), dtype=np.float32), + "heat": spaces.Box(low=0, high=1, shape=(1, 3, 3), dtype=np.float32) + }) + + def _get_obs(self, pyre_obs: PyreObservation): + map_state = pyre_obs.map_state + w, h = map_state.grid_w, map_state.grid_h + + # Build 7-channel grid + # Channels: 0:Wall, 1:Door_Open, 2:Door_Closed, 3:Exit, 4:Obstacle, 5:Fire, 6:Smoke + # (Floor is implicit as all zeros in other channels) + grid = np.zeros((7, 24, 24), dtype=np.float32) + + visible = {(x, y) for x, y in map_state.visible_cells} + for y in range(h): + for x in range(w): + if self.observation_mode == "visible" and (x, y) not in visible and (x, y) != (map_state.agent_x, map_state.agent_y): + continue + + i = y * w + x + ct = map_state.cell_grid[i] + if ct == 1: grid[0, y, x] = 1.0 # Wall + elif ct == 2: grid[1, y, x] = 1.0 # Door Open + elif ct == 3: grid[2, y, x] = 1.0 # Door Closed + elif ct == 4: grid[3, y, x] = 1.0 # Exit + elif ct == 5: grid[4, y, x] = 1.0 # Obstacle + + grid[5, y, x] = float(map_state.fire_grid[i]) + grid[6, y, x] = float(map_state.smoke_grid[i]) + + # Global features + metadata = pyre_obs.metadata or {} + nearest_exit = float(metadata.get("nearest_exit_distance", 48) or 48.0) / 48.0 + # smoke_level → is_coughing proxy (moderate/heavy smoke = coughing) + smoke = getattr(pyre_obs, "smoke_level", "none") or "none" + is_coughing = 1.0 if smoke in ("moderate", "heavy") else 0.0 + + global_feats = np.array([ + float(pyre_obs.agent_health) / 100.0, + float(pyre_obs.agent_health) / 100.0, # oxygen_level proxy + float(map_state.step_count) / float(map_state.max_steps), + float(map_state.fire_spread_rate), + float(map_state.humidity), + float(map_state.agent_x) / 24.0, + float(map_state.agent_y) / 24.0, + nearest_exit, + is_coughing, + ], dtype=np.float32) + + # Heat sensor — derive 3×3 fire neighbourhood around agent from the fire grid + ax, ay = map_state.agent_x, map_state.agent_y + gw, gh = map_state.grid_w, map_state.grid_h + heat_vals = [] + for dy in (-1, 0, 1): + for dx in (-1, 0, 1): + nx, ny = ax + dx, ay + dy + if 0 <= nx < gw and 0 <= ny < gh: + heat_vals.append(float(map_state.fire_grid[ny * gw + nx])) + else: + heat_vals.append(0.0) + heat = np.array(heat_vals, dtype=np.float32).reshape(1, 3, 3) + + return { + "grid": grid, + "global": global_feats, + "heat": heat + } + + def reset(self, seed=None, options=None): + super().reset(seed=seed) + difficulty = options.get("difficulty", self.difficulty) if options else self.difficulty + pyre_obs = self.env.reset(seed=seed, difficulty=difficulty) + return self._get_obs(pyre_obs), {} + + def step(self, action_idx): + # Map Discrete action to PyreAction + if action_idx < 4: + dirs = ["north", "south", "west", "east"] + action = PyreAction(action="move", direction=dirs[action_idx]) + elif action_idx < 8: + dirs = ["north", "south", "west", "east"] + action = PyreAction(action="look", direction=dirs[action_idx - 4]) + elif action_idx == 8: + action = PyreAction(action="wait") + elif action_idx < 9 + 16: + action = PyreAction(action="door", target_id=f"door_{action_idx - 8}", door_state="open") + else: + action = PyreAction(action="door", target_id=f"door_{action_idx - 24}", door_state="close") + + pyre_obs = self.env.step(action) + + obs = self._get_obs(pyre_obs) + reward = pyre_obs.reward + terminated = pyre_obs.done + truncated = False # Step limit handled by env.done + + return obs, reward, terminated, truncated, {"pyre_obs": pyre_obs} + +if __name__ == "__main__": + from stable_baselines3 import PPO + from stable_baselines3.common.callbacks import CheckpointCallback + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--episodes", type=int, default=1500, help="Total episodes to train across all levels") + parser.add_argument("--difficulty", type=str, default="curriculum", help="easy, medium, hard, random, or curriculum") + parser.add_argument("--output", type=str, default="artifacts/ppo_pyre_multilevel") + args = parser.parse_args() + + from gymnasium.wrappers import RecordEpisodeStatistics + + # Custom wrapper to handle difficulty changes + class MultiLevelWrapper(gym.Wrapper): + def __init__(self, env, mode="curriculum"): + super().__init__(env) + self.mode = mode + self.current_difficulty = "easy" + self.step_count = 0 + self.total_steps = 0 + + def reset(self, **kwargs): + if self.mode == "random": + self.current_difficulty = np.random.choice(["easy", "medium", "hard"]) + elif self.mode == "curriculum": + if self.total_steps < 0.33 * total_training_steps: + self.current_difficulty = "easy" + elif self.total_steps < 0.66 * total_training_steps: + self.current_difficulty = "medium" + else: + self.current_difficulty = "hard" + else: + self.current_difficulty = self.mode + + # Extract options from kwargs if present, or create new + options = kwargs.get("options") + if options is None: + options = {} + options["difficulty"] = self.current_difficulty + kwargs["options"] = options + + return self.env.reset(**kwargs) + + def step(self, action): + obs, reward, term, trunc, info = self.env.step(action) + self.total_steps += 1 + info["difficulty"] = self.current_difficulty + return obs, reward, term, trunc, info + + total_training_steps = args.episodes * 60 + + env = PyreGymEnv(difficulty="easy") # Base difficulty + env = MultiLevelWrapper(env, mode=args.difficulty) + env = RecordEpisodeStatistics(env) + + # Custom CNN policy for the grid + # Increased network capacity for multiple levels + policy_kwargs = dict( + activation_fn=th.nn.ReLU, + net_arch=dict(pi=[256, 128], qf=[256, 128]) + ) + + model = PPO( + "MultiInputPolicy", + env, + verbose=1, + learning_rate=2e-4, # Slightly lower LR for stability across levels + n_steps=2048, + batch_size=128, + n_epochs=10, + gamma=0.99, + gae_lambda=0.95, + clip_range=0.2, + ent_coef=0.02, # Higher entropy to encourage exploration in procedural maps + ) + + print(f"Starting multi-level training (mode: {args.difficulty})...") + + # Add a simple callback to log episode rewards to a CSV + from stable_baselines3.common.callbacks import BaseCallback + import csv + from pathlib import Path + + class CSVLogCallback(BaseCallback): + def __init__(self, filename): + super().__init__() + self.filename = filename + self.results = [] + def _on_step(self): + # Check every step for finished episodes + for info in self.locals.get("infos", []): + if "episode" in info: + self.results.append({ + "step": self.num_timesteps, + "reward": info["episode"]["r"], + "length": info["episode"]["l"] + }) + return True + def _on_rollout_end(self): + # Save every rollout + if self.results: + with open(self.filename, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["step", "reward", "length"]) + writer.writeheader() + writer.writerows(self.results) + return True + + csv_path = args.output + ".csv" + callback = CSVLogCallback(csv_path) + + # CNN MultiInputPolicy needs far more steps than a flat MLP to warm up. + # episodes * 50 ≈ 15k steps (too few). Use episodes * 500 for meaningful learning. + model.learn(total_timesteps=args.episodes * 500, callback=callback) + + model.save(args.output) + print(f"Model saved to {args.output}") + print(f"Metrics saved to {csv_path}") + + # Generate a quick SVG graph if we have results + if callback.results: + try: + from examples.train_rl_agent import save_training_graph + # Mocking the row format expected by the baseline plotter + rows = [{"episode": i, "reward": r["reward"], "evacuated": 0} for i, r in enumerate(callback.results)] + save_training_graph(Path(args.output + ".svg"), rows, []) + print(f"Graph saved to {args.output}.svg") + except Exception as e: + print(f"Could not generate SVG automatically: {e}") + print("CSV is available at " + csv_path) diff --git a/examples/train_torch_ppo.py b/examples/train_torch_ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..d899b3c86a0c717693ac90a847d290a979cb36b4 --- /dev/null +++ b/examples/train_torch_ppo.py @@ -0,0 +1,1278 @@ +""" +PyTorch PPO Agent for Pyre — Fire Evacuation RL Training Script. + +=== ENVIRONMENT SUMMARY === +Pyre is a partial-observability crisis navigation environment: + - Grid: 16×16 (easy/medium) or 20×24 (hard, procedural) + - Agent: Spawns inside a burning building, must evacuate before dying + - Fire: Spreads via cellular automaton — wind, humidity, fuel vary per episode + - Partial observability: visibility radius (2–5 cells) shrinks in heavy smoke + - Doors: Can be opened/closed to slow fire spread (+0.5 strategic door bonus) + - Health: 100 HP, drains from smoke (0.5–5/step) and fire (10/step) + +=== ACTION SPACE (41 discrete) === + 0–3 : move(north|south|west|east) + 4–7 : look(north|south|west|east) — scan without moving, still costs a step + 8 : wait() + 9–24 : door(door_1..16, open) + 25–40 : door(door_1..16, close) + Runtime action masking via `available_actions_hint` prevents invalid moves. + +=== OBSERVATION ENCODING === + Per-step grid: 24×24 padded map × 10 channels + • 6 one-hot cell type (floor/wall/door_open/door_closed/exit/obstacle) + • fire intensity [0, 1] + • smoke density [0, 1] + • visibility mask (1=visible, 0=unseen) + • agent position mask + Global scalars (22): health, step_progress, fire_spread, humidity, + agent_x, agent_y, exit_distance, reachable_exits, visible_cells, + fire_sources, smoke_severity, alive, evacuated, wind (one-hot 5), difficulty (one-hot 3) + Frame stacking: 4 consecutive frames → input_dim = 5782 × 4 = 23128 + +=== REWARD STRUCTURE === + Per-step: + -0.01 time penalty (urgency) + +0.10 BFS progress toward nearest unblocked exit + -0.05 regression (moved farther from exit) + +0.05 safe-progress bonus (progress through smoke-free cell) + -0.50 danger penalty (moved into smoke≥moderate or fire-adjacent) + -0.02×dmg health drain penalty + +0.50 strategic door close (adjacent to fire, once per door per episode) + +0.02 exploration bonus (first visit to cell) + Terminal: + +5.00 evacuation success + +1.50×(hp/100) health survival bonus (max +1.5) + -10.0 death + -5.00 timeout + 0→+3.0 near-miss partial credit (based on closest exit approach) + +0.05×remaining_steps time bonus + +=== ALGORITHM: PPO (Proximal Policy Optimization) === +WHY PPO over alternatives: + • DQN — Off-policy, harder credit assignment for sparse terminal rewards; no clean action masking + • A2C — Simpler but no clipping → unstable on hard stochastic episodes + • SAC — Designed for continuous spaces; discrete SAC works but adds complexity + • LSTM-PPO — Better for fully text-only obs; grid map_state already encodes spatial state + → PPO + frame-stack + action-mask hits the sweet spot for this env + +Key PPO improvements over the existing NumPy A2C (train_rl_agent.py): + ✓ PPO clip (ε=0.2) prevents catastrophic updates + ✓ Entropy regularization sustains exploration in smoke-obscured corridors + ✓ Value function clipping stabilises critic under sparse terminal rewards + ✓ GPU acceleration 10–20× faster than NumPy baseline + ✓ LayerNorm in network improves gradient flow for large input dims + ✓ Linear LR decay stabilises late-stage convergence + ✓ Better curriculum 3-stage easy→medium→hard with patience gating + +Usage: + python examples/train_torch_ppo.py --episodes 500 --device cuda + python examples/train_torch_ppo.py --episodes 300 --difficulty-schedule easy,medium,hard + python examples/train_torch_ppo.py --resume artifacts/pyre_ppo_checkpoint.pt + python examples/train_torch_ppo.py --describe-only +""" + +from __future__ import annotations + +import argparse +import csv +import json +import os +import sys +import time +from collections import deque +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Optional, Sequence, Tuple + +import numpy as np + +# --------------------------------------------------------------------------- +# Optional torch import — fail fast with a helpful message +# --------------------------------------------------------------------------- +try: + import torch + import torch.nn as nn + import torch.nn.functional as F + from torch.optim import Adam + from torch.optim.lr_scheduler import LinearLR +except ImportError: + sys.exit( + "PyTorch not found. Install with:\n" + " pip install torch --index-url https://download.pytorch.org/whl/cu121\n" + "or for CPU only:\n" + " pip install torch" + ) + +# --------------------------------------------------------------------------- +# Project imports — support both package install and direct run from root +# --------------------------------------------------------------------------- +_ROOT = Path(__file__).resolve().parent.parent +if str(_ROOT) not in sys.path: + sys.path.insert(0, str(_ROOT)) + +try: + from pyre_env.models import PyreAction, PyreObservation + from pyre_env.server.pyre_env_environment import PyreEnvironment +except ModuleNotFoundError: + try: + from models import PyreAction, PyreObservation + from server.pyre_env_environment import PyreEnvironment + except ModuleNotFoundError: + sys.exit( + "Cannot import Pyre modules. Run this script from the openenv-pyre root:\n" + " python examples/train_torch_ppo.py" + ) + +# --------------------------------------------------------------------------- +# Reuse the established observation/action interface from train_rl_agent.py +# These are the canonical definitions for this environment. +# --------------------------------------------------------------------------- +MAX_GRID_W = 24 +MAX_GRID_H = 24 +MAX_DOORS = 16 +DIRECTIONS = ("north", "south", "west", "east") +WINDS = ("CALM", "NORTH", "SOUTH", "WEST", "EAST") +DIFFICULTIES = ("easy", "medium", "hard") + +MOVE_KEYS = [f"move(direction='{d}')" for d in DIRECTIONS] +LOOK_KEYS = [f"look(direction='{d}')" for d in DIRECTIONS] +WAIT_KEY = "wait()" +OPEN_KEYS = [f"door(target_id='door_{i}', door_state='open')" for i in range(1, MAX_DOORS + 1)] +CLOSE_KEYS = [f"door(target_id='door_{i}', door_state='close')" for i in range(1, MAX_DOORS + 1)] +ACTION_KEYS = MOVE_KEYS + LOOK_KEYS + [WAIT_KEY] + OPEN_KEYS + CLOSE_KEYS +ACTION_DIM = len(ACTION_KEYS) # 41 +ACTION_TO_INDEX = {key: idx for idx, key in enumerate(ACTION_KEYS)} + +import re +_MOVE_RE = re.compile(r"move\(direction='(north|south|west|east)'\)") +_LOOK_RE = re.compile(r"look\(direction='(north|south|west|east)'\)") +_DOOR_RE = re.compile(r"door\(target_id='(door_(\d+))', door_state='(open|close)'\)") + + +def action_index_to_env_action(index: int) -> PyreAction: + if 0 <= index < 4: + return PyreAction(action="move", direction=DIRECTIONS[index]) + if 4 <= index < 8: + return PyreAction(action="look", direction=DIRECTIONS[index - 4]) + if index == 8: + return PyreAction(action="wait") + if 9 <= index < 9 + MAX_DOORS: + door_id = f"door_{index - 8}" + return PyreAction(action="door", target_id=door_id, door_state="open") + door_slot = index - (9 + MAX_DOORS) + door_id = f"door_{door_slot + 1}" + return PyreAction(action="door", target_id=door_id, door_state="close") + + +def build_action_mask(observation: PyreObservation, exclude_look: bool = True) -> np.ndarray: + """Build a binary validity mask over the 41-action space. + + exclude_look=True (default for RL): + Suppresses all 4 'look' actions. The RL agent already receives the full + grid via map_state — look gives zero new information but wastes a step + and earns no reward. Excluding it concentrates the policy on moves and + doors, which are the only actions that can improve the agent's position. + """ + mask = np.zeros(ACTION_DIM, dtype=np.float32) + for hint in observation.available_actions_hint: + idx = ACTION_TO_INDEX.get(hint) + if idx is not None: + mask[idx] = 1.0 + continue + m = _MOVE_RE.fullmatch(hint) + if m: + mask[ACTION_TO_INDEX[f"move(direction='{m.group(1)}')"]] = 1.0 + continue + m = _LOOK_RE.fullmatch(hint) + if m: + if not exclude_look: + mask[ACTION_TO_INDEX[f"look(direction='{m.group(1)}')"]] = 1.0 + continue + m = _DOOR_RE.fullmatch(hint) + if m: + door_id, door_num, state = m.group(1), int(m.group(2)), m.group(3) + if 1 <= door_num <= MAX_DOORS: + mask[ACTION_TO_INDEX[f"door(target_id='{door_id}', door_state='{state}')"]] = 1.0 + if mask.sum() == 0: + mask[ACTION_TO_INDEX[WAIT_KEY]] = 1.0 + return mask + + +class ObservationEncoder: + """Encode PyreObservation into a fixed-length float32 vector. + + Mode 'visible': only populate cells within the agent's sight radius — + mimics true partial observability; preferred for training. + Mode 'full': expose complete ground-truth grid — useful for debugging + or oracle upper-bound experiments. + + Output shape: (base_dim,) = (MAX_GRID_W × MAX_GRID_H × 10 + 25,) = (5785,) + With history stacking of k frames: (5785 × k,) + + The 3 extra scalars over the v1 baseline are map-agnostic exit-compass + features (Fix 3): exit_dx_norm, exit_dy_norm, exit_manhattan_norm. + These allow the agent to locate the nearest exit on procedurally generated + maps without having to memorise layout-specific coordinates. + """ + + base_dim = MAX_GRID_W * MAX_GRID_H * 10 + 25 + + def __init__(self, mode: str = "visible"): + if mode not in {"visible", "full"}: + raise ValueError(f"mode must be 'visible' or 'full', got '{mode}'") + self.mode = mode + + def encode(self, observation: PyreObservation) -> np.ndarray: + ms = observation.map_state + if ms is None: + raise ValueError("map_state is required for encoding.") + + cell_one_hot = np.zeros((MAX_GRID_H, MAX_GRID_W, 6), dtype=np.float32) + fire_ch = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32) + smoke_ch = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32) + vis_ch = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32) + agent_ch = np.zeros((MAX_GRID_H, MAX_GRID_W), dtype=np.float32) + + visible = {(x, y) for x, y in ms.visible_cells} + for y in range(ms.grid_h): + for x in range(ms.grid_w): + if self.mode == "visible" and (x, y) not in visible and (x, y) != (ms.agent_x, ms.agent_y): + continue + i = y * ms.grid_w + x + ct = int(ms.cell_grid[i]) + if 0 <= ct <= 5: + cell_one_hot[y, x, ct] = 1.0 + fire_ch[y, x] = float(ms.fire_grid[i]) + smoke_ch[y, x] = float(ms.smoke_grid[i]) + vis_ch[y, x] = 1.0 if (x, y) in visible else 0.0 + + if 0 <= ms.agent_x < MAX_GRID_W and 0 <= ms.agent_y < MAX_GRID_H: + agent_ch[ms.agent_y, ms.agent_x] = 1.0 + + grid_features = np.concatenate([ + cell_one_hot.reshape(-1), + fire_ch.reshape(-1), + smoke_ch.reshape(-1), + vis_ch.reshape(-1), + agent_ch.reshape(-1), + ]) + + meta = observation.metadata or {} + wind = str(meta.get("wind_dir", ms.wind_dir or "CALM")).upper() + diff = str(meta.get("difficulty", "medium")).lower() + wi = WINDS.index(wind) if wind in WINDS else 0 + di = DIFFICULTIES.index(diff) if diff in DIFFICULTIES else 1 + + wind_oh = np.zeros(len(WINDS), dtype=np.float32); wind_oh[wi] = 1.0 + diff_oh = np.zeros(len(DIFFICULTIES), dtype=np.float32); diff_oh[di] = 1.0 + + # Fix 3 — map-agnostic exit compass features. + # Compute the direction vector and normalised Manhattan distance to the + # nearest exit cell (cell_type == 4) directly from the live grid. + # This gives the agent an exit "compass" that works on procedurally + # generated maps without memorising any layout. + EXIT_CELL_TYPE = 4 + ax, ay = ms.agent_x, ms.agent_y + gw, gh = ms.grid_w, ms.grid_h + best_dist = float(gw + gh) + best_dx = 0.0 + best_dy = 0.0 + for cy in range(gh): + for cx in range(gw): + if int(ms.cell_grid[cy * gw + cx]) == EXIT_CELL_TYPE: + d = abs(cx - ax) + abs(cy - ay) + if d < best_dist: + best_dist = d + best_dx = float(cx - ax) / max(1, gw - 1) + best_dy = float(cy - ay) / max(1, gh - 1) + exit_manhattan_norm = best_dist / float(gw + gh) + + global_features = np.array([ + float(observation.agent_health) / 100.0, + float(ms.agent_health) / 100.0, + float(ms.step_count) / max(1, ms.max_steps), + float(ms.fire_spread_rate), + float(ms.humidity), + float(ms.agent_x) / max(1, ms.grid_w - 1), + float(ms.agent_y) / max(1, ms.grid_h - 1), + float(meta.get("nearest_exit_distance", MAX_GRID_W + MAX_GRID_H) or 0.0) / float(MAX_GRID_W + MAX_GRID_H), + float(meta.get("reachable_exit_count", 0.0)) / 4.0, + float(meta.get("visible_cell_count", 0.0)) / float(MAX_GRID_W * MAX_GRID_H), + float(meta.get("fire_sources", 0.0)) / 5.0, + {"none": 0.0, "light": 0.33, "moderate": 0.66, "heavy": 1.0}.get(observation.smoke_level, 0.0), + 1.0 if ms.agent_alive else 0.0, + 1.0 if ms.agent_evacuated else 0.0, + # Fix 3: exit-compass (3 new scalars — map-agnostic, layout-independent) + best_dx, # signed x-direction toward nearest exit + best_dy, # signed y-direction toward nearest exit + exit_manhattan_norm, # how far away the exit is (0 = here, 1 = max) + ], dtype=np.float32) + + return np.concatenate([grid_features, global_features, wind_oh, diff_oh]).astype(np.float32) + + +# --------------------------------------------------------------------------- +# Neural Network +# --------------------------------------------------------------------------- + +class ActorCritic(nn.Module): + """Shared-backbone Actor-Critic network for PPO. + + Architecture: + Input → LayerNorm → FC(512) → LayerNorm → ReLU + → FC(256) → LayerNorm → ReLU + → FC(128) → ReLU + ┌──────────────┴──────────────┐ + Policy head (→ logits) Value head (→ scalar) + + LayerNorm before activations improves gradient flow for the large + (23128-dim) flat input without requiring feature normalization. + """ + + def __init__(self, input_dim: int, action_dim: int, hidden_sizes: Tuple[int, ...] = (512, 256, 128)): + super().__init__() + h1, h2, h3 = hidden_sizes + + self.shared = nn.Sequential( + nn.LayerNorm(input_dim), + nn.Linear(input_dim, h1), + nn.LayerNorm(h1), + nn.ReLU(), + nn.Linear(h1, h2), + nn.LayerNorm(h2), + nn.ReLU(), + nn.Linear(h2, h3), + nn.ReLU(), + ) + + # Orthogonal init — standard for PPO (improves early convergence) + self._init_orthogonal() + + self.policy_head = nn.Linear(h3, action_dim) + self.value_head = nn.Linear(h3, 1) + + # Small init for output heads prevents saturated softmax early on + nn.init.orthogonal_(self.policy_head.weight, gain=0.01) + nn.init.zeros_(self.policy_head.bias) + nn.init.orthogonal_(self.value_head.weight, gain=1.0) + nn.init.zeros_(self.value_head.bias) + + def _init_orthogonal(self) -> None: + for layer in self.shared: + if isinstance(layer, nn.Linear): + nn.init.orthogonal_(layer.weight, gain=np.sqrt(2)) + nn.init.zeros_(layer.bias) + + def forward( + self, + obs: torch.Tensor, + mask: torch.Tensor, + ) -> Tuple[torch.distributions.Categorical, torch.Tensor]: + """ + Args: + obs: (B, input_dim) float32 + mask: (B, action_dim) float32 — 1.0 = valid, 0.0 = invalid + Returns: + dist: Categorical distribution (action masking applied as -inf) + values: (B,) float32 + """ + features = self.shared(obs) + logits = self.policy_head(features) + + # Mask invalid actions with -inf before softmax (numerically stable) + logits = torch.where(mask.bool(), logits, torch.full_like(logits, -1e9)) + + dist = torch.distributions.Categorical(logits=logits) + values = self.value_head(features).squeeze(-1) + return dist, values + + def act( + self, + obs: torch.Tensor, + mask: torch.Tensor, + deterministic: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Sample (or take greedy) action. Returns (action, log_prob, value).""" + dist, values = self(obs, mask) + action = dist.mode if deterministic else dist.sample() + log_prob = dist.log_prob(action) + return action, log_prob, values + + def evaluate( + self, + obs: torch.Tensor, + mask: torch.Tensor, + action: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Evaluate stored actions during PPO update. Returns (log_prob, value, entropy).""" + dist, values = self(obs, mask) + log_prob = dist.log_prob(action) + entropy = dist.entropy() + return log_prob, values, entropy + + +# --------------------------------------------------------------------------- +# Rollout buffer +# --------------------------------------------------------------------------- + +@dataclass +class RolloutBuffer: + """Stores transitions for a batch of episodes before PPO update.""" + obs: List[np.ndarray] = field(default_factory=list) + masks: List[np.ndarray] = field(default_factory=list) + actions: List[int] = field(default_factory=list) + rewards: List[float] = field(default_factory=list) + log_probs: List[float] = field(default_factory=list) + values: List[float] = field(default_factory=list) + dones: List[bool] = field(default_factory=list) + + def clear(self) -> None: + self.obs.clear() + self.masks.clear() + self.actions.clear() + self.rewards.clear() + self.log_probs.clear() + self.values.clear() + self.dones.clear() + + def __len__(self) -> int: + return len(self.rewards) + + +# --------------------------------------------------------------------------- +# GAE computation +# --------------------------------------------------------------------------- + +def compute_gae( + rewards: np.ndarray, + values: np.ndarray, + dones: np.ndarray, + gamma: float, + gae_lambda: float, +) -> Tuple[np.ndarray, np.ndarray]: + """Generalized Advantage Estimation. + + Returns (returns, advantages) — both shape (T,). + Episode boundaries (done=True) reset the GAE accumulator so advantages + don't bleed across episodes within a mixed batch. + """ + T = len(rewards) + advantages = np.zeros(T, dtype=np.float32) + gae = 0.0 + next_value = 0.0 + for t in reversed(range(T)): + if dones[t]: + next_value = 0.0 + gae = 0.0 + delta = rewards[t] + gamma * next_value * (1.0 - dones[t]) - values[t] + gae = delta + gamma * gae_lambda * (1.0 - dones[t]) * gae + advantages[t] = gae + next_value = values[t] + returns = advantages + values + return returns, advantages + + +# --------------------------------------------------------------------------- +# Episode runner +# --------------------------------------------------------------------------- + +@dataclass +class EpisodeResult: + total_reward: float + steps: int + evacuated: bool + final_health: float + difficulty: str + + +def run_episode( + env: PyreEnvironment, + network: ActorCritic, + encoder: ObservationEncoder, + device: torch.device, + difficulty: str, + history_length: int, + buffer: RolloutBuffer, + deterministic: bool = False, +) -> EpisodeResult: + """Run one episode, appending transitions to *buffer*.""" + observation = env.reset(difficulty=difficulty) + zero_frame = np.zeros(encoder.base_dim, dtype=np.float32) + frames: deque = deque([zero_frame.copy() for _ in range(history_length)], maxlen=history_length) + frames.append(encoder.encode(observation)) + + total_reward = 0.0 + final_health = observation.agent_health + evacuated = False + steps = 0 + # Anti-loop tracking: remember the last LOOP_WINDOW positions this episode. + # Revisiting any of them means the agent is circling, not exploring. + LOOP_WINDOW = 12 + recent_positions: deque = deque(maxlen=LOOP_WINDOW) + + network.eval() + with torch.no_grad(): + while True: + state_vec = np.concatenate(list(frames), dtype=np.float32) + # exclude_look=True: RL agent sees full grid — look wastes steps + action_mask = build_action_mask(observation, exclude_look=True) + + obs_t = torch.tensor(state_vec, dtype=torch.float32, device=device).unsqueeze(0) + mask_t = torch.tensor(action_mask, dtype=torch.float32, device=device).unsqueeze(0) + + action_t, log_prob_t, value_t = network.act(obs_t, mask_t, deterministic=deterministic) + + action_idx = int(action_t.item()) + env_action = action_index_to_env_action(action_idx) + next_obs = env.step(env_action) + + reward = float(next_obs.reward or 0.0) + + # ---------------------------------------------------------------- + # Reward shaping 1 — idle penalty + # The env's -0.01/step is too weak; make waiting explicitly costly. + # ---------------------------------------------------------------- + chosen_action = env_action.action + if chosen_action == "wait": + reward -= 0.05 + + # ---------------------------------------------------------------- + # Reward shaping 2 — fire-approach penalty (Fix 2) + # Penalise landing on (or moving next to) a cell with active fire. + # This is stronger than the env's DangerPenalty and fires *before* + # health drain accumulates, teaching the agent to predict spread. + # We look at the NEW observation's map to catch the current step. + # ---------------------------------------------------------------- + ms_next = next_obs.map_state + if ms_next is not None and chosen_action.startswith("move"): + ax, ay = ms_next.agent_x, ms_next.agent_y + gw, gh = ms_next.grid_w, ms_next.grid_h + fire_grid = ms_next.fire_grid + for dx, dy in ((0, 1), (0, -1), (1, 0), (-1, 0)): + nx, ny = ax + dx, ay + dy + if 0 <= nx < gw and 0 <= ny < gh: + if float(fire_grid[ny * gw + nx]) > 0.15: + reward -= 0.15 # early fire-proximity warning + break + + # ---------------------------------------------------------------- + # Reward shaping 3 — anti-loop penalty + # If the agent steps onto a cell it occupied in the last LOOP_WINDOW + # steps, it is circling. Penalise to force forward exploration. + # Fires only on move actions — wait is already penalised above. + # ---------------------------------------------------------------- + if ms_next is not None and chosen_action.startswith("move"): + cur_pos = (ms_next.agent_x, ms_next.agent_y) + if cur_pos in recent_positions: + reward -= 0.2 # break the loop + recent_positions.append(cur_pos) + + done = bool(next_obs.done) + + buffer.obs.append(state_vec) + buffer.masks.append(action_mask) + buffer.actions.append(action_idx) + buffer.rewards.append(reward) + buffer.log_probs.append(float(log_prob_t.item())) + buffer.values.append(float(value_t.item())) + buffer.dones.append(done) + + total_reward += reward + steps += 1 + final_health = next_obs.agent_health + evacuated = next_obs.agent_evacuated + + frames.append(encoder.encode(next_obs)) + observation = next_obs + if done: + break + + return EpisodeResult( + total_reward=total_reward, + steps=steps, + evacuated=evacuated, + final_health=final_health, + difficulty=difficulty, + ) + + +# --------------------------------------------------------------------------- +# PPO update +# --------------------------------------------------------------------------- + +def ppo_update( + network: ActorCritic, + optimizer: Adam, + buffer: RolloutBuffer, + device: torch.device, + clip_eps: float, + value_clip_eps: float, + entropy_coef: float, + value_coef: float, + n_epochs: int, + minibatch_size: int, + gamma: float, + gae_lambda: float, + max_grad_norm: float, +) -> Dict[str, float]: + """Full PPO update over the collected rollout buffer.""" + rewards = np.array(buffer.rewards, dtype=np.float32) + values = np.array(buffer.values, dtype=np.float32) + dones = np.array(buffer.dones, dtype=np.float32) + + returns, advantages = compute_gae(rewards, values, dones, gamma, gae_lambda) + + # Normalize advantages across the whole batch (reduces variance) + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + obs_arr = torch.tensor(np.stack(buffer.obs), dtype=torch.float32, device=device) + mask_arr = torch.tensor(np.stack(buffer.masks), dtype=torch.float32, device=device) + action_arr = torch.tensor(buffer.actions, dtype=torch.long, device=device) + old_logp_arr = torch.tensor(buffer.log_probs, dtype=torch.float32, device=device) + return_arr = torch.tensor(returns, dtype=torch.float32, device=device) + adv_arr = torch.tensor(advantages, dtype=torch.float32, device=device) + old_value_arr = torch.tensor(values, dtype=torch.float32, device=device) + + T = len(buffer) + metrics = {"policy_loss": 0.0, "value_loss": 0.0, "entropy": 0.0, "approx_kl": 0.0, "clip_frac": 0.0} + n_updates = 0 + + network.train() + for _ in range(n_epochs): + perm = torch.randperm(T, device=device) + for start in range(0, T, minibatch_size): + idx = perm[start:start + minibatch_size] + if len(idx) < 2: + continue + + log_prob, value, entropy = network.evaluate(obs_arr[idx], mask_arr[idx], action_arr[idx]) + + # PPO ratio and clipped surrogate loss + ratio = torch.exp(log_prob - old_logp_arr[idx]) + adv_mb = adv_arr[idx] + surr1 = ratio * adv_mb + surr2 = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * adv_mb + policy_loss = -torch.min(surr1, surr2).mean() + + # Value loss with optional clipping (stabilises critic) + ret_mb = return_arr[idx] + old_val_mb = old_value_arr[idx] + value_pred_clipped = old_val_mb + torch.clamp(value - old_val_mb, -value_clip_eps, value_clip_eps) + value_loss = torch.max( + F.mse_loss(value, ret_mb), + F.mse_loss(value_pred_clipped, ret_mb), + ) + + entropy_loss = -entropy.mean() + + loss = policy_loss + value_coef * value_loss + entropy_coef * entropy_loss + + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(network.parameters(), max_grad_norm) + optimizer.step() + + with torch.no_grad(): + approx_kl = ((ratio - 1) - (log_prob - old_logp_arr[idx])).mean().item() + clip_frac = ((ratio - 1.0).abs() > clip_eps).float().mean().item() + + metrics["policy_loss"] += policy_loss.item() + metrics["value_loss"] += value_loss.item() + metrics["entropy"] += entropy.mean().item() + metrics["approx_kl"] += approx_kl + metrics["clip_frac"] += clip_frac + n_updates += 1 + + if n_updates > 0: + for k in metrics: + metrics[k] /= n_updates + return metrics + + +# --------------------------------------------------------------------------- +# Evaluation +# --------------------------------------------------------------------------- + +def evaluate_policy( + env: PyreEnvironment, + network: ActorCritic, + encoder: ObservationEncoder, + device: torch.device, + difficulty: str, + history_length: int, + n_episodes: int, +) -> Dict[str, float]: + rewards, successes, steps = [], [], [] + dummy_buffer = RolloutBuffer() + for _ in range(n_episodes): + result = run_episode( + env=env, network=network, encoder=encoder, device=device, + difficulty=difficulty, history_length=history_length, + buffer=dummy_buffer, deterministic=True, + ) + dummy_buffer.clear() + rewards.append(result.total_reward) + successes.append(float(result.evacuated)) + steps.append(result.steps) + return { + "reward_mean": float(np.mean(rewards)), + "reward_max": float(np.max(rewards)), + "success_rate": float(np.mean(successes)), + "steps_mean": float(np.mean(steps)), + } + + +# --------------------------------------------------------------------------- +# PNG graph (matplotlib) +# --------------------------------------------------------------------------- + +def save_training_graph_png( + path: Path, + episode_rows: List[Dict], + eval_rows: List[Dict], + window: int = 20, +) -> None: + """Save a publication-quality PNG training graph with dual Y-axes.""" + try: + import matplotlib + matplotlib.use("Agg") # non-interactive backend — no display needed + import matplotlib.pyplot as plt + import matplotlib.ticker as mticker + except ImportError: + print("[warn] matplotlib not installed — skipping PNG graph. Run: uv pip install matplotlib") + return + + if not episode_rows: + return + + path.parent.mkdir(parents=True, exist_ok=True) + + episodes = [int(r["episode"]) for r in episode_rows] + rewards = [float(r["reward"]) for r in episode_rows] + evacuated = [float(r["evacuated"]) for r in episode_rows] + difficulty = [str(r["difficulty"]) for r in episode_rows] + + # Moving average helper + def ma(values: list, w: int) -> list: + out, run, q = [], 0.0, [] + for v in values: + q.append(v); run += v + if len(q) > w: run -= q.pop(0) + out.append(run / len(q)) + return out + + reward_ma = ma(rewards, window) + success_ma = ma(evacuated, window) + + eval_eps = [int(r["episode"]) for r in eval_rows] + eval_succ = [float(r["success_rate"]) for r in eval_rows] + + # Difficulty shading regions + diff_colors = {"easy": "#d4edda", "medium": "#fff3cd", "hard": "#f8d7da"} + regions: List[tuple] = [] + if difficulty: + cur, start = difficulty[0], episodes[0] + for ep, d in zip(episodes[1:], difficulty[1:]): + if d != cur: + regions.append((start, ep, cur)) + cur, start = d, ep + regions.append((start, episodes[-1], cur)) + + fig, ax1 = plt.subplots(figsize=(14, 6)) + ax2 = ax1.twinx() + + # Shade difficulty regions + for x0, x1, diff in regions: + ax1.axvspan(x0, x1, color=diff_colors.get(diff, "#eeeeee"), alpha=0.35, zorder=0) + + # Zero line + ax1.axhline(0, color="#aaaaaa", linewidth=0.8, linestyle="--", zorder=1) + + # Raw reward (faint) + ax1.plot(episodes, rewards, color="#d1c7bc", linewidth=0.8, + alpha=0.6, label="Episode reward", zorder=2) + + # Reward moving average + ax1.plot(episodes, reward_ma, color="#c1661c", linewidth=2.5, + label=f"Reward (MA-{window})", zorder=3) + + # Success moving average (right axis) + ax2.plot(episodes, success_ma, color="#1a7a8a", linewidth=2.5, + linestyle="-", label=f"Success rate (MA-{window})", zorder=3) + + # Eval checkpoints + if eval_eps: + ax2.scatter(eval_eps, eval_succ, color="#0d5b6b", s=60, zorder=5, + marker="D", label="Eval success", edgecolors="white", linewidths=1.2) + + # Axes labels & formatting + ax1.set_xlabel("Episode", fontsize=13, fontweight="bold", labelpad=8) + ax1.set_ylabel("Reward", fontsize=13, fontweight="bold", color="#c1661c", labelpad=8) + ax2.set_ylabel("Success Rate", fontsize=13, fontweight="bold", color="#1a7a8a", labelpad=8) + + ax1.tick_params(axis="y", labelcolor="#c1661c") + ax2.tick_params(axis="y", labelcolor="#1a7a8a") + ax2.set_ylim(-0.05, 1.05) + ax2.yaxis.set_major_formatter(mticker.PercentFormatter(xmax=1.0, decimals=0)) + + ax1.grid(True, which="major", linestyle="--", linewidth=0.6, + color="#dddddd", alpha=0.8, zorder=0) + ax1.set_xlim(episodes[0], episodes[-1]) + + ax1.tick_params(axis="x", labelsize=10) + ax1.tick_params(axis="y", labelsize=10) + ax2.tick_params(axis="y", labelsize=10) + + # Title + total_eps = episodes[-1] + final_sr = success_ma[-1] if success_ma else 0.0 + fig.suptitle( + f"Pyre PPO Training — {total_eps} episodes | final success rate: {final_sr:.0%}", + fontsize=14, fontweight="bold", y=1.01, + ) + + # Difficulty legend patches + import matplotlib.patches as mpatches + diff_patches = [ + mpatches.Patch(color=diff_colors[d], alpha=0.6, label=d.capitalize()) + for d in ["easy", "medium", "hard"] if any(r == d for r in difficulty) + ] + + # Combine legends from both axes + h1, l1 = ax1.get_legend_handles_labels() + h2, l2 = ax2.get_legend_handles_labels() + ax1.legend(h1 + h2 + diff_patches, l1 + l2 + [p.get_label() for p in diff_patches], + loc="upper left", fontsize=9, framealpha=0.85) + + fig.tight_layout() + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + + +# --------------------------------------------------------------------------- +# Curriculum scheduling +# --------------------------------------------------------------------------- + +def build_curriculum(schedule_str: str, n_episodes: int) -> List[str]: + """Expand comma-separated difficulty stages evenly over n_episodes. + + Example: 'easy,medium,hard' with 300 episodes → 100 each. + Used only when patience_threshold=0 (static schedule). + """ + stages = [s.strip().lower() for s in schedule_str.split(",") if s.strip()] + if not stages: + stages = ["medium"] + for s in stages: + if s not in DIFFICULTIES: + raise ValueError(f"Unknown difficulty '{s}'. Choose from {DIFFICULTIES}.") + seg = max(1, n_episodes // len(stages)) + schedule = [] + for s in stages: + schedule.extend([s] * seg) + while len(schedule) < n_episodes: + schedule.append(stages[-1]) + return schedule[:n_episodes] + + +class PatienceCurriculum: + """Dynamic difficulty scheduler that gates advancement on sustained success rate. + + Stays on current difficulty until success_rate_30 >= threshold for + patience_window consecutive episodes, then advances to the next stage. + During the hard phase an optional mix_ratio fraction of episodes are + replayed on the previous (medium) difficulty to prevent catastrophic + forgetting of the medium policy. + + Args: + stages: ordered list of difficulty strings, e.g. ['easy','medium','hard'] + threshold: minimum success rate (0–1) required before advancing + patience_window: number of consecutive episodes that must meet threshold + mix_ratio: fraction of hard-phase episodes to run on medium instead (0–1) + """ + + def __init__( + self, + stages: List[str], + threshold: float, + patience_window: int, + mix_ratio: float = 0.0, + ) -> None: + self.stages = stages + self.threshold = threshold + self.patience_window = patience_window + self.mix_ratio = mix_ratio + self.stage_idx = 0 + self._streak = 0 + + @property + def current(self) -> str: + return self.stages[self.stage_idx] + + def step(self, success_rate_30: float) -> str: + """Call once per episode *after* appending to success_window. + + Returns the difficulty to use for the *next* episode. + Also handles the hard-phase medium-mix injection. + """ + if self.stage_idx < len(self.stages) - 1: + if success_rate_30 >= self.threshold: + self._streak += 1 + else: + self._streak = 0 + if self._streak >= self.patience_window: + self.stage_idx += 1 + self._streak = 0 + print( + f" [curriculum] Advanced to '{self.current}' " + f"(success_rate_30={success_rate_30:.2f} >= {self.threshold} " + f"for {self.patience_window} eps)" + ) + + # Hard-phase mix: occasionally replay medium to prevent forgetting + if self.current == "hard" and self.mix_ratio > 0.0 and len(self.stages) >= 2: + prev = self.stages[self.stage_idx - 1] + if np.random.rand() < self.mix_ratio: + return prev # medium replay episode + return self.current + + +# --------------------------------------------------------------------------- +# Checkpoint +# --------------------------------------------------------------------------- + +def save_checkpoint( + path: Path, + network: ActorCritic, + optimizer: Adam, + scheduler, + episode: int, + args: argparse.Namespace, +) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + torch.save({ + "episode": episode, + "network_state": network.state_dict(), + "optimizer_state": optimizer.state_dict(), + "scheduler_state": scheduler.state_dict() if scheduler else None, + "args": vars(args), + }, path) + + +def load_checkpoint( + path: Path, + network: ActorCritic, + optimizer: Adam, + scheduler, +) -> int: + ckpt = torch.load(path, map_location="cpu", weights_only=False) + network.load_state_dict(ckpt["network_state"]) + optimizer.load_state_dict(ckpt["optimizer_state"]) + if scheduler and ckpt.get("scheduler_state"): + scheduler.load_state_dict(ckpt["scheduler_state"]) + start_episode = int(ckpt.get("episode", 0)) + print(f"[resume] Loaded checkpoint from episode {start_episode}: {path}") + return start_episode + + +# --------------------------------------------------------------------------- +# CSV logging +# --------------------------------------------------------------------------- + +def save_csv(path: Path, rows: List[Dict]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + if not rows: + return + with path.open("w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) + writer.writeheader() + writer.writerows(rows) + + +# --------------------------------------------------------------------------- +# Main training loop +# --------------------------------------------------------------------------- + +def train(args: argparse.Namespace) -> None: + device = torch.device(args.device if torch.cuda.is_available() or args.device == "cpu" else "cpu") + if args.device == "cuda" and not torch.cuda.is_available(): + print("[warn] CUDA not available - falling back to CPU.") + + print(f"[config] device={device} episodes={args.episodes} batch={args.update_every} eps " + f"hidden={args.hidden_sizes} frames={args.history_length}") + print(f"[config] curriculum: {args.difficulty_schedule}") + print(f"[config] PPO clip_eps={args.clip_eps} entropy={args.entropy_coef} lr={args.learning_rate}\n") + + encoder = ObservationEncoder(mode=args.observation_mode) + input_dim = encoder.base_dim * args.history_length + + hidden_sizes = tuple(int(h) for h in args.hidden_sizes.split(",")) + network = ActorCritic(input_dim=input_dim, action_dim=ACTION_DIM, hidden_sizes=hidden_sizes).to(device) + optimizer = Adam(network.parameters(), lr=args.learning_rate, eps=1e-5) + + total_steps_for_scheduler = args.episodes // args.update_every + scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=args.lr_end_factor, + total_iters=max(1, total_steps_for_scheduler)) if args.lr_decay else None + + env = PyreEnvironment(max_steps=args.max_steps) + + # Build curriculum — patience-gated (dynamic) or static + stages = [s.strip().lower() for s in args.difficulty_schedule.split(",") if s.strip()] + if args.patience_threshold > 0: + patience_curriculum = PatienceCurriculum( + stages=stages, + threshold=args.patience_threshold, + patience_window=args.patience_window, + mix_ratio=args.hard_mix_ratio, + ) + static_curriculum: Optional[List[str]] = None + print(f"[curriculum] patience-gated: threshold={args.patience_threshold} " + f"window={args.patience_window} mix={args.hard_mix_ratio}") + else: + patience_curriculum = None + static_curriculum = build_curriculum(args.difficulty_schedule, args.episodes) + print(f"[curriculum] static: {args.difficulty_schedule}") + + start_episode = 0 + if args.resume: + resume_path = Path(args.resume) + if resume_path.exists(): + start_episode = load_checkpoint(resume_path, network, optimizer, scheduler) + + # Tracking + buffer = RolloutBuffer() + episode_rows: List[Dict] = [] + eval_rows: List[Dict] = [] + reward_window: deque = deque(maxlen=30) + success_window: deque = deque(maxlen=30) + + n_params = sum(p.numel() for p in network.parameters()) + print(f"[network] Parameters: {n_params:,}") + print(f"[network] Input dim: {input_dim:,} (encoder.base_dim={encoder.base_dim} x {args.history_length} frames)") + print(f"[network] Action dim: {ACTION_DIM} (4 move + 4 look + 1 wait + {MAX_DOORS} open + {MAX_DOORS} close)") + print() + + t_start = time.time() + + for ep_idx in range(start_episode, args.episodes): + # Determine difficulty for this episode + if patience_curriculum is not None: + difficulty = patience_curriculum.current + else: + difficulty = static_curriculum[ep_idx] # type: ignore[index] + + result = run_episode( + env=env, network=network, encoder=encoder, device=device, + difficulty=difficulty, history_length=args.history_length, + buffer=buffer, deterministic=False, + ) + + reward_window.append(result.total_reward) + success_window.append(float(result.evacuated)) + + # Advance patience curriculum *after* updating success_window + if patience_curriculum is not None: + difficulty = patience_curriculum.step(float(np.mean(success_window))) + + ep_num = ep_idx + 1 + episode_rows.append({ + "episode": ep_num, + "difficulty": difficulty, + "reward": round(result.total_reward, 4), + "evacuated": int(result.evacuated), + "steps": result.steps, + "final_health": round(result.final_health, 2), + "reward_mean_30": round(float(np.mean(reward_window)), 4), + "success_rate_30": round(float(np.mean(success_window)), 4), + }) + + elapsed = time.time() - t_start + print( + f"ep={ep_num:04d} [{difficulty:<6}] " + f"steps={result.steps:03d} " + f"reward={result.total_reward:+8.3f} " + f"evac={int(result.evacuated)} " + f"hp={result.final_health:5.1f} " + f"suc30={float(np.mean(success_window)):.2f} " + f"r30={float(np.mean(reward_window)):+7.2f} " + f"t={elapsed:.0f}s" + ) + + # PPO update every N episodes + should_update = (ep_num % args.update_every == 0) or (ep_num == args.episodes) + if should_update and len(buffer) > 0: + ppo_metrics = ppo_update( + network=network, optimizer=optimizer, buffer=buffer, device=device, + clip_eps=args.clip_eps, value_clip_eps=args.clip_eps, + entropy_coef=args.entropy_coef, value_coef=args.value_coef, + n_epochs=args.update_epochs, minibatch_size=args.minibatch_size, + gamma=args.gamma, gae_lambda=args.gae_lambda, + max_grad_norm=args.max_grad_norm, + ) + if scheduler: + scheduler.step() + buffer.clear() + + cur_lr = optimizer.param_groups[0]["lr"] + print( + f" >> PPO update samples={len(buffer) if len(buffer) > 0 else 'flushed'} " + f"pi_loss={ppo_metrics['policy_loss']:+.4f} " + f"v_loss={ppo_metrics['value_loss']:.4f} " + f"entropy={ppo_metrics['entropy']:.4f} " + f"kl={ppo_metrics['approx_kl']:.4f} " + f"clip%={ppo_metrics['clip_frac']:.2f} " + f"lr={cur_lr:.2e}" + ) + + # Periodic evaluation + if args.eval_every > 0 and (ep_num % args.eval_every == 0 or ep_num == args.episodes): + eval_m = evaluate_policy( + env=env, network=network, encoder=encoder, device=device, + difficulty=args.eval_difficulty, history_length=args.history_length, + n_episodes=args.eval_episodes, + ) + eval_rows.append({"episode": ep_num, "difficulty": args.eval_difficulty, **{k: round(v, 4) for k, v in eval_m.items()}}) + print( + f" ** EVAL [{args.eval_difficulty}] " + f"reward={eval_m['reward_mean']:+.3f} " + f"success={eval_m['success_rate']:.2f} " + f"steps={eval_m['steps_mean']:.1f}" + ) + + # Periodic checkpoint + if args.checkpoint and args.checkpoint_every > 0 and ep_num % args.checkpoint_every == 0: + save_checkpoint(Path(args.checkpoint), network, optimizer, scheduler, ep_num, args) + print(f" [ckpt] saved -> {args.checkpoint}") + + # Final save + if args.output: + out = Path(args.output) + save_checkpoint(out, network, optimizer, scheduler, args.episodes, args) + print(f"\n[done] Model saved -> {out}") + + if args.save_metrics: + csv_path = out.with_suffix(".csv") + save_csv(csv_path, episode_rows) + print(f"[done] Metrics CSV -> {csv_path}") + + if eval_rows: + eval_csv = out.parent / (out.stem + "_eval.csv") + save_csv(eval_csv, eval_rows) + print(f"[done] Eval CSV -> {eval_csv}") + + if args.save_graph: + png_path = out.with_suffix(".png") + save_training_graph_png(png_path, episode_rows, eval_rows) + print(f"[done] Graph PNG -> {png_path}") + + total_time = time.time() - t_start + print(f"\n[summary] {args.episodes - start_episode} episodes in {total_time:.1f}s " + f"({(args.episodes - start_episode) / max(1, total_time):.1f} eps/s)") + print(f"[summary] Final success rate (last 30): {float(np.mean(success_window)):.2f}") + print(f"[summary] Final reward mean (last 30): {float(np.mean(reward_window)):+.3f}") + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def describe_env() -> None: + print(__doc__) + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser( + description="PPO training for Pyre fire-evacuation environment", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Training scale + p.add_argument("--episodes", type=int, default=400, help="Total training episodes") + p.add_argument("--max-steps", type=int, default=150, help="Max steps per episode") + p.add_argument("--device", type=str, default="cuda", choices=("cuda", "cpu"), help="Torch device") + + # Curriculum + p.add_argument("--difficulty", type=str, default="easy", choices=DIFFICULTIES, + help="Single difficulty (overridden by --difficulty-schedule if set)") + p.add_argument("--difficulty-schedule", type=str, default="easy,medium,hard", + help="Comma-separated curriculum stages. With --patience-threshold>0 these " + "become gated stages; otherwise split evenly across episodes.") + p.add_argument("--patience-threshold", type=float, default=0.65, + help="Success-rate threshold (30-ep window) required before advancing to next " + "difficulty. Set 0 to use static even-split schedule.") + p.add_argument("--patience-window", type=int, default=15, + help="Episodes that must sustain >= patience-threshold before advancing.") + p.add_argument("--hard-mix-ratio", type=float, default=0.25, + help="Fraction of hard-phase episodes to replay on medium (0=pure hard). " + "Prevents catastrophic forgetting of the medium policy.") + p.add_argument("--eval-difficulty", type=str, default="medium", choices=DIFFICULTIES) + p.add_argument("--eval-episodes", type=int, default=10) + p.add_argument("--eval-every", type=int, default=50) + + # Observation + p.add_argument("--observation-mode", type=str, default="visible", choices=("visible", "full"), + help="'visible': partial obs (realistic); 'full': oracle grid (debug)") + p.add_argument("--history-length", type=int, default=4, + help="Frames stacked per observation (temporal context for partial obs)") + + # Network + p.add_argument("--hidden-sizes", type=str, default="512,256,128", + help="Comma-separated MLP hidden layer sizes") + + # PPO hyperparameters + p.add_argument("--update-every", type=int, default=5, + help="Episodes between PPO updates (smaller = faster feedback loop early in training)") + p.add_argument("--update-epochs", type=int, default=4, + help="Gradient passes over each collected batch (PPO allows >1)") + p.add_argument("--minibatch-size", type=int, default=256) + p.add_argument("--clip-eps", type=float, default=0.2, help="PPO surrogate clip ε") + p.add_argument("--entropy-coef", type=float, default=0.03, + help="Entropy bonus coefficient — higher = more exploration (0.03 default encourages early exit-seeking)") + p.add_argument("--value-coef", type=float, default=0.5) + p.add_argument("--gamma", type=float, default=0.99) + p.add_argument("--gae-lambda", type=float, default=0.95) + p.add_argument("--max-grad-norm", type=float, default=0.5) + + # Optimizer / LR schedule + p.add_argument("--learning-rate", type=float, default=3e-4) + p.add_argument("--lr-decay", action="store_true", default=True, + help="Linear LR decay to lr_end_factor × initial_lr over training") + p.add_argument("--lr-end-factor", type=float, default=0.1, + help="LR at end of training = initial_lr × this value") + + # Persistence + p.add_argument("--output", type=str, default="artifacts/pyre_ppo.pt", + help="Path to save final model checkpoint") + p.add_argument("--checkpoint", type=str, default="artifacts/pyre_ppo_checkpoint.pt", + help="Path for periodic checkpoints (also used by --resume)") + p.add_argument("--checkpoint-every", type=int, default=50) + p.add_argument("--resume", type=str, default=None, + help="Path to checkpoint to resume training from") + p.add_argument("--save-metrics", action="store_true", default=True, + help="Save per-episode metrics as CSV alongside the model") + p.add_argument("--save-graph", action="store_true", default=True, + help="Save a PNG training graph alongside the model (requires matplotlib)") + + # Misc + p.add_argument("--seed", type=int, default=42) + p.add_argument("--describe-only", action="store_true", + help="Print environment/algorithm description and exit") + + return p.parse_args() + + +def main() -> None: + args = parse_args() + + if args.describe_only: + describe_env() + return + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + train(args) + + +if __name__ == "__main__": + main() diff --git a/examples/train_torch_ppo_http.py b/examples/train_torch_ppo_http.py new file mode 100644 index 0000000000000000000000000000000000000000..14a6df52891b3120e39102de1e6662b314f7ccdc --- /dev/null +++ b/examples/train_torch_ppo_http.py @@ -0,0 +1,492 @@ +"""PPO trainer that talks to the Pyre env via HTTP (localhost:8000). + +Identical training logic to train_torch_ppo.py, but the environment is +accessed through the REST API instead of a direct Python import. This +lets you run the server once and connect any number of training scripts, +remote notebooks, or evaluation tools to the same live instance. + +Usage +----- +1. Start the server (in a separate terminal): + cd openenv-pyre + .venv/Scripts/python.exe server/app.py + +2. Run this script: + .venv/Scripts/python.exe examples/train_torch_ppo_http.py + +Optional flags (identical to train_torch_ppo.py): + --server Base URL of the Pyre server [default: http://localhost:8000] + --episodes Total training episodes [default: 400] + --difficulty-schedule Curriculum [default: easy,easy,easy,medium,medium] + --output Where to save the model .pt [default: artifacts/pyre_ppo_http.pt] + ... (all other flags are the same as train_torch_ppo.py) +""" + +from __future__ import annotations + +import argparse +import csv +import os +import sys +import time +from collections import deque +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional + +import numpy as np +import requests +import torch +import torch.nn as nn +import torch.optim as optim +from torch.distributions import Categorical + +# --------------------------------------------------------------------------- +# Resolve project root so we can import shared models regardless of CWD +# --------------------------------------------------------------------------- +_HERE = Path(__file__).resolve().parent +_ROOT = _HERE.parent +if str(_ROOT) not in sys.path: + sys.path.insert(0, str(_ROOT)) + +try: + from models import PyreAction, PyreMapState, PyreObservation +except ImportError: + from openenv_pyre.models import PyreAction, PyreMapState, PyreObservation + +# Reuse all shared utilities from the direct-import trainer +from examples.train_torch_ppo import ( + ACTION_KEYS, + ACTION_DIM, + ACTION_TO_INDEX, + DIFFICULTIES, + MAX_DOORS, + MAX_GRID_H, + MAX_GRID_W, + WAIT_KEY, + WINDS, + ActorCritic, + ObservationEncoder, + RolloutBuffer, + action_index_to_env_action, + build_action_mask, + compute_gae, + ppo_update, + save_training_graph_png, +) + + +# --------------------------------------------------------------------------- +# HTTP environment wrapper +# --------------------------------------------------------------------------- + +class HttpPyreEnv: + """Thin wrapper around the Pyre REST API. + + Exposes the same ``reset()`` / ``step()`` interface as ``PyreEnvironment`` + so the episode runner needs no changes. + + POST /reset → {"difficulty": str, "seed"?: int} + POST /step → {"action": str, "direction"?: str, + "target_id"?: str, "door_state"?: str} + Both return → {"observation": {...}, "reward": float, + "done": bool, "metadata": {...}} + """ + + def __init__(self, base_url: str = "http://localhost:8000", timeout: int = 15): + self.base_url = base_url.rstrip("/") + self.timeout = timeout + self.session = requests.Session() + self.session.headers.update({"Content-Type": "application/json"}) + + # ------------------------------------------------------------------ + def _parse(self, data: Dict[str, Any]) -> PyreObservation: + """Convert a raw JSON response dict into a PyreObservation.""" + obs_raw = data.get("observation", data) + + map_state: Optional[PyreMapState] = None + ms_raw = obs_raw.get("map_state") + if ms_raw: + map_state = PyreMapState(**ms_raw) + + return PyreObservation( + narrative=obs_raw.get("narrative", ""), + agent_evacuated=obs_raw.get("agent_evacuated", False), + location_label=obs_raw.get("location_label", ""), + smoke_level=obs_raw.get("smoke_level", "none"), + fire_visible=obs_raw.get("fire_visible", False), + fire_direction=obs_raw.get("fire_direction"), + agent_health=float(obs_raw.get("agent_health", 100.0)), + health_status=obs_raw.get("health_status", "Good"), + wind_dir=obs_raw.get("wind_dir", "CALM"), + visible_objects=obs_raw.get("visible_objects", []), + blocked_exit_ids=obs_raw.get("blocked_exit_ids", []), + audible_signals=obs_raw.get("audible_signals", []), + elapsed_steps=obs_raw.get("elapsed_steps", 0), + last_action_feedback=obs_raw.get("last_action_feedback", ""), + available_actions_hint=obs_raw.get("available_actions_hint", []), + map_state=map_state, + reward=float(data.get("reward", 0.0)), + done=bool(data.get("done", False)), + metadata=data.get("metadata", {}), + ) + + # ------------------------------------------------------------------ + def reset(self, difficulty: str = "easy", seed: Optional[int] = None) -> PyreObservation: + payload: Dict[str, Any] = {"difficulty": difficulty} + if seed is not None: + payload["seed"] = seed + resp = self.session.post( + f"{self.base_url}/reset", json=payload, timeout=self.timeout + ) + resp.raise_for_status() + return self._parse(resp.json()) + + # ------------------------------------------------------------------ + def step(self, action: PyreAction) -> PyreObservation: + payload: Dict[str, Any] = {"action": action.action} + if action.direction is not None: + payload["direction"] = action.direction + if action.target_id is not None: + payload["target_id"] = action.target_id + if action.door_state is not None: + payload["door_state"] = action.door_state + resp = self.session.post( + f"{self.base_url}/step", json=payload, timeout=self.timeout + ) + resp.raise_for_status() + return self._parse(resp.json()) + + # ------------------------------------------------------------------ + def health_check(self) -> bool: + """Return True if the server is reachable.""" + try: + r = self.session.get(f"{self.base_url}/state", timeout=5) + return r.status_code < 500 + except requests.exceptions.RequestException: + return False + + +# --------------------------------------------------------------------------- +# Episode runner (identical reward shaping as train_torch_ppo.py) +# --------------------------------------------------------------------------- + +@dataclass +class EpisodeResult: + total_reward: float + steps: int + evacuated: bool + final_health: float + difficulty: str + + +def run_episode( + env: HttpPyreEnv, + network: ActorCritic, + encoder: ObservationEncoder, + device: torch.device, + difficulty: str, + history_length: int, + buffer: RolloutBuffer, + deterministic: bool = False, +) -> EpisodeResult: + observation = env.reset(difficulty=difficulty) + zero_frame = np.zeros(encoder.base_dim, dtype=np.float32) + frames: deque = deque([zero_frame.copy() for _ in range(history_length)], maxlen=history_length) + frames.append(encoder.encode(observation)) + + total_reward = 0.0 + final_health = observation.agent_health + evacuated = False + steps = 0 + LOOP_WINDOW = 12 + recent_positions: deque = deque(maxlen=LOOP_WINDOW) + + network.eval() + with torch.no_grad(): + while True: + state_vec = np.concatenate(list(frames), dtype=np.float32) + action_mask = build_action_mask(observation, exclude_look=True) + + obs_t = torch.tensor(state_vec, dtype=torch.float32, device=device).unsqueeze(0) + mask_t = torch.tensor(action_mask, dtype=torch.float32, device=device).unsqueeze(0) + + action_t, log_prob_t, value_t = network.act(obs_t, mask_t, deterministic=deterministic) + + action_idx = int(action_t.item()) + env_action = action_index_to_env_action(action_idx) + next_obs = env.step(env_action) + + reward = float(next_obs.reward or 0.0) + chosen_action = env_action.action + + # Shaping 1 — idle penalty + if chosen_action == "wait": + reward -= 0.05 + + # Shaping 2 — fire-approach penalty + ms_next = next_obs.map_state + if ms_next is not None and chosen_action.startswith("move"): + ax, ay = ms_next.agent_x, ms_next.agent_y + gw, gh = ms_next.grid_w, ms_next.grid_h + for dx, dy in ((0, 1), (0, -1), (1, 0), (-1, 0)): + nx, ny = ax + dx, ay + dy + if 0 <= nx < gw and 0 <= ny < gh: + if float(ms_next.fire_grid[ny * gw + nx]) > 0.15: + reward -= 0.15 + break + + # Shaping 3 — anti-loop penalty + if ms_next is not None and chosen_action.startswith("move"): + cur_pos = (ms_next.agent_x, ms_next.agent_y) + if cur_pos in recent_positions: + reward -= 0.2 + recent_positions.append(cur_pos) + + done = bool(next_obs.done) + + buffer.obs.append(state_vec) + buffer.masks.append(action_mask) + buffer.actions.append(action_idx) + buffer.rewards.append(reward) + buffer.log_probs.append(float(log_prob_t.item())) + buffer.values.append(float(value_t.item())) + buffer.dones.append(done) + + total_reward += reward + steps += 1 + final_health = next_obs.agent_health + evacuated = next_obs.agent_evacuated + + frames.append(encoder.encode(next_obs)) + observation = next_obs + if done: + break + + return EpisodeResult( + total_reward=total_reward, + steps=steps, + evacuated=evacuated, + final_health=final_health, + difficulty=difficulty, + ) + + +# --------------------------------------------------------------------------- +# Training loop +# --------------------------------------------------------------------------- + +def train(args: argparse.Namespace) -> None: + device = torch.device("cuda" if args.device == "cuda" and torch.cuda.is_available() else "cpu") + encoder = ObservationEncoder(mode=args.observation_mode) + input_dim = encoder.base_dim * args.history_length + hidden_sizes = [int(x) for x in args.hidden_sizes.split(",")] + action_dim = ACTION_DIM + + # Connect to server + env = HttpPyreEnv(base_url=args.server) + print(f"[server] Connecting to {args.server} ...", end=" ", flush=True) + if not env.health_check(): + print("FAILED\n[error] Server not reachable. Start it with: python server/app.py") + sys.exit(1) + print("OK") + + # Network + network = ActorCritic(input_dim, action_dim, hidden_sizes).to(device) + optimizer = optim.Adam(network.parameters(), lr=args.lr) + + total_params = sum(p.numel() for p in network.parameters()) + print(f"\n[config] server={args.server}") + print(f"[config] device={device} episodes={args.episodes} batch={args.update_every} eps") + print(f"[config] curriculum: {args.difficulty_schedule}") + print(f"[config] PPO clip_eps={args.clip_eps} entropy={args.entropy_coef} lr={args.lr}") + print(f"\n[network] Parameters: {total_params:,}") + print(f"[network] Input dim: {input_dim:,} (encoder.base_dim={encoder.base_dim} x {args.history_length} frames)") + print(f"[network] Action dim: {action_dim} (4 move + 4 look + 1 wait + {MAX_DOORS} open + {MAX_DOORS} close)\n", flush=True) + + schedule = args.difficulty_schedule.split(",") + buffer = RolloutBuffer() + metrics: list = [] + eval_metrics: list = [] + success_window: deque = deque(maxlen=30) + reward_window: deque = deque(maxlen=30) + t0 = time.time() + lr_scheduler = optim.lr_scheduler.LinearLR( + optimizer, start_factor=1.0, end_factor=0.1, total_iters=args.episodes + ) + + for ep in range(1, args.episodes + 1): + stage_idx = min(int((ep - 1) / args.episodes * len(schedule)), len(schedule) - 1) + difficulty = schedule[stage_idx] + + result = run_episode(env, network, encoder, device, difficulty, args.history_length, buffer) + success_window.append(1 if result.evacuated else 0) + reward_window.append(result.total_reward) + suc30 = sum(success_window) / len(success_window) + r30 = sum(reward_window) / len(reward_window) + elapsed = int(time.time() - t0) + + evac_sym = "1" if result.evacuated else "0" + print( + f"ep={ep:04d} [{difficulty:<6}] steps={result.steps:03d} " + f"reward={result.total_reward:+8.3f} evac={evac_sym} " + f"hp={result.final_health:5.1f} suc30={suc30:.2f} " + f"r30={r30:+7.2f} t={elapsed}s" + ) + + metrics.append({ + "episode": ep, "difficulty": difficulty, "steps": result.steps, + "reward": round(result.total_reward, 4), "evacuated": int(result.evacuated), + "final_health": result.final_health, "suc30": round(suc30, 3), "r30": round(r30, 3), + }) + + # PPO update + if ep % args.update_every == 0 and len(buffer.obs) > 0: + network.train() + stats = ppo_update( + network=network, optimizer=optimizer, buffer=buffer, device=device, + clip_eps=args.clip_eps, value_clip_eps=args.clip_eps, + entropy_coef=args.entropy_coef, value_coef=args.value_coef, + n_epochs=args.update_epochs, minibatch_size=args.minibatch_size, + gamma=args.gamma, gae_lambda=args.gae_lambda, + max_grad_norm=args.max_grad_norm, + ) + lr_scheduler.step() + cur_lr = optimizer.param_groups[0]["lr"] + print( + f" >> PPO update samples=flushed " + f"pi_loss={stats['policy_loss']:+.4f} v_loss={stats['value_loss']:.4f} " + f"entropy={stats['entropy']:.4f} kl={stats['approx_kl']:.4f} " + f"clip%={stats['clip_frac']:.2f} lr={cur_lr:.2e}" + ) + buffer.clear() + network.eval() + + # Evaluation + if ep % args.eval_every == 0: + eval_rewards, eval_success, eval_steps_list = [], [], [] + eval_buf = RolloutBuffer() + for _ in range(args.eval_episodes): + er = run_episode( + env, network, encoder, device, + args.eval_difficulty, args.history_length, + eval_buf, deterministic=True, + ) + eval_rewards.append(er.total_reward) + eval_success.append(1 if er.evacuated else 0) + eval_steps_list.append(er.steps) + avg_r = sum(eval_rewards) / len(eval_rewards) + avg_s = sum(eval_success) / len(eval_success) + avg_st = sum(eval_steps_list) / len(eval_steps_list) + print(f" ** EVAL [{args.eval_difficulty}] reward={avg_r:+.3f} success={avg_s:.2f} steps={avg_st:.1f}") + eval_metrics.append({ + "episode": ep, "eval_difficulty": args.eval_difficulty, + "avg_reward": round(avg_r, 4), "success_rate": round(avg_s, 3), + "avg_steps": round(avg_st, 1), + }) + + # Checkpoint + if args.checkpoint and ep % args.checkpoint_every == 0: + torch.save(network.state_dict(), args.checkpoint) + print(f" [ckpt] saved -> {args.checkpoint}") + + # --- Save artefacts --- + out = Path(args.output) + out.parent.mkdir(parents=True, exist_ok=True) + torch.save(network.state_dict(), out) + print(f"\n[done] Model saved -> {out}") + + if args.save_metrics and metrics: + csv_path = out.with_suffix(".csv") + with open(csv_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=metrics[0].keys()) + writer.writeheader() + writer.writerows(metrics) + print(f"[done] Metrics CSV -> {csv_path}") + + if eval_metrics: + eval_csv = out.with_stem(out.stem + "_eval").with_suffix(".csv") + with open(eval_csv, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=eval_metrics[0].keys()) + writer.writeheader() + writer.writerows(eval_metrics) + print(f"[done] Eval CSV -> {eval_csv}") + + if args.save_graph: + try: + png_path = out.with_suffix(".png") + save_training_graph_png(metrics, eval_metrics, str(png_path)) + print(f"[done] Graph PNG -> {png_path}") + except Exception as e: + print(f"[warn] Graph skipped: {e}") + + suc_final = sum(success_window) / max(1, len(success_window)) + r_final = sum(reward_window) / max(1, len(reward_window)) + elapsed_total = time.time() - t0 + print(f"\n[summary] {args.episodes} episodes in {elapsed_total:.1f}s ({args.episodes / elapsed_total:.1f} eps/s)") + print(f"[summary] Final success rate (last 30): {suc_final:.2f}") + print(f"[summary] Final reward mean (last 30): {r_final:+.3f}") + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser( + description="PPO trainer using the Pyre HTTP server (localhost:8000)" + ) + + # Server + p.add_argument("--server", type=str, default="http://localhost:8000", + help="Base URL of the running Pyre env server") + + # Training + p.add_argument("--episodes", type=int, default=400) + p.add_argument("--device", type=str, default="cpu", choices=("cuda", "cpu")) + + # Curriculum + p.add_argument("--difficulty-schedule", type=str, default="easy,easy,easy,medium,medium") + p.add_argument("--eval-difficulty", type=str, default="medium", choices=DIFFICULTIES) + p.add_argument("--eval-episodes", type=int, default=10) + p.add_argument("--eval-every", type=int, default=50) + + # Observation + p.add_argument("--observation-mode", type=str, default="visible", choices=("visible", "full")) + p.add_argument("--history-length", type=int, default=4) + + # Network + p.add_argument("--hidden-sizes", type=str, default="256,128,64") + + # PPO + p.add_argument("--lr", type=float, default=3e-4) + p.add_argument("--gamma", type=float, default=0.99) + p.add_argument("--gae-lambda", type=float, default=0.95) + p.add_argument("--clip-eps", type=float, default=0.2) + p.add_argument("--value-coef", type=float, default=0.5) + p.add_argument("--entropy-coef", type=float, default=0.03) + p.add_argument("--update-every", type=int, default=5) + p.add_argument("--update-epochs", type=int, default=4) + p.add_argument("--minibatch-size", type=int, default=256) + p.add_argument("--max-grad-norm", type=float, default=0.5) + + # Output + p.add_argument("--output", type=str, default="artifacts/pyre_ppo_http.pt") + p.add_argument("--checkpoint", type=str, default="artifacts/pyre_ppo_http_ckpt.pt") + p.add_argument("--checkpoint-every", type=int, default=50) + p.add_argument("--save-metrics", action="store_true", default=True) + p.add_argument("--save-graph", action="store_true", default=True) + p.add_argument("--seed", type=int, default=42) + + return p.parse_args() + + +def main() -> None: + args = parse_args() + torch.manual_seed(args.seed) + np.random.seed(args.seed) + train(args) + + +if __name__ == "__main__": + main() diff --git a/frontend/README.md b/frontend/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d100c62e06da7611548c15427275df915e41e363 --- /dev/null +++ b/frontend/README.md @@ -0,0 +1,93 @@ +# Pyre — Frontend Visualization + +A cinematic real-time visualization for the **Pyre Crisis Navigation Environment** — a reinforcement learning environment where an LLM agent navigates a burning building. + +## Quick start + +```bash +# Open directly in a browser — no build step needed +open frontend/index.html +``` + +The app runs entirely in-browser. **Demo mode** simulates the fire physics in JavaScript (no server required). **Live mode** connects to the deployed environment. + +--- + +## Demo mode vs Live mode + +| | Demo | Live | +|---|---|---| +| Server needed | ✗ | ✓ | +| Fire physics | JS port (exact match) | Python server | +| Full reward rubric | Simplified | Complete | +| Toggle | Default | Click "Live" in topbar | + +**Live server:** `https://krooz-pyre-env.hf.space` + +--- + +## Controls + +| Key | Action | +|---|---| +| `Space` | Play / pause | +| `→` | Single step | +| `R` | New episode | +| `1`–`5` | Speed ½× / 1× / 2× / 4× / 8× | + +Bottom bar: difficulty selector, seed input, speed control, reset. + +--- + +## Recording episodes (Python) + +```bash +pip install requests # only stdlib used, no install needed + +python bridge/recorder.py \ + --url https://krooz-pyre-env.hf.space/web \ + --episodes 10 \ + --difficulty medium \ + --out episodes/ +``` + +Episodes are saved as JSON files to `episodes/`. Each file contains full frame-by-frame grid data (cell, fire, smoke grids + agent position + visible cells). + +--- + +## File structure + +``` +frontend/ +├── index.html Main app — open this +└── js/ + ├── sim.js JS port of pyre_env fire simulation + floor plans + ├── renderer.js Canvas2D rendering (fire particles, fog-of-war, agent trail) + └── app.js App controller, charts, HUD, live/demo modes + +bridge/ +└── recorder.py Record live episodes to JSON for replay +``` + +--- + +## Architecture notes + +**Rendering:** HTML5 Canvas 2D — sufficient at 60fps for 16×16 grids; additive blending (`globalCompositeOperation: lighter`) for fire glow; ember particle pool (200 max); fog-of-war via per-cell alpha overlay. + +**Demo agent:** BFS toward nearest unblocked exit, 15% random exploration, avoids fire cells > 0.4 intensity. + +**Live bridge:** Polls `/web/scene` every 800ms; applies grid state to the same rendering pipeline. + +--- + +## Demo script (30-second stage walkthrough) + +1. **Open** `frontend/index.html` — fire simulation starts automatically at 1× +2. **Point out** the dark floor plan canvas with glowing fire cells, fog-of-war, and cyan agent dot +3. **Slow to ½×** to show per-step fire propagation and smoke spread +4. **Speed to 4×** — show agent navigating toward exits (green glow), closing doors (blue bars) to slow fire +5. **Highlight** the side panel: cumulative reward curve dipping on smoke exposure, fire cell count climbing, action histogram +6. **Describe partial observability** — the dark unexplored cells vs. visible corridor +7. **Reset (R)** with a different seed to show episode variety +8. If server is available: click **Live** — "Connected" chip turns green, real Python environment takes over diff --git a/frontend/eslint.config.js b/frontend/eslint.config.js new file mode 100644 index 0000000000000000000000000000000000000000..ef614d25c11dd1e89a8df0b2eaf934170b44daa8 --- /dev/null +++ b/frontend/eslint.config.js @@ -0,0 +1,22 @@ +import js from '@eslint/js' +import globals from 'globals' +import reactHooks from 'eslint-plugin-react-hooks' +import reactRefresh from 'eslint-plugin-react-refresh' +import tseslint from 'typescript-eslint' +import { defineConfig, globalIgnores } from 'eslint/config' + +export default defineConfig([ + globalIgnores(['dist']), + { + files: ['**/*.{ts,tsx}'], + extends: [ + js.configs.recommended, + tseslint.configs.recommended, + reactHooks.configs.flat.recommended, + reactRefresh.configs.vite, + ], + languageOptions: { + globals: globals.browser, + }, + }, +]) diff --git a/frontend/index.html b/frontend/index.html new file mode 100644 index 0000000000000000000000000000000000000000..177e5e08cc91c134fafcb3612944244dc658b7ef --- /dev/null +++ b/frontend/index.html @@ -0,0 +1,16 @@ + + +
+ + + +