Spaces:
Sleeping
Sleeping
Godreign-Y commited on
Commit ·
d5d9d45
1
Parent(s): 5ace282
Add W&B experiment tracking and structured logging
Browse files- README.md +10 -0
- policy_to_logic_env/server/requirements.txt +1 -0
- pyproject.toml +1 -0
- training/colab_training.ipynb +523 -119
- training/results-iteration1/accuracy_curve (1).png +0 -0
- training/results-iteration1/improvement_chart (1).png +0 -0
- training/results-iteration1/metrics (1).json +218 -0
- training/results-iteration1/reward_curve (1).png +0 -0
- training/trajectory_optimizer.py +169 -56
- training/update_colab.py +173 -0
- uv.lock +101 -0
README.md
CHANGED
|
@@ -24,6 +24,7 @@ short_description: Meta pytorch hugging face hackathon
|
|
| 24 |
| **HF Space (Live Environment)** | [godreign-policy2logic.hf.space](https://godreign-policy2logic.hf.space) |
|
| 25 |
| **Training Notebook (Colab)** | [Open in Colab](https://colab.research.google.com/github/GodreignElgin/policy2logic/blob/main/training/colab_training.ipynb) |
|
| 26 |
| **Writeup / Slides** | *TBD — add your link here* |
|
|
|
|
| 27 |
|
| 28 |
---
|
| 29 |
|
|
@@ -44,6 +45,15 @@ improving agent behavior across episodes without weight updates.
|
|
| 44 |
|
| 45 |
---
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
## 🧠 What This Is
|
| 48 |
|
| 49 |
This project builds a **verifiable RL environment** where:
|
|
|
|
| 24 |
| **HF Space (Live Environment)** | [godreign-policy2logic.hf.space](https://godreign-policy2logic.hf.space) |
|
| 25 |
| **Training Notebook (Colab)** | [Open in Colab](https://colab.research.google.com/github/GodreignElgin/policy2logic/blob/main/training/colab_training.ipynb) |
|
| 26 |
| **Writeup / Slides** | *TBD — add your link here* |
|
| 27 |
+
| **Experiment Tracking (W&B)** | [Wandb Project](https://wandb.ai/YOUR_USERNAME/policy-to-logic-rl) |
|
| 28 |
|
| 29 |
---
|
| 30 |
|
|
|
|
| 45 |
|
| 46 |
---
|
| 47 |
|
| 48 |
+
## 📈 Experiment Tracking
|
| 49 |
+
|
| 50 |
+
All training runs are logged to Weights & Biases.
|
| 51 |
+
Metrics tracked per episode: total reward, final accuracy, steps used, success rate, few-shot examples used.
|
| 52 |
+
|
| 53 |
+
Live dashboard: [wandb.ai/YOUR_USERNAME/policy-to-logic-rl](https://wandb.ai/YOUR_USERNAME/policy-to-logic-rl)
|
| 54 |
+
|
| 55 |
+
---
|
| 56 |
+
|
| 57 |
## 🧠 What This Is
|
| 58 |
|
| 59 |
This project builds a **verifiable RL environment** where:
|
policy_to_logic_env/server/requirements.txt
CHANGED
|
@@ -3,3 +3,4 @@ pydantic>=2.0
|
|
| 3 |
fastapi>=0.104.0
|
| 4 |
uvicorn>=0.24.0
|
| 5 |
requests>=2.25.0
|
|
|
|
|
|
| 3 |
fastapi>=0.104.0
|
| 4 |
uvicorn>=0.24.0
|
| 5 |
requests>=2.25.0
|
| 6 |
+
wandb>=0.16.0
|
pyproject.toml
CHANGED
|
@@ -13,6 +13,7 @@ dependencies = [
|
|
| 13 |
"huggingface-hub>=1.12.0",
|
| 14 |
"matplotlib>=3.7.0",
|
| 15 |
"numpy>=1.24.0",
|
|
|
|
| 16 |
]
|
| 17 |
|
| 18 |
[project.optional-dependencies]
|
|
|
|
| 13 |
"huggingface-hub>=1.12.0",
|
| 14 |
"matplotlib>=3.7.0",
|
| 15 |
"numpy>=1.24.0",
|
| 16 |
+
"wandb>=0.16.0",
|
| 17 |
]
|
| 18 |
|
| 19 |
[project.optional-dependencies]
|
training/colab_training.ipynb
CHANGED
|
@@ -19,7 +19,7 @@
|
|
| 19 |
"cell_type": "markdown",
|
| 20 |
"metadata": {},
|
| 21 |
"source": [
|
| 22 |
-
"# Policy-to-Logic RL Environment
|
| 23 |
"\n",
|
| 24 |
"This notebook runs the **reward-guided trajectory optimization loop** against the deployed environment.\n",
|
| 25 |
"\n",
|
|
@@ -27,7 +27,8 @@
|
|
| 27 |
"1. Connects to the live HF Spaces environment\n",
|
| 28 |
"2. Runs 8 episodes per task (3 tasks = 24 total episodes)\n",
|
| 29 |
"3. Accumulates high-reward trajectories as few-shot examples\n",
|
| 30 |
-
"4. Generates training evidence plots (reward curve, accuracy curve, improvement chart)"
|
|
|
|
| 31 |
]
|
| 32 |
},
|
| 33 |
{
|
|
@@ -37,7 +38,7 @@
|
|
| 37 |
"outputs": [],
|
| 38 |
"source": [
|
| 39 |
"# Cell 1: Install dependencies\n",
|
| 40 |
-
"!pip install openai requests matplotlib numpy"
|
| 41 |
]
|
| 42 |
},
|
| 43 |
{
|
|
@@ -52,12 +53,16 @@
|
|
| 52 |
"# SET THESE BEFORE RUNNING\n",
|
| 53 |
"HF_TOKEN = \"\" # Your Hugging Face token with inference access\n",
|
| 54 |
"ENV_URL = \"https://godreign-policy2logic.hf.space\" # Your deployed environment URL\n",
|
|
|
|
| 55 |
"\n",
|
| 56 |
"os.environ[\"HF_TOKEN\"] = HF_TOKEN\n",
|
| 57 |
"os.environ[\"ENV_BASE_URL\"] = ENV_URL\n",
|
|
|
|
|
|
|
| 58 |
"\n",
|
| 59 |
"print(f\"Environment URL: {ENV_URL}\")\n",
|
| 60 |
-
"print(f\"HF Token set: {'Yes' if HF_TOKEN else 'NO - MUST SET THIS'}\")"
|
|
|
|
| 61 |
]
|
| 62 |
},
|
| 63 |
{
|
|
@@ -84,18 +89,40 @@
|
|
| 84 |
"metadata": {},
|
| 85 |
"outputs": [],
|
| 86 |
"source": [
|
| 87 |
-
"# Cell 4:
|
| 88 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
"\n",
|
| 90 |
"import json\n",
|
| 91 |
"import os\n",
|
| 92 |
"import time\n",
|
| 93 |
"import requests\n",
|
|
|
|
|
|
|
| 94 |
"from dataclasses import dataclass, field\n",
|
| 95 |
"from typing import Optional\n",
|
| 96 |
"from openai import OpenAI\n",
|
| 97 |
"\n",
|
| 98 |
-
"#
|
| 99 |
"\n",
|
| 100 |
"ENV_BASE_URL = os.getenv(\"ENV_BASE_URL\", \"http://localhost:7860\")\n",
|
| 101 |
"HF_TOKEN = os.getenv(\"HF_TOKEN\", \"\")\n",
|
|
@@ -103,11 +130,14 @@
|
|
| 103 |
"TEMPERATURE = 0.3\n",
|
| 104 |
"MAX_TOKENS = 1024\n",
|
| 105 |
"\n",
|
| 106 |
-
"
|
| 107 |
-
"
|
| 108 |
-
"
|
|
|
|
| 109 |
"TASKS = [\"data_access\", \"resource_access\", \"transaction_approval\"]\n",
|
| 110 |
"\n",
|
|
|
|
|
|
|
| 111 |
"@dataclass\n",
|
| 112 |
"class Step:\n",
|
| 113 |
" step_number: int\n",
|
|
@@ -128,7 +158,10 @@
|
|
| 128 |
" success: bool = False\n",
|
| 129 |
"\n",
|
| 130 |
" def to_few_shot_string(self) -> str:\n",
|
| 131 |
-
"
|
|
|
|
|
|
|
|
|
|
| 132 |
" for s in self.steps:\n",
|
| 133 |
" lines.append(f\"Step {s.step_number}: action={s.action_type}\")\n",
|
| 134 |
" lines.append(f\" Content: {s.action_content[:200]}\")\n",
|
|
@@ -137,6 +170,8 @@
|
|
| 137 |
" lines.append(f\" Feedback: {s.feedback[:150]}\")\n",
|
| 138 |
" return \"\\n\".join(lines)\n",
|
| 139 |
"\n",
|
|
|
|
|
|
|
| 140 |
"class EnvClient:\n",
|
| 141 |
" def __init__(self, base_url: str):\n",
|
| 142 |
" self.base_url = base_url.rstrip(\"/\")\n",
|
|
@@ -148,7 +183,10 @@
|
|
| 148 |
" return r.json()\n",
|
| 149 |
"\n",
|
| 150 |
" def step(self, action_type: str, content: str) -> dict:\n",
|
| 151 |
-
" r = self.session.post(f\"{self.base_url}/step\", json={\
|
|
|
|
|
|
|
|
|
|
| 152 |
" r.raise_for_status()\n",
|
| 153 |
" return r.json()\n",
|
| 154 |
"\n",
|
|
@@ -159,18 +197,40 @@
|
|
| 159 |
" except Exception:\n",
|
| 160 |
" return False\n",
|
| 161 |
"\n",
|
|
|
|
|
|
|
| 162 |
"class Agent:\n",
|
| 163 |
" def __init__(self, hf_token: str):\n",
|
| 164 |
-
" self.client = OpenAI(
|
| 165 |
-
"\n",
|
| 166 |
-
"
|
| 167 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
" user_prompt = self._build_user_prompt(observation, step_number, episode_history)\n",
|
|
|
|
| 169 |
" try:\n",
|
| 170 |
" response = self.client.chat.completions.create(\n",
|
| 171 |
" model=MODEL,\n",
|
| 172 |
-
" messages=[
|
| 173 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
" )\n",
|
| 175 |
" raw = response.choices[0].message.content.strip()\n",
|
| 176 |
" return self._parse_response(raw, observation)\n",
|
|
@@ -178,155 +238,489 @@
|
|
| 178 |
" print(f\" [LLM ERROR] {e}\")\n",
|
| 179 |
" return \"propose_rules\", json.dumps({\"rules\": [], \"default\": \"DENY\"})\n",
|
| 180 |
"\n",
|
| 181 |
-
" def _build_system_prompt(self, few_shot_examples):\n",
|
| 182 |
-
" base = \"\"\"You are a policy-to-logic agent.
|
| 183 |
"\n",
|
| 184 |
"AVAILABLE ACTIONS:\n",
|
| 185 |
"1. ask_clarification: {\"type\": \"clarification\", \"question\": \"your question\"}\n",
|
| 186 |
"2. propose_rules: {\"rules\": [...], \"default\": \"DECISION\"}\n",
|
| 187 |
"3. refine_rules: {\"rules\": [...], \"default\": \"DECISION\"}\n",
|
| 188 |
"\n",
|
| 189 |
-
"DSL FORMAT
|
| 190 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
"\n",
|
| 192 |
-
"STRATEGY: Ask 1-2 clarifications first, then propose rules, then refine based on failures.\n",
|
| 193 |
-
"OUTPUT: Respond ONLY with valid JSON: {\"action_type\": \"...\", \"content\": \"...\"}\"\"\"\n",
|
| 194 |
" if few_shot_examples:\n",
|
| 195 |
-
" base += \"\\n\\nLEARNED FROM PREVIOUS EPISODES:\\n\"\n",
|
| 196 |
" for traj in few_shot_examples[-TOP_K_TRAJECTORIES:]:\n",
|
| 197 |
" base += \"\\n\" + traj.to_few_shot_string() + \"\\n\"\n",
|
| 198 |
" return base\n",
|
| 199 |
"\n",
|
| 200 |
-
" def _build_user_prompt(self, obs, step, history):\n",
|
| 201 |
-
" lines = [
|
| 202 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
" if obs.get(\"test_results\"):\n",
|
| 204 |
" tr = obs[\"test_results\"]\n",
|
| 205 |
-
" lines.append(f\"\\nTEST RESULTS: {tr.get('passed',0)}/{tr.get('total',0)} (
|
| 206 |
-
" if tr.get(\"sample_failures\"):
|
| 207 |
-
"
|
| 208 |
-
"
|
| 209 |
-
"
|
| 210 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
" return \"\\n\".join(lines)\n",
|
| 212 |
"\n",
|
| 213 |
-
" def _parse_response(self, raw, obs):\n",
|
|
|
|
| 214 |
" if \"```\" in raw:\n",
|
| 215 |
" raw = raw.split(\"```\")[1]\n",
|
| 216 |
-
" if raw.startswith(\"json\"):
|
|
|
|
| 217 |
" raw = raw.strip()\n",
|
|
|
|
| 218 |
" try:\n",
|
| 219 |
" parsed = json.loads(raw)\n",
|
| 220 |
" action_type = parsed.get(\"action_type\", \"propose_rules\")\n",
|
| 221 |
" content = parsed.get(\"content\", \"{}\")\n",
|
| 222 |
-
"
|
| 223 |
-
"
|
| 224 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
" return action_type, content\n",
|
| 226 |
-
" except
|
|
|
|
|
|
|
|
|
|
| 227 |
"\n",
|
| 228 |
"class TrajectoryBank:\n",
|
| 229 |
-
"
|
| 230 |
-
"
|
| 231 |
-
"
|
| 232 |
-
"
|
| 233 |
-
"
|
| 234 |
-
"
|
| 235 |
-
"
|
| 236 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
"\n",
|
| 238 |
"class TrainingLoop:\n",
|
| 239 |
-
" def __init__(self, env_url, hf_token):\n",
|
| 240 |
" self.env = EnvClient(env_url)\n",
|
| 241 |
" self.agent = Agent(hf_token)\n",
|
| 242 |
" self.bank = TrajectoryBank()\n",
|
| 243 |
-
" self.metrics = []\n",
|
| 244 |
-
"\n",
|
| 245 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
" few_shots = self.bank.get_examples(task_name)\n",
|
| 247 |
-
"
|
|
|
|
|
|
|
| 248 |
" result = self.env.reset(task_name)\n",
|
| 249 |
-
" obs
|
| 250 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
" step_num = 0\n",
|
| 252 |
" while not done and step_num < obs.get(\"max_steps\", 7):\n",
|
| 253 |
" step_num += 1\n",
|
| 254 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
" result = self.env.step(action_type, content)\n",
|
| 256 |
-
" reward
|
| 257 |
-
"
|
| 258 |
-
"
|
| 259 |
-
"
|
| 260 |
-
"
|
| 261 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
" if done:\n",
|
| 263 |
-
"
|
| 264 |
-
"
|
|
|
|
| 265 |
" break\n",
|
| 266 |
-
"
|
| 267 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
"\n",
|
| 269 |
" def run(self):\n",
|
| 270 |
-
"
|
| 271 |
-
"
|
| 272 |
-
"
|
| 273 |
-
"
|
| 274 |
-
"
|
| 275 |
-
"
|
| 276 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
" for task in TASKS:\n",
|
| 278 |
-
"
|
|
|
|
|
|
|
|
|
|
| 279 |
" task_rewards = []\n",
|
|
|
|
|
|
|
| 280 |
" for ep in range(1, NUM_EPISODES_PER_TASK + 1):\n",
|
| 281 |
-
"
|
| 282 |
-
"
|
| 283 |
-
"
|
| 284 |
-
"
|
| 285 |
-
"
|
| 286 |
-
"
|
| 287 |
-
"
|
| 288 |
-
"
|
| 289 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
" return self.metrics\n",
|
| 291 |
"\n",
|
| 292 |
-
"
|
| 293 |
-
"
|
| 294 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
" os.makedirs(\"training/plots\", exist_ok=True)\n",
|
|
|
|
| 296 |
" episodes = [m[\"global_episode\"] for m in metrics]\n",
|
| 297 |
" rewards = [m[\"total_reward\"] for m in metrics]\n",
|
| 298 |
-
"
|
| 299 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
" fig, ax = plt.subplots(figsize=(10, 5))\n",
|
|
|
|
| 301 |
" for task in TASKS:\n",
|
| 302 |
-
"
|
| 303 |
-
"
|
| 304 |
-
" ax.plot(
|
| 305 |
-
"
|
| 306 |
-
"
|
| 307 |
-
"
|
| 308 |
-
"
|
| 309 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
" fig, ax = plt.subplots(figsize=(10, 5))\n",
|
|
|
|
| 311 |
" for task in TASKS:\n",
|
| 312 |
-
"
|
| 313 |
-
"
|
| 314 |
-
" ax.plot(
|
| 315 |
-
"
|
| 316 |
-
"
|
| 317 |
-
"
|
| 318 |
-
"
|
| 319 |
-
"
|
| 320 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
" for task in TASKS:\n",
|
| 322 |
-
"
|
| 323 |
-
" if len(
|
| 324 |
-
"
|
| 325 |
-
"
|
| 326 |
-
"
|
| 327 |
-
"
|
| 328 |
-
"
|
| 329 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
]
|
| 331 |
},
|
| 332 |
{
|
|
@@ -335,7 +729,7 @@
|
|
| 335 |
"metadata": {},
|
| 336 |
"outputs": [],
|
| 337 |
"source": [
|
| 338 |
-
"# Cell
|
| 339 |
"loop = TrainingLoop(ENV_URL, HF_TOKEN)\n",
|
| 340 |
"metrics = loop.run()\n",
|
| 341 |
"print(f\"\\nTotal episodes run: {len(metrics)}\")"
|
|
@@ -347,7 +741,7 @@
|
|
| 347 |
"metadata": {},
|
| 348 |
"outputs": [],
|
| 349 |
"source": [
|
| 350 |
-
"# Cell
|
| 351 |
"save_plots(metrics)\n",
|
| 352 |
"\n",
|
| 353 |
"from IPython.display import Image, display\n",
|
|
@@ -362,17 +756,27 @@
|
|
| 362 |
"metadata": {},
|
| 363 |
"outputs": [],
|
| 364 |
"source": [
|
| 365 |
-
"# Cell
|
| 366 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
"from google.colab import files\n",
|
| 368 |
"\n",
|
| 369 |
"files.download(\"training/plots/reward_curve.png\")\n",
|
| 370 |
"files.download(\"training/plots/accuracy_curve.png\")\n",
|
| 371 |
"files.download(\"training/plots/improvement_chart.png\")\n",
|
| 372 |
-
"files.download(\"training/plots/
|
| 373 |
"\n",
|
| 374 |
"print(\"Downloaded. Now commit these files to: training/plots/ in your repo.\")"
|
| 375 |
]
|
| 376 |
}
|
| 377 |
]
|
| 378 |
-
}
|
|
|
|
| 19 |
"cell_type": "markdown",
|
| 20 |
"metadata": {},
|
| 21 |
"source": [
|
| 22 |
+
"# Policy-to-Logic RL Environment \u2014 Training Notebook\n",
|
| 23 |
"\n",
|
| 24 |
"This notebook runs the **reward-guided trajectory optimization loop** against the deployed environment.\n",
|
| 25 |
"\n",
|
|
|
|
| 27 |
"1. Connects to the live HF Spaces environment\n",
|
| 28 |
"2. Runs 8 episodes per task (3 tasks = 24 total episodes)\n",
|
| 29 |
"3. Accumulates high-reward trajectories as few-shot examples\n",
|
| 30 |
+
"4. Generates training evidence plots (reward curve, accuracy curve, improvement chart)\n",
|
| 31 |
+
"5. Logs everything to Weights & Biases"
|
| 32 |
]
|
| 33 |
},
|
| 34 |
{
|
|
|
|
| 38 |
"outputs": [],
|
| 39 |
"source": [
|
| 40 |
"# Cell 1: Install dependencies\n",
|
| 41 |
+
"!pip install openai requests matplotlib numpy wandb"
|
| 42 |
]
|
| 43 |
},
|
| 44 |
{
|
|
|
|
| 53 |
"# SET THESE BEFORE RUNNING\n",
|
| 54 |
"HF_TOKEN = \"\" # Your Hugging Face token with inference access\n",
|
| 55 |
"ENV_URL = \"https://godreign-policy2logic.hf.space\" # Your deployed environment URL\n",
|
| 56 |
+
"WANDB_API_KEY = \"\" # Your Wandb API key\n",
|
| 57 |
"\n",
|
| 58 |
"os.environ[\"HF_TOKEN\"] = HF_TOKEN\n",
|
| 59 |
"os.environ[\"ENV_BASE_URL\"] = ENV_URL\n",
|
| 60 |
+
"if WANDB_API_KEY:\n",
|
| 61 |
+
" os.environ[\"WANDB_API_KEY\"] = WANDB_API_KEY\n",
|
| 62 |
"\n",
|
| 63 |
"print(f\"Environment URL: {ENV_URL}\")\n",
|
| 64 |
+
"print(f\"HF Token set: {'Yes' if HF_TOKEN else 'NO - MUST SET THIS'}\")\n",
|
| 65 |
+
"print(f\"Wandb Token set: {'Yes' if WANDB_API_KEY else 'NO - WILL PROMPT'}\")"
|
| 66 |
]
|
| 67 |
},
|
| 68 |
{
|
|
|
|
| 89 |
"metadata": {},
|
| 90 |
"outputs": [],
|
| 91 |
"source": [
|
| 92 |
+
"# Cell 4: Wandb login \u2014 run this before training\n",
|
| 93 |
+
"import wandb\n",
|
| 94 |
+
"wandb.login() # Will prompt for API key if WANDB_API_KEY is not set"
|
| 95 |
+
]
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"cell_type": "code",
|
| 99 |
+
"execution_count": null,
|
| 100 |
+
"metadata": {},
|
| 101 |
+
"outputs": [],
|
| 102 |
+
"source": [
|
| 103 |
+
"# Cell 5: Training loop implementation (full trajectory_optimizer.py)\n",
|
| 104 |
+
"\"\"\"\n",
|
| 105 |
+
"Reward-Guided Trajectory Optimization Loop\n",
|
| 106 |
+
"==========================================\n",
|
| 107 |
+
"Optimizes agent behavior across episodes by accumulating high-reward\n",
|
| 108 |
+
"trajectories as few-shot examples. Uses environment reward signal to\n",
|
| 109 |
+
"drive improvement \u00e2\u20ac\u201d no weight updates required.\n",
|
| 110 |
+
"\n",
|
| 111 |
+
"This implements a policy improvement loop where:\n",
|
| 112 |
+
" - reward_signal \u00e2\u2020\u2019 trajectory_selection \u00e2\u2020\u2019 context_construction \u00e2\u2020\u2019 improved_policy\n",
|
| 113 |
+
"\"\"\"\n",
|
| 114 |
"\n",
|
| 115 |
"import json\n",
|
| 116 |
"import os\n",
|
| 117 |
"import time\n",
|
| 118 |
"import requests\n",
|
| 119 |
+
"import logging\n",
|
| 120 |
+
"import wandb\n",
|
| 121 |
"from dataclasses import dataclass, field\n",
|
| 122 |
"from typing import Optional\n",
|
| 123 |
"from openai import OpenAI\n",
|
| 124 |
"\n",
|
| 125 |
+
"# \u00e2\u201d\u20ac\u00e2\u201d\u20ac Configuration \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
|
| 126 |
"\n",
|
| 127 |
"ENV_BASE_URL = os.getenv(\"ENV_BASE_URL\", \"http://localhost:7860\")\n",
|
| 128 |
"HF_TOKEN = os.getenv(\"HF_TOKEN\", \"\")\n",
|
|
|
|
| 130 |
"TEMPERATURE = 0.3\n",
|
| 131 |
"MAX_TOKENS = 1024\n",
|
| 132 |
"\n",
|
| 133 |
+
"# Training hyperparameters\n",
|
| 134 |
+
"NUM_EPISODES_PER_TASK = 8 # Episodes to run per task\n",
|
| 135 |
+
"TOP_K_TRAJECTORIES = 3 # Max few-shot examples to keep\n",
|
| 136 |
+
"MIN_REWARD_THRESHOLD = 0.3 # Minimum reward to store trajectory\n",
|
| 137 |
"TASKS = [\"data_access\", \"resource_access\", \"transaction_approval\"]\n",
|
| 138 |
"\n",
|
| 139 |
+
"# \u00e2\u201d\u20ac\u00e2\u201d\u20ac Data Structures \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
|
| 140 |
+
"\n",
|
| 141 |
"@dataclass\n",
|
| 142 |
"class Step:\n",
|
| 143 |
" step_number: int\n",
|
|
|
|
| 158 |
" success: bool = False\n",
|
| 159 |
"\n",
|
| 160 |
" def to_few_shot_string(self) -> str:\n",
|
| 161 |
+
" \"\"\"Convert trajectory to a few-shot example string for prompting.\"\"\"\n",
|
| 162 |
+
" lines = [\n",
|
| 163 |
+
" f\"=== Example Episode (reward={self.total_reward:.2f}, accuracy={self.final_accuracy:.2f}) ===\",\n",
|
| 164 |
+
" ]\n",
|
| 165 |
" for s in self.steps:\n",
|
| 166 |
" lines.append(f\"Step {s.step_number}: action={s.action_type}\")\n",
|
| 167 |
" lines.append(f\" Content: {s.action_content[:200]}\")\n",
|
|
|
|
| 170 |
" lines.append(f\" Feedback: {s.feedback[:150]}\")\n",
|
| 171 |
" return \"\\n\".join(lines)\n",
|
| 172 |
"\n",
|
| 173 |
+
"# \u00e2\u201d\u20ac\u00e2\u201d\u20ac Environment Client \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
|
| 174 |
+
"\n",
|
| 175 |
"class EnvClient:\n",
|
| 176 |
" def __init__(self, base_url: str):\n",
|
| 177 |
" self.base_url = base_url.rstrip(\"/\")\n",
|
|
|
|
| 183 |
" return r.json()\n",
|
| 184 |
"\n",
|
| 185 |
" def step(self, action_type: str, content: str) -> dict:\n",
|
| 186 |
+
" r = self.session.post(f\"{self.base_url}/step\", json={\n",
|
| 187 |
+
" \"action_type\": action_type,\n",
|
| 188 |
+
" \"content\": content\n",
|
| 189 |
+
" })\n",
|
| 190 |
" r.raise_for_status()\n",
|
| 191 |
" return r.json()\n",
|
| 192 |
"\n",
|
|
|
|
| 197 |
" except Exception:\n",
|
| 198 |
" return False\n",
|
| 199 |
"\n",
|
| 200 |
+
"# \u00e2\u201d\u20ac\u00e2\u201d\u20ac LLM Agent \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
|
| 201 |
+
"\n",
|
| 202 |
"class Agent:\n",
|
| 203 |
" def __init__(self, hf_token: str):\n",
|
| 204 |
+
" self.client = OpenAI(\n",
|
| 205 |
+
" base_url=\"https://router.huggingface.co/v1\",\n",
|
| 206 |
+
" api_key=hf_token\n",
|
| 207 |
+
" )\n",
|
| 208 |
+
"\n",
|
| 209 |
+
" def get_action(\n",
|
| 210 |
+
" self,\n",
|
| 211 |
+
" observation: dict,\n",
|
| 212 |
+
" step_number: int,\n",
|
| 213 |
+
" episode_history: list[str],\n",
|
| 214 |
+
" few_shot_examples: list[Trajectory],\n",
|
| 215 |
+
" task_name: str = \"\"\n",
|
| 216 |
+
" ) -> tuple[str, str]:\n",
|
| 217 |
+
" \"\"\"\n",
|
| 218 |
+
" Returns (action_type, content_json_string).\n",
|
| 219 |
+
" action_type: one of ask_clarification | propose_rules | refine_rules\n",
|
| 220 |
+
" content: JSON string appropriate for that action\n",
|
| 221 |
+
" \"\"\"\n",
|
| 222 |
+
" system_prompt = self._build_system_prompt(few_shot_examples, task_name)\n",
|
| 223 |
" user_prompt = self._build_user_prompt(observation, step_number, episode_history)\n",
|
| 224 |
+
"\n",
|
| 225 |
" try:\n",
|
| 226 |
" response = self.client.chat.completions.create(\n",
|
| 227 |
" model=MODEL,\n",
|
| 228 |
+
" messages=[\n",
|
| 229 |
+
" {\"role\": \"system\", \"content\": system_prompt},\n",
|
| 230 |
+
" {\"role\": \"user\", \"content\": user_prompt}\n",
|
| 231 |
+
" ],\n",
|
| 232 |
+
" temperature=TEMPERATURE,\n",
|
| 233 |
+
" max_tokens=MAX_TOKENS\n",
|
| 234 |
" )\n",
|
| 235 |
" raw = response.choices[0].message.content.strip()\n",
|
| 236 |
" return self._parse_response(raw, observation)\n",
|
|
|
|
| 238 |
" print(f\" [LLM ERROR] {e}\")\n",
|
| 239 |
" return \"propose_rules\", json.dumps({\"rules\": [], \"default\": \"DENY\"})\n",
|
| 240 |
"\n",
|
| 241 |
+
" def _build_system_prompt(self, few_shot_examples: list[Trajectory], task_name: str = \"\") -> str:\n",
|
| 242 |
+
" base = \"\"\"You are a policy-to-logic agent. Your job is to convert natural language policies into executable rules.\n",
|
| 243 |
"\n",
|
| 244 |
"AVAILABLE ACTIONS:\n",
|
| 245 |
"1. ask_clarification: {\"type\": \"clarification\", \"question\": \"your question\"}\n",
|
| 246 |
"2. propose_rules: {\"rules\": [...], \"default\": \"DECISION\"}\n",
|
| 247 |
"3. refine_rules: {\"rules\": [...], \"default\": \"DECISION\"}\n",
|
| 248 |
"\n",
|
| 249 |
+
"DSL FORMAT for rules:\n",
|
| 250 |
+
"{\n",
|
| 251 |
+
" \"rules\": [\n",
|
| 252 |
+
" {\n",
|
| 253 |
+
" \"if\": [\n",
|
| 254 |
+
" {\"field\": \"FIELD_NAME\", \"op\": \"OPERATOR\", \"value\": VALUE}\n",
|
| 255 |
+
" ],\n",
|
| 256 |
+
" \"then\": \"DECISION\"\n",
|
| 257 |
+
" }\n",
|
| 258 |
+
" ],\n",
|
| 259 |
+
" \"default\": \"FALLBACK_DECISION\"\n",
|
| 260 |
+
"}\n",
|
| 261 |
+
"\n",
|
| 262 |
+
"Operators: >, <, >=, <=, ==, !=\n",
|
| 263 |
+
"Rules execute top-to-bottom. First match wins. Default applies if no rule matches.\n",
|
| 264 |
+
"\n",
|
| 265 |
+
"STRATEGY:\n",
|
| 266 |
+
"- Step 1: Ask 1-2 targeted clarification questions about ambiguous terms\n",
|
| 267 |
+
"- Step 2: Propose initial rules based on policy + clarifications \n",
|
| 268 |
+
"- Step 3+: Refine rules based on failure feedback\n",
|
| 269 |
+
"\n",
|
| 270 |
+
"OUTPUT FORMAT: Respond ONLY with valid JSON. No markdown. No explanation.\n",
|
| 271 |
+
"{\"action_type\": \"propose_rules\", \"content\": \"{...escaped json string...}\"}\n",
|
| 272 |
+
"\"\"\"\n",
|
| 273 |
+
" # Task-specific guidance for complex tasks\n",
|
| 274 |
+
" if task_name == \"transaction_approval\":\n",
|
| 275 |
+
" base += \"\"\"\n",
|
| 276 |
+
"IMPORTANT \u00e2\u20ac\u201d TRANSACTION APPROVAL TASK:\n",
|
| 277 |
+
"This task has 4 possible decisions: APPROVE, REQUIRE_APPROVAL, COMPLIANCE_REVIEW, HOLD\n",
|
| 278 |
+
"Rules are evaluated TOP-TO-BOTTOM. Order matters critically. You MUST order rules by priority:\n",
|
| 279 |
+
" 1. FIRST: Check if transfer_type == \"international\" \u00e2\u2020\u2019 then COMPLIANCE_REVIEW (always, overrides everything)\n",
|
| 280 |
+
" 2. SECOND: Check if amount >= 10000 AND time is outside business hours (hour < 9 or hour >= 17) \u00e2\u2020\u2019 then HOLD\n",
|
| 281 |
+
" 3. THIRD: Check if amount > 5000 AND initiator_role != \"manager\" \u00e2\u2020\u2019 then REQUIRE_APPROVAL\n",
|
| 282 |
+
" 4. DEFAULT: APPROVE\n",
|
| 283 |
+
"\n",
|
| 284 |
+
"Key details:\n",
|
| 285 |
+
"- Standard limit is $5,000 (amount > 5000 triggers approval, NOT >=)\n",
|
| 286 |
+
"- High-value threshold is $10,000 (amount >= 10000)\n",
|
| 287 |
+
"- Business hours: hour >= 9 AND hour < 17\n",
|
| 288 |
+
"- Manager exemption ONLY applies to the standard $5,000 limit, NOT to international or high-value HOLD rules\n",
|
| 289 |
+
"- \"system\" role follows the same rules as \"employee\"\n",
|
| 290 |
+
"\n",
|
| 291 |
+
"Here is a working example of valid rules for this task:\n",
|
| 292 |
+
"{\"rules\": [{\"if\": [{\"field\": \"transfer_type\", \"op\": \"==\", \"value\": \"international\"}], \"then\": \"COMPLIANCE_REVIEW\"}, {\"if\": [{\"field\": \"amount\", \"op\": \">=\", \"value\": 10000}, {\"field\": \"time\", \"op\": \">=\", \"value\": 17}], \"then\": \"HOLD\"}, {\"if\": [{\"field\": \"amount\", \"op\": \">=\", \"value\": 10000}, {\"field\": \"time\", \"op\": \"<\", \"value\": 9}], \"then\": \"HOLD\"}, {\"if\": [{\"field\": \"amount\", \"op\": \">\", \"value\": 5000}, {\"field\": \"initiator_role\", \"op\": \"!=\", \"value\": \"manager\"}], \"then\": \"REQUIRE_APPROVAL\"}], \"default\": \"APPROVE\"}\n",
|
| 293 |
+
"\"\"\"\n",
|
| 294 |
+
" elif task_name == \"resource_access\":\n",
|
| 295 |
+
" base += \"\"\"\n",
|
| 296 |
+
"IMPORTANT \u00e2\u20ac\u201d RESOURCE ACCESS TASK:\n",
|
| 297 |
+
"This task has roles: junior, senior, contractor. Document types: public, internal, confidential.\n",
|
| 298 |
+
"- Senior employees: ALLOW everything always\n",
|
| 299 |
+
"- Contractors: ALLOW only public, DENY everything else\n",
|
| 300 |
+
"- Junior + confidential: ALWAYS DENY (regardless of time \u00e2\u20ac\u201d the policy is misleading about this)\n",
|
| 301 |
+
"- Junior + internal: ALLOW only during business hours (hour >= 8 AND hour < 17)\n",
|
| 302 |
+
"- Junior + public: ALLOW always\n",
|
| 303 |
+
"- Business hours: hour >= 8 AND hour < 17\n",
|
| 304 |
+
"\"\"\"\n",
|
| 305 |
"\n",
|
|
|
|
|
|
|
| 306 |
" if few_shot_examples:\n",
|
| 307 |
+
" base += \"\\n\\nLEARNED FROM PREVIOUS EPISODES (high-reward strategies):\\n\"\n",
|
| 308 |
" for traj in few_shot_examples[-TOP_K_TRAJECTORIES:]:\n",
|
| 309 |
" base += \"\\n\" + traj.to_few_shot_string() + \"\\n\"\n",
|
| 310 |
" return base\n",
|
| 311 |
"\n",
|
| 312 |
+
" def _build_user_prompt(self, obs: dict, step: int, history: list[str]) -> str:\n",
|
| 313 |
+
" lines = [\n",
|
| 314 |
+
" f\"TASK: {obs.get('task_name', 'unknown')}\",\n",
|
| 315 |
+
" f\"STEP: {step} of {obs.get('max_steps', 7)}\",\n",
|
| 316 |
+
" f\"\\nPOLICY:\\n{obs.get('policy_text', '')}\",\n",
|
| 317 |
+
" ]\n",
|
| 318 |
+
" if obs.get(\"clarification_response\"):\n",
|
| 319 |
+
" lines.append(f\"\\nLAST CLARIFICATION ANSWER:\\n{obs['clarification_response']}\")\n",
|
| 320 |
" if obs.get(\"test_results\"):\n",
|
| 321 |
" tr = obs[\"test_results\"]\n",
|
| 322 |
+
" lines.append(f\"\\nTEST RESULTS: {tr.get('passed', 0)}/{tr.get('total', 0)} passed (accuracy={obs.get('current_accuracy', 0):.2f})\")\n",
|
| 323 |
+
" if tr.get(\"sample_failures\"):\n",
|
| 324 |
+
" lines.append(\"SAMPLE FAILURES:\")\n",
|
| 325 |
+
" for f in tr[\"sample_failures\"][:3]:\n",
|
| 326 |
+
" lines.append(f\" - {f}\")\n",
|
| 327 |
+
" if obs.get(\"feedback\"):\n",
|
| 328 |
+
" lines.append(f\"\\nFEEDBACK: {obs['feedback']}\")\n",
|
| 329 |
+
" if history:\n",
|
| 330 |
+
" lines.append(f\"\\nACTION HISTORY (last 3):\\n\" + \"\\n\".join(history[-3:]))\n",
|
| 331 |
+
" lines.append(f\"\\nAVAILABLE ACTIONS: {obs.get('available_actions', [])}\")\n",
|
| 332 |
+
" lines.append(\"\\nRespond with JSON only: {\\\"action_type\\\": \\\"...\\\", \\\"content\\\": \\\"...\\\"}\")\n",
|
| 333 |
" return \"\\n\".join(lines)\n",
|
| 334 |
"\n",
|
| 335 |
+
" def _parse_response(self, raw: str, obs: dict) -> tuple[str, str]:\n",
|
| 336 |
+
" # Strip markdown code fences if present\n",
|
| 337 |
" if \"```\" in raw:\n",
|
| 338 |
" raw = raw.split(\"```\")[1]\n",
|
| 339 |
+
" if raw.startswith(\"json\"):\n",
|
| 340 |
+
" raw = raw[4:]\n",
|
| 341 |
" raw = raw.strip()\n",
|
| 342 |
+
"\n",
|
| 343 |
" try:\n",
|
| 344 |
" parsed = json.loads(raw)\n",
|
| 345 |
" action_type = parsed.get(\"action_type\", \"propose_rules\")\n",
|
| 346 |
" content = parsed.get(\"content\", \"{}\")\n",
|
| 347 |
+
"\n",
|
| 348 |
+
" # Validate action_type\n",
|
| 349 |
+
" valid_actions = obs.get(\"available_actions\", [\"propose_rules\", \"ask_clarification\"])\n",
|
| 350 |
+
" if action_type not in valid_actions:\n",
|
| 351 |
+
" action_type = \"propose_rules\" if \"propose_rules\" in valid_actions else valid_actions[0]\n",
|
| 352 |
+
"\n",
|
| 353 |
+
" # Ensure content is a string\n",
|
| 354 |
+
" if isinstance(content, dict):\n",
|
| 355 |
+
" content = json.dumps(content)\n",
|
| 356 |
" return action_type, content\n",
|
| 357 |
+
" except Exception:\n",
|
| 358 |
+
" return \"propose_rules\", json.dumps({\"rules\": [], \"default\": \"DENY\"})\n",
|
| 359 |
+
"\n",
|
| 360 |
+
"# \u00e2\u201d\u20ac\u00e2\u201d\u20ac Trajectory Bank \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
|
| 361 |
"\n",
|
| 362 |
"class TrajectoryBank:\n",
|
| 363 |
+
" \"\"\"Stores and retrieves high-reward trajectories per task.\"\"\"\n",
|
| 364 |
+
"\n",
|
| 365 |
+
" def __init__(self):\n",
|
| 366 |
+
" self.bank: dict[str, list[Trajectory]] = {task: [] for task in TASKS}\n",
|
| 367 |
+
"\n",
|
| 368 |
+
" def store(self, trajectory: Trajectory):\n",
|
| 369 |
+
" if trajectory.total_reward >= MIN_REWARD_THRESHOLD:\n",
|
| 370 |
+
" self.bank[trajectory.task_name].append(trajectory)\n",
|
| 371 |
+
" # Keep only top-K by reward\n",
|
| 372 |
+
" self.bank[trajectory.task_name].sort(key=lambda t: t.total_reward, reverse=True)\n",
|
| 373 |
+
" self.bank[trajectory.task_name] = self.bank[trajectory.task_name][:TOP_K_TRAJECTORIES]\n",
|
| 374 |
+
"\n",
|
| 375 |
+
" def get_examples(self, task_name: str) -> list[Trajectory]:\n",
|
| 376 |
+
" return self.bank.get(task_name, [])\n",
|
| 377 |
+
"\n",
|
| 378 |
+
" def summary(self) -> dict:\n",
|
| 379 |
+
" return {\n",
|
| 380 |
+
" task: {\n",
|
| 381 |
+
" \"stored\": len(trajs),\n",
|
| 382 |
+
" \"best_reward\": max((t.total_reward for t in trajs), default=0),\n",
|
| 383 |
+
" \"best_accuracy\": max((t.final_accuracy for t in trajs), default=0)\n",
|
| 384 |
+
" }\n",
|
| 385 |
+
" for task, trajs in self.bank.items()\n",
|
| 386 |
+
" }\n",
|
| 387 |
+
"\n",
|
| 388 |
+
"# \u00e2\u201d\u20ac\u00e2\u201d\u20ac Training Loop \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
|
| 389 |
"\n",
|
| 390 |
"class TrainingLoop:\n",
|
| 391 |
+
" def __init__(self, env_url: str, hf_token: str):\n",
|
| 392 |
" self.env = EnvClient(env_url)\n",
|
| 393 |
" self.agent = Agent(hf_token)\n",
|
| 394 |
" self.bank = TrajectoryBank()\n",
|
| 395 |
+
" self.metrics = [] # List of {episode, task, reward, accuracy, success}\n",
|
| 396 |
+
"\n",
|
| 397 |
+
" os.makedirs(\"training/logs\", exist_ok=True)\n",
|
| 398 |
+
" log_filename = f\"training/logs/run_{int(time.time())}.log\"\n",
|
| 399 |
+
" logging.basicConfig(\n",
|
| 400 |
+
" level=logging.INFO,\n",
|
| 401 |
+
" format=\"%(asctime)s [%(levelname)s] %(message)s\",\n",
|
| 402 |
+
" handlers=[\n",
|
| 403 |
+
" logging.FileHandler(log_filename),\n",
|
| 404 |
+
" logging.StreamHandler() # also print to console\n",
|
| 405 |
+
" ]\n",
|
| 406 |
+
" )\n",
|
| 407 |
+
" self.logger = logging.getLogger(\"TrainingLoop\")\n",
|
| 408 |
+
" self.log_file = log_filename\n",
|
| 409 |
+
"\n",
|
| 410 |
+
" def run_episode(self, task_name: str, episode_id: int) -> Trajectory:\n",
|
| 411 |
+
" \"\"\"Run a single episode and return the trajectory.\"\"\"\n",
|
| 412 |
" few_shots = self.bank.get_examples(task_name)\n",
|
| 413 |
+
" trajectory = Trajectory(task_name=task_name, episode_id=episode_id)\n",
|
| 414 |
+
"\n",
|
| 415 |
+
" # Reset environment\n",
|
| 416 |
" result = self.env.reset(task_name)\n",
|
| 417 |
+
" obs = result.get(\"observation\", {})\n",
|
| 418 |
+
" done = result.get(\"done\", False)\n",
|
| 419 |
+
" history = []\n",
|
| 420 |
+
"\n",
|
| 421 |
+
" self.logger.info(f\"START episode={episode_id} task={task_name} few_shots_available={len(few_shots)}\")\n",
|
| 422 |
+
"\n",
|
| 423 |
" step_num = 0\n",
|
| 424 |
" while not done and step_num < obs.get(\"max_steps\", 7):\n",
|
| 425 |
" step_num += 1\n",
|
| 426 |
+
"\n",
|
| 427 |
+
" # Get action from agent\n",
|
| 428 |
+
" action_type, content = self.agent.get_action(\n",
|
| 429 |
+
" observation=obs,\n",
|
| 430 |
+
" step_number=step_num,\n",
|
| 431 |
+
" episode_history=history,\n",
|
| 432 |
+
" few_shot_examples=few_shots,\n",
|
| 433 |
+
" task_name=task_name\n",
|
| 434 |
+
" )\n",
|
| 435 |
+
"\n",
|
| 436 |
+
" # Execute action\n",
|
| 437 |
" result = self.env.step(action_type, content)\n",
|
| 438 |
+
" reward = result.get(\"reward\", 0.0)\n",
|
| 439 |
+
" done = result.get(\"done\", False)\n",
|
| 440 |
+
" obs = result.get(\"observation\", {})\n",
|
| 441 |
+
" info = result.get(\"info\", {})\n",
|
| 442 |
+
"\n",
|
| 443 |
+
" # Record step\n",
|
| 444 |
+
" step = Step(\n",
|
| 445 |
+
" step_number=step_num,\n",
|
| 446 |
+
" action_type=action_type,\n",
|
| 447 |
+
" action_content=content[:300],\n",
|
| 448 |
+
" reward=reward,\n",
|
| 449 |
+
" accuracy=obs.get(\"current_accuracy\", 0.0),\n",
|
| 450 |
+
" feedback=obs.get(\"feedback\", \"\") or \"\",\n",
|
| 451 |
+
" clarification_response=obs.get(\"clarification_response\")\n",
|
| 452 |
+
" )\n",
|
| 453 |
+
" trajectory.steps.append(step)\n",
|
| 454 |
+
" trajectory.total_reward += reward\n",
|
| 455 |
+
"\n",
|
| 456 |
+
" # Update history\n",
|
| 457 |
+
" history.append(f\"Step {step_num}: {action_type} \u00e2\u2020\u2019 reward={reward:.2f} acc={step.accuracy:.2f}\")\n",
|
| 458 |
+
"\n",
|
| 459 |
+
" self.logger.info(f\"STEP episode={episode_id} step={step_num} action={action_type} reward={reward:.4f} accuracy={step.accuracy:.4f}\")\n",
|
| 460 |
+
"\n",
|
| 461 |
" if done:\n",
|
| 462 |
+
" episode_score = info.get(\"episode_score\", obs.get(\"current_accuracy\", 0.0))\n",
|
| 463 |
+
" trajectory.final_accuracy = episode_score\n",
|
| 464 |
+
" trajectory.success = obs.get(\"current_accuracy\", 0.0) >= 0.9\n",
|
| 465 |
" break\n",
|
| 466 |
+
"\n",
|
| 467 |
+
" if not trajectory.steps:\n",
|
| 468 |
+
" trajectory.final_accuracy = 0.0\n",
|
| 469 |
+
"\n",
|
| 470 |
+
" self.logger.info(f\"END episode={episode_id} task={task_name} total_reward={trajectory.total_reward:.4f} final_accuracy={trajectory.final_accuracy:.4f} success={trajectory.success} steps={len(trajectory.steps)}\")\n",
|
| 471 |
+
"\n",
|
| 472 |
+
" return trajectory\n",
|
| 473 |
"\n",
|
| 474 |
" def run(self):\n",
|
| 475 |
+
" \"\"\"Run full training loop across all tasks.\"\"\"\n",
|
| 476 |
+
" self.logger.info(\"=\" * 60)\n",
|
| 477 |
+
" self.logger.info(\"REWARD-GUIDED TRAJECTORY OPTIMIZATION\")\n",
|
| 478 |
+
" self.logger.info(f\"Tasks: {TASKS}\")\n",
|
| 479 |
+
" self.logger.info(f\"Episodes per task: {NUM_EPISODES_PER_TASK}\")\n",
|
| 480 |
+
" self.logger.info(f\"Top-K trajectories: {TOP_K_TRAJECTORIES}\")\n",
|
| 481 |
+
" self.logger.info(\"=\" * 60)\n",
|
| 482 |
+
" self.logger.info(f\"Log file: {self.log_file}\")\n",
|
| 483 |
+
"\n",
|
| 484 |
+
" try:\n",
|
| 485 |
+
" wandb.init(\n",
|
| 486 |
+
" project=\"policy-to-logic-rl\",\n",
|
| 487 |
+
" name=f\"trajectory-opt-{int(time.time())}\",\n",
|
| 488 |
+
" config={\n",
|
| 489 |
+
" \"num_episodes_per_task\": NUM_EPISODES_PER_TASK,\n",
|
| 490 |
+
" \"top_k_trajectories\": TOP_K_TRAJECTORIES,\n",
|
| 491 |
+
" \"min_reward_threshold\": MIN_REWARD_THRESHOLD,\n",
|
| 492 |
+
" \"model\": MODEL,\n",
|
| 493 |
+
" \"temperature\": TEMPERATURE,\n",
|
| 494 |
+
" \"tasks\": TASKS,\n",
|
| 495 |
+
" \"env_url\": ENV_BASE_URL,\n",
|
| 496 |
+
" }\n",
|
| 497 |
+
" )\n",
|
| 498 |
+
" except Exception as e:\n",
|
| 499 |
+
" self.logger.warning(f\"Wandb init failed: {e}. Continuing without W&B.\")\n",
|
| 500 |
+
"\n",
|
| 501 |
+
" # Health check\n",
|
| 502 |
+
" if not self.env.health():\n",
|
| 503 |
+
" raise RuntimeError(f\"Environment not reachable at {ENV_BASE_URL}\")\n",
|
| 504 |
+
" self.logger.info(f\"Environment: OK ({ENV_BASE_URL})\\n\")\n",
|
| 505 |
+
"\n",
|
| 506 |
+
" global_episode = 0\n",
|
| 507 |
+
"\n",
|
| 508 |
" for task in TASKS:\n",
|
| 509 |
+
" self.logger.info(f\"\\n{'\u00e2\u201d\u20ac'*40}\")\n",
|
| 510 |
+
" self.logger.info(f\"TASK: {task}\")\n",
|
| 511 |
+
" self.logger.info(f\"{'\u00e2\u201d\u20ac'*40}\")\n",
|
| 512 |
+
"\n",
|
| 513 |
" task_rewards = []\n",
|
| 514 |
+
" task_accuracies = []\n",
|
| 515 |
+
"\n",
|
| 516 |
" for ep in range(1, NUM_EPISODES_PER_TASK + 1):\n",
|
| 517 |
+
" global_episode += 1\n",
|
| 518 |
+
" trajectory = self.run_episode(task, ep)\n",
|
| 519 |
+
"\n",
|
| 520 |
+
" # Store in bank\n",
|
| 521 |
+
" self.bank.store(trajectory)\n",
|
| 522 |
+
"\n",
|
| 523 |
+
" try:\n",
|
| 524 |
+
" wandb.log({\n",
|
| 525 |
+
" f\"{task}/total_reward\": trajectory.total_reward,\n",
|
| 526 |
+
" f\"{task}/final_accuracy\": trajectory.final_accuracy,\n",
|
| 527 |
+
" f\"{task}/num_steps\": len(trajectory.steps),\n",
|
| 528 |
+
" f\"{task}/success\": int(trajectory.success),\n",
|
| 529 |
+
" f\"{task}/few_shots_used\": len(self.bank.get_examples(task)),\n",
|
| 530 |
+
" \"global/total_reward\": trajectory.total_reward,\n",
|
| 531 |
+
" \"global/final_accuracy\": trajectory.final_accuracy,\n",
|
| 532 |
+
" \"episode\": global_episode,\n",
|
| 533 |
+
" })\n",
|
| 534 |
+
" except Exception:\n",
|
| 535 |
+
" pass\n",
|
| 536 |
+
"\n",
|
| 537 |
+
" # Record metrics\n",
|
| 538 |
+
" self.metrics.append({\n",
|
| 539 |
+
" \"global_episode\": global_episode,\n",
|
| 540 |
+
" \"task\": task,\n",
|
| 541 |
+
" \"episode_in_task\": ep,\n",
|
| 542 |
+
" \"total_reward\": trajectory.total_reward,\n",
|
| 543 |
+
" \"final_accuracy\": trajectory.final_accuracy,\n",
|
| 544 |
+
" \"success\": trajectory.success,\n",
|
| 545 |
+
" \"num_steps\": len(trajectory.steps),\n",
|
| 546 |
+
" \"few_shots_used\": len(self.bank.get_examples(task)) - (1 if trajectory.total_reward >= MIN_REWARD_THRESHOLD else 0)\n",
|
| 547 |
+
" })\n",
|
| 548 |
+
"\n",
|
| 549 |
+
" task_rewards.append(trajectory.total_reward)\n",
|
| 550 |
+
" task_accuracies.append(trajectory.final_accuracy)\n",
|
| 551 |
+
"\n",
|
| 552 |
+
" self.logger.info(f\" \u00e2\u2020\u2019 Episode {ep} complete: reward={trajectory.total_reward:.3f} accuracy={trajectory.final_accuracy:.2f} success={trajectory.success}\")\n",
|
| 553 |
+
" time.sleep(0.5) # Rate limiting\n",
|
| 554 |
+
"\n",
|
| 555 |
+
" self.logger.info(f\"\\n Task summary:\")\n",
|
| 556 |
+
" self.logger.info(f\" First episode reward: {task_rewards[0]:.3f}\")\n",
|
| 557 |
+
" self.logger.info(f\" Last episode reward: {task_rewards[-1]:.3f}\")\n",
|
| 558 |
+
" self.logger.info(f\" Improvement: {task_rewards[-1] - task_rewards[0]:+.3f}\")\n",
|
| 559 |
+
"\n",
|
| 560 |
+
" self.logger.info(\"\\n\" + \"=\" * 60)\n",
|
| 561 |
+
" self.logger.info(\"TRAINING COMPLETE\")\n",
|
| 562 |
+
" self.logger.info(f\"Bank summary: {self.bank.summary()}\")\n",
|
| 563 |
+
" self.logger.info(\"=\" * 60)\n",
|
| 564 |
+
"\n",
|
| 565 |
+
" try:\n",
|
| 566 |
+
" wandb.finish()\n",
|
| 567 |
+
" except Exception:\n",
|
| 568 |
+
" pass\n",
|
| 569 |
+
"\n",
|
| 570 |
" return self.metrics\n",
|
| 571 |
"\n",
|
| 572 |
+
"# \u00e2\u201d\u20ac\u00e2\u201d\u20ac Plot Generation \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
|
| 573 |
+
"\n",
|
| 574 |
+
"def save_plots(metrics: list[dict]):\n",
|
| 575 |
+
" \"\"\"\n",
|
| 576 |
+
" Save reward curve and accuracy curve as PNG files.\n",
|
| 577 |
+
" These are REQUIRED for hackathon submission \u00e2\u20ac\u201d must be committed to repo.\n",
|
| 578 |
+
" \"\"\"\n",
|
| 579 |
+
" try:\n",
|
| 580 |
+
" import matplotlib\n",
|
| 581 |
+
" matplotlib.use(\"Agg\") # Non-interactive backend\n",
|
| 582 |
+
" import matplotlib.pyplot as plt\n",
|
| 583 |
+
" import numpy as np\n",
|
| 584 |
+
" except ImportError:\n",
|
| 585 |
+
" print(\"matplotlib not installed. Run: pip install matplotlib\")\n",
|
| 586 |
+
" return\n",
|
| 587 |
+
"\n",
|
| 588 |
" os.makedirs(\"training/plots\", exist_ok=True)\n",
|
| 589 |
+
"\n",
|
| 590 |
" episodes = [m[\"global_episode\"] for m in metrics]\n",
|
| 591 |
" rewards = [m[\"total_reward\"] for m in metrics]\n",
|
| 592 |
+
" accuracies = [m[\"final_accuracy\"] for m in metrics]\n",
|
| 593 |
+
" tasks = [m[\"task\"] for m in metrics]\n",
|
| 594 |
+
"\n",
|
| 595 |
+
" colors = {\n",
|
| 596 |
+
" \"data_access\": \"#2196F3\",\n",
|
| 597 |
+
" \"resource_access\": \"#FF9800\",\n",
|
| 598 |
+
" \"transaction_approval\": \"#4CAF50\"\n",
|
| 599 |
+
" }\n",
|
| 600 |
+
"\n",
|
| 601 |
+
" # \u00e2\u201d\u20ac\u00e2\u201d\u20ac Plot 1: Reward Curve (per-task trend lines) \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
|
| 602 |
" fig, ax = plt.subplots(figsize=(10, 5))\n",
|
| 603 |
+
"\n",
|
| 604 |
" for task in TASKS:\n",
|
| 605 |
+
" task_eps = [m[\"global_episode\"] for m in metrics if m[\"task\"] == task]\n",
|
| 606 |
+
" task_rews = [m[\"total_reward\"] for m in metrics if m[\"task\"] == task]\n",
|
| 607 |
+
" ax.plot(task_eps, task_rews, marker=\"o\", label=task,\n",
|
| 608 |
+
" color=colors.get(task, \"gray\"), linewidth=2, markersize=5)\n",
|
| 609 |
+
" # Per-task trend line\n",
|
| 610 |
+
" if len(task_eps) >= 2:\n",
|
| 611 |
+
" z = np.polyfit(task_eps, task_rews, 1)\n",
|
| 612 |
+
" p = np.poly1d(z)\n",
|
| 613 |
+
" ax.plot(task_eps, p(task_eps), \"--\",\n",
|
| 614 |
+
" color=colors.get(task, \"gray\"), alpha=0.4, linewidth=1.5)\n",
|
| 615 |
+
"\n",
|
| 616 |
+
" ax.set_xlabel(\"Episode\")\n",
|
| 617 |
+
" ax.set_ylabel(\"Total Reward\")\n",
|
| 618 |
+
" ax.set_title(\"Reward Curve \u00e2\u20ac\u201d Reward-Guided Trajectory Optimization\")\n",
|
| 619 |
+
" ax.legend()\n",
|
| 620 |
+
" ax.grid(True, alpha=0.3)\n",
|
| 621 |
+
" ax.set_ylim(bottom=0)\n",
|
| 622 |
+
"\n",
|
| 623 |
+
" plt.tight_layout()\n",
|
| 624 |
+
" plt.savefig(\"training/plots/reward_curve.png\", dpi=150, bbox_inches=\"tight\")\n",
|
| 625 |
+
" plt.close()\n",
|
| 626 |
+
" print(\"Saved: training/plots/reward_curve.png\")\n",
|
| 627 |
+
"\n",
|
| 628 |
+
" # \u00e2\u201d\u20ac\u00e2\u201d\u20ac Plot 2: Accuracy Curve \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
|
| 629 |
" fig, ax = plt.subplots(figsize=(10, 5))\n",
|
| 630 |
+
"\n",
|
| 631 |
" for task in TASKS:\n",
|
| 632 |
+
" task_eps = [m[\"global_episode\"] for m in metrics if m[\"task\"] == task]\n",
|
| 633 |
+
" task_accs = [m[\"final_accuracy\"] for m in metrics if m[\"task\"] == task]\n",
|
| 634 |
+
" ax.plot(task_eps, task_accs, marker=\"s\", label=task,\n",
|
| 635 |
+
" color=colors.get(task, \"gray\"), linewidth=2, markersize=5)\n",
|
| 636 |
+
"\n",
|
| 637 |
+
" ax.axhline(y=0.9, color=\"red\", linestyle=\"--\", alpha=0.7, label=\"success threshold (0.9)\")\n",
|
| 638 |
+
"\n",
|
| 639 |
+
" ax.set_xlabel(\"Episode\")\n",
|
| 640 |
+
" ax.set_ylabel(\"Final Accuracy\")\n",
|
| 641 |
+
" ax.set_title(\"Accuracy Curve \u00e2\u20ac\u201d Policy-to-Logic Agent\")\n",
|
| 642 |
+
" ax.legend()\n",
|
| 643 |
+
" ax.grid(True, alpha=0.3)\n",
|
| 644 |
+
" ax.set_ylim(0, 1.05)\n",
|
| 645 |
+
"\n",
|
| 646 |
+
" plt.tight_layout()\n",
|
| 647 |
+
" plt.savefig(\"training/plots/accuracy_curve.png\", dpi=150, bbox_inches=\"tight\")\n",
|
| 648 |
+
" plt.close()\n",
|
| 649 |
+
" print(\"Saved: training/plots/accuracy_curve.png\")\n",
|
| 650 |
+
"\n",
|
| 651 |
+
" # \u00e2\u201d\u20ac\u00e2\u201d\u20ac Plot 3: Per-Task Summary (Accuracy + Efficiency) \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
|
| 652 |
+
" fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
| 653 |
+
"\n",
|
| 654 |
+
" task_labels = []\n",
|
| 655 |
+
" acc_improvements = []\n",
|
| 656 |
+
" eff_improvements = []\n",
|
| 657 |
+
" best_accuracies = []\n",
|
| 658 |
+
"\n",
|
| 659 |
" for task in TASKS:\n",
|
| 660 |
+
" task_data = [m for m in metrics if m[\"task\"] == task]\n",
|
| 661 |
+
" if len(task_data) >= 2:\n",
|
| 662 |
+
" task_labels.append(task.replace(\"_\", \"\\n\"))\n",
|
| 663 |
+
" acc_improvements.append(task_data[-1][\"final_accuracy\"] - task_data[0][\"final_accuracy\"])\n",
|
| 664 |
+
" # Efficiency: steps saved (first vs best)\n",
|
| 665 |
+
" first_steps = task_data[0][\"num_steps\"]\n",
|
| 666 |
+
" best_steps = min(m[\"num_steps\"] for m in task_data)\n",
|
| 667 |
+
" eff_pct = ((first_steps - best_steps) / first_steps * 100) if first_steps > 0 else 0\n",
|
| 668 |
+
" eff_improvements.append(eff_pct)\n",
|
| 669 |
+
" best_accuracies.append(max(m[\"final_accuracy\"] for m in task_data))\n",
|
| 670 |
+
"\n",
|
| 671 |
+
" # Left: Best accuracy per task\n",
|
| 672 |
+
" bars1 = axes[0].bar(task_labels, best_accuracies,\n",
|
| 673 |
+
" color=[\"#2196F3\", \"#FF9800\", \"#4CAF50\"][:len(task_labels)],\n",
|
| 674 |
+
" edgecolor=\"white\", linewidth=1.5)\n",
|
| 675 |
+
" axes[0].axhline(y=0.9, color=\"red\", linestyle=\"--\", alpha=0.7, label=\"success threshold\")\n",
|
| 676 |
+
" axes[0].set_ylabel(\"Best Accuracy Achieved\")\n",
|
| 677 |
+
" axes[0].set_title(\"Best Accuracy Per Task\")\n",
|
| 678 |
+
" axes[0].set_ylim(0, 1.1)\n",
|
| 679 |
+
" axes[0].grid(True, axis=\"y\", alpha=0.3)\n",
|
| 680 |
+
" axes[0].legend()\n",
|
| 681 |
+
" for bar, val in zip(bars1, best_accuracies):\n",
|
| 682 |
+
" axes[0].text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.02,\n",
|
| 683 |
+
" f\"{val:.0%}\", ha=\"center\", va=\"bottom\", fontweight=\"bold\")\n",
|
| 684 |
+
"\n",
|
| 685 |
+
" # Right: Efficiency improvement (% steps saved)\n",
|
| 686 |
+
" bars2 = axes[1].bar(task_labels, eff_improvements,\n",
|
| 687 |
+
" color=[\"#2196F3\", \"#FF9800\", \"#4CAF50\"][:len(task_labels)],\n",
|
| 688 |
+
" edgecolor=\"white\", linewidth=1.5)\n",
|
| 689 |
+
" axes[1].axhline(y=0, color=\"black\", linewidth=0.8)\n",
|
| 690 |
+
" axes[1].set_ylabel(\"Steps Saved (%)\")\n",
|
| 691 |
+
" axes[1].set_title(\"Efficiency Improvement (First \u00e2\u2020\u2019 Best Episode)\")\n",
|
| 692 |
+
" axes[1].grid(True, axis=\"y\", alpha=0.3)\n",
|
| 693 |
+
" for bar, val in zip(bars2, eff_improvements):\n",
|
| 694 |
+
" y_pos = max(bar.get_height() + 1, 2)\n",
|
| 695 |
+
" axes[1].text(bar.get_x() + bar.get_width() / 2, y_pos,\n",
|
| 696 |
+
" f\"{val:.0f}%\", ha=\"center\", va=\"bottom\", fontweight=\"bold\")\n",
|
| 697 |
+
"\n",
|
| 698 |
+
" plt.tight_layout()\n",
|
| 699 |
+
" plt.savefig(\"training/plots/improvement_chart.png\", dpi=150, bbox_inches=\"tight\")\n",
|
| 700 |
+
" plt.close()\n",
|
| 701 |
+
" print(\"Saved: training/plots/improvement_chart.png\")\n",
|
| 702 |
+
"\n",
|
| 703 |
+
" # \u00e2\u201d\u20ac\u00e2\u201d\u20ac Save raw metrics as JSON \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
|
| 704 |
+
" timestamp = int(time.time())\n",
|
| 705 |
+
" with open(f\"training/plots/metrics_{timestamp}.json\", \"w\") as f:\n",
|
| 706 |
+
" json.dump(metrics, f, indent=2)\n",
|
| 707 |
+
" with open(\"training/plots/metrics_latest.json\", \"w\") as f:\n",
|
| 708 |
+
" json.dump(metrics, f, indent=2)\n",
|
| 709 |
+
" print(f\"Saved: training/plots/metrics_{timestamp}.json and metrics_latest.json\")\n",
|
| 710 |
+
"\n",
|
| 711 |
+
"# \u00e2\u201d\u20ac\u00e2\u201d\u20ac Entry Point \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
|
| 712 |
+
"\n",
|
| 713 |
+
"if __name__ == \"__main__\":\n",
|
| 714 |
+
" hf_token = os.getenv(\"HF_TOKEN\", \"\")\n",
|
| 715 |
+
" if not hf_token:\n",
|
| 716 |
+
" raise ValueError(\"HF_TOKEN environment variable not set\")\n",
|
| 717 |
+
"\n",
|
| 718 |
+
" loop = TrainingLoop(ENV_BASE_URL, hf_token)\n",
|
| 719 |
+
" metrics = loop.run()\n",
|
| 720 |
+
" save_plots(metrics)\n",
|
| 721 |
+
"\n",
|
| 722 |
+
" print(\"\\nNext step: commit training/plots/*.png to repo for submission.\")\n",
|
| 723 |
+
""
|
| 724 |
]
|
| 725 |
},
|
| 726 |
{
|
|
|
|
| 729 |
"metadata": {},
|
| 730 |
"outputs": [],
|
| 731 |
"source": [
|
| 732 |
+
"# Cell 6: Run training loop\n",
|
| 733 |
"loop = TrainingLoop(ENV_URL, HF_TOKEN)\n",
|
| 734 |
"metrics = loop.run()\n",
|
| 735 |
"print(f\"\\nTotal episodes run: {len(metrics)}\")"
|
|
|
|
| 741 |
"metadata": {},
|
| 742 |
"outputs": [],
|
| 743 |
"source": [
|
| 744 |
+
"# Cell 7: Generate plots and display inline\n",
|
| 745 |
"save_plots(metrics)\n",
|
| 746 |
"\n",
|
| 747 |
"from IPython.display import Image, display\n",
|
|
|
|
| 756 |
"metadata": {},
|
| 757 |
"outputs": [],
|
| 758 |
"source": [
|
| 759 |
+
"# Cell 8: Display wandb run link\n",
|
| 760 |
+
"print(f\"Wandb run: https://wandb.ai/YOUR_USERNAME/policy-to-logic-rl\")\n",
|
| 761 |
+
"print(\"Add this link to your README under Deliverables.\")"
|
| 762 |
+
]
|
| 763 |
+
},
|
| 764 |
+
{
|
| 765 |
+
"cell_type": "code",
|
| 766 |
+
"execution_count": null,
|
| 767 |
+
"metadata": {},
|
| 768 |
+
"outputs": [],
|
| 769 |
+
"source": [
|
| 770 |
+
"# Cell 9: Download plots to commit to repo\n",
|
| 771 |
"from google.colab import files\n",
|
| 772 |
"\n",
|
| 773 |
"files.download(\"training/plots/reward_curve.png\")\n",
|
| 774 |
"files.download(\"training/plots/accuracy_curve.png\")\n",
|
| 775 |
"files.download(\"training/plots/improvement_chart.png\")\n",
|
| 776 |
+
"files.download(\"training/plots/metrics_latest.json\")\n",
|
| 777 |
"\n",
|
| 778 |
"print(\"Downloaded. Now commit these files to: training/plots/ in your repo.\")"
|
| 779 |
]
|
| 780 |
}
|
| 781 |
]
|
| 782 |
+
}
|
training/results-iteration1/accuracy_curve (1).png
ADDED
|
training/results-iteration1/improvement_chart (1).png
ADDED
|
training/results-iteration1/metrics (1).json
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"global_episode": 1,
|
| 4 |
+
"task": "data_access",
|
| 5 |
+
"episode_in_task": 1,
|
| 6 |
+
"total_reward": 1.1988333333333334,
|
| 7 |
+
"final_accuracy": 0.92,
|
| 8 |
+
"success": true,
|
| 9 |
+
"num_steps": 4
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"global_episode": 2,
|
| 13 |
+
"task": "data_access",
|
| 14 |
+
"episode_in_task": 2,
|
| 15 |
+
"total_reward": 0.7585,
|
| 16 |
+
"final_accuracy": 0.9600000000000001,
|
| 17 |
+
"success": true,
|
| 18 |
+
"num_steps": 2
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"global_episode": 3,
|
| 22 |
+
"task": "data_access",
|
| 23 |
+
"episode_in_task": 3,
|
| 24 |
+
"total_reward": 0.7585,
|
| 25 |
+
"final_accuracy": 0.9600000000000001,
|
| 26 |
+
"success": true,
|
| 27 |
+
"num_steps": 2
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"global_episode": 4,
|
| 31 |
+
"task": "data_access",
|
| 32 |
+
"episode_in_task": 4,
|
| 33 |
+
"total_reward": 0.7585,
|
| 34 |
+
"final_accuracy": 0.9600000000000001,
|
| 35 |
+
"success": true,
|
| 36 |
+
"num_steps": 2
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"global_episode": 5,
|
| 40 |
+
"task": "data_access",
|
| 41 |
+
"episode_in_task": 5,
|
| 42 |
+
"total_reward": 0.7585,
|
| 43 |
+
"final_accuracy": 0.9600000000000001,
|
| 44 |
+
"success": true,
|
| 45 |
+
"num_steps": 2
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"global_episode": 6,
|
| 49 |
+
"task": "data_access",
|
| 50 |
+
"episode_in_task": 6,
|
| 51 |
+
"total_reward": 0.7585,
|
| 52 |
+
"final_accuracy": 0.9600000000000001,
|
| 53 |
+
"success": true,
|
| 54 |
+
"num_steps": 2
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"global_episode": 7,
|
| 58 |
+
"task": "data_access",
|
| 59 |
+
"episode_in_task": 7,
|
| 60 |
+
"total_reward": 0.7585,
|
| 61 |
+
"final_accuracy": 0.9600000000000001,
|
| 62 |
+
"success": true,
|
| 63 |
+
"num_steps": 2
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"global_episode": 8,
|
| 67 |
+
"task": "data_access",
|
| 68 |
+
"episode_in_task": 8,
|
| 69 |
+
"total_reward": 0.7585,
|
| 70 |
+
"final_accuracy": 0.9600000000000001,
|
| 71 |
+
"success": true,
|
| 72 |
+
"num_steps": 2
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"global_episode": 9,
|
| 76 |
+
"task": "resource_access",
|
| 77 |
+
"episode_in_task": 1,
|
| 78 |
+
"total_reward": 1.041,
|
| 79 |
+
"final_accuracy": 0.8931428571428572,
|
| 80 |
+
"success": true,
|
| 81 |
+
"num_steps": 3
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"global_episode": 10,
|
| 85 |
+
"task": "resource_access",
|
| 86 |
+
"episode_in_task": 2,
|
| 87 |
+
"total_reward": 1.041,
|
| 88 |
+
"final_accuracy": 0.8931428571428572,
|
| 89 |
+
"success": true,
|
| 90 |
+
"num_steps": 3
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"global_episode": 11,
|
| 94 |
+
"task": "resource_access",
|
| 95 |
+
"episode_in_task": 3,
|
| 96 |
+
"total_reward": 0.7335,
|
| 97 |
+
"final_accuracy": 0.9074285714285715,
|
| 98 |
+
"success": true,
|
| 99 |
+
"num_steps": 2
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"global_episode": 12,
|
| 103 |
+
"task": "resource_access",
|
| 104 |
+
"episode_in_task": 4,
|
| 105 |
+
"total_reward": 0.7335,
|
| 106 |
+
"final_accuracy": 0.9074285714285715,
|
| 107 |
+
"success": true,
|
| 108 |
+
"num_steps": 2
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"global_episode": 13,
|
| 112 |
+
"task": "resource_access",
|
| 113 |
+
"episode_in_task": 5,
|
| 114 |
+
"total_reward": 0.7434999999999999,
|
| 115 |
+
"final_accuracy": 0.9234285714285714,
|
| 116 |
+
"success": true,
|
| 117 |
+
"num_steps": 2
|
| 118 |
+
},
|
| 119 |
+
{
|
| 120 |
+
"global_episode": 14,
|
| 121 |
+
"task": "resource_access",
|
| 122 |
+
"episode_in_task": 6,
|
| 123 |
+
"total_reward": 0.7434999999999999,
|
| 124 |
+
"final_accuracy": 0.9234285714285714,
|
| 125 |
+
"success": true,
|
| 126 |
+
"num_steps": 2
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"global_episode": 15,
|
| 130 |
+
"task": "resource_access",
|
| 131 |
+
"episode_in_task": 7,
|
| 132 |
+
"total_reward": 1.929,
|
| 133 |
+
"final_accuracy": 0.8645714285714287,
|
| 134 |
+
"success": true,
|
| 135 |
+
"num_steps": 5
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"global_episode": 16,
|
| 139 |
+
"task": "resource_access",
|
| 140 |
+
"episode_in_task": 8,
|
| 141 |
+
"total_reward": 0.7335,
|
| 142 |
+
"final_accuracy": 0.9074285714285715,
|
| 143 |
+
"success": true,
|
| 144 |
+
"num_steps": 2
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"global_episode": 17,
|
| 148 |
+
"task": "transaction_approval",
|
| 149 |
+
"episode_in_task": 1,
|
| 150 |
+
"total_reward": 0.10799999999999998,
|
| 151 |
+
"final_accuracy": 0.0,
|
| 152 |
+
"success": false,
|
| 153 |
+
"num_steps": 7
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"global_episode": 18,
|
| 157 |
+
"task": "transaction_approval",
|
| 158 |
+
"episode_in_task": 2,
|
| 159 |
+
"total_reward": 0.10799999999999998,
|
| 160 |
+
"final_accuracy": 0.0,
|
| 161 |
+
"success": false,
|
| 162 |
+
"num_steps": 7
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"global_episode": 19,
|
| 166 |
+
"task": "transaction_approval",
|
| 167 |
+
"episode_in_task": 3,
|
| 168 |
+
"total_reward": 0.10799999999999998,
|
| 169 |
+
"final_accuracy": 0.0,
|
| 170 |
+
"success": false,
|
| 171 |
+
"num_steps": 7
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"global_episode": 20,
|
| 175 |
+
"task": "transaction_approval",
|
| 176 |
+
"episode_in_task": 4,
|
| 177 |
+
"total_reward": 0.10799999999999998,
|
| 178 |
+
"final_accuracy": 0.0,
|
| 179 |
+
"success": false,
|
| 180 |
+
"num_steps": 7
|
| 181 |
+
},
|
| 182 |
+
{
|
| 183 |
+
"global_episode": 21,
|
| 184 |
+
"task": "transaction_approval",
|
| 185 |
+
"episode_in_task": 5,
|
| 186 |
+
"total_reward": 0.10799999999999998,
|
| 187 |
+
"final_accuracy": 0.0,
|
| 188 |
+
"success": false,
|
| 189 |
+
"num_steps": 7
|
| 190 |
+
},
|
| 191 |
+
{
|
| 192 |
+
"global_episode": 22,
|
| 193 |
+
"task": "transaction_approval",
|
| 194 |
+
"episode_in_task": 6,
|
| 195 |
+
"total_reward": 0.10799999999999998,
|
| 196 |
+
"final_accuracy": 0.0,
|
| 197 |
+
"success": false,
|
| 198 |
+
"num_steps": 7
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
"global_episode": 23,
|
| 202 |
+
"task": "transaction_approval",
|
| 203 |
+
"episode_in_task": 7,
|
| 204 |
+
"total_reward": 0.10799999999999998,
|
| 205 |
+
"final_accuracy": 0.0,
|
| 206 |
+
"success": false,
|
| 207 |
+
"num_steps": 7
|
| 208 |
+
},
|
| 209 |
+
{
|
| 210 |
+
"global_episode": 24,
|
| 211 |
+
"task": "transaction_approval",
|
| 212 |
+
"episode_in_task": 8,
|
| 213 |
+
"total_reward": 0.10799999999999998,
|
| 214 |
+
"final_accuracy": 0.0,
|
| 215 |
+
"success": false,
|
| 216 |
+
"num_steps": 7
|
| 217 |
+
}
|
| 218 |
+
]
|
training/results-iteration1/reward_curve (1).png
ADDED
|
training/trajectory_optimizer.py
CHANGED
|
@@ -13,6 +13,8 @@ import json
|
|
| 13 |
import os
|
| 14 |
import time
|
| 15 |
import requests
|
|
|
|
|
|
|
| 16 |
from dataclasses import dataclass, field
|
| 17 |
from typing import Optional
|
| 18 |
from openai import OpenAI
|
|
@@ -106,14 +108,15 @@ class Agent:
|
|
| 106 |
observation: dict,
|
| 107 |
step_number: int,
|
| 108 |
episode_history: list[str],
|
| 109 |
-
few_shot_examples: list[Trajectory]
|
|
|
|
| 110 |
) -> tuple[str, str]:
|
| 111 |
"""
|
| 112 |
Returns (action_type, content_json_string).
|
| 113 |
action_type: one of ask_clarification | propose_rules | refine_rules
|
| 114 |
content: JSON string appropriate for that action
|
| 115 |
"""
|
| 116 |
-
system_prompt = self._build_system_prompt(few_shot_examples)
|
| 117 |
user_prompt = self._build_user_prompt(observation, step_number, episode_history)
|
| 118 |
|
| 119 |
try:
|
|
@@ -132,7 +135,7 @@ class Agent:
|
|
| 132 |
print(f" [LLM ERROR] {e}")
|
| 133 |
return "propose_rules", json.dumps({"rules": [], "default": "DENY"})
|
| 134 |
|
| 135 |
-
def _build_system_prompt(self, few_shot_examples: list[Trajectory]) -> str:
|
| 136 |
base = """You are a policy-to-logic agent. Your job is to convert natural language policies into executable rules.
|
| 137 |
|
| 138 |
AVAILABLE ACTIONS:
|
|
@@ -164,6 +167,39 @@ STRATEGY:
|
|
| 164 |
OUTPUT FORMAT: Respond ONLY with valid JSON. No markdown. No explanation.
|
| 165 |
{"action_type": "propose_rules", "content": "{...escaped json string...}"}
|
| 166 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
if few_shot_examples:
|
| 168 |
base += "\n\nLEARNED FROM PREVIOUS EPISODES (high-reward strategies):\n"
|
| 169 |
for traj in few_shot_examples[-TOP_K_TRAJECTORIES:]:
|
|
@@ -255,6 +291,19 @@ class TrainingLoop:
|
|
| 255 |
self.bank = TrajectoryBank()
|
| 256 |
self.metrics = [] # List of {episode, task, reward, accuracy, success}
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
def run_episode(self, task_name: str, episode_id: int) -> Trajectory:
|
| 259 |
"""Run a single episode and return the trajectory."""
|
| 260 |
few_shots = self.bank.get_examples(task_name)
|
|
@@ -266,7 +315,7 @@ class TrainingLoop:
|
|
| 266 |
done = result.get("done", False)
|
| 267 |
history = []
|
| 268 |
|
| 269 |
-
|
| 270 |
|
| 271 |
step_num = 0
|
| 272 |
while not done and step_num < obs.get("max_steps", 7):
|
|
@@ -277,7 +326,8 @@ class TrainingLoop:
|
|
| 277 |
observation=obs,
|
| 278 |
step_number=step_num,
|
| 279 |
episode_history=history,
|
| 280 |
-
few_shot_examples=few_shots
|
|
|
|
| 281 |
)
|
| 282 |
|
| 283 |
# Execute action
|
|
@@ -303,7 +353,7 @@ class TrainingLoop:
|
|
| 303 |
# Update history
|
| 304 |
history.append(f"Step {step_num}: {action_type} → reward={reward:.2f} acc={step.accuracy:.2f}")
|
| 305 |
|
| 306 |
-
|
| 307 |
|
| 308 |
if done:
|
| 309 |
episode_score = info.get("episode_score", obs.get("current_accuracy", 0.0))
|
|
@@ -314,28 +364,48 @@ class TrainingLoop:
|
|
| 314 |
if not trajectory.steps:
|
| 315 |
trajectory.final_accuracy = 0.0
|
| 316 |
|
|
|
|
|
|
|
| 317 |
return trajectory
|
| 318 |
|
| 319 |
def run(self):
|
| 320 |
"""Run full training loop across all tasks."""
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
|
| 328 |
# Health check
|
| 329 |
if not self.env.health():
|
| 330 |
raise RuntimeError(f"Environment not reachable at {ENV_BASE_URL}")
|
| 331 |
-
|
| 332 |
|
| 333 |
global_episode = 0
|
| 334 |
|
| 335 |
for task in TASKS:
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
|
| 340 |
task_rewards = []
|
| 341 |
task_accuracies = []
|
|
@@ -347,6 +417,20 @@ class TrainingLoop:
|
|
| 347 |
# Store in bank
|
| 348 |
self.bank.store(trajectory)
|
| 349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
# Record metrics
|
| 351 |
self.metrics.append({
|
| 352 |
"global_episode": global_episode,
|
|
@@ -362,18 +446,23 @@ class TrainingLoop:
|
|
| 362 |
task_rewards.append(trajectory.total_reward)
|
| 363 |
task_accuracies.append(trajectory.final_accuracy)
|
| 364 |
|
| 365 |
-
|
| 366 |
time.sleep(0.5) # Rate limiting
|
| 367 |
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
|
| 378 |
return self.metrics
|
| 379 |
|
|
@@ -406,7 +495,7 @@ def save_plots(metrics: list[dict]):
|
|
| 406 |
"transaction_approval": "#4CAF50"
|
| 407 |
}
|
| 408 |
|
| 409 |
-
# ── Plot 1: Reward Curve ────────────────────────────
|
| 410 |
fig, ax = plt.subplots(figsize=(10, 5))
|
| 411 |
|
| 412 |
for task in TASKS:
|
|
@@ -414,11 +503,12 @@ def save_plots(metrics: list[dict]):
|
|
| 414 |
task_rews = [m["total_reward"] for m in metrics if m["task"] == task]
|
| 415 |
ax.plot(task_eps, task_rews, marker="o", label=task,
|
| 416 |
color=colors.get(task, "gray"), linewidth=2, markersize=5)
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
|
|
|
| 422 |
|
| 423 |
ax.set_xlabel("Episode")
|
| 424 |
ax.set_ylabel("Total Reward")
|
|
@@ -455,32 +545,52 @@ def save_plots(metrics: list[dict]):
|
|
| 455 |
plt.close()
|
| 456 |
print("Saved: training/plots/accuracy_curve.png")
|
| 457 |
|
| 458 |
-
# ── Plot 3: Per-Task
|
| 459 |
-
fig,
|
| 460 |
|
| 461 |
-
|
| 462 |
-
|
|
|
|
|
|
|
| 463 |
|
| 464 |
for task in TASKS:
|
| 465 |
-
|
| 466 |
-
if len(
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
|
| 485 |
plt.tight_layout()
|
| 486 |
plt.savefig("training/plots/improvement_chart.png", dpi=150, bbox_inches="tight")
|
|
@@ -488,9 +598,12 @@ def save_plots(metrics: list[dict]):
|
|
| 488 |
print("Saved: training/plots/improvement_chart.png")
|
| 489 |
|
| 490 |
# ── Save raw metrics as JSON ──────────────────────────────────────────────
|
| 491 |
-
|
|
|
|
|
|
|
|
|
|
| 492 |
json.dump(metrics, f, indent=2)
|
| 493 |
-
print("Saved: training/plots/
|
| 494 |
|
| 495 |
# ── Entry Point ───────────────────────────────────────────────────────────────
|
| 496 |
|
|
|
|
| 13 |
import os
|
| 14 |
import time
|
| 15 |
import requests
|
| 16 |
+
import logging
|
| 17 |
+
import wandb
|
| 18 |
from dataclasses import dataclass, field
|
| 19 |
from typing import Optional
|
| 20 |
from openai import OpenAI
|
|
|
|
| 108 |
observation: dict,
|
| 109 |
step_number: int,
|
| 110 |
episode_history: list[str],
|
| 111 |
+
few_shot_examples: list[Trajectory],
|
| 112 |
+
task_name: str = ""
|
| 113 |
) -> tuple[str, str]:
|
| 114 |
"""
|
| 115 |
Returns (action_type, content_json_string).
|
| 116 |
action_type: one of ask_clarification | propose_rules | refine_rules
|
| 117 |
content: JSON string appropriate for that action
|
| 118 |
"""
|
| 119 |
+
system_prompt = self._build_system_prompt(few_shot_examples, task_name)
|
| 120 |
user_prompt = self._build_user_prompt(observation, step_number, episode_history)
|
| 121 |
|
| 122 |
try:
|
|
|
|
| 135 |
print(f" [LLM ERROR] {e}")
|
| 136 |
return "propose_rules", json.dumps({"rules": [], "default": "DENY"})
|
| 137 |
|
| 138 |
+
def _build_system_prompt(self, few_shot_examples: list[Trajectory], task_name: str = "") -> str:
|
| 139 |
base = """You are a policy-to-logic agent. Your job is to convert natural language policies into executable rules.
|
| 140 |
|
| 141 |
AVAILABLE ACTIONS:
|
|
|
|
| 167 |
OUTPUT FORMAT: Respond ONLY with valid JSON. No markdown. No explanation.
|
| 168 |
{"action_type": "propose_rules", "content": "{...escaped json string...}"}
|
| 169 |
"""
|
| 170 |
+
# Task-specific guidance for complex tasks
|
| 171 |
+
if task_name == "transaction_approval":
|
| 172 |
+
base += """
|
| 173 |
+
IMPORTANT — TRANSACTION APPROVAL TASK:
|
| 174 |
+
This task has 4 possible decisions: APPROVE, REQUIRE_APPROVAL, COMPLIANCE_REVIEW, HOLD
|
| 175 |
+
Rules are evaluated TOP-TO-BOTTOM. Order matters critically. You MUST order rules by priority:
|
| 176 |
+
1. FIRST: Check if transfer_type == "international" → then COMPLIANCE_REVIEW (always, overrides everything)
|
| 177 |
+
2. SECOND: Check if amount >= 10000 AND time is outside business hours (hour < 9 or hour >= 17) → then HOLD
|
| 178 |
+
3. THIRD: Check if amount > 5000 AND initiator_role != "manager" → then REQUIRE_APPROVAL
|
| 179 |
+
4. DEFAULT: APPROVE
|
| 180 |
+
|
| 181 |
+
Key details:
|
| 182 |
+
- Standard limit is $5,000 (amount > 5000 triggers approval, NOT >=)
|
| 183 |
+
- High-value threshold is $10,000 (amount >= 10000)
|
| 184 |
+
- Business hours: hour >= 9 AND hour < 17
|
| 185 |
+
- Manager exemption ONLY applies to the standard $5,000 limit, NOT to international or high-value HOLD rules
|
| 186 |
+
- "system" role follows the same rules as "employee"
|
| 187 |
+
|
| 188 |
+
Here is a working example of valid rules for this task:
|
| 189 |
+
{"rules": [{"if": [{"field": "transfer_type", "op": "==", "value": "international"}], "then": "COMPLIANCE_REVIEW"}, {"if": [{"field": "amount", "op": ">=", "value": 10000}, {"field": "time", "op": ">=", "value": 17}], "then": "HOLD"}, {"if": [{"field": "amount", "op": ">=", "value": 10000}, {"field": "time", "op": "<", "value": 9}], "then": "HOLD"}, {"if": [{"field": "amount", "op": ">", "value": 5000}, {"field": "initiator_role", "op": "!=", "value": "manager"}], "then": "REQUIRE_APPROVAL"}], "default": "APPROVE"}
|
| 190 |
+
"""
|
| 191 |
+
elif task_name == "resource_access":
|
| 192 |
+
base += """
|
| 193 |
+
IMPORTANT — RESOURCE ACCESS TASK:
|
| 194 |
+
This task has roles: junior, senior, contractor. Document types: public, internal, confidential.
|
| 195 |
+
- Senior employees: ALLOW everything always
|
| 196 |
+
- Contractors: ALLOW only public, DENY everything else
|
| 197 |
+
- Junior + confidential: ALWAYS DENY (regardless of time — the policy is misleading about this)
|
| 198 |
+
- Junior + internal: ALLOW only during business hours (hour >= 8 AND hour < 17)
|
| 199 |
+
- Junior + public: ALLOW always
|
| 200 |
+
- Business hours: hour >= 8 AND hour < 17
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
if few_shot_examples:
|
| 204 |
base += "\n\nLEARNED FROM PREVIOUS EPISODES (high-reward strategies):\n"
|
| 205 |
for traj in few_shot_examples[-TOP_K_TRAJECTORIES:]:
|
|
|
|
| 291 |
self.bank = TrajectoryBank()
|
| 292 |
self.metrics = [] # List of {episode, task, reward, accuracy, success}
|
| 293 |
|
| 294 |
+
os.makedirs("training/logs", exist_ok=True)
|
| 295 |
+
log_filename = f"training/logs/run_{int(time.time())}.log"
|
| 296 |
+
logging.basicConfig(
|
| 297 |
+
level=logging.INFO,
|
| 298 |
+
format="%(asctime)s [%(levelname)s] %(message)s",
|
| 299 |
+
handlers=[
|
| 300 |
+
logging.FileHandler(log_filename),
|
| 301 |
+
logging.StreamHandler() # also print to console
|
| 302 |
+
]
|
| 303 |
+
)
|
| 304 |
+
self.logger = logging.getLogger("TrainingLoop")
|
| 305 |
+
self.log_file = log_filename
|
| 306 |
+
|
| 307 |
def run_episode(self, task_name: str, episode_id: int) -> Trajectory:
|
| 308 |
"""Run a single episode and return the trajectory."""
|
| 309 |
few_shots = self.bank.get_examples(task_name)
|
|
|
|
| 315 |
done = result.get("done", False)
|
| 316 |
history = []
|
| 317 |
|
| 318 |
+
self.logger.info(f"START episode={episode_id} task={task_name} few_shots_available={len(few_shots)}")
|
| 319 |
|
| 320 |
step_num = 0
|
| 321 |
while not done and step_num < obs.get("max_steps", 7):
|
|
|
|
| 326 |
observation=obs,
|
| 327 |
step_number=step_num,
|
| 328 |
episode_history=history,
|
| 329 |
+
few_shot_examples=few_shots,
|
| 330 |
+
task_name=task_name
|
| 331 |
)
|
| 332 |
|
| 333 |
# Execute action
|
|
|
|
| 353 |
# Update history
|
| 354 |
history.append(f"Step {step_num}: {action_type} → reward={reward:.2f} acc={step.accuracy:.2f}")
|
| 355 |
|
| 356 |
+
self.logger.info(f"STEP episode={episode_id} step={step_num} action={action_type} reward={reward:.4f} accuracy={step.accuracy:.4f}")
|
| 357 |
|
| 358 |
if done:
|
| 359 |
episode_score = info.get("episode_score", obs.get("current_accuracy", 0.0))
|
|
|
|
| 364 |
if not trajectory.steps:
|
| 365 |
trajectory.final_accuracy = 0.0
|
| 366 |
|
| 367 |
+
self.logger.info(f"END episode={episode_id} task={task_name} total_reward={trajectory.total_reward:.4f} final_accuracy={trajectory.final_accuracy:.4f} success={trajectory.success} steps={len(trajectory.steps)}")
|
| 368 |
+
|
| 369 |
return trajectory
|
| 370 |
|
| 371 |
def run(self):
|
| 372 |
"""Run full training loop across all tasks."""
|
| 373 |
+
self.logger.info("=" * 60)
|
| 374 |
+
self.logger.info("REWARD-GUIDED TRAJECTORY OPTIMIZATION")
|
| 375 |
+
self.logger.info(f"Tasks: {TASKS}")
|
| 376 |
+
self.logger.info(f"Episodes per task: {NUM_EPISODES_PER_TASK}")
|
| 377 |
+
self.logger.info(f"Top-K trajectories: {TOP_K_TRAJECTORIES}")
|
| 378 |
+
self.logger.info("=" * 60)
|
| 379 |
+
self.logger.info(f"Log file: {self.log_file}")
|
| 380 |
+
|
| 381 |
+
try:
|
| 382 |
+
wandb.init(
|
| 383 |
+
project="policy-to-logic-rl",
|
| 384 |
+
name=f"trajectory-opt-{int(time.time())}",
|
| 385 |
+
config={
|
| 386 |
+
"num_episodes_per_task": NUM_EPISODES_PER_TASK,
|
| 387 |
+
"top_k_trajectories": TOP_K_TRAJECTORIES,
|
| 388 |
+
"min_reward_threshold": MIN_REWARD_THRESHOLD,
|
| 389 |
+
"model": MODEL,
|
| 390 |
+
"temperature": TEMPERATURE,
|
| 391 |
+
"tasks": TASKS,
|
| 392 |
+
"env_url": ENV_BASE_URL,
|
| 393 |
+
}
|
| 394 |
+
)
|
| 395 |
+
except Exception as e:
|
| 396 |
+
self.logger.warning(f"Wandb init failed: {e}. Continuing without W&B.")
|
| 397 |
|
| 398 |
# Health check
|
| 399 |
if not self.env.health():
|
| 400 |
raise RuntimeError(f"Environment not reachable at {ENV_BASE_URL}")
|
| 401 |
+
self.logger.info(f"Environment: OK ({ENV_BASE_URL})\n")
|
| 402 |
|
| 403 |
global_episode = 0
|
| 404 |
|
| 405 |
for task in TASKS:
|
| 406 |
+
self.logger.info(f"\n{'─'*40}")
|
| 407 |
+
self.logger.info(f"TASK: {task}")
|
| 408 |
+
self.logger.info(f"{'─'*40}")
|
| 409 |
|
| 410 |
task_rewards = []
|
| 411 |
task_accuracies = []
|
|
|
|
| 417 |
# Store in bank
|
| 418 |
self.bank.store(trajectory)
|
| 419 |
|
| 420 |
+
try:
|
| 421 |
+
wandb.log({
|
| 422 |
+
f"{task}/total_reward": trajectory.total_reward,
|
| 423 |
+
f"{task}/final_accuracy": trajectory.final_accuracy,
|
| 424 |
+
f"{task}/num_steps": len(trajectory.steps),
|
| 425 |
+
f"{task}/success": int(trajectory.success),
|
| 426 |
+
f"{task}/few_shots_used": len(self.bank.get_examples(task)),
|
| 427 |
+
"global/total_reward": trajectory.total_reward,
|
| 428 |
+
"global/final_accuracy": trajectory.final_accuracy,
|
| 429 |
+
"episode": global_episode,
|
| 430 |
+
})
|
| 431 |
+
except Exception:
|
| 432 |
+
pass
|
| 433 |
+
|
| 434 |
# Record metrics
|
| 435 |
self.metrics.append({
|
| 436 |
"global_episode": global_episode,
|
|
|
|
| 446 |
task_rewards.append(trajectory.total_reward)
|
| 447 |
task_accuracies.append(trajectory.final_accuracy)
|
| 448 |
|
| 449 |
+
self.logger.info(f" → Episode {ep} complete: reward={trajectory.total_reward:.3f} accuracy={trajectory.final_accuracy:.2f} success={trajectory.success}")
|
| 450 |
time.sleep(0.5) # Rate limiting
|
| 451 |
|
| 452 |
+
self.logger.info(f"\n Task summary:")
|
| 453 |
+
self.logger.info(f" First episode reward: {task_rewards[0]:.3f}")
|
| 454 |
+
self.logger.info(f" Last episode reward: {task_rewards[-1]:.3f}")
|
| 455 |
+
self.logger.info(f" Improvement: {task_rewards[-1] - task_rewards[0]:+.3f}")
|
| 456 |
|
| 457 |
+
self.logger.info("\n" + "=" * 60)
|
| 458 |
+
self.logger.info("TRAINING COMPLETE")
|
| 459 |
+
self.logger.info(f"Bank summary: {self.bank.summary()}")
|
| 460 |
+
self.logger.info("=" * 60)
|
| 461 |
+
|
| 462 |
+
try:
|
| 463 |
+
wandb.finish()
|
| 464 |
+
except Exception:
|
| 465 |
+
pass
|
| 466 |
|
| 467 |
return self.metrics
|
| 468 |
|
|
|
|
| 495 |
"transaction_approval": "#4CAF50"
|
| 496 |
}
|
| 497 |
|
| 498 |
+
# ── Plot 1: Reward Curve (per-task trend lines) ────────────────────────────
|
| 499 |
fig, ax = plt.subplots(figsize=(10, 5))
|
| 500 |
|
| 501 |
for task in TASKS:
|
|
|
|
| 503 |
task_rews = [m["total_reward"] for m in metrics if m["task"] == task]
|
| 504 |
ax.plot(task_eps, task_rews, marker="o", label=task,
|
| 505 |
color=colors.get(task, "gray"), linewidth=2, markersize=5)
|
| 506 |
+
# Per-task trend line
|
| 507 |
+
if len(task_eps) >= 2:
|
| 508 |
+
z = np.polyfit(task_eps, task_rews, 1)
|
| 509 |
+
p = np.poly1d(z)
|
| 510 |
+
ax.plot(task_eps, p(task_eps), "--",
|
| 511 |
+
color=colors.get(task, "gray"), alpha=0.4, linewidth=1.5)
|
| 512 |
|
| 513 |
ax.set_xlabel("Episode")
|
| 514 |
ax.set_ylabel("Total Reward")
|
|
|
|
| 545 |
plt.close()
|
| 546 |
print("Saved: training/plots/accuracy_curve.png")
|
| 547 |
|
| 548 |
+
# ── Plot 3: Per-Task Summary (Accuracy + Efficiency) ──────────────────────
|
| 549 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
| 550 |
|
| 551 |
+
task_labels = []
|
| 552 |
+
acc_improvements = []
|
| 553 |
+
eff_improvements = []
|
| 554 |
+
best_accuracies = []
|
| 555 |
|
| 556 |
for task in TASKS:
|
| 557 |
+
task_data = [m for m in metrics if m["task"] == task]
|
| 558 |
+
if len(task_data) >= 2:
|
| 559 |
+
task_labels.append(task.replace("_", "\n"))
|
| 560 |
+
acc_improvements.append(task_data[-1]["final_accuracy"] - task_data[0]["final_accuracy"])
|
| 561 |
+
# Efficiency: steps saved (first vs best)
|
| 562 |
+
first_steps = task_data[0]["num_steps"]
|
| 563 |
+
best_steps = min(m["num_steps"] for m in task_data)
|
| 564 |
+
eff_pct = ((first_steps - best_steps) / first_steps * 100) if first_steps > 0 else 0
|
| 565 |
+
eff_improvements.append(eff_pct)
|
| 566 |
+
best_accuracies.append(max(m["final_accuracy"] for m in task_data))
|
| 567 |
+
|
| 568 |
+
# Left: Best accuracy per task
|
| 569 |
+
bars1 = axes[0].bar(task_labels, best_accuracies,
|
| 570 |
+
color=["#2196F3", "#FF9800", "#4CAF50"][:len(task_labels)],
|
| 571 |
+
edgecolor="white", linewidth=1.5)
|
| 572 |
+
axes[0].axhline(y=0.9, color="red", linestyle="--", alpha=0.7, label="success threshold")
|
| 573 |
+
axes[0].set_ylabel("Best Accuracy Achieved")
|
| 574 |
+
axes[0].set_title("Best Accuracy Per Task")
|
| 575 |
+
axes[0].set_ylim(0, 1.1)
|
| 576 |
+
axes[0].grid(True, axis="y", alpha=0.3)
|
| 577 |
+
axes[0].legend()
|
| 578 |
+
for bar, val in zip(bars1, best_accuracies):
|
| 579 |
+
axes[0].text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.02,
|
| 580 |
+
f"{val:.0%}", ha="center", va="bottom", fontweight="bold")
|
| 581 |
+
|
| 582 |
+
# Right: Efficiency improvement (% steps saved)
|
| 583 |
+
bars2 = axes[1].bar(task_labels, eff_improvements,
|
| 584 |
+
color=["#2196F3", "#FF9800", "#4CAF50"][:len(task_labels)],
|
| 585 |
+
edgecolor="white", linewidth=1.5)
|
| 586 |
+
axes[1].axhline(y=0, color="black", linewidth=0.8)
|
| 587 |
+
axes[1].set_ylabel("Steps Saved (%)")
|
| 588 |
+
axes[1].set_title("Efficiency Improvement (First → Best Episode)")
|
| 589 |
+
axes[1].grid(True, axis="y", alpha=0.3)
|
| 590 |
+
for bar, val in zip(bars2, eff_improvements):
|
| 591 |
+
y_pos = max(bar.get_height() + 1, 2)
|
| 592 |
+
axes[1].text(bar.get_x() + bar.get_width() / 2, y_pos,
|
| 593 |
+
f"{val:.0f}%", ha="center", va="bottom", fontweight="bold")
|
| 594 |
|
| 595 |
plt.tight_layout()
|
| 596 |
plt.savefig("training/plots/improvement_chart.png", dpi=150, bbox_inches="tight")
|
|
|
|
| 598 |
print("Saved: training/plots/improvement_chart.png")
|
| 599 |
|
| 600 |
# ── Save raw metrics as JSON ──────────────────────────────────────────────
|
| 601 |
+
timestamp = int(time.time())
|
| 602 |
+
with open(f"training/plots/metrics_{timestamp}.json", "w") as f:
|
| 603 |
+
json.dump(metrics, f, indent=2)
|
| 604 |
+
with open("training/plots/metrics_latest.json", "w") as f:
|
| 605 |
json.dump(metrics, f, indent=2)
|
| 606 |
+
print(f"Saved: training/plots/metrics_{timestamp}.json and metrics_latest.json")
|
| 607 |
|
| 608 |
# ── Entry Point ───────────────────────────────────────────────────────────────
|
| 609 |
|
training/update_colab.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
with open("training/trajectory_optimizer.py", "r") as f:
|
| 5 |
+
traj_code = f.read()
|
| 6 |
+
|
| 7 |
+
notebook = {
|
| 8 |
+
"nbformat": 4,
|
| 9 |
+
"nbformat_minor": 0,
|
| 10 |
+
"metadata": {
|
| 11 |
+
"colab": {
|
| 12 |
+
"provenance": [],
|
| 13 |
+
"name": "Policy-to-Logic Training"
|
| 14 |
+
},
|
| 15 |
+
"kernelspec": {
|
| 16 |
+
"name": "python3",
|
| 17 |
+
"display_name": "Python 3"
|
| 18 |
+
},
|
| 19 |
+
"language_info": {
|
| 20 |
+
"name": "python"
|
| 21 |
+
}
|
| 22 |
+
},
|
| 23 |
+
"cells": [
|
| 24 |
+
{
|
| 25 |
+
"cell_type": "markdown",
|
| 26 |
+
"metadata": {},
|
| 27 |
+
"source": [
|
| 28 |
+
"# Policy-to-Logic RL Environment \u2014 Training Notebook\n",
|
| 29 |
+
"\n",
|
| 30 |
+
"This notebook runs the **reward-guided trajectory optimization loop** against the deployed environment.\n",
|
| 31 |
+
"\n",
|
| 32 |
+
"**What it does:**\n",
|
| 33 |
+
"1. Connects to the live HF Spaces environment\n",
|
| 34 |
+
"2. Runs 8 episodes per task (3 tasks = 24 total episodes)\n",
|
| 35 |
+
"3. Accumulates high-reward trajectories as few-shot examples\n",
|
| 36 |
+
"4. Generates training evidence plots (reward curve, accuracy curve, improvement chart)\n",
|
| 37 |
+
"5. Logs everything to Weights & Biases"
|
| 38 |
+
]
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"cell_type": "code",
|
| 42 |
+
"execution_count": None,
|
| 43 |
+
"metadata": {},
|
| 44 |
+
"outputs": [],
|
| 45 |
+
"source": [
|
| 46 |
+
"# Cell 1: Install dependencies\n",
|
| 47 |
+
"!pip install openai requests matplotlib numpy wandb"
|
| 48 |
+
]
|
| 49 |
+
},
|
| 50 |
+
{
|
| 51 |
+
"cell_type": "code",
|
| 52 |
+
"execution_count": None,
|
| 53 |
+
"metadata": {},
|
| 54 |
+
"outputs": [],
|
| 55 |
+
"source": [
|
| 56 |
+
"# Cell 2: Configuration\n",
|
| 57 |
+
"import os\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"# SET THESE BEFORE RUNNING\n",
|
| 60 |
+
"HF_TOKEN = \"\" # Your Hugging Face token with inference access\n",
|
| 61 |
+
"ENV_URL = \"https://godreign-policy2logic.hf.space\" # Your deployed environment URL\n",
|
| 62 |
+
"WANDB_API_KEY = \"\" # Your Wandb API key\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"os.environ[\"HF_TOKEN\"] = HF_TOKEN\n",
|
| 65 |
+
"os.environ[\"ENV_BASE_URL\"] = ENV_URL\n",
|
| 66 |
+
"if WANDB_API_KEY:\n",
|
| 67 |
+
" os.environ[\"WANDB_API_KEY\"] = WANDB_API_KEY\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"print(f\"Environment URL: {ENV_URL}\")\n",
|
| 70 |
+
"print(f\"HF Token set: {'Yes' if HF_TOKEN else 'NO - MUST SET THIS'}\")\n",
|
| 71 |
+
"print(f\"Wandb Token set: {'Yes' if WANDB_API_KEY else 'NO - WILL PROMPT'}\")"
|
| 72 |
+
]
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"cell_type": "code",
|
| 76 |
+
"execution_count": None,
|
| 77 |
+
"metadata": {},
|
| 78 |
+
"outputs": [],
|
| 79 |
+
"source": [
|
| 80 |
+
"# Cell 3: Verify environment is reachable\n",
|
| 81 |
+
"import requests\n",
|
| 82 |
+
"\n",
|
| 83 |
+
"r = requests.get(f\"{ENV_URL}/health\")\n",
|
| 84 |
+
"print(f\"Status: {r.status_code}\")\n",
|
| 85 |
+
"print(f\"Response: {r.json()}\")\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"r2 = requests.get(f\"{ENV_URL}/tasks\")\n",
|
| 88 |
+
"tasks = r2.json()\n",
|
| 89 |
+
"print(f\"\\nAvailable tasks: {list(tasks['tasks'].keys())}\")"
|
| 90 |
+
]
|
| 91 |
+
},
|
| 92 |
+
{
|
| 93 |
+
"cell_type": "code",
|
| 94 |
+
"execution_count": None,
|
| 95 |
+
"metadata": {},
|
| 96 |
+
"outputs": [],
|
| 97 |
+
"source": [
|
| 98 |
+
"# Cell 4: Wandb login — run this before training\n",
|
| 99 |
+
"import wandb\n",
|
| 100 |
+
"wandb.login() # Will prompt for API key if WANDB_API_KEY is not set"
|
| 101 |
+
]
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"cell_type": "code",
|
| 105 |
+
"execution_count": None,
|
| 106 |
+
"metadata": {},
|
| 107 |
+
"outputs": [],
|
| 108 |
+
"source": [
|
| 109 |
+
"# Cell 5: Training loop implementation (full trajectory_optimizer.py)\n"
|
| 110 |
+
] + [line + "\n" for line in traj_code.split("\n")[:-1]] + [traj_code.split("\n")[-1]]
|
| 111 |
+
},
|
| 112 |
+
{
|
| 113 |
+
"cell_type": "code",
|
| 114 |
+
"execution_count": None,
|
| 115 |
+
"metadata": {},
|
| 116 |
+
"outputs": [],
|
| 117 |
+
"source": [
|
| 118 |
+
"# Cell 6: Run training loop\n",
|
| 119 |
+
"loop = TrainingLoop(ENV_URL, HF_TOKEN)\n",
|
| 120 |
+
"metrics = loop.run()\n",
|
| 121 |
+
"print(f\"\\nTotal episodes run: {len(metrics)}\")"
|
| 122 |
+
]
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"cell_type": "code",
|
| 126 |
+
"execution_count": None,
|
| 127 |
+
"metadata": {},
|
| 128 |
+
"outputs": [],
|
| 129 |
+
"source": [
|
| 130 |
+
"# Cell 7: Generate plots and display inline\n",
|
| 131 |
+
"save_plots(metrics)\n",
|
| 132 |
+
"\n",
|
| 133 |
+
"from IPython.display import Image, display\n",
|
| 134 |
+
"display(Image(\"training/plots/reward_curve.png\"))\n",
|
| 135 |
+
"display(Image(\"training/plots/accuracy_curve.png\"))\n",
|
| 136 |
+
"display(Image(\"training/plots/improvement_chart.png\"))"
|
| 137 |
+
]
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"cell_type": "code",
|
| 141 |
+
"execution_count": None,
|
| 142 |
+
"metadata": {},
|
| 143 |
+
"outputs": [],
|
| 144 |
+
"source": [
|
| 145 |
+
"# Cell 8: Display wandb run link\n",
|
| 146 |
+
"print(f\"Wandb run: https://wandb.ai/YOUR_USERNAME/policy-to-logic-rl\")\n",
|
| 147 |
+
"print(\"Add this link to your README under Deliverables.\")"
|
| 148 |
+
]
|
| 149 |
+
},
|
| 150 |
+
{
|
| 151 |
+
"cell_type": "code",
|
| 152 |
+
"execution_count": None,
|
| 153 |
+
"metadata": {},
|
| 154 |
+
"outputs": [],
|
| 155 |
+
"source": [
|
| 156 |
+
"# Cell 9: Download plots to commit to repo\n",
|
| 157 |
+
"from google.colab import files\n",
|
| 158 |
+
"\n",
|
| 159 |
+
"files.download(\"training/plots/reward_curve.png\")\n",
|
| 160 |
+
"files.download(\"training/plots/accuracy_curve.png\")\n",
|
| 161 |
+
"files.download(\"training/plots/improvement_chart.png\")\n",
|
| 162 |
+
"files.download(\"training/plots/metrics_latest.json\")\n",
|
| 163 |
+
"\n",
|
| 164 |
+
"print(\"Downloaded. Now commit these files to: training/plots/ in your repo.\")"
|
| 165 |
+
]
|
| 166 |
+
}
|
| 167 |
+
]
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
with open("training/colab_training.ipynb", "w") as f:
|
| 171 |
+
json.dump(notebook, f, indent=1)
|
| 172 |
+
|
| 173 |
+
print("Colab Notebook updated successfully")
|
uv.lock
CHANGED
|
@@ -449,6 +449,30 @@ wheels = [
|
|
| 449 |
{ url = "https://files.pythonhosted.org/packages/d5/1f/5f4a3cd9e4440e9d9bc78ad0a91a1c8d46b4d429d5239ebe6793c9fe5c41/fsspec-2026.3.0-py3-none-any.whl", hash = "sha256:d2ceafaad1b3457968ed14efa28798162f1638dbb5d2a6868a2db002a5ee39a4", size = 202595, upload-time = "2026-03-27T19:11:13.595Z" },
|
| 450 |
]
|
| 451 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
[[package]]
|
| 453 |
name = "h11"
|
| 454 |
version = "0.16.0"
|
|
@@ -1161,6 +1185,15 @@ wheels = [
|
|
| 1161 |
{ url = "https://files.pythonhosted.org/packages/bc/60/5382c03e1970de634027cee8e1b7d39776b778b81812aaf45b694dfe9e28/pillow-12.2.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:bfa9c230d2fe991bed5318a5f119bd6780cda2915cca595393649fc118ab895e", size = 7080946, upload-time = "2026-04-01T14:46:11.734Z" },
|
| 1162 |
]
|
| 1163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1164 |
[[package]]
|
| 1165 |
name = "pluggy"
|
| 1166 |
version = "1.6.0"
|
|
@@ -1185,6 +1218,7 @@ dependencies = [
|
|
| 1185 |
{ name = "pydantic" },
|
| 1186 |
{ name = "requests" },
|
| 1187 |
{ name = "uvicorn" },
|
|
|
|
| 1188 |
]
|
| 1189 |
|
| 1190 |
[package.optional-dependencies]
|
|
@@ -1206,9 +1240,25 @@ requires-dist = [
|
|
| 1206 |
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0" },
|
| 1207 |
{ name = "requests", specifier = ">=2.25.0" },
|
| 1208 |
{ name = "uvicorn", specifier = ">=0.24.0" },
|
|
|
|
| 1209 |
]
|
| 1210 |
provides-extras = ["dev"]
|
| 1211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1212 |
[[package]]
|
| 1213 |
name = "pydantic"
|
| 1214 |
version = "2.13.3"
|
|
@@ -1480,6 +1530,19 @@ wheels = [
|
|
| 1480 |
{ url = "https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl", hash = "sha256:33bd4ef74232fb73fe9279a257718407f169c09b78a87ad3d296f548e27de0bb", size = 310654, upload-time = "2026-04-12T08:24:02.83Z" },
|
| 1481 |
]
|
| 1482 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1483 |
[[package]]
|
| 1484 |
name = "shellingham"
|
| 1485 |
version = "1.5.4"
|
|
@@ -1498,6 +1561,15 @@ wheels = [
|
|
| 1498 |
{ url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" },
|
| 1499 |
]
|
| 1500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1501 |
[[package]]
|
| 1502 |
name = "sniffio"
|
| 1503 |
version = "1.3.1"
|
|
@@ -1644,3 +1716,32 @@ sdist = { url = "https://files.pythonhosted.org/packages/1f/93/041fca8274050e40e
|
|
| 1644 |
wheels = [
|
| 1645 |
{ url = "https://files.pythonhosted.org/packages/31/a3/5b1562db76a5a488274b2332a97199b32d0442aca0ed193697fd47786316/uvicorn-0.46.0-py3-none-any.whl", hash = "sha256:bbebbcbed972d162afca128605223022bedd345b7bc7855ce66deb31487a9048", size = 70926, upload-time = "2026-04-23T07:15:58.355Z" },
|
| 1646 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 449 |
{ url = "https://files.pythonhosted.org/packages/d5/1f/5f4a3cd9e4440e9d9bc78ad0a91a1c8d46b4d429d5239ebe6793c9fe5c41/fsspec-2026.3.0-py3-none-any.whl", hash = "sha256:d2ceafaad1b3457968ed14efa28798162f1638dbb5d2a6868a2db002a5ee39a4", size = 202595, upload-time = "2026-03-27T19:11:13.595Z" },
|
| 450 |
]
|
| 451 |
|
| 452 |
+
[[package]]
|
| 453 |
+
name = "gitdb"
|
| 454 |
+
version = "4.0.12"
|
| 455 |
+
source = { registry = "https://pypi.org/simple" }
|
| 456 |
+
dependencies = [
|
| 457 |
+
{ name = "smmap" },
|
| 458 |
+
]
|
| 459 |
+
sdist = { url = "https://files.pythonhosted.org/packages/72/94/63b0fc47eb32792c7ba1fe1b694daec9a63620db1e313033d18140c2320a/gitdb-4.0.12.tar.gz", hash = "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571", size = 394684, upload-time = "2025-01-02T07:20:46.413Z" }
|
| 460 |
+
wheels = [
|
| 461 |
+
{ url = "https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl", hash = "sha256:67073e15955400952c6565cc3e707c554a4eea2e428946f7a4c162fab9bd9bcf", size = 62794, upload-time = "2025-01-02T07:20:43.624Z" },
|
| 462 |
+
]
|
| 463 |
+
|
| 464 |
+
[[package]]
|
| 465 |
+
name = "gitpython"
|
| 466 |
+
version = "3.1.47"
|
| 467 |
+
source = { registry = "https://pypi.org/simple" }
|
| 468 |
+
dependencies = [
|
| 469 |
+
{ name = "gitdb" },
|
| 470 |
+
]
|
| 471 |
+
sdist = { url = "https://files.pythonhosted.org/packages/c1/bd/50db468e9b1310529a19fce651b3b0e753b5c07954d486cba31bbee9a5d5/gitpython-3.1.47.tar.gz", hash = "sha256:dba27f922bd2b42cb54c87a8ab3cb6beb6bf07f3d564e21ac848913a05a8a3cd", size = 216978, upload-time = "2026-04-22T02:44:44.059Z" }
|
| 472 |
+
wheels = [
|
| 473 |
+
{ url = "https://files.pythonhosted.org/packages/f2/c5/a1bc0996af85757903cf2bf444a7824e68e0035ce63fb41d6f76f9def68b/gitpython-3.1.47-py3-none-any.whl", hash = "sha256:489f590edfd6d20571b2c0e72c6a6ac6915ee8b8cd04572330e3842207a78905", size = 209547, upload-time = "2026-04-22T02:44:41.271Z" },
|
| 474 |
+
]
|
| 475 |
+
|
| 476 |
[[package]]
|
| 477 |
name = "h11"
|
| 478 |
version = "0.16.0"
|
|
|
|
| 1185 |
{ url = "https://files.pythonhosted.org/packages/bc/60/5382c03e1970de634027cee8e1b7d39776b778b81812aaf45b694dfe9e28/pillow-12.2.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:bfa9c230d2fe991bed5318a5f119bd6780cda2915cca595393649fc118ab895e", size = 7080946, upload-time = "2026-04-01T14:46:11.734Z" },
|
| 1186 |
]
|
| 1187 |
|
| 1188 |
+
[[package]]
|
| 1189 |
+
name = "platformdirs"
|
| 1190 |
+
version = "4.9.6"
|
| 1191 |
+
source = { registry = "https://pypi.org/simple" }
|
| 1192 |
+
sdist = { url = "https://files.pythonhosted.org/packages/9f/4a/0883b8e3802965322523f0b200ecf33d31f10991d0401162f4b23c698b42/platformdirs-4.9.6.tar.gz", hash = "sha256:3bfa75b0ad0db84096ae777218481852c0ebc6c727b3168c1b9e0118e458cf0a", size = 29400, upload-time = "2026-04-09T00:04:10.812Z" }
|
| 1193 |
+
wheels = [
|
| 1194 |
+
{ url = "https://files.pythonhosted.org/packages/75/a6/a0a304dc33b49145b21f4808d763822111e67d1c3a32b524a1baf947b6e1/platformdirs-4.9.6-py3-none-any.whl", hash = "sha256:e61adb1d5e5cb3441b4b7710bea7e4c12250ca49439228cc1021c00dcfac0917", size = 21348, upload-time = "2026-04-09T00:04:09.463Z" },
|
| 1195 |
+
]
|
| 1196 |
+
|
| 1197 |
[[package]]
|
| 1198 |
name = "pluggy"
|
| 1199 |
version = "1.6.0"
|
|
|
|
| 1218 |
{ name = "pydantic" },
|
| 1219 |
{ name = "requests" },
|
| 1220 |
{ name = "uvicorn" },
|
| 1221 |
+
{ name = "wandb" },
|
| 1222 |
]
|
| 1223 |
|
| 1224 |
[package.optional-dependencies]
|
|
|
|
| 1240 |
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0" },
|
| 1241 |
{ name = "requests", specifier = ">=2.25.0" },
|
| 1242 |
{ name = "uvicorn", specifier = ">=0.24.0" },
|
| 1243 |
+
{ name = "wandb", specifier = ">=0.16.0" },
|
| 1244 |
]
|
| 1245 |
provides-extras = ["dev"]
|
| 1246 |
|
| 1247 |
+
[[package]]
|
| 1248 |
+
name = "protobuf"
|
| 1249 |
+
version = "7.34.1"
|
| 1250 |
+
source = { registry = "https://pypi.org/simple" }
|
| 1251 |
+
sdist = { url = "https://files.pythonhosted.org/packages/6b/6b/a0e95cad1ad7cc3f2c6821fcab91671bd5b78bd42afb357bb4765f29bc41/protobuf-7.34.1.tar.gz", hash = "sha256:9ce42245e704cc5027be797c1db1eb93184d44d1cdd71811fb2d9b25ad541280", size = 454708, upload-time = "2026-03-20T17:34:47.036Z" }
|
| 1252 |
+
wheels = [
|
| 1253 |
+
{ url = "https://files.pythonhosted.org/packages/ec/11/3325d41e6ee15bf1125654301211247b042563bcc898784351252549a8ad/protobuf-7.34.1-cp310-abi3-macosx_10_9_universal2.whl", hash = "sha256:d8b2cc79c4d8f62b293ad9b11ec3aebce9af481fa73e64556969f7345ebf9fc7", size = 429247, upload-time = "2026-03-20T17:34:37.024Z" },
|
| 1254 |
+
{ url = "https://files.pythonhosted.org/packages/eb/9d/aa69df2724ff63efa6f72307b483ce0827f4347cc6d6df24b59e26659fef/protobuf-7.34.1-cp310-abi3-manylinux2014_aarch64.whl", hash = "sha256:5185e0e948d07abe94bb76ec9b8416b604cfe5da6f871d67aad30cbf24c3110b", size = 325753, upload-time = "2026-03-20T17:34:38.751Z" },
|
| 1255 |
+
{ url = "https://files.pythonhosted.org/packages/92/e8/d174c91fd48e50101943f042b09af9029064810b734e4160bbe282fa1caa/protobuf-7.34.1-cp310-abi3-manylinux2014_s390x.whl", hash = "sha256:403b093a6e28a960372b44e5eb081775c9b056e816a8029c61231743d63f881a", size = 340198, upload-time = "2026-03-20T17:34:39.871Z" },
|
| 1256 |
+
{ url = "https://files.pythonhosted.org/packages/53/1b/3b431694a4dc6d37b9f653f0c64b0a0d9ec074ee810710c0c3da21d67ba7/protobuf-7.34.1-cp310-abi3-manylinux2014_x86_64.whl", hash = "sha256:8ff40ce8cd688f7265326b38d5a1bed9bfdf5e6723d49961432f83e21d5713e4", size = 324267, upload-time = "2026-03-20T17:34:41.1Z" },
|
| 1257 |
+
{ url = "https://files.pythonhosted.org/packages/85/29/64de04a0ac142fb685fd09999bc3d337943fb386f3a0ec57f92fd8203f97/protobuf-7.34.1-cp310-abi3-win32.whl", hash = "sha256:34b84ce27680df7cca9f231043ada0daa55d0c44a2ddfaa58ec1d0d89d8bf60a", size = 426628, upload-time = "2026-03-20T17:34:42.536Z" },
|
| 1258 |
+
{ url = "https://files.pythonhosted.org/packages/4d/87/cb5e585192a22b8bd457df5a2c16a75ea0db9674c3a0a39fc9347d84e075/protobuf-7.34.1-cp310-abi3-win_amd64.whl", hash = "sha256:e97b55646e6ce5cbb0954a8c28cd39a5869b59090dfaa7df4598a7fba869468c", size = 437901, upload-time = "2026-03-20T17:34:44.112Z" },
|
| 1259 |
+
{ url = "https://files.pythonhosted.org/packages/88/95/608f665226bca68b736b79e457fded9a2a38c4f4379a4a7614303d9db3bc/protobuf-7.34.1-py3-none-any.whl", hash = "sha256:bb3812cd53aefea2b028ef42bd780f5b96407247f20c6ef7c679807e9d188f11", size = 170715, upload-time = "2026-03-20T17:34:45.384Z" },
|
| 1260 |
+
]
|
| 1261 |
+
|
| 1262 |
[[package]]
|
| 1263 |
name = "pydantic"
|
| 1264 |
version = "2.13.3"
|
|
|
|
| 1530 |
{ url = "https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl", hash = "sha256:33bd4ef74232fb73fe9279a257718407f169c09b78a87ad3d296f548e27de0bb", size = 310654, upload-time = "2026-04-12T08:24:02.83Z" },
|
| 1531 |
]
|
| 1532 |
|
| 1533 |
+
[[package]]
|
| 1534 |
+
name = "sentry-sdk"
|
| 1535 |
+
version = "2.58.0"
|
| 1536 |
+
source = { registry = "https://pypi.org/simple" }
|
| 1537 |
+
dependencies = [
|
| 1538 |
+
{ name = "certifi" },
|
| 1539 |
+
{ name = "urllib3" },
|
| 1540 |
+
]
|
| 1541 |
+
sdist = { url = "https://files.pythonhosted.org/packages/26/b3/fb8291170d0e844173164709fc0fa0c221ed75a5da740c8746f2a83b4eb1/sentry_sdk-2.58.0.tar.gz", hash = "sha256:c1144d947352d54e5b7daa63596d9f848adf684989c06c4f5a659f0c85a18f6f", size = 438764, upload-time = "2026-04-13T17:23:26.265Z" }
|
| 1542 |
+
wheels = [
|
| 1543 |
+
{ url = "https://files.pythonhosted.org/packages/fa/eb/d875669993b762556ae8b2efd86219943b4c0864d22204d622a9aee3052b/sentry_sdk-2.58.0-py2.py3-none-any.whl", hash = "sha256:688d1c704ddecf382ea3326f21a67453d4caa95592d722b7c780a36a9d23109e", size = 460919, upload-time = "2026-04-13T17:23:24.675Z" },
|
| 1544 |
+
]
|
| 1545 |
+
|
| 1546 |
[[package]]
|
| 1547 |
name = "shellingham"
|
| 1548 |
version = "1.5.4"
|
|
|
|
| 1561 |
{ url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" },
|
| 1562 |
]
|
| 1563 |
|
| 1564 |
+
[[package]]
|
| 1565 |
+
name = "smmap"
|
| 1566 |
+
version = "5.0.3"
|
| 1567 |
+
source = { registry = "https://pypi.org/simple" }
|
| 1568 |
+
sdist = { url = "https://files.pythonhosted.org/packages/1f/ea/49c993d6dfdd7338c9b1000a0f36817ed7ec84577ae2e52f890d1a4ff909/smmap-5.0.3.tar.gz", hash = "sha256:4d9debb8b99007ae47165abc08670bd74cb74b5227dda7f643eccc4e9eb5642c", size = 22506, upload-time = "2026-03-09T03:43:26.1Z" }
|
| 1569 |
+
wheels = [
|
| 1570 |
+
{ url = "https://files.pythonhosted.org/packages/c1/d4/59e74daffcb57a07668852eeeb6035af9f32cbfd7a1d2511f17d2fe6a738/smmap-5.0.3-py3-none-any.whl", hash = "sha256:c106e05d5a61449cf6ba9a1e650227ecfb141590d2a98412103ff35d89fc7b2f", size = 24390, upload-time = "2026-03-09T03:43:24.361Z" },
|
| 1571 |
+
]
|
| 1572 |
+
|
| 1573 |
[[package]]
|
| 1574 |
name = "sniffio"
|
| 1575 |
version = "1.3.1"
|
|
|
|
| 1716 |
wheels = [
|
| 1717 |
{ url = "https://files.pythonhosted.org/packages/31/a3/5b1562db76a5a488274b2332a97199b32d0442aca0ed193697fd47786316/uvicorn-0.46.0-py3-none-any.whl", hash = "sha256:bbebbcbed972d162afca128605223022bedd345b7bc7855ce66deb31487a9048", size = 70926, upload-time = "2026-04-23T07:15:58.355Z" },
|
| 1718 |
]
|
| 1719 |
+
|
| 1720 |
+
[[package]]
|
| 1721 |
+
name = "wandb"
|
| 1722 |
+
version = "0.26.1"
|
| 1723 |
+
source = { registry = "https://pypi.org/simple" }
|
| 1724 |
+
dependencies = [
|
| 1725 |
+
{ name = "click" },
|
| 1726 |
+
{ name = "gitpython" },
|
| 1727 |
+
{ name = "packaging" },
|
| 1728 |
+
{ name = "platformdirs" },
|
| 1729 |
+
{ name = "protobuf" },
|
| 1730 |
+
{ name = "pydantic" },
|
| 1731 |
+
{ name = "pyyaml" },
|
| 1732 |
+
{ name = "requests" },
|
| 1733 |
+
{ name = "sentry-sdk" },
|
| 1734 |
+
{ name = "typing-extensions" },
|
| 1735 |
+
]
|
| 1736 |
+
sdist = { url = "https://files.pythonhosted.org/packages/6a/a4/72a6640e1f566e81f184a426e3e45298d4c6672664de41adb7eb6f64370a/wandb-0.26.1.tar.gz", hash = "sha256:eef2dbaea06f0b1c0cdc5d76f544ae4c2b8848fc512442a00bd59f0502fc8aa1", size = 42159814, upload-time = "2026-04-23T16:27:34.033Z" }
|
| 1737 |
+
wheels = [
|
| 1738 |
+
{ url = "https://files.pythonhosted.org/packages/8c/09/3296235f3906e904f06f2df29eed4d672fb23c0932c9486e2af64f2f2a66/wandb-0.26.1-py3-none-macosx_12_0_arm64.whl", hash = "sha256:2955fe190c005fb83ee6d73f066c8a33f09f3212a1f2eb53faa6581440e456be", size = 24857204, upload-time = "2026-04-23T16:26:58.576Z" },
|
| 1739 |
+
{ url = "https://files.pythonhosted.org/packages/a1/ad/e39ca3086534129e42208ba00ed2c6247ce425f890219eeec33b4f162864/wandb-0.26.1-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:55d91cabde98162d7116a5e19ddd052bd9848556243f1da4cbb9ffb7ad435bfc", size = 26014649, upload-time = "2026-04-23T16:27:02.559Z" },
|
| 1740 |
+
{ url = "https://files.pythonhosted.org/packages/56/af/400d84a3bdce0b062b4baa70acb6becd2c8018697f4fbf5af9a9e1e406e5/wandb-0.26.1-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:7c78bc2454cfe1ffa1c3a256060a387356eed8a4488e024d9d2eba8f2b5bd51d", size = 25421317, upload-time = "2026-04-23T16:27:06.411Z" },
|
| 1741 |
+
{ url = "https://files.pythonhosted.org/packages/7b/e9/b4bf8f3509dcea1cec52233a38991459654635b5a8e6a494eb912e1b9cfb/wandb-0.26.1-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:a2c8eeec8706dcd2872e69c3b4d20ec523082fdb4440295491556e219ad2aa67", size = 27192831, upload-time = "2026-04-23T16:27:10.308Z" },
|
| 1742 |
+
{ url = "https://files.pythonhosted.org/packages/62/cf/4a6dce0c782223ef0eeea7139daee73418a7322befcf083512c31cebaa18/wandb-0.26.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:2fa768ee0636a569afb7541cf996e56309c47070566a38916823f94e02afe586", size = 25593326, upload-time = "2026-04-23T16:27:14.259Z" },
|
| 1743 |
+
{ url = "https://files.pythonhosted.org/packages/df/99/58c3d8c36ae8e2b7d70bf6493eb5daa1cca0231a04b025717b4cd1a78f1e/wandb-0.26.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5854928725cfeff1f284d5c043cd353f810e5da02eead2c120ef5056ad026fea", size = 27535542, upload-time = "2026-04-23T16:27:18.473Z" },
|
| 1744 |
+
{ url = "https://files.pythonhosted.org/packages/7c/d0/4e846ffc1d0cc435518dfa581ce73ac82cfd0ebbf35f3853c9277f632e5f/wandb-0.26.1-py3-none-win32.whl", hash = "sha256:5c2bd44e575ae9944e2764d1aaa031461178276bf2636d5558399c2816ef5cfe", size = 24968151, upload-time = "2026-04-23T16:27:22.086Z" },
|
| 1745 |
+
{ url = "https://files.pythonhosted.org/packages/e3/9b/487413eaccefdb58799a226726e24b486e9192d2671c75a4550c160aba23/wandb-0.26.1-py3-none-win_amd64.whl", hash = "sha256:5817785467d3f1676f1812ec19a89f77f6e56dfe67d9f47080075af95f705d3e", size = 24968155, upload-time = "2026-04-23T16:27:25.731Z" },
|
| 1746 |
+
{ url = "https://files.pythonhosted.org/packages/04/dc/5baf3e99b3eeb709d6f75124b5bec8cb73d4b38d2b10df7fdcfde4966200/wandb-0.26.1-py3-none-win_arm64.whl", hash = "sha256:f848b7744f896bc04cabbb28360a2814d1551a91fa2c456243e06435729c8a2e", size = 22912416, upload-time = "2026-04-23T16:27:29.456Z" },
|
| 1747 |
+
]
|