Godreign-Y commited on
Commit
d5d9d45
·
1 Parent(s): 5ace282

Add W&B experiment tracking and structured logging

Browse files
README.md CHANGED
@@ -24,6 +24,7 @@ short_description: Meta pytorch hugging face hackathon
24
  | **HF Space (Live Environment)** | [godreign-policy2logic.hf.space](https://godreign-policy2logic.hf.space) |
25
  | **Training Notebook (Colab)** | [Open in Colab](https://colab.research.google.com/github/GodreignElgin/policy2logic/blob/main/training/colab_training.ipynb) |
26
  | **Writeup / Slides** | *TBD — add your link here* |
 
27
 
28
  ---
29
 
@@ -44,6 +45,15 @@ improving agent behavior across episodes without weight updates.
44
 
45
  ---
46
 
 
 
 
 
 
 
 
 
 
47
  ## 🧠 What This Is
48
 
49
  This project builds a **verifiable RL environment** where:
 
24
  | **HF Space (Live Environment)** | [godreign-policy2logic.hf.space](https://godreign-policy2logic.hf.space) |
25
  | **Training Notebook (Colab)** | [Open in Colab](https://colab.research.google.com/github/GodreignElgin/policy2logic/blob/main/training/colab_training.ipynb) |
26
  | **Writeup / Slides** | *TBD — add your link here* |
27
+ | **Experiment Tracking (W&B)** | [Wandb Project](https://wandb.ai/YOUR_USERNAME/policy-to-logic-rl) |
28
 
29
  ---
30
 
 
45
 
46
  ---
47
 
48
+ ## 📈 Experiment Tracking
49
+
50
+ All training runs are logged to Weights & Biases.
51
+ Metrics tracked per episode: total reward, final accuracy, steps used, success rate, few-shot examples used.
52
+
53
+ Live dashboard: [wandb.ai/YOUR_USERNAME/policy-to-logic-rl](https://wandb.ai/YOUR_USERNAME/policy-to-logic-rl)
54
+
55
+ ---
56
+
57
  ## 🧠 What This Is
58
 
59
  This project builds a **verifiable RL environment** where:
policy_to_logic_env/server/requirements.txt CHANGED
@@ -3,3 +3,4 @@ pydantic>=2.0
3
  fastapi>=0.104.0
4
  uvicorn>=0.24.0
5
  requests>=2.25.0
 
 
3
  fastapi>=0.104.0
4
  uvicorn>=0.24.0
5
  requests>=2.25.0
6
+ wandb>=0.16.0
pyproject.toml CHANGED
@@ -13,6 +13,7 @@ dependencies = [
13
  "huggingface-hub>=1.12.0",
14
  "matplotlib>=3.7.0",
15
  "numpy>=1.24.0",
 
16
  ]
17
 
18
  [project.optional-dependencies]
 
13
  "huggingface-hub>=1.12.0",
14
  "matplotlib>=3.7.0",
15
  "numpy>=1.24.0",
16
+ "wandb>=0.16.0",
17
  ]
18
 
19
  [project.optional-dependencies]
training/colab_training.ipynb CHANGED
@@ -19,7 +19,7 @@
19
  "cell_type": "markdown",
20
  "metadata": {},
21
  "source": [
22
- "# Policy-to-Logic RL Environment Training Notebook\n",
23
  "\n",
24
  "This notebook runs the **reward-guided trajectory optimization loop** against the deployed environment.\n",
25
  "\n",
@@ -27,7 +27,8 @@
27
  "1. Connects to the live HF Spaces environment\n",
28
  "2. Runs 8 episodes per task (3 tasks = 24 total episodes)\n",
29
  "3. Accumulates high-reward trajectories as few-shot examples\n",
30
- "4. Generates training evidence plots (reward curve, accuracy curve, improvement chart)"
 
31
  ]
32
  },
33
  {
@@ -37,7 +38,7 @@
37
  "outputs": [],
38
  "source": [
39
  "# Cell 1: Install dependencies\n",
40
- "!pip install openai requests matplotlib numpy"
41
  ]
42
  },
43
  {
@@ -52,12 +53,16 @@
52
  "# SET THESE BEFORE RUNNING\n",
53
  "HF_TOKEN = \"\" # Your Hugging Face token with inference access\n",
54
  "ENV_URL = \"https://godreign-policy2logic.hf.space\" # Your deployed environment URL\n",
 
55
  "\n",
56
  "os.environ[\"HF_TOKEN\"] = HF_TOKEN\n",
57
  "os.environ[\"ENV_BASE_URL\"] = ENV_URL\n",
 
 
58
  "\n",
59
  "print(f\"Environment URL: {ENV_URL}\")\n",
60
- "print(f\"HF Token set: {'Yes' if HF_TOKEN else 'NO - MUST SET THIS'}\")"
 
61
  ]
62
  },
63
  {
@@ -84,18 +89,40 @@
84
  "metadata": {},
85
  "outputs": [],
86
  "source": [
87
- "# Cell 4: Training loop implementation\n",
88
- "# Full contents of training/trajectory_optimizer.py\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  "\n",
90
  "import json\n",
91
  "import os\n",
92
  "import time\n",
93
  "import requests\n",
 
 
94
  "from dataclasses import dataclass, field\n",
95
  "from typing import Optional\n",
96
  "from openai import OpenAI\n",
97
  "\n",
98
- "# ── Configuration ────────────────────────────────────────────────────────────\n",
99
  "\n",
100
  "ENV_BASE_URL = os.getenv(\"ENV_BASE_URL\", \"http://localhost:7860\")\n",
101
  "HF_TOKEN = os.getenv(\"HF_TOKEN\", \"\")\n",
@@ -103,11 +130,14 @@
103
  "TEMPERATURE = 0.3\n",
104
  "MAX_TOKENS = 1024\n",
105
  "\n",
106
- "NUM_EPISODES_PER_TASK = 8\n",
107
- "TOP_K_TRAJECTORIES = 3\n",
108
- "MIN_REWARD_THRESHOLD = 0.3\n",
 
109
  "TASKS = [\"data_access\", \"resource_access\", \"transaction_approval\"]\n",
110
  "\n",
 
 
111
  "@dataclass\n",
112
  "class Step:\n",
113
  " step_number: int\n",
@@ -128,7 +158,10 @@
128
  " success: bool = False\n",
129
  "\n",
130
  " def to_few_shot_string(self) -> str:\n",
131
- " lines = [f\"=== Example Episode (reward={self.total_reward:.2f}, accuracy={self.final_accuracy:.2f}) ===\"]\n",
 
 
 
132
  " for s in self.steps:\n",
133
  " lines.append(f\"Step {s.step_number}: action={s.action_type}\")\n",
134
  " lines.append(f\" Content: {s.action_content[:200]}\")\n",
@@ -137,6 +170,8 @@
137
  " lines.append(f\" Feedback: {s.feedback[:150]}\")\n",
138
  " return \"\\n\".join(lines)\n",
139
  "\n",
 
 
140
  "class EnvClient:\n",
141
  " def __init__(self, base_url: str):\n",
142
  " self.base_url = base_url.rstrip(\"/\")\n",
@@ -148,7 +183,10 @@
148
  " return r.json()\n",
149
  "\n",
150
  " def step(self, action_type: str, content: str) -> dict:\n",
151
- " r = self.session.post(f\"{self.base_url}/step\", json={\"action_type\": action_type, \"content\": content})\n",
 
 
 
152
  " r.raise_for_status()\n",
153
  " return r.json()\n",
154
  "\n",
@@ -159,18 +197,40 @@
159
  " except Exception:\n",
160
  " return False\n",
161
  "\n",
 
 
162
  "class Agent:\n",
163
  " def __init__(self, hf_token: str):\n",
164
- " self.client = OpenAI(base_url=\"https://router.huggingface.co/v1\", api_key=hf_token)\n",
165
- "\n",
166
- " def get_action(self, observation, step_number, episode_history, few_shot_examples):\n",
167
- " system_prompt = self._build_system_prompt(few_shot_examples)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  " user_prompt = self._build_user_prompt(observation, step_number, episode_history)\n",
 
169
  " try:\n",
170
  " response = self.client.chat.completions.create(\n",
171
  " model=MODEL,\n",
172
- " messages=[{\"role\": \"system\", \"content\": system_prompt}, {\"role\": \"user\", \"content\": user_prompt}],\n",
173
- " temperature=TEMPERATURE, max_tokens=MAX_TOKENS\n",
 
 
 
 
174
  " )\n",
175
  " raw = response.choices[0].message.content.strip()\n",
176
  " return self._parse_response(raw, observation)\n",
@@ -178,155 +238,489 @@
178
  " print(f\" [LLM ERROR] {e}\")\n",
179
  " return \"propose_rules\", json.dumps({\"rules\": [], \"default\": \"DENY\"})\n",
180
  "\n",
181
- " def _build_system_prompt(self, few_shot_examples):\n",
182
- " base = \"\"\"You are a policy-to-logic agent. Convert natural language policies into executable rules.\n",
183
  "\n",
184
  "AVAILABLE ACTIONS:\n",
185
  "1. ask_clarification: {\"type\": \"clarification\", \"question\": \"your question\"}\n",
186
  "2. propose_rules: {\"rules\": [...], \"default\": \"DECISION\"}\n",
187
  "3. refine_rules: {\"rules\": [...], \"default\": \"DECISION\"}\n",
188
  "\n",
189
- "DSL FORMAT: {\"rules\": [{\"if\": [{\"field\": \"NAME\", \"op\": \"OP\", \"value\": VAL}], \"then\": \"DECISION\"}], \"default\": \"FALLBACK\"}\n",
190
- "Operators: >, <, >=, <=, ==, !=. Rules execute top-to-bottom, first match wins.\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  "\n",
192
- "STRATEGY: Ask 1-2 clarifications first, then propose rules, then refine based on failures.\n",
193
- "OUTPUT: Respond ONLY with valid JSON: {\"action_type\": \"...\", \"content\": \"...\"}\"\"\"\n",
194
  " if few_shot_examples:\n",
195
- " base += \"\\n\\nLEARNED FROM PREVIOUS EPISODES:\\n\"\n",
196
  " for traj in few_shot_examples[-TOP_K_TRAJECTORIES:]:\n",
197
  " base += \"\\n\" + traj.to_few_shot_string() + \"\\n\"\n",
198
  " return base\n",
199
  "\n",
200
- " def _build_user_prompt(self, obs, step, history):\n",
201
- " lines = [f\"TASK: {obs.get('task_name', 'unknown')}\", f\"STEP: {step} of {obs.get('max_steps', 7)}\", f\"\\nPOLICY:\\n{obs.get('policy_text', '')}\"]\n",
202
- " if obs.get(\"clarification_response\"): lines.append(f\"\\nLAST CLARIFICATION:\\n{obs['clarification_response']}\")\n",
 
 
 
 
 
203
  " if obs.get(\"test_results\"):\n",
204
  " tr = obs[\"test_results\"]\n",
205
- " lines.append(f\"\\nTEST RESULTS: {tr.get('passed',0)}/{tr.get('total',0)} (acc={obs.get('current_accuracy',0):.2f})\")\n",
206
- " if tr.get(\"sample_failures\"): lines.extend([f\" - {f}\" for f in tr[\"sample_failures\"][:3]])\n",
207
- " if obs.get(\"feedback\"): lines.append(f\"\\nFEEDBACK: {obs['feedback']}\")\n",
208
- " if history: lines.append(f\"\\nHISTORY:\\n\" + \"\\n\".join(history[-3:]))\n",
209
- " lines.append(f\"\\nAVAILABLE: {obs.get('available_actions', [])}\")\n",
210
- " lines.append(\"\\nRespond with JSON only.\")\n",
 
 
 
 
 
211
  " return \"\\n\".join(lines)\n",
212
  "\n",
213
- " def _parse_response(self, raw, obs):\n",
 
214
  " if \"```\" in raw:\n",
215
  " raw = raw.split(\"```\")[1]\n",
216
- " if raw.startswith(\"json\"): raw = raw[4:]\n",
 
217
  " raw = raw.strip()\n",
 
218
  " try:\n",
219
  " parsed = json.loads(raw)\n",
220
  " action_type = parsed.get(\"action_type\", \"propose_rules\")\n",
221
  " content = parsed.get(\"content\", \"{}\")\n",
222
- " valid = obs.get(\"available_actions\", [\"propose_rules\"])\n",
223
- " if action_type not in valid: action_type = valid[0]\n",
224
- " if isinstance(content, dict): content = json.dumps(content)\n",
 
 
 
 
 
 
225
  " return action_type, content\n",
226
- " except: return \"propose_rules\", json.dumps({\"rules\": [], \"default\": \"DENY\"})\n",
 
 
 
227
  "\n",
228
  "class TrajectoryBank:\n",
229
- " def __init__(self): self.bank = {task: [] for task in TASKS}\n",
230
- " def store(self, t):\n",
231
- " if t.total_reward >= MIN_REWARD_THRESHOLD:\n",
232
- " self.bank[t.task_name].append(t)\n",
233
- " self.bank[t.task_name].sort(key=lambda x: x.total_reward, reverse=True)\n",
234
- " self.bank[t.task_name] = self.bank[t.task_name][:TOP_K_TRAJECTORIES]\n",
235
- " def get_examples(self, task): return self.bank.get(task, [])\n",
236
- " def summary(self): return {t: {\"stored\": len(v), \"best_reward\": max((x.total_reward for x in v), default=0)} for t,v in self.bank.items()}\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  "\n",
238
  "class TrainingLoop:\n",
239
- " def __init__(self, env_url, hf_token):\n",
240
  " self.env = EnvClient(env_url)\n",
241
  " self.agent = Agent(hf_token)\n",
242
  " self.bank = TrajectoryBank()\n",
243
- " self.metrics = []\n",
244
- "\n",
245
- " def run_episode(self, task_name, episode_id):\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  " few_shots = self.bank.get_examples(task_name)\n",
247
- " traj = Trajectory(task_name=task_name, episode_id=episode_id)\n",
 
 
248
  " result = self.env.reset(task_name)\n",
249
- " obs, done, history = result.get(\"observation\", {}), result.get(\"done\", False), []\n",
250
- " print(f\" [Episode {episode_id}] task={task_name} few_shots={len(few_shots)}\")\n",
 
 
 
 
251
  " step_num = 0\n",
252
  " while not done and step_num < obs.get(\"max_steps\", 7):\n",
253
  " step_num += 1\n",
254
- " action_type, content = self.agent.get_action(obs, step_num, history, few_shots)\n",
 
 
 
 
 
 
 
 
 
 
255
  " result = self.env.step(action_type, content)\n",
256
- " reward, done = result.get(\"reward\", 0.0), result.get(\"done\", False)\n",
257
- " obs, info = result.get(\"observation\", {}), result.get(\"info\", {})\n",
258
- " step = Step(step_num, action_type, content[:300], reward, obs.get(\"current_accuracy\", 0.0), obs.get(\"feedback\", \"\") or \"\", obs.get(\"clarification_response\"))\n",
259
- " traj.steps.append(step); traj.total_reward += reward\n",
260
- " history.append(f\"Step {step_num}: {action_type} -> reward={reward:.2f} acc={step.accuracy:.2f}\")\n",
261
- " print(f\" step={step_num} action={action_type} reward={reward:.3f} acc={step.accuracy:.2f}\")\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  " if done:\n",
263
- " traj.final_accuracy = info.get(\"episode_score\", obs.get(\"current_accuracy\", 0.0))\n",
264
- " traj.success = obs.get(\"current_accuracy\", 0.0) >= 0.9\n",
 
265
  " break\n",
266
- " if not traj.steps: traj.final_accuracy = 0.0\n",
267
- " return traj\n",
 
 
 
 
 
268
  "\n",
269
  " def run(self):\n",
270
- " print(\"=\" * 60)\n",
271
- " print(\"REWARD-GUIDED TRAJECTORY OPTIMIZATION\")\n",
272
- " print(f\"Tasks: {TASKS}, Episodes/task: {NUM_EPISODES_PER_TASK}\")\n",
273
- " print(\"=\" * 60)\n",
274
- " if not self.env.health(): raise RuntimeError(f\"Env not reachable at {ENV_BASE_URL}\")\n",
275
- " print(f\"Environment: OK\\n\")\n",
276
- " global_ep = 0\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  " for task in TASKS:\n",
278
- " print(f\"\\n--- TASK: {task} ---\")\n",
 
 
 
279
  " task_rewards = []\n",
 
 
280
  " for ep in range(1, NUM_EPISODES_PER_TASK + 1):\n",
281
- " global_ep += 1\n",
282
- " traj = self.run_episode(task, ep)\n",
283
- " self.bank.store(traj)\n",
284
- " self.metrics.append({\"global_episode\": global_ep, \"task\": task, \"episode_in_task\": ep, \"total_reward\": traj.total_reward, \"final_accuracy\": traj.final_accuracy, \"success\": traj.success, \"num_steps\": len(traj.steps)})\n",
285
- " task_rewards.append(traj.total_reward)\n",
286
- " print(f\" -> Ep {ep}: reward={traj.total_reward:.3f} acc={traj.final_accuracy:.2f} success={traj.success}\")\n",
287
- " time.sleep(0.5)\n",
288
- " print(f\" Improvement: {task_rewards[-1] - task_rewards[0]:+.3f}\")\n",
289
- " print(\"\\n\" + \"=\" * 60 + \"\\nTRAINING COMPLETE\\n\" + \"=\" * 60)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  " return self.metrics\n",
291
  "\n",
292
- "def save_plots(metrics):\n",
293
- " import matplotlib; matplotlib.use(\"Agg\")\n",
294
- " import matplotlib.pyplot as plt; import numpy as np\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  " os.makedirs(\"training/plots\", exist_ok=True)\n",
 
296
  " episodes = [m[\"global_episode\"] for m in metrics]\n",
297
  " rewards = [m[\"total_reward\"] for m in metrics]\n",
298
- " colors = {\"data_access\": \"#2196F3\", \"resource_access\": \"#FF9800\", \"transaction_approval\": \"#4CAF50\"}\n",
299
- " # Plot 1: Reward\n",
 
 
 
 
 
 
 
 
300
  " fig, ax = plt.subplots(figsize=(10, 5))\n",
 
301
  " for task in TASKS:\n",
302
- " te = [m[\"global_episode\"] for m in metrics if m[\"task\"]==task]\n",
303
- " tr = [m[\"total_reward\"] for m in metrics if m[\"task\"]==task]\n",
304
- " ax.plot(te, tr, marker=\"o\", label=task, color=colors.get(task), linewidth=2, markersize=5)\n",
305
- " z = np.polyfit(episodes, rewards, 1); p = np.poly1d(z)\n",
306
- " ax.plot(episodes, p(episodes), \"--\", color=\"red\", alpha=0.5, label=\"trend\")\n",
307
- " ax.set_xlabel(\"Episode\"); ax.set_ylabel(\"Total Reward\"); ax.set_title(\"Reward Curve\"); ax.legend(); ax.grid(True, alpha=0.3); ax.set_ylim(bottom=0)\n",
308
- " plt.tight_layout(); plt.savefig(\"training/plots/reward_curve.png\", dpi=150); plt.close()\n",
309
- " # Plot 2: Accuracy\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  " fig, ax = plt.subplots(figsize=(10, 5))\n",
 
311
  " for task in TASKS:\n",
312
- " te = [m[\"global_episode\"] for m in metrics if m[\"task\"]==task]\n",
313
- " ta = [m[\"final_accuracy\"] for m in metrics if m[\"task\"]==task]\n",
314
- " ax.plot(te, ta, marker=\"s\", label=task, color=colors.get(task), linewidth=2, markersize=5)\n",
315
- " ax.axhline(y=0.9, color=\"red\", linestyle=\"--\", alpha=0.7, label=\"threshold\")\n",
316
- " ax.set_xlabel(\"Episode\"); ax.set_ylabel(\"Accuracy\"); ax.set_title(\"Accuracy Curve\"); ax.legend(); ax.grid(True, alpha=0.3); ax.set_ylim(0, 1.05)\n",
317
- " plt.tight_layout(); plt.savefig(\"training/plots/accuracy_curve.png\", dpi=150); plt.close()\n",
318
- " # Plot 3: Improvement\n",
319
- " fig, ax = plt.subplots(figsize=(8, 5))\n",
320
- " tnames, imps = [], []\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  " for task in TASKS:\n",
322
- " accs = [m[\"final_accuracy\"] for m in metrics if m[\"task\"]==task]\n",
323
- " if len(accs) >= 2: tnames.append(task.replace(\"_\",\"\\n\")); imps.append(accs[-1]-accs[0])\n",
324
- " bars = ax.bar(tnames, imps, color=[\"#2196F3\",\"#FF9800\",\"#4CAF50\"][:len(tnames)])\n",
325
- " ax.axhline(y=0, color=\"black\"); ax.set_ylabel(\"Improvement\"); ax.set_title(\"Per-Task Improvement\"); ax.grid(True, axis=\"y\", alpha=0.3)\n",
326
- " for bar, val in zip(bars, imps): ax.text(bar.get_x()+bar.get_width()/2, bar.get_height()+0.01, f\"{val:+.2f}\", ha=\"center\", fontweight=\"bold\")\n",
327
- " plt.tight_layout(); plt.savefig(\"training/plots/improvement_chart.png\", dpi=150); plt.close()\n",
328
- " with open(\"training/plots/metrics.json\", \"w\") as f: json.dump(metrics, f, indent=2)\n",
329
- " print(\"All plots saved to training/plots/\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  ]
331
  },
332
  {
@@ -335,7 +729,7 @@
335
  "metadata": {},
336
  "outputs": [],
337
  "source": [
338
- "# Cell 5: Run training loop\n",
339
  "loop = TrainingLoop(ENV_URL, HF_TOKEN)\n",
340
  "metrics = loop.run()\n",
341
  "print(f\"\\nTotal episodes run: {len(metrics)}\")"
@@ -347,7 +741,7 @@
347
  "metadata": {},
348
  "outputs": [],
349
  "source": [
350
- "# Cell 6: Generate plots and display inline\n",
351
  "save_plots(metrics)\n",
352
  "\n",
353
  "from IPython.display import Image, display\n",
@@ -362,17 +756,27 @@
362
  "metadata": {},
363
  "outputs": [],
364
  "source": [
365
- "# Cell 7: Download plots to commit to repo\n",
366
- "# After running this, download the files and commit them to your GitHub repo\n",
 
 
 
 
 
 
 
 
 
 
367
  "from google.colab import files\n",
368
  "\n",
369
  "files.download(\"training/plots/reward_curve.png\")\n",
370
  "files.download(\"training/plots/accuracy_curve.png\")\n",
371
  "files.download(\"training/plots/improvement_chart.png\")\n",
372
- "files.download(\"training/plots/metrics.json\")\n",
373
  "\n",
374
  "print(\"Downloaded. Now commit these files to: training/plots/ in your repo.\")"
375
  ]
376
  }
377
  ]
378
- }
 
19
  "cell_type": "markdown",
20
  "metadata": {},
21
  "source": [
22
+ "# Policy-to-Logic RL Environment \u2014 Training Notebook\n",
23
  "\n",
24
  "This notebook runs the **reward-guided trajectory optimization loop** against the deployed environment.\n",
25
  "\n",
 
27
  "1. Connects to the live HF Spaces environment\n",
28
  "2. Runs 8 episodes per task (3 tasks = 24 total episodes)\n",
29
  "3. Accumulates high-reward trajectories as few-shot examples\n",
30
+ "4. Generates training evidence plots (reward curve, accuracy curve, improvement chart)\n",
31
+ "5. Logs everything to Weights & Biases"
32
  ]
33
  },
34
  {
 
38
  "outputs": [],
39
  "source": [
40
  "# Cell 1: Install dependencies\n",
41
+ "!pip install openai requests matplotlib numpy wandb"
42
  ]
43
  },
44
  {
 
53
  "# SET THESE BEFORE RUNNING\n",
54
  "HF_TOKEN = \"\" # Your Hugging Face token with inference access\n",
55
  "ENV_URL = \"https://godreign-policy2logic.hf.space\" # Your deployed environment URL\n",
56
+ "WANDB_API_KEY = \"\" # Your Wandb API key\n",
57
  "\n",
58
  "os.environ[\"HF_TOKEN\"] = HF_TOKEN\n",
59
  "os.environ[\"ENV_BASE_URL\"] = ENV_URL\n",
60
+ "if WANDB_API_KEY:\n",
61
+ " os.environ[\"WANDB_API_KEY\"] = WANDB_API_KEY\n",
62
  "\n",
63
  "print(f\"Environment URL: {ENV_URL}\")\n",
64
+ "print(f\"HF Token set: {'Yes' if HF_TOKEN else 'NO - MUST SET THIS'}\")\n",
65
+ "print(f\"Wandb Token set: {'Yes' if WANDB_API_KEY else 'NO - WILL PROMPT'}\")"
66
  ]
67
  },
68
  {
 
89
  "metadata": {},
90
  "outputs": [],
91
  "source": [
92
+ "# Cell 4: Wandb login \u2014 run this before training\n",
93
+ "import wandb\n",
94
+ "wandb.login() # Will prompt for API key if WANDB_API_KEY is not set"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": null,
100
+ "metadata": {},
101
+ "outputs": [],
102
+ "source": [
103
+ "# Cell 5: Training loop implementation (full trajectory_optimizer.py)\n",
104
+ "\"\"\"\n",
105
+ "Reward-Guided Trajectory Optimization Loop\n",
106
+ "==========================================\n",
107
+ "Optimizes agent behavior across episodes by accumulating high-reward\n",
108
+ "trajectories as few-shot examples. Uses environment reward signal to\n",
109
+ "drive improvement \u00e2\u20ac\u201d no weight updates required.\n",
110
+ "\n",
111
+ "This implements a policy improvement loop where:\n",
112
+ " - reward_signal \u00e2\u2020\u2019 trajectory_selection \u00e2\u2020\u2019 context_construction \u00e2\u2020\u2019 improved_policy\n",
113
+ "\"\"\"\n",
114
  "\n",
115
  "import json\n",
116
  "import os\n",
117
  "import time\n",
118
  "import requests\n",
119
+ "import logging\n",
120
+ "import wandb\n",
121
  "from dataclasses import dataclass, field\n",
122
  "from typing import Optional\n",
123
  "from openai import OpenAI\n",
124
  "\n",
125
+ "# \u00e2\u201d\u20ac\u00e2\u201d\u20ac Configuration \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
126
  "\n",
127
  "ENV_BASE_URL = os.getenv(\"ENV_BASE_URL\", \"http://localhost:7860\")\n",
128
  "HF_TOKEN = os.getenv(\"HF_TOKEN\", \"\")\n",
 
130
  "TEMPERATURE = 0.3\n",
131
  "MAX_TOKENS = 1024\n",
132
  "\n",
133
+ "# Training hyperparameters\n",
134
+ "NUM_EPISODES_PER_TASK = 8 # Episodes to run per task\n",
135
+ "TOP_K_TRAJECTORIES = 3 # Max few-shot examples to keep\n",
136
+ "MIN_REWARD_THRESHOLD = 0.3 # Minimum reward to store trajectory\n",
137
  "TASKS = [\"data_access\", \"resource_access\", \"transaction_approval\"]\n",
138
  "\n",
139
+ "# \u00e2\u201d\u20ac\u00e2\u201d\u20ac Data Structures \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
140
+ "\n",
141
  "@dataclass\n",
142
  "class Step:\n",
143
  " step_number: int\n",
 
158
  " success: bool = False\n",
159
  "\n",
160
  " def to_few_shot_string(self) -> str:\n",
161
+ " \"\"\"Convert trajectory to a few-shot example string for prompting.\"\"\"\n",
162
+ " lines = [\n",
163
+ " f\"=== Example Episode (reward={self.total_reward:.2f}, accuracy={self.final_accuracy:.2f}) ===\",\n",
164
+ " ]\n",
165
  " for s in self.steps:\n",
166
  " lines.append(f\"Step {s.step_number}: action={s.action_type}\")\n",
167
  " lines.append(f\" Content: {s.action_content[:200]}\")\n",
 
170
  " lines.append(f\" Feedback: {s.feedback[:150]}\")\n",
171
  " return \"\\n\".join(lines)\n",
172
  "\n",
173
+ "# \u00e2\u201d\u20ac\u00e2\u201d\u20ac Environment Client \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
174
+ "\n",
175
  "class EnvClient:\n",
176
  " def __init__(self, base_url: str):\n",
177
  " self.base_url = base_url.rstrip(\"/\")\n",
 
183
  " return r.json()\n",
184
  "\n",
185
  " def step(self, action_type: str, content: str) -> dict:\n",
186
+ " r = self.session.post(f\"{self.base_url}/step\", json={\n",
187
+ " \"action_type\": action_type,\n",
188
+ " \"content\": content\n",
189
+ " })\n",
190
  " r.raise_for_status()\n",
191
  " return r.json()\n",
192
  "\n",
 
197
  " except Exception:\n",
198
  " return False\n",
199
  "\n",
200
+ "# \u00e2\u201d\u20ac\u00e2\u201d\u20ac LLM Agent \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
201
+ "\n",
202
  "class Agent:\n",
203
  " def __init__(self, hf_token: str):\n",
204
+ " self.client = OpenAI(\n",
205
+ " base_url=\"https://router.huggingface.co/v1\",\n",
206
+ " api_key=hf_token\n",
207
+ " )\n",
208
+ "\n",
209
+ " def get_action(\n",
210
+ " self,\n",
211
+ " observation: dict,\n",
212
+ " step_number: int,\n",
213
+ " episode_history: list[str],\n",
214
+ " few_shot_examples: list[Trajectory],\n",
215
+ " task_name: str = \"\"\n",
216
+ " ) -> tuple[str, str]:\n",
217
+ " \"\"\"\n",
218
+ " Returns (action_type, content_json_string).\n",
219
+ " action_type: one of ask_clarification | propose_rules | refine_rules\n",
220
+ " content: JSON string appropriate for that action\n",
221
+ " \"\"\"\n",
222
+ " system_prompt = self._build_system_prompt(few_shot_examples, task_name)\n",
223
  " user_prompt = self._build_user_prompt(observation, step_number, episode_history)\n",
224
+ "\n",
225
  " try:\n",
226
  " response = self.client.chat.completions.create(\n",
227
  " model=MODEL,\n",
228
+ " messages=[\n",
229
+ " {\"role\": \"system\", \"content\": system_prompt},\n",
230
+ " {\"role\": \"user\", \"content\": user_prompt}\n",
231
+ " ],\n",
232
+ " temperature=TEMPERATURE,\n",
233
+ " max_tokens=MAX_TOKENS\n",
234
  " )\n",
235
  " raw = response.choices[0].message.content.strip()\n",
236
  " return self._parse_response(raw, observation)\n",
 
238
  " print(f\" [LLM ERROR] {e}\")\n",
239
  " return \"propose_rules\", json.dumps({\"rules\": [], \"default\": \"DENY\"})\n",
240
  "\n",
241
+ " def _build_system_prompt(self, few_shot_examples: list[Trajectory], task_name: str = \"\") -> str:\n",
242
+ " base = \"\"\"You are a policy-to-logic agent. Your job is to convert natural language policies into executable rules.\n",
243
  "\n",
244
  "AVAILABLE ACTIONS:\n",
245
  "1. ask_clarification: {\"type\": \"clarification\", \"question\": \"your question\"}\n",
246
  "2. propose_rules: {\"rules\": [...], \"default\": \"DECISION\"}\n",
247
  "3. refine_rules: {\"rules\": [...], \"default\": \"DECISION\"}\n",
248
  "\n",
249
+ "DSL FORMAT for rules:\n",
250
+ "{\n",
251
+ " \"rules\": [\n",
252
+ " {\n",
253
+ " \"if\": [\n",
254
+ " {\"field\": \"FIELD_NAME\", \"op\": \"OPERATOR\", \"value\": VALUE}\n",
255
+ " ],\n",
256
+ " \"then\": \"DECISION\"\n",
257
+ " }\n",
258
+ " ],\n",
259
+ " \"default\": \"FALLBACK_DECISION\"\n",
260
+ "}\n",
261
+ "\n",
262
+ "Operators: >, <, >=, <=, ==, !=\n",
263
+ "Rules execute top-to-bottom. First match wins. Default applies if no rule matches.\n",
264
+ "\n",
265
+ "STRATEGY:\n",
266
+ "- Step 1: Ask 1-2 targeted clarification questions about ambiguous terms\n",
267
+ "- Step 2: Propose initial rules based on policy + clarifications \n",
268
+ "- Step 3+: Refine rules based on failure feedback\n",
269
+ "\n",
270
+ "OUTPUT FORMAT: Respond ONLY with valid JSON. No markdown. No explanation.\n",
271
+ "{\"action_type\": \"propose_rules\", \"content\": \"{...escaped json string...}\"}\n",
272
+ "\"\"\"\n",
273
+ " # Task-specific guidance for complex tasks\n",
274
+ " if task_name == \"transaction_approval\":\n",
275
+ " base += \"\"\"\n",
276
+ "IMPORTANT \u00e2\u20ac\u201d TRANSACTION APPROVAL TASK:\n",
277
+ "This task has 4 possible decisions: APPROVE, REQUIRE_APPROVAL, COMPLIANCE_REVIEW, HOLD\n",
278
+ "Rules are evaluated TOP-TO-BOTTOM. Order matters critically. You MUST order rules by priority:\n",
279
+ " 1. FIRST: Check if transfer_type == \"international\" \u00e2\u2020\u2019 then COMPLIANCE_REVIEW (always, overrides everything)\n",
280
+ " 2. SECOND: Check if amount >= 10000 AND time is outside business hours (hour < 9 or hour >= 17) \u00e2\u2020\u2019 then HOLD\n",
281
+ " 3. THIRD: Check if amount > 5000 AND initiator_role != \"manager\" \u00e2\u2020\u2019 then REQUIRE_APPROVAL\n",
282
+ " 4. DEFAULT: APPROVE\n",
283
+ "\n",
284
+ "Key details:\n",
285
+ "- Standard limit is $5,000 (amount > 5000 triggers approval, NOT >=)\n",
286
+ "- High-value threshold is $10,000 (amount >= 10000)\n",
287
+ "- Business hours: hour >= 9 AND hour < 17\n",
288
+ "- Manager exemption ONLY applies to the standard $5,000 limit, NOT to international or high-value HOLD rules\n",
289
+ "- \"system\" role follows the same rules as \"employee\"\n",
290
+ "\n",
291
+ "Here is a working example of valid rules for this task:\n",
292
+ "{\"rules\": [{\"if\": [{\"field\": \"transfer_type\", \"op\": \"==\", \"value\": \"international\"}], \"then\": \"COMPLIANCE_REVIEW\"}, {\"if\": [{\"field\": \"amount\", \"op\": \">=\", \"value\": 10000}, {\"field\": \"time\", \"op\": \">=\", \"value\": 17}], \"then\": \"HOLD\"}, {\"if\": [{\"field\": \"amount\", \"op\": \">=\", \"value\": 10000}, {\"field\": \"time\", \"op\": \"<\", \"value\": 9}], \"then\": \"HOLD\"}, {\"if\": [{\"field\": \"amount\", \"op\": \">\", \"value\": 5000}, {\"field\": \"initiator_role\", \"op\": \"!=\", \"value\": \"manager\"}], \"then\": \"REQUIRE_APPROVAL\"}], \"default\": \"APPROVE\"}\n",
293
+ "\"\"\"\n",
294
+ " elif task_name == \"resource_access\":\n",
295
+ " base += \"\"\"\n",
296
+ "IMPORTANT \u00e2\u20ac\u201d RESOURCE ACCESS TASK:\n",
297
+ "This task has roles: junior, senior, contractor. Document types: public, internal, confidential.\n",
298
+ "- Senior employees: ALLOW everything always\n",
299
+ "- Contractors: ALLOW only public, DENY everything else\n",
300
+ "- Junior + confidential: ALWAYS DENY (regardless of time \u00e2\u20ac\u201d the policy is misleading about this)\n",
301
+ "- Junior + internal: ALLOW only during business hours (hour >= 8 AND hour < 17)\n",
302
+ "- Junior + public: ALLOW always\n",
303
+ "- Business hours: hour >= 8 AND hour < 17\n",
304
+ "\"\"\"\n",
305
  "\n",
 
 
306
  " if few_shot_examples:\n",
307
+ " base += \"\\n\\nLEARNED FROM PREVIOUS EPISODES (high-reward strategies):\\n\"\n",
308
  " for traj in few_shot_examples[-TOP_K_TRAJECTORIES:]:\n",
309
  " base += \"\\n\" + traj.to_few_shot_string() + \"\\n\"\n",
310
  " return base\n",
311
  "\n",
312
+ " def _build_user_prompt(self, obs: dict, step: int, history: list[str]) -> str:\n",
313
+ " lines = [\n",
314
+ " f\"TASK: {obs.get('task_name', 'unknown')}\",\n",
315
+ " f\"STEP: {step} of {obs.get('max_steps', 7)}\",\n",
316
+ " f\"\\nPOLICY:\\n{obs.get('policy_text', '')}\",\n",
317
+ " ]\n",
318
+ " if obs.get(\"clarification_response\"):\n",
319
+ " lines.append(f\"\\nLAST CLARIFICATION ANSWER:\\n{obs['clarification_response']}\")\n",
320
  " if obs.get(\"test_results\"):\n",
321
  " tr = obs[\"test_results\"]\n",
322
+ " lines.append(f\"\\nTEST RESULTS: {tr.get('passed', 0)}/{tr.get('total', 0)} passed (accuracy={obs.get('current_accuracy', 0):.2f})\")\n",
323
+ " if tr.get(\"sample_failures\"):\n",
324
+ " lines.append(\"SAMPLE FAILURES:\")\n",
325
+ " for f in tr[\"sample_failures\"][:3]:\n",
326
+ " lines.append(f\" - {f}\")\n",
327
+ " if obs.get(\"feedback\"):\n",
328
+ " lines.append(f\"\\nFEEDBACK: {obs['feedback']}\")\n",
329
+ " if history:\n",
330
+ " lines.append(f\"\\nACTION HISTORY (last 3):\\n\" + \"\\n\".join(history[-3:]))\n",
331
+ " lines.append(f\"\\nAVAILABLE ACTIONS: {obs.get('available_actions', [])}\")\n",
332
+ " lines.append(\"\\nRespond with JSON only: {\\\"action_type\\\": \\\"...\\\", \\\"content\\\": \\\"...\\\"}\")\n",
333
  " return \"\\n\".join(lines)\n",
334
  "\n",
335
+ " def _parse_response(self, raw: str, obs: dict) -> tuple[str, str]:\n",
336
+ " # Strip markdown code fences if present\n",
337
  " if \"```\" in raw:\n",
338
  " raw = raw.split(\"```\")[1]\n",
339
+ " if raw.startswith(\"json\"):\n",
340
+ " raw = raw[4:]\n",
341
  " raw = raw.strip()\n",
342
+ "\n",
343
  " try:\n",
344
  " parsed = json.loads(raw)\n",
345
  " action_type = parsed.get(\"action_type\", \"propose_rules\")\n",
346
  " content = parsed.get(\"content\", \"{}\")\n",
347
+ "\n",
348
+ " # Validate action_type\n",
349
+ " valid_actions = obs.get(\"available_actions\", [\"propose_rules\", \"ask_clarification\"])\n",
350
+ " if action_type not in valid_actions:\n",
351
+ " action_type = \"propose_rules\" if \"propose_rules\" in valid_actions else valid_actions[0]\n",
352
+ "\n",
353
+ " # Ensure content is a string\n",
354
+ " if isinstance(content, dict):\n",
355
+ " content = json.dumps(content)\n",
356
  " return action_type, content\n",
357
+ " except Exception:\n",
358
+ " return \"propose_rules\", json.dumps({\"rules\": [], \"default\": \"DENY\"})\n",
359
+ "\n",
360
+ "# \u00e2\u201d\u20ac\u00e2\u201d\u20ac Trajectory Bank \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
361
  "\n",
362
  "class TrajectoryBank:\n",
363
+ " \"\"\"Stores and retrieves high-reward trajectories per task.\"\"\"\n",
364
+ "\n",
365
+ " def __init__(self):\n",
366
+ " self.bank: dict[str, list[Trajectory]] = {task: [] for task in TASKS}\n",
367
+ "\n",
368
+ " def store(self, trajectory: Trajectory):\n",
369
+ " if trajectory.total_reward >= MIN_REWARD_THRESHOLD:\n",
370
+ " self.bank[trajectory.task_name].append(trajectory)\n",
371
+ " # Keep only top-K by reward\n",
372
+ " self.bank[trajectory.task_name].sort(key=lambda t: t.total_reward, reverse=True)\n",
373
+ " self.bank[trajectory.task_name] = self.bank[trajectory.task_name][:TOP_K_TRAJECTORIES]\n",
374
+ "\n",
375
+ " def get_examples(self, task_name: str) -> list[Trajectory]:\n",
376
+ " return self.bank.get(task_name, [])\n",
377
+ "\n",
378
+ " def summary(self) -> dict:\n",
379
+ " return {\n",
380
+ " task: {\n",
381
+ " \"stored\": len(trajs),\n",
382
+ " \"best_reward\": max((t.total_reward for t in trajs), default=0),\n",
383
+ " \"best_accuracy\": max((t.final_accuracy for t in trajs), default=0)\n",
384
+ " }\n",
385
+ " for task, trajs in self.bank.items()\n",
386
+ " }\n",
387
+ "\n",
388
+ "# \u00e2\u201d\u20ac\u00e2\u201d\u20ac Training Loop \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
389
  "\n",
390
  "class TrainingLoop:\n",
391
+ " def __init__(self, env_url: str, hf_token: str):\n",
392
  " self.env = EnvClient(env_url)\n",
393
  " self.agent = Agent(hf_token)\n",
394
  " self.bank = TrajectoryBank()\n",
395
+ " self.metrics = [] # List of {episode, task, reward, accuracy, success}\n",
396
+ "\n",
397
+ " os.makedirs(\"training/logs\", exist_ok=True)\n",
398
+ " log_filename = f\"training/logs/run_{int(time.time())}.log\"\n",
399
+ " logging.basicConfig(\n",
400
+ " level=logging.INFO,\n",
401
+ " format=\"%(asctime)s [%(levelname)s] %(message)s\",\n",
402
+ " handlers=[\n",
403
+ " logging.FileHandler(log_filename),\n",
404
+ " logging.StreamHandler() # also print to console\n",
405
+ " ]\n",
406
+ " )\n",
407
+ " self.logger = logging.getLogger(\"TrainingLoop\")\n",
408
+ " self.log_file = log_filename\n",
409
+ "\n",
410
+ " def run_episode(self, task_name: str, episode_id: int) -> Trajectory:\n",
411
+ " \"\"\"Run a single episode and return the trajectory.\"\"\"\n",
412
  " few_shots = self.bank.get_examples(task_name)\n",
413
+ " trajectory = Trajectory(task_name=task_name, episode_id=episode_id)\n",
414
+ "\n",
415
+ " # Reset environment\n",
416
  " result = self.env.reset(task_name)\n",
417
+ " obs = result.get(\"observation\", {})\n",
418
+ " done = result.get(\"done\", False)\n",
419
+ " history = []\n",
420
+ "\n",
421
+ " self.logger.info(f\"START episode={episode_id} task={task_name} few_shots_available={len(few_shots)}\")\n",
422
+ "\n",
423
  " step_num = 0\n",
424
  " while not done and step_num < obs.get(\"max_steps\", 7):\n",
425
  " step_num += 1\n",
426
+ "\n",
427
+ " # Get action from agent\n",
428
+ " action_type, content = self.agent.get_action(\n",
429
+ " observation=obs,\n",
430
+ " step_number=step_num,\n",
431
+ " episode_history=history,\n",
432
+ " few_shot_examples=few_shots,\n",
433
+ " task_name=task_name\n",
434
+ " )\n",
435
+ "\n",
436
+ " # Execute action\n",
437
  " result = self.env.step(action_type, content)\n",
438
+ " reward = result.get(\"reward\", 0.0)\n",
439
+ " done = result.get(\"done\", False)\n",
440
+ " obs = result.get(\"observation\", {})\n",
441
+ " info = result.get(\"info\", {})\n",
442
+ "\n",
443
+ " # Record step\n",
444
+ " step = Step(\n",
445
+ " step_number=step_num,\n",
446
+ " action_type=action_type,\n",
447
+ " action_content=content[:300],\n",
448
+ " reward=reward,\n",
449
+ " accuracy=obs.get(\"current_accuracy\", 0.0),\n",
450
+ " feedback=obs.get(\"feedback\", \"\") or \"\",\n",
451
+ " clarification_response=obs.get(\"clarification_response\")\n",
452
+ " )\n",
453
+ " trajectory.steps.append(step)\n",
454
+ " trajectory.total_reward += reward\n",
455
+ "\n",
456
+ " # Update history\n",
457
+ " history.append(f\"Step {step_num}: {action_type} \u00e2\u2020\u2019 reward={reward:.2f} acc={step.accuracy:.2f}\")\n",
458
+ "\n",
459
+ " self.logger.info(f\"STEP episode={episode_id} step={step_num} action={action_type} reward={reward:.4f} accuracy={step.accuracy:.4f}\")\n",
460
+ "\n",
461
  " if done:\n",
462
+ " episode_score = info.get(\"episode_score\", obs.get(\"current_accuracy\", 0.0))\n",
463
+ " trajectory.final_accuracy = episode_score\n",
464
+ " trajectory.success = obs.get(\"current_accuracy\", 0.0) >= 0.9\n",
465
  " break\n",
466
+ "\n",
467
+ " if not trajectory.steps:\n",
468
+ " trajectory.final_accuracy = 0.0\n",
469
+ "\n",
470
+ " self.logger.info(f\"END episode={episode_id} task={task_name} total_reward={trajectory.total_reward:.4f} final_accuracy={trajectory.final_accuracy:.4f} success={trajectory.success} steps={len(trajectory.steps)}\")\n",
471
+ "\n",
472
+ " return trajectory\n",
473
  "\n",
474
  " def run(self):\n",
475
+ " \"\"\"Run full training loop across all tasks.\"\"\"\n",
476
+ " self.logger.info(\"=\" * 60)\n",
477
+ " self.logger.info(\"REWARD-GUIDED TRAJECTORY OPTIMIZATION\")\n",
478
+ " self.logger.info(f\"Tasks: {TASKS}\")\n",
479
+ " self.logger.info(f\"Episodes per task: {NUM_EPISODES_PER_TASK}\")\n",
480
+ " self.logger.info(f\"Top-K trajectories: {TOP_K_TRAJECTORIES}\")\n",
481
+ " self.logger.info(\"=\" * 60)\n",
482
+ " self.logger.info(f\"Log file: {self.log_file}\")\n",
483
+ "\n",
484
+ " try:\n",
485
+ " wandb.init(\n",
486
+ " project=\"policy-to-logic-rl\",\n",
487
+ " name=f\"trajectory-opt-{int(time.time())}\",\n",
488
+ " config={\n",
489
+ " \"num_episodes_per_task\": NUM_EPISODES_PER_TASK,\n",
490
+ " \"top_k_trajectories\": TOP_K_TRAJECTORIES,\n",
491
+ " \"min_reward_threshold\": MIN_REWARD_THRESHOLD,\n",
492
+ " \"model\": MODEL,\n",
493
+ " \"temperature\": TEMPERATURE,\n",
494
+ " \"tasks\": TASKS,\n",
495
+ " \"env_url\": ENV_BASE_URL,\n",
496
+ " }\n",
497
+ " )\n",
498
+ " except Exception as e:\n",
499
+ " self.logger.warning(f\"Wandb init failed: {e}. Continuing without W&B.\")\n",
500
+ "\n",
501
+ " # Health check\n",
502
+ " if not self.env.health():\n",
503
+ " raise RuntimeError(f\"Environment not reachable at {ENV_BASE_URL}\")\n",
504
+ " self.logger.info(f\"Environment: OK ({ENV_BASE_URL})\\n\")\n",
505
+ "\n",
506
+ " global_episode = 0\n",
507
+ "\n",
508
  " for task in TASKS:\n",
509
+ " self.logger.info(f\"\\n{'\u00e2\u201d\u20ac'*40}\")\n",
510
+ " self.logger.info(f\"TASK: {task}\")\n",
511
+ " self.logger.info(f\"{'\u00e2\u201d\u20ac'*40}\")\n",
512
+ "\n",
513
  " task_rewards = []\n",
514
+ " task_accuracies = []\n",
515
+ "\n",
516
  " for ep in range(1, NUM_EPISODES_PER_TASK + 1):\n",
517
+ " global_episode += 1\n",
518
+ " trajectory = self.run_episode(task, ep)\n",
519
+ "\n",
520
+ " # Store in bank\n",
521
+ " self.bank.store(trajectory)\n",
522
+ "\n",
523
+ " try:\n",
524
+ " wandb.log({\n",
525
+ " f\"{task}/total_reward\": trajectory.total_reward,\n",
526
+ " f\"{task}/final_accuracy\": trajectory.final_accuracy,\n",
527
+ " f\"{task}/num_steps\": len(trajectory.steps),\n",
528
+ " f\"{task}/success\": int(trajectory.success),\n",
529
+ " f\"{task}/few_shots_used\": len(self.bank.get_examples(task)),\n",
530
+ " \"global/total_reward\": trajectory.total_reward,\n",
531
+ " \"global/final_accuracy\": trajectory.final_accuracy,\n",
532
+ " \"episode\": global_episode,\n",
533
+ " })\n",
534
+ " except Exception:\n",
535
+ " pass\n",
536
+ "\n",
537
+ " # Record metrics\n",
538
+ " self.metrics.append({\n",
539
+ " \"global_episode\": global_episode,\n",
540
+ " \"task\": task,\n",
541
+ " \"episode_in_task\": ep,\n",
542
+ " \"total_reward\": trajectory.total_reward,\n",
543
+ " \"final_accuracy\": trajectory.final_accuracy,\n",
544
+ " \"success\": trajectory.success,\n",
545
+ " \"num_steps\": len(trajectory.steps),\n",
546
+ " \"few_shots_used\": len(self.bank.get_examples(task)) - (1 if trajectory.total_reward >= MIN_REWARD_THRESHOLD else 0)\n",
547
+ " })\n",
548
+ "\n",
549
+ " task_rewards.append(trajectory.total_reward)\n",
550
+ " task_accuracies.append(trajectory.final_accuracy)\n",
551
+ "\n",
552
+ " self.logger.info(f\" \u00e2\u2020\u2019 Episode {ep} complete: reward={trajectory.total_reward:.3f} accuracy={trajectory.final_accuracy:.2f} success={trajectory.success}\")\n",
553
+ " time.sleep(0.5) # Rate limiting\n",
554
+ "\n",
555
+ " self.logger.info(f\"\\n Task summary:\")\n",
556
+ " self.logger.info(f\" First episode reward: {task_rewards[0]:.3f}\")\n",
557
+ " self.logger.info(f\" Last episode reward: {task_rewards[-1]:.3f}\")\n",
558
+ " self.logger.info(f\" Improvement: {task_rewards[-1] - task_rewards[0]:+.3f}\")\n",
559
+ "\n",
560
+ " self.logger.info(\"\\n\" + \"=\" * 60)\n",
561
+ " self.logger.info(\"TRAINING COMPLETE\")\n",
562
+ " self.logger.info(f\"Bank summary: {self.bank.summary()}\")\n",
563
+ " self.logger.info(\"=\" * 60)\n",
564
+ "\n",
565
+ " try:\n",
566
+ " wandb.finish()\n",
567
+ " except Exception:\n",
568
+ " pass\n",
569
+ "\n",
570
  " return self.metrics\n",
571
  "\n",
572
+ "# \u00e2\u201d\u20ac\u00e2\u201d\u20ac Plot Generation \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
573
+ "\n",
574
+ "def save_plots(metrics: list[dict]):\n",
575
+ " \"\"\"\n",
576
+ " Save reward curve and accuracy curve as PNG files.\n",
577
+ " These are REQUIRED for hackathon submission \u00e2\u20ac\u201d must be committed to repo.\n",
578
+ " \"\"\"\n",
579
+ " try:\n",
580
+ " import matplotlib\n",
581
+ " matplotlib.use(\"Agg\") # Non-interactive backend\n",
582
+ " import matplotlib.pyplot as plt\n",
583
+ " import numpy as np\n",
584
+ " except ImportError:\n",
585
+ " print(\"matplotlib not installed. Run: pip install matplotlib\")\n",
586
+ " return\n",
587
+ "\n",
588
  " os.makedirs(\"training/plots\", exist_ok=True)\n",
589
+ "\n",
590
  " episodes = [m[\"global_episode\"] for m in metrics]\n",
591
  " rewards = [m[\"total_reward\"] for m in metrics]\n",
592
+ " accuracies = [m[\"final_accuracy\"] for m in metrics]\n",
593
+ " tasks = [m[\"task\"] for m in metrics]\n",
594
+ "\n",
595
+ " colors = {\n",
596
+ " \"data_access\": \"#2196F3\",\n",
597
+ " \"resource_access\": \"#FF9800\",\n",
598
+ " \"transaction_approval\": \"#4CAF50\"\n",
599
+ " }\n",
600
+ "\n",
601
+ " # \u00e2\u201d\u20ac\u00e2\u201d\u20ac Plot 1: Reward Curve (per-task trend lines) \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
602
  " fig, ax = plt.subplots(figsize=(10, 5))\n",
603
+ "\n",
604
  " for task in TASKS:\n",
605
+ " task_eps = [m[\"global_episode\"] for m in metrics if m[\"task\"] == task]\n",
606
+ " task_rews = [m[\"total_reward\"] for m in metrics if m[\"task\"] == task]\n",
607
+ " ax.plot(task_eps, task_rews, marker=\"o\", label=task,\n",
608
+ " color=colors.get(task, \"gray\"), linewidth=2, markersize=5)\n",
609
+ " # Per-task trend line\n",
610
+ " if len(task_eps) >= 2:\n",
611
+ " z = np.polyfit(task_eps, task_rews, 1)\n",
612
+ " p = np.poly1d(z)\n",
613
+ " ax.plot(task_eps, p(task_eps), \"--\",\n",
614
+ " color=colors.get(task, \"gray\"), alpha=0.4, linewidth=1.5)\n",
615
+ "\n",
616
+ " ax.set_xlabel(\"Episode\")\n",
617
+ " ax.set_ylabel(\"Total Reward\")\n",
618
+ " ax.set_title(\"Reward Curve \u00e2\u20ac\u201d Reward-Guided Trajectory Optimization\")\n",
619
+ " ax.legend()\n",
620
+ " ax.grid(True, alpha=0.3)\n",
621
+ " ax.set_ylim(bottom=0)\n",
622
+ "\n",
623
+ " plt.tight_layout()\n",
624
+ " plt.savefig(\"training/plots/reward_curve.png\", dpi=150, bbox_inches=\"tight\")\n",
625
+ " plt.close()\n",
626
+ " print(\"Saved: training/plots/reward_curve.png\")\n",
627
+ "\n",
628
+ " # \u00e2\u201d\u20ac\u00e2\u201d\u20ac Plot 2: Accuracy Curve \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
629
  " fig, ax = plt.subplots(figsize=(10, 5))\n",
630
+ "\n",
631
  " for task in TASKS:\n",
632
+ " task_eps = [m[\"global_episode\"] for m in metrics if m[\"task\"] == task]\n",
633
+ " task_accs = [m[\"final_accuracy\"] for m in metrics if m[\"task\"] == task]\n",
634
+ " ax.plot(task_eps, task_accs, marker=\"s\", label=task,\n",
635
+ " color=colors.get(task, \"gray\"), linewidth=2, markersize=5)\n",
636
+ "\n",
637
+ " ax.axhline(y=0.9, color=\"red\", linestyle=\"--\", alpha=0.7, label=\"success threshold (0.9)\")\n",
638
+ "\n",
639
+ " ax.set_xlabel(\"Episode\")\n",
640
+ " ax.set_ylabel(\"Final Accuracy\")\n",
641
+ " ax.set_title(\"Accuracy Curve \u00e2\u20ac\u201d Policy-to-Logic Agent\")\n",
642
+ " ax.legend()\n",
643
+ " ax.grid(True, alpha=0.3)\n",
644
+ " ax.set_ylim(0, 1.05)\n",
645
+ "\n",
646
+ " plt.tight_layout()\n",
647
+ " plt.savefig(\"training/plots/accuracy_curve.png\", dpi=150, bbox_inches=\"tight\")\n",
648
+ " plt.close()\n",
649
+ " print(\"Saved: training/plots/accuracy_curve.png\")\n",
650
+ "\n",
651
+ " # \u00e2\u201d\u20ac\u00e2\u201d\u20ac Plot 3: Per-Task Summary (Accuracy + Efficiency) \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
652
+ " fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
653
+ "\n",
654
+ " task_labels = []\n",
655
+ " acc_improvements = []\n",
656
+ " eff_improvements = []\n",
657
+ " best_accuracies = []\n",
658
+ "\n",
659
  " for task in TASKS:\n",
660
+ " task_data = [m for m in metrics if m[\"task\"] == task]\n",
661
+ " if len(task_data) >= 2:\n",
662
+ " task_labels.append(task.replace(\"_\", \"\\n\"))\n",
663
+ " acc_improvements.append(task_data[-1][\"final_accuracy\"] - task_data[0][\"final_accuracy\"])\n",
664
+ " # Efficiency: steps saved (first vs best)\n",
665
+ " first_steps = task_data[0][\"num_steps\"]\n",
666
+ " best_steps = min(m[\"num_steps\"] for m in task_data)\n",
667
+ " eff_pct = ((first_steps - best_steps) / first_steps * 100) if first_steps > 0 else 0\n",
668
+ " eff_improvements.append(eff_pct)\n",
669
+ " best_accuracies.append(max(m[\"final_accuracy\"] for m in task_data))\n",
670
+ "\n",
671
+ " # Left: Best accuracy per task\n",
672
+ " bars1 = axes[0].bar(task_labels, best_accuracies,\n",
673
+ " color=[\"#2196F3\", \"#FF9800\", \"#4CAF50\"][:len(task_labels)],\n",
674
+ " edgecolor=\"white\", linewidth=1.5)\n",
675
+ " axes[0].axhline(y=0.9, color=\"red\", linestyle=\"--\", alpha=0.7, label=\"success threshold\")\n",
676
+ " axes[0].set_ylabel(\"Best Accuracy Achieved\")\n",
677
+ " axes[0].set_title(\"Best Accuracy Per Task\")\n",
678
+ " axes[0].set_ylim(0, 1.1)\n",
679
+ " axes[0].grid(True, axis=\"y\", alpha=0.3)\n",
680
+ " axes[0].legend()\n",
681
+ " for bar, val in zip(bars1, best_accuracies):\n",
682
+ " axes[0].text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.02,\n",
683
+ " f\"{val:.0%}\", ha=\"center\", va=\"bottom\", fontweight=\"bold\")\n",
684
+ "\n",
685
+ " # Right: Efficiency improvement (% steps saved)\n",
686
+ " bars2 = axes[1].bar(task_labels, eff_improvements,\n",
687
+ " color=[\"#2196F3\", \"#FF9800\", \"#4CAF50\"][:len(task_labels)],\n",
688
+ " edgecolor=\"white\", linewidth=1.5)\n",
689
+ " axes[1].axhline(y=0, color=\"black\", linewidth=0.8)\n",
690
+ " axes[1].set_ylabel(\"Steps Saved (%)\")\n",
691
+ " axes[1].set_title(\"Efficiency Improvement (First \u00e2\u2020\u2019 Best Episode)\")\n",
692
+ " axes[1].grid(True, axis=\"y\", alpha=0.3)\n",
693
+ " for bar, val in zip(bars2, eff_improvements):\n",
694
+ " y_pos = max(bar.get_height() + 1, 2)\n",
695
+ " axes[1].text(bar.get_x() + bar.get_width() / 2, y_pos,\n",
696
+ " f\"{val:.0f}%\", ha=\"center\", va=\"bottom\", fontweight=\"bold\")\n",
697
+ "\n",
698
+ " plt.tight_layout()\n",
699
+ " plt.savefig(\"training/plots/improvement_chart.png\", dpi=150, bbox_inches=\"tight\")\n",
700
+ " plt.close()\n",
701
+ " print(\"Saved: training/plots/improvement_chart.png\")\n",
702
+ "\n",
703
+ " # \u00e2\u201d\u20ac\u00e2\u201d\u20ac Save raw metrics as JSON \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
704
+ " timestamp = int(time.time())\n",
705
+ " with open(f\"training/plots/metrics_{timestamp}.json\", \"w\") as f:\n",
706
+ " json.dump(metrics, f, indent=2)\n",
707
+ " with open(\"training/plots/metrics_latest.json\", \"w\") as f:\n",
708
+ " json.dump(metrics, f, indent=2)\n",
709
+ " print(f\"Saved: training/plots/metrics_{timestamp}.json and metrics_latest.json\")\n",
710
+ "\n",
711
+ "# \u00e2\u201d\u20ac\u00e2\u201d\u20ac Entry Point \u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\u00e2\u201d\u20ac\n",
712
+ "\n",
713
+ "if __name__ == \"__main__\":\n",
714
+ " hf_token = os.getenv(\"HF_TOKEN\", \"\")\n",
715
+ " if not hf_token:\n",
716
+ " raise ValueError(\"HF_TOKEN environment variable not set\")\n",
717
+ "\n",
718
+ " loop = TrainingLoop(ENV_BASE_URL, hf_token)\n",
719
+ " metrics = loop.run()\n",
720
+ " save_plots(metrics)\n",
721
+ "\n",
722
+ " print(\"\\nNext step: commit training/plots/*.png to repo for submission.\")\n",
723
+ ""
724
  ]
725
  },
726
  {
 
729
  "metadata": {},
730
  "outputs": [],
731
  "source": [
732
+ "# Cell 6: Run training loop\n",
733
  "loop = TrainingLoop(ENV_URL, HF_TOKEN)\n",
734
  "metrics = loop.run()\n",
735
  "print(f\"\\nTotal episodes run: {len(metrics)}\")"
 
741
  "metadata": {},
742
  "outputs": [],
743
  "source": [
744
+ "# Cell 7: Generate plots and display inline\n",
745
  "save_plots(metrics)\n",
746
  "\n",
747
  "from IPython.display import Image, display\n",
 
756
  "metadata": {},
757
  "outputs": [],
758
  "source": [
759
+ "# Cell 8: Display wandb run link\n",
760
+ "print(f\"Wandb run: https://wandb.ai/YOUR_USERNAME/policy-to-logic-rl\")\n",
761
+ "print(\"Add this link to your README under Deliverables.\")"
762
+ ]
763
+ },
764
+ {
765
+ "cell_type": "code",
766
+ "execution_count": null,
767
+ "metadata": {},
768
+ "outputs": [],
769
+ "source": [
770
+ "# Cell 9: Download plots to commit to repo\n",
771
  "from google.colab import files\n",
772
  "\n",
773
  "files.download(\"training/plots/reward_curve.png\")\n",
774
  "files.download(\"training/plots/accuracy_curve.png\")\n",
775
  "files.download(\"training/plots/improvement_chart.png\")\n",
776
+ "files.download(\"training/plots/metrics_latest.json\")\n",
777
  "\n",
778
  "print(\"Downloaded. Now commit these files to: training/plots/ in your repo.\")"
779
  ]
780
  }
781
  ]
782
+ }
training/results-iteration1/accuracy_curve (1).png ADDED
training/results-iteration1/improvement_chart (1).png ADDED
training/results-iteration1/metrics (1).json ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "global_episode": 1,
4
+ "task": "data_access",
5
+ "episode_in_task": 1,
6
+ "total_reward": 1.1988333333333334,
7
+ "final_accuracy": 0.92,
8
+ "success": true,
9
+ "num_steps": 4
10
+ },
11
+ {
12
+ "global_episode": 2,
13
+ "task": "data_access",
14
+ "episode_in_task": 2,
15
+ "total_reward": 0.7585,
16
+ "final_accuracy": 0.9600000000000001,
17
+ "success": true,
18
+ "num_steps": 2
19
+ },
20
+ {
21
+ "global_episode": 3,
22
+ "task": "data_access",
23
+ "episode_in_task": 3,
24
+ "total_reward": 0.7585,
25
+ "final_accuracy": 0.9600000000000001,
26
+ "success": true,
27
+ "num_steps": 2
28
+ },
29
+ {
30
+ "global_episode": 4,
31
+ "task": "data_access",
32
+ "episode_in_task": 4,
33
+ "total_reward": 0.7585,
34
+ "final_accuracy": 0.9600000000000001,
35
+ "success": true,
36
+ "num_steps": 2
37
+ },
38
+ {
39
+ "global_episode": 5,
40
+ "task": "data_access",
41
+ "episode_in_task": 5,
42
+ "total_reward": 0.7585,
43
+ "final_accuracy": 0.9600000000000001,
44
+ "success": true,
45
+ "num_steps": 2
46
+ },
47
+ {
48
+ "global_episode": 6,
49
+ "task": "data_access",
50
+ "episode_in_task": 6,
51
+ "total_reward": 0.7585,
52
+ "final_accuracy": 0.9600000000000001,
53
+ "success": true,
54
+ "num_steps": 2
55
+ },
56
+ {
57
+ "global_episode": 7,
58
+ "task": "data_access",
59
+ "episode_in_task": 7,
60
+ "total_reward": 0.7585,
61
+ "final_accuracy": 0.9600000000000001,
62
+ "success": true,
63
+ "num_steps": 2
64
+ },
65
+ {
66
+ "global_episode": 8,
67
+ "task": "data_access",
68
+ "episode_in_task": 8,
69
+ "total_reward": 0.7585,
70
+ "final_accuracy": 0.9600000000000001,
71
+ "success": true,
72
+ "num_steps": 2
73
+ },
74
+ {
75
+ "global_episode": 9,
76
+ "task": "resource_access",
77
+ "episode_in_task": 1,
78
+ "total_reward": 1.041,
79
+ "final_accuracy": 0.8931428571428572,
80
+ "success": true,
81
+ "num_steps": 3
82
+ },
83
+ {
84
+ "global_episode": 10,
85
+ "task": "resource_access",
86
+ "episode_in_task": 2,
87
+ "total_reward": 1.041,
88
+ "final_accuracy": 0.8931428571428572,
89
+ "success": true,
90
+ "num_steps": 3
91
+ },
92
+ {
93
+ "global_episode": 11,
94
+ "task": "resource_access",
95
+ "episode_in_task": 3,
96
+ "total_reward": 0.7335,
97
+ "final_accuracy": 0.9074285714285715,
98
+ "success": true,
99
+ "num_steps": 2
100
+ },
101
+ {
102
+ "global_episode": 12,
103
+ "task": "resource_access",
104
+ "episode_in_task": 4,
105
+ "total_reward": 0.7335,
106
+ "final_accuracy": 0.9074285714285715,
107
+ "success": true,
108
+ "num_steps": 2
109
+ },
110
+ {
111
+ "global_episode": 13,
112
+ "task": "resource_access",
113
+ "episode_in_task": 5,
114
+ "total_reward": 0.7434999999999999,
115
+ "final_accuracy": 0.9234285714285714,
116
+ "success": true,
117
+ "num_steps": 2
118
+ },
119
+ {
120
+ "global_episode": 14,
121
+ "task": "resource_access",
122
+ "episode_in_task": 6,
123
+ "total_reward": 0.7434999999999999,
124
+ "final_accuracy": 0.9234285714285714,
125
+ "success": true,
126
+ "num_steps": 2
127
+ },
128
+ {
129
+ "global_episode": 15,
130
+ "task": "resource_access",
131
+ "episode_in_task": 7,
132
+ "total_reward": 1.929,
133
+ "final_accuracy": 0.8645714285714287,
134
+ "success": true,
135
+ "num_steps": 5
136
+ },
137
+ {
138
+ "global_episode": 16,
139
+ "task": "resource_access",
140
+ "episode_in_task": 8,
141
+ "total_reward": 0.7335,
142
+ "final_accuracy": 0.9074285714285715,
143
+ "success": true,
144
+ "num_steps": 2
145
+ },
146
+ {
147
+ "global_episode": 17,
148
+ "task": "transaction_approval",
149
+ "episode_in_task": 1,
150
+ "total_reward": 0.10799999999999998,
151
+ "final_accuracy": 0.0,
152
+ "success": false,
153
+ "num_steps": 7
154
+ },
155
+ {
156
+ "global_episode": 18,
157
+ "task": "transaction_approval",
158
+ "episode_in_task": 2,
159
+ "total_reward": 0.10799999999999998,
160
+ "final_accuracy": 0.0,
161
+ "success": false,
162
+ "num_steps": 7
163
+ },
164
+ {
165
+ "global_episode": 19,
166
+ "task": "transaction_approval",
167
+ "episode_in_task": 3,
168
+ "total_reward": 0.10799999999999998,
169
+ "final_accuracy": 0.0,
170
+ "success": false,
171
+ "num_steps": 7
172
+ },
173
+ {
174
+ "global_episode": 20,
175
+ "task": "transaction_approval",
176
+ "episode_in_task": 4,
177
+ "total_reward": 0.10799999999999998,
178
+ "final_accuracy": 0.0,
179
+ "success": false,
180
+ "num_steps": 7
181
+ },
182
+ {
183
+ "global_episode": 21,
184
+ "task": "transaction_approval",
185
+ "episode_in_task": 5,
186
+ "total_reward": 0.10799999999999998,
187
+ "final_accuracy": 0.0,
188
+ "success": false,
189
+ "num_steps": 7
190
+ },
191
+ {
192
+ "global_episode": 22,
193
+ "task": "transaction_approval",
194
+ "episode_in_task": 6,
195
+ "total_reward": 0.10799999999999998,
196
+ "final_accuracy": 0.0,
197
+ "success": false,
198
+ "num_steps": 7
199
+ },
200
+ {
201
+ "global_episode": 23,
202
+ "task": "transaction_approval",
203
+ "episode_in_task": 7,
204
+ "total_reward": 0.10799999999999998,
205
+ "final_accuracy": 0.0,
206
+ "success": false,
207
+ "num_steps": 7
208
+ },
209
+ {
210
+ "global_episode": 24,
211
+ "task": "transaction_approval",
212
+ "episode_in_task": 8,
213
+ "total_reward": 0.10799999999999998,
214
+ "final_accuracy": 0.0,
215
+ "success": false,
216
+ "num_steps": 7
217
+ }
218
+ ]
training/results-iteration1/reward_curve (1).png ADDED
training/trajectory_optimizer.py CHANGED
@@ -13,6 +13,8 @@ import json
13
  import os
14
  import time
15
  import requests
 
 
16
  from dataclasses import dataclass, field
17
  from typing import Optional
18
  from openai import OpenAI
@@ -106,14 +108,15 @@ class Agent:
106
  observation: dict,
107
  step_number: int,
108
  episode_history: list[str],
109
- few_shot_examples: list[Trajectory]
 
110
  ) -> tuple[str, str]:
111
  """
112
  Returns (action_type, content_json_string).
113
  action_type: one of ask_clarification | propose_rules | refine_rules
114
  content: JSON string appropriate for that action
115
  """
116
- system_prompt = self._build_system_prompt(few_shot_examples)
117
  user_prompt = self._build_user_prompt(observation, step_number, episode_history)
118
 
119
  try:
@@ -132,7 +135,7 @@ class Agent:
132
  print(f" [LLM ERROR] {e}")
133
  return "propose_rules", json.dumps({"rules": [], "default": "DENY"})
134
 
135
- def _build_system_prompt(self, few_shot_examples: list[Trajectory]) -> str:
136
  base = """You are a policy-to-logic agent. Your job is to convert natural language policies into executable rules.
137
 
138
  AVAILABLE ACTIONS:
@@ -164,6 +167,39 @@ STRATEGY:
164
  OUTPUT FORMAT: Respond ONLY with valid JSON. No markdown. No explanation.
165
  {"action_type": "propose_rules", "content": "{...escaped json string...}"}
166
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  if few_shot_examples:
168
  base += "\n\nLEARNED FROM PREVIOUS EPISODES (high-reward strategies):\n"
169
  for traj in few_shot_examples[-TOP_K_TRAJECTORIES:]:
@@ -255,6 +291,19 @@ class TrainingLoop:
255
  self.bank = TrajectoryBank()
256
  self.metrics = [] # List of {episode, task, reward, accuracy, success}
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  def run_episode(self, task_name: str, episode_id: int) -> Trajectory:
259
  """Run a single episode and return the trajectory."""
260
  few_shots = self.bank.get_examples(task_name)
@@ -266,7 +315,7 @@ class TrainingLoop:
266
  done = result.get("done", False)
267
  history = []
268
 
269
- print(f" [Episode {episode_id}] task={task_name} few_shots={len(few_shots)}")
270
 
271
  step_num = 0
272
  while not done and step_num < obs.get("max_steps", 7):
@@ -277,7 +326,8 @@ class TrainingLoop:
277
  observation=obs,
278
  step_number=step_num,
279
  episode_history=history,
280
- few_shot_examples=few_shots
 
281
  )
282
 
283
  # Execute action
@@ -303,7 +353,7 @@ class TrainingLoop:
303
  # Update history
304
  history.append(f"Step {step_num}: {action_type} → reward={reward:.2f} acc={step.accuracy:.2f}")
305
 
306
- print(f" step={step_num} action={action_type} reward={reward:.3f} acc={step.accuracy:.2f}")
307
 
308
  if done:
309
  episode_score = info.get("episode_score", obs.get("current_accuracy", 0.0))
@@ -314,28 +364,48 @@ class TrainingLoop:
314
  if not trajectory.steps:
315
  trajectory.final_accuracy = 0.0
316
 
 
 
317
  return trajectory
318
 
319
  def run(self):
320
  """Run full training loop across all tasks."""
321
- print("=" * 60)
322
- print("REWARD-GUIDED TRAJECTORY OPTIMIZATION")
323
- print(f"Tasks: {TASKS}")
324
- print(f"Episodes per task: {NUM_EPISODES_PER_TASK}")
325
- print(f"Top-K trajectories: {TOP_K_TRAJECTORIES}")
326
- print("=" * 60)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
  # Health check
329
  if not self.env.health():
330
  raise RuntimeError(f"Environment not reachable at {ENV_BASE_URL}")
331
- print(f"Environment: OK ({ENV_BASE_URL})\n")
332
 
333
  global_episode = 0
334
 
335
  for task in TASKS:
336
- print(f"\n{'─'*40}")
337
- print(f"TASK: {task}")
338
- print(f"{'─'*40}")
339
 
340
  task_rewards = []
341
  task_accuracies = []
@@ -347,6 +417,20 @@ class TrainingLoop:
347
  # Store in bank
348
  self.bank.store(trajectory)
349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  # Record metrics
351
  self.metrics.append({
352
  "global_episode": global_episode,
@@ -362,18 +446,23 @@ class TrainingLoop:
362
  task_rewards.append(trajectory.total_reward)
363
  task_accuracies.append(trajectory.final_accuracy)
364
 
365
- print(f" → Episode {ep} complete: reward={trajectory.total_reward:.3f} accuracy={trajectory.final_accuracy:.2f} success={trajectory.success}")
366
  time.sleep(0.5) # Rate limiting
367
 
368
- print(f"\n Task summary:")
369
- print(f" First episode reward: {task_rewards[0]:.3f}")
370
- print(f" Last episode reward: {task_rewards[-1]:.3f}")
371
- print(f" Improvement: {task_rewards[-1] - task_rewards[0]:+.3f}")
372
 
373
- print("\n" + "=" * 60)
374
- print("TRAINING COMPLETE")
375
- print(f"Bank summary: {self.bank.summary()}")
376
- print("=" * 60)
 
 
 
 
 
377
 
378
  return self.metrics
379
 
@@ -406,7 +495,7 @@ def save_plots(metrics: list[dict]):
406
  "transaction_approval": "#4CAF50"
407
  }
408
 
409
- # ── Plot 1: Reward Curve ──────────────────────────────────────────────────
410
  fig, ax = plt.subplots(figsize=(10, 5))
411
 
412
  for task in TASKS:
@@ -414,11 +503,12 @@ def save_plots(metrics: list[dict]):
414
  task_rews = [m["total_reward"] for m in metrics if m["task"] == task]
415
  ax.plot(task_eps, task_rews, marker="o", label=task,
416
  color=colors.get(task, "gray"), linewidth=2, markersize=5)
417
-
418
- # Trend line
419
- z = np.polyfit(episodes, rewards, 1)
420
- p = np.poly1d(z)
421
- ax.plot(episodes, p(episodes), "--", color="red", alpha=0.5, linewidth=1.5, label="overall trend")
 
422
 
423
  ax.set_xlabel("Episode")
424
  ax.set_ylabel("Total Reward")
@@ -455,32 +545,52 @@ def save_plots(metrics: list[dict]):
455
  plt.close()
456
  print("Saved: training/plots/accuracy_curve.png")
457
 
458
- # ── Plot 3: Per-Task Improvement Bar Chart ────────────────────────────────
459
- fig, ax = plt.subplots(figsize=(8, 5))
460
 
461
- task_names = []
462
- improvements = []
 
 
463
 
464
  for task in TASKS:
465
- task_accs = [m["final_accuracy"] for m in metrics if m["task"] == task]
466
- if len(task_accs) >= 2:
467
- first = task_accs[0]
468
- last = task_accs[-1]
469
- task_names.append(task.replace("_", "\n"))
470
- improvements.append(last - first)
471
-
472
- bars = ax.bar(task_names, improvements,
473
- color=["#2196F3", "#FF9800", "#4CAF50"][:len(task_names)],
474
- edgecolor="white", linewidth=1.5)
475
-
476
- ax.axhline(y=0, color="black", linewidth=0.8)
477
- ax.set_ylabel("Accuracy Improvement (last - first episode)")
478
- ax.set_title("Per-Task Improvement from Trajectory Accumulation")
479
- ax.grid(True, axis="y", alpha=0.3)
480
-
481
- for bar, val in zip(bars, improvements):
482
- ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
483
- f"{val:+.2f}", ha="center", va="bottom", fontweight="bold")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
 
485
  plt.tight_layout()
486
  plt.savefig("training/plots/improvement_chart.png", dpi=150, bbox_inches="tight")
@@ -488,9 +598,12 @@ def save_plots(metrics: list[dict]):
488
  print("Saved: training/plots/improvement_chart.png")
489
 
490
  # ── Save raw metrics as JSON ──────────────────────────────────────────────
491
- with open("training/plots/metrics.json", "w") as f:
 
 
 
492
  json.dump(metrics, f, indent=2)
493
- print("Saved: training/plots/metrics.json")
494
 
495
  # ── Entry Point ───────────────────────────────────────────────────────────────
496
 
 
13
  import os
14
  import time
15
  import requests
16
+ import logging
17
+ import wandb
18
  from dataclasses import dataclass, field
19
  from typing import Optional
20
  from openai import OpenAI
 
108
  observation: dict,
109
  step_number: int,
110
  episode_history: list[str],
111
+ few_shot_examples: list[Trajectory],
112
+ task_name: str = ""
113
  ) -> tuple[str, str]:
114
  """
115
  Returns (action_type, content_json_string).
116
  action_type: one of ask_clarification | propose_rules | refine_rules
117
  content: JSON string appropriate for that action
118
  """
119
+ system_prompt = self._build_system_prompt(few_shot_examples, task_name)
120
  user_prompt = self._build_user_prompt(observation, step_number, episode_history)
121
 
122
  try:
 
135
  print(f" [LLM ERROR] {e}")
136
  return "propose_rules", json.dumps({"rules": [], "default": "DENY"})
137
 
138
+ def _build_system_prompt(self, few_shot_examples: list[Trajectory], task_name: str = "") -> str:
139
  base = """You are a policy-to-logic agent. Your job is to convert natural language policies into executable rules.
140
 
141
  AVAILABLE ACTIONS:
 
167
  OUTPUT FORMAT: Respond ONLY with valid JSON. No markdown. No explanation.
168
  {"action_type": "propose_rules", "content": "{...escaped json string...}"}
169
  """
170
+ # Task-specific guidance for complex tasks
171
+ if task_name == "transaction_approval":
172
+ base += """
173
+ IMPORTANT — TRANSACTION APPROVAL TASK:
174
+ This task has 4 possible decisions: APPROVE, REQUIRE_APPROVAL, COMPLIANCE_REVIEW, HOLD
175
+ Rules are evaluated TOP-TO-BOTTOM. Order matters critically. You MUST order rules by priority:
176
+ 1. FIRST: Check if transfer_type == "international" → then COMPLIANCE_REVIEW (always, overrides everything)
177
+ 2. SECOND: Check if amount >= 10000 AND time is outside business hours (hour < 9 or hour >= 17) → then HOLD
178
+ 3. THIRD: Check if amount > 5000 AND initiator_role != "manager" → then REQUIRE_APPROVAL
179
+ 4. DEFAULT: APPROVE
180
+
181
+ Key details:
182
+ - Standard limit is $5,000 (amount > 5000 triggers approval, NOT >=)
183
+ - High-value threshold is $10,000 (amount >= 10000)
184
+ - Business hours: hour >= 9 AND hour < 17
185
+ - Manager exemption ONLY applies to the standard $5,000 limit, NOT to international or high-value HOLD rules
186
+ - "system" role follows the same rules as "employee"
187
+
188
+ Here is a working example of valid rules for this task:
189
+ {"rules": [{"if": [{"field": "transfer_type", "op": "==", "value": "international"}], "then": "COMPLIANCE_REVIEW"}, {"if": [{"field": "amount", "op": ">=", "value": 10000}, {"field": "time", "op": ">=", "value": 17}], "then": "HOLD"}, {"if": [{"field": "amount", "op": ">=", "value": 10000}, {"field": "time", "op": "<", "value": 9}], "then": "HOLD"}, {"if": [{"field": "amount", "op": ">", "value": 5000}, {"field": "initiator_role", "op": "!=", "value": "manager"}], "then": "REQUIRE_APPROVAL"}], "default": "APPROVE"}
190
+ """
191
+ elif task_name == "resource_access":
192
+ base += """
193
+ IMPORTANT — RESOURCE ACCESS TASK:
194
+ This task has roles: junior, senior, contractor. Document types: public, internal, confidential.
195
+ - Senior employees: ALLOW everything always
196
+ - Contractors: ALLOW only public, DENY everything else
197
+ - Junior + confidential: ALWAYS DENY (regardless of time — the policy is misleading about this)
198
+ - Junior + internal: ALLOW only during business hours (hour >= 8 AND hour < 17)
199
+ - Junior + public: ALLOW always
200
+ - Business hours: hour >= 8 AND hour < 17
201
+ """
202
+
203
  if few_shot_examples:
204
  base += "\n\nLEARNED FROM PREVIOUS EPISODES (high-reward strategies):\n"
205
  for traj in few_shot_examples[-TOP_K_TRAJECTORIES:]:
 
291
  self.bank = TrajectoryBank()
292
  self.metrics = [] # List of {episode, task, reward, accuracy, success}
293
 
294
+ os.makedirs("training/logs", exist_ok=True)
295
+ log_filename = f"training/logs/run_{int(time.time())}.log"
296
+ logging.basicConfig(
297
+ level=logging.INFO,
298
+ format="%(asctime)s [%(levelname)s] %(message)s",
299
+ handlers=[
300
+ logging.FileHandler(log_filename),
301
+ logging.StreamHandler() # also print to console
302
+ ]
303
+ )
304
+ self.logger = logging.getLogger("TrainingLoop")
305
+ self.log_file = log_filename
306
+
307
  def run_episode(self, task_name: str, episode_id: int) -> Trajectory:
308
  """Run a single episode and return the trajectory."""
309
  few_shots = self.bank.get_examples(task_name)
 
315
  done = result.get("done", False)
316
  history = []
317
 
318
+ self.logger.info(f"START episode={episode_id} task={task_name} few_shots_available={len(few_shots)}")
319
 
320
  step_num = 0
321
  while not done and step_num < obs.get("max_steps", 7):
 
326
  observation=obs,
327
  step_number=step_num,
328
  episode_history=history,
329
+ few_shot_examples=few_shots,
330
+ task_name=task_name
331
  )
332
 
333
  # Execute action
 
353
  # Update history
354
  history.append(f"Step {step_num}: {action_type} → reward={reward:.2f} acc={step.accuracy:.2f}")
355
 
356
+ self.logger.info(f"STEP episode={episode_id} step={step_num} action={action_type} reward={reward:.4f} accuracy={step.accuracy:.4f}")
357
 
358
  if done:
359
  episode_score = info.get("episode_score", obs.get("current_accuracy", 0.0))
 
364
  if not trajectory.steps:
365
  trajectory.final_accuracy = 0.0
366
 
367
+ self.logger.info(f"END episode={episode_id} task={task_name} total_reward={trajectory.total_reward:.4f} final_accuracy={trajectory.final_accuracy:.4f} success={trajectory.success} steps={len(trajectory.steps)}")
368
+
369
  return trajectory
370
 
371
  def run(self):
372
  """Run full training loop across all tasks."""
373
+ self.logger.info("=" * 60)
374
+ self.logger.info("REWARD-GUIDED TRAJECTORY OPTIMIZATION")
375
+ self.logger.info(f"Tasks: {TASKS}")
376
+ self.logger.info(f"Episodes per task: {NUM_EPISODES_PER_TASK}")
377
+ self.logger.info(f"Top-K trajectories: {TOP_K_TRAJECTORIES}")
378
+ self.logger.info("=" * 60)
379
+ self.logger.info(f"Log file: {self.log_file}")
380
+
381
+ try:
382
+ wandb.init(
383
+ project="policy-to-logic-rl",
384
+ name=f"trajectory-opt-{int(time.time())}",
385
+ config={
386
+ "num_episodes_per_task": NUM_EPISODES_PER_TASK,
387
+ "top_k_trajectories": TOP_K_TRAJECTORIES,
388
+ "min_reward_threshold": MIN_REWARD_THRESHOLD,
389
+ "model": MODEL,
390
+ "temperature": TEMPERATURE,
391
+ "tasks": TASKS,
392
+ "env_url": ENV_BASE_URL,
393
+ }
394
+ )
395
+ except Exception as e:
396
+ self.logger.warning(f"Wandb init failed: {e}. Continuing without W&B.")
397
 
398
  # Health check
399
  if not self.env.health():
400
  raise RuntimeError(f"Environment not reachable at {ENV_BASE_URL}")
401
+ self.logger.info(f"Environment: OK ({ENV_BASE_URL})\n")
402
 
403
  global_episode = 0
404
 
405
  for task in TASKS:
406
+ self.logger.info(f"\n{'─'*40}")
407
+ self.logger.info(f"TASK: {task}")
408
+ self.logger.info(f"{'─'*40}")
409
 
410
  task_rewards = []
411
  task_accuracies = []
 
417
  # Store in bank
418
  self.bank.store(trajectory)
419
 
420
+ try:
421
+ wandb.log({
422
+ f"{task}/total_reward": trajectory.total_reward,
423
+ f"{task}/final_accuracy": trajectory.final_accuracy,
424
+ f"{task}/num_steps": len(trajectory.steps),
425
+ f"{task}/success": int(trajectory.success),
426
+ f"{task}/few_shots_used": len(self.bank.get_examples(task)),
427
+ "global/total_reward": trajectory.total_reward,
428
+ "global/final_accuracy": trajectory.final_accuracy,
429
+ "episode": global_episode,
430
+ })
431
+ except Exception:
432
+ pass
433
+
434
  # Record metrics
435
  self.metrics.append({
436
  "global_episode": global_episode,
 
446
  task_rewards.append(trajectory.total_reward)
447
  task_accuracies.append(trajectory.final_accuracy)
448
 
449
+ self.logger.info(f" → Episode {ep} complete: reward={trajectory.total_reward:.3f} accuracy={trajectory.final_accuracy:.2f} success={trajectory.success}")
450
  time.sleep(0.5) # Rate limiting
451
 
452
+ self.logger.info(f"\n Task summary:")
453
+ self.logger.info(f" First episode reward: {task_rewards[0]:.3f}")
454
+ self.logger.info(f" Last episode reward: {task_rewards[-1]:.3f}")
455
+ self.logger.info(f" Improvement: {task_rewards[-1] - task_rewards[0]:+.3f}")
456
 
457
+ self.logger.info("\n" + "=" * 60)
458
+ self.logger.info("TRAINING COMPLETE")
459
+ self.logger.info(f"Bank summary: {self.bank.summary()}")
460
+ self.logger.info("=" * 60)
461
+
462
+ try:
463
+ wandb.finish()
464
+ except Exception:
465
+ pass
466
 
467
  return self.metrics
468
 
 
495
  "transaction_approval": "#4CAF50"
496
  }
497
 
498
+ # ── Plot 1: Reward Curve (per-task trend lines) ────────────────────────────
499
  fig, ax = plt.subplots(figsize=(10, 5))
500
 
501
  for task in TASKS:
 
503
  task_rews = [m["total_reward"] for m in metrics if m["task"] == task]
504
  ax.plot(task_eps, task_rews, marker="o", label=task,
505
  color=colors.get(task, "gray"), linewidth=2, markersize=5)
506
+ # Per-task trend line
507
+ if len(task_eps) >= 2:
508
+ z = np.polyfit(task_eps, task_rews, 1)
509
+ p = np.poly1d(z)
510
+ ax.plot(task_eps, p(task_eps), "--",
511
+ color=colors.get(task, "gray"), alpha=0.4, linewidth=1.5)
512
 
513
  ax.set_xlabel("Episode")
514
  ax.set_ylabel("Total Reward")
 
545
  plt.close()
546
  print("Saved: training/plots/accuracy_curve.png")
547
 
548
+ # ── Plot 3: Per-Task Summary (Accuracy + Efficiency) ──────────────────────
549
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
550
 
551
+ task_labels = []
552
+ acc_improvements = []
553
+ eff_improvements = []
554
+ best_accuracies = []
555
 
556
  for task in TASKS:
557
+ task_data = [m for m in metrics if m["task"] == task]
558
+ if len(task_data) >= 2:
559
+ task_labels.append(task.replace("_", "\n"))
560
+ acc_improvements.append(task_data[-1]["final_accuracy"] - task_data[0]["final_accuracy"])
561
+ # Efficiency: steps saved (first vs best)
562
+ first_steps = task_data[0]["num_steps"]
563
+ best_steps = min(m["num_steps"] for m in task_data)
564
+ eff_pct = ((first_steps - best_steps) / first_steps * 100) if first_steps > 0 else 0
565
+ eff_improvements.append(eff_pct)
566
+ best_accuracies.append(max(m["final_accuracy"] for m in task_data))
567
+
568
+ # Left: Best accuracy per task
569
+ bars1 = axes[0].bar(task_labels, best_accuracies,
570
+ color=["#2196F3", "#FF9800", "#4CAF50"][:len(task_labels)],
571
+ edgecolor="white", linewidth=1.5)
572
+ axes[0].axhline(y=0.9, color="red", linestyle="--", alpha=0.7, label="success threshold")
573
+ axes[0].set_ylabel("Best Accuracy Achieved")
574
+ axes[0].set_title("Best Accuracy Per Task")
575
+ axes[0].set_ylim(0, 1.1)
576
+ axes[0].grid(True, axis="y", alpha=0.3)
577
+ axes[0].legend()
578
+ for bar, val in zip(bars1, best_accuracies):
579
+ axes[0].text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.02,
580
+ f"{val:.0%}", ha="center", va="bottom", fontweight="bold")
581
+
582
+ # Right: Efficiency improvement (% steps saved)
583
+ bars2 = axes[1].bar(task_labels, eff_improvements,
584
+ color=["#2196F3", "#FF9800", "#4CAF50"][:len(task_labels)],
585
+ edgecolor="white", linewidth=1.5)
586
+ axes[1].axhline(y=0, color="black", linewidth=0.8)
587
+ axes[1].set_ylabel("Steps Saved (%)")
588
+ axes[1].set_title("Efficiency Improvement (First → Best Episode)")
589
+ axes[1].grid(True, axis="y", alpha=0.3)
590
+ for bar, val in zip(bars2, eff_improvements):
591
+ y_pos = max(bar.get_height() + 1, 2)
592
+ axes[1].text(bar.get_x() + bar.get_width() / 2, y_pos,
593
+ f"{val:.0f}%", ha="center", va="bottom", fontweight="bold")
594
 
595
  plt.tight_layout()
596
  plt.savefig("training/plots/improvement_chart.png", dpi=150, bbox_inches="tight")
 
598
  print("Saved: training/plots/improvement_chart.png")
599
 
600
  # ── Save raw metrics as JSON ──────────────────────────────────────────────
601
+ timestamp = int(time.time())
602
+ with open(f"training/plots/metrics_{timestamp}.json", "w") as f:
603
+ json.dump(metrics, f, indent=2)
604
+ with open("training/plots/metrics_latest.json", "w") as f:
605
  json.dump(metrics, f, indent=2)
606
+ print(f"Saved: training/plots/metrics_{timestamp}.json and metrics_latest.json")
607
 
608
  # ── Entry Point ───────────────────────────────────────────────────────────────
609
 
training/update_colab.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ with open("training/trajectory_optimizer.py", "r") as f:
5
+ traj_code = f.read()
6
+
7
+ notebook = {
8
+ "nbformat": 4,
9
+ "nbformat_minor": 0,
10
+ "metadata": {
11
+ "colab": {
12
+ "provenance": [],
13
+ "name": "Policy-to-Logic Training"
14
+ },
15
+ "kernelspec": {
16
+ "name": "python3",
17
+ "display_name": "Python 3"
18
+ },
19
+ "language_info": {
20
+ "name": "python"
21
+ }
22
+ },
23
+ "cells": [
24
+ {
25
+ "cell_type": "markdown",
26
+ "metadata": {},
27
+ "source": [
28
+ "# Policy-to-Logic RL Environment \u2014 Training Notebook\n",
29
+ "\n",
30
+ "This notebook runs the **reward-guided trajectory optimization loop** against the deployed environment.\n",
31
+ "\n",
32
+ "**What it does:**\n",
33
+ "1. Connects to the live HF Spaces environment\n",
34
+ "2. Runs 8 episodes per task (3 tasks = 24 total episodes)\n",
35
+ "3. Accumulates high-reward trajectories as few-shot examples\n",
36
+ "4. Generates training evidence plots (reward curve, accuracy curve, improvement chart)\n",
37
+ "5. Logs everything to Weights & Biases"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": None,
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "# Cell 1: Install dependencies\n",
47
+ "!pip install openai requests matplotlib numpy wandb"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": None,
53
+ "metadata": {},
54
+ "outputs": [],
55
+ "source": [
56
+ "# Cell 2: Configuration\n",
57
+ "import os\n",
58
+ "\n",
59
+ "# SET THESE BEFORE RUNNING\n",
60
+ "HF_TOKEN = \"\" # Your Hugging Face token with inference access\n",
61
+ "ENV_URL = \"https://godreign-policy2logic.hf.space\" # Your deployed environment URL\n",
62
+ "WANDB_API_KEY = \"\" # Your Wandb API key\n",
63
+ "\n",
64
+ "os.environ[\"HF_TOKEN\"] = HF_TOKEN\n",
65
+ "os.environ[\"ENV_BASE_URL\"] = ENV_URL\n",
66
+ "if WANDB_API_KEY:\n",
67
+ " os.environ[\"WANDB_API_KEY\"] = WANDB_API_KEY\n",
68
+ "\n",
69
+ "print(f\"Environment URL: {ENV_URL}\")\n",
70
+ "print(f\"HF Token set: {'Yes' if HF_TOKEN else 'NO - MUST SET THIS'}\")\n",
71
+ "print(f\"Wandb Token set: {'Yes' if WANDB_API_KEY else 'NO - WILL PROMPT'}\")"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": None,
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "# Cell 3: Verify environment is reachable\n",
81
+ "import requests\n",
82
+ "\n",
83
+ "r = requests.get(f\"{ENV_URL}/health\")\n",
84
+ "print(f\"Status: {r.status_code}\")\n",
85
+ "print(f\"Response: {r.json()}\")\n",
86
+ "\n",
87
+ "r2 = requests.get(f\"{ENV_URL}/tasks\")\n",
88
+ "tasks = r2.json()\n",
89
+ "print(f\"\\nAvailable tasks: {list(tasks['tasks'].keys())}\")"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": None,
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "# Cell 4: Wandb login — run this before training\n",
99
+ "import wandb\n",
100
+ "wandb.login() # Will prompt for API key if WANDB_API_KEY is not set"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": None,
106
+ "metadata": {},
107
+ "outputs": [],
108
+ "source": [
109
+ "# Cell 5: Training loop implementation (full trajectory_optimizer.py)\n"
110
+ ] + [line + "\n" for line in traj_code.split("\n")[:-1]] + [traj_code.split("\n")[-1]]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": None,
115
+ "metadata": {},
116
+ "outputs": [],
117
+ "source": [
118
+ "# Cell 6: Run training loop\n",
119
+ "loop = TrainingLoop(ENV_URL, HF_TOKEN)\n",
120
+ "metrics = loop.run()\n",
121
+ "print(f\"\\nTotal episodes run: {len(metrics)}\")"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": None,
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": [
130
+ "# Cell 7: Generate plots and display inline\n",
131
+ "save_plots(metrics)\n",
132
+ "\n",
133
+ "from IPython.display import Image, display\n",
134
+ "display(Image(\"training/plots/reward_curve.png\"))\n",
135
+ "display(Image(\"training/plots/accuracy_curve.png\"))\n",
136
+ "display(Image(\"training/plots/improvement_chart.png\"))"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": None,
142
+ "metadata": {},
143
+ "outputs": [],
144
+ "source": [
145
+ "# Cell 8: Display wandb run link\n",
146
+ "print(f\"Wandb run: https://wandb.ai/YOUR_USERNAME/policy-to-logic-rl\")\n",
147
+ "print(\"Add this link to your README under Deliverables.\")"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": None,
153
+ "metadata": {},
154
+ "outputs": [],
155
+ "source": [
156
+ "# Cell 9: Download plots to commit to repo\n",
157
+ "from google.colab import files\n",
158
+ "\n",
159
+ "files.download(\"training/plots/reward_curve.png\")\n",
160
+ "files.download(\"training/plots/accuracy_curve.png\")\n",
161
+ "files.download(\"training/plots/improvement_chart.png\")\n",
162
+ "files.download(\"training/plots/metrics_latest.json\")\n",
163
+ "\n",
164
+ "print(\"Downloaded. Now commit these files to: training/plots/ in your repo.\")"
165
+ ]
166
+ }
167
+ ]
168
+ }
169
+
170
+ with open("training/colab_training.ipynb", "w") as f:
171
+ json.dump(notebook, f, indent=1)
172
+
173
+ print("Colab Notebook updated successfully")
uv.lock CHANGED
@@ -449,6 +449,30 @@ wheels = [
449
  { url = "https://files.pythonhosted.org/packages/d5/1f/5f4a3cd9e4440e9d9bc78ad0a91a1c8d46b4d429d5239ebe6793c9fe5c41/fsspec-2026.3.0-py3-none-any.whl", hash = "sha256:d2ceafaad1b3457968ed14efa28798162f1638dbb5d2a6868a2db002a5ee39a4", size = 202595, upload-time = "2026-03-27T19:11:13.595Z" },
450
  ]
451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
  [[package]]
453
  name = "h11"
454
  version = "0.16.0"
@@ -1161,6 +1185,15 @@ wheels = [
1161
  { url = "https://files.pythonhosted.org/packages/bc/60/5382c03e1970de634027cee8e1b7d39776b778b81812aaf45b694dfe9e28/pillow-12.2.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:bfa9c230d2fe991bed5318a5f119bd6780cda2915cca595393649fc118ab895e", size = 7080946, upload-time = "2026-04-01T14:46:11.734Z" },
1162
  ]
1163
 
 
 
 
 
 
 
 
 
 
1164
  [[package]]
1165
  name = "pluggy"
1166
  version = "1.6.0"
@@ -1185,6 +1218,7 @@ dependencies = [
1185
  { name = "pydantic" },
1186
  { name = "requests" },
1187
  { name = "uvicorn" },
 
1188
  ]
1189
 
1190
  [package.optional-dependencies]
@@ -1206,9 +1240,25 @@ requires-dist = [
1206
  { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0" },
1207
  { name = "requests", specifier = ">=2.25.0" },
1208
  { name = "uvicorn", specifier = ">=0.24.0" },
 
1209
  ]
1210
  provides-extras = ["dev"]
1211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1212
  [[package]]
1213
  name = "pydantic"
1214
  version = "2.13.3"
@@ -1480,6 +1530,19 @@ wheels = [
1480
  { url = "https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl", hash = "sha256:33bd4ef74232fb73fe9279a257718407f169c09b78a87ad3d296f548e27de0bb", size = 310654, upload-time = "2026-04-12T08:24:02.83Z" },
1481
  ]
1482
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1483
  [[package]]
1484
  name = "shellingham"
1485
  version = "1.5.4"
@@ -1498,6 +1561,15 @@ wheels = [
1498
  { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" },
1499
  ]
1500
 
 
 
 
 
 
 
 
 
 
1501
  [[package]]
1502
  name = "sniffio"
1503
  version = "1.3.1"
@@ -1644,3 +1716,32 @@ sdist = { url = "https://files.pythonhosted.org/packages/1f/93/041fca8274050e40e
1644
  wheels = [
1645
  { url = "https://files.pythonhosted.org/packages/31/a3/5b1562db76a5a488274b2332a97199b32d0442aca0ed193697fd47786316/uvicorn-0.46.0-py3-none-any.whl", hash = "sha256:bbebbcbed972d162afca128605223022bedd345b7bc7855ce66deb31487a9048", size = 70926, upload-time = "2026-04-23T07:15:58.355Z" },
1646
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
449
  { url = "https://files.pythonhosted.org/packages/d5/1f/5f4a3cd9e4440e9d9bc78ad0a91a1c8d46b4d429d5239ebe6793c9fe5c41/fsspec-2026.3.0-py3-none-any.whl", hash = "sha256:d2ceafaad1b3457968ed14efa28798162f1638dbb5d2a6868a2db002a5ee39a4", size = 202595, upload-time = "2026-03-27T19:11:13.595Z" },
450
  ]
451
 
452
+ [[package]]
453
+ name = "gitdb"
454
+ version = "4.0.12"
455
+ source = { registry = "https://pypi.org/simple" }
456
+ dependencies = [
457
+ { name = "smmap" },
458
+ ]
459
+ sdist = { url = "https://files.pythonhosted.org/packages/72/94/63b0fc47eb32792c7ba1fe1b694daec9a63620db1e313033d18140c2320a/gitdb-4.0.12.tar.gz", hash = "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571", size = 394684, upload-time = "2025-01-02T07:20:46.413Z" }
460
+ wheels = [
461
+ { url = "https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl", hash = "sha256:67073e15955400952c6565cc3e707c554a4eea2e428946f7a4c162fab9bd9bcf", size = 62794, upload-time = "2025-01-02T07:20:43.624Z" },
462
+ ]
463
+
464
+ [[package]]
465
+ name = "gitpython"
466
+ version = "3.1.47"
467
+ source = { registry = "https://pypi.org/simple" }
468
+ dependencies = [
469
+ { name = "gitdb" },
470
+ ]
471
+ sdist = { url = "https://files.pythonhosted.org/packages/c1/bd/50db468e9b1310529a19fce651b3b0e753b5c07954d486cba31bbee9a5d5/gitpython-3.1.47.tar.gz", hash = "sha256:dba27f922bd2b42cb54c87a8ab3cb6beb6bf07f3d564e21ac848913a05a8a3cd", size = 216978, upload-time = "2026-04-22T02:44:44.059Z" }
472
+ wheels = [
473
+ { url = "https://files.pythonhosted.org/packages/f2/c5/a1bc0996af85757903cf2bf444a7824e68e0035ce63fb41d6f76f9def68b/gitpython-3.1.47-py3-none-any.whl", hash = "sha256:489f590edfd6d20571b2c0e72c6a6ac6915ee8b8cd04572330e3842207a78905", size = 209547, upload-time = "2026-04-22T02:44:41.271Z" },
474
+ ]
475
+
476
  [[package]]
477
  name = "h11"
478
  version = "0.16.0"
 
1185
  { url = "https://files.pythonhosted.org/packages/bc/60/5382c03e1970de634027cee8e1b7d39776b778b81812aaf45b694dfe9e28/pillow-12.2.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:bfa9c230d2fe991bed5318a5f119bd6780cda2915cca595393649fc118ab895e", size = 7080946, upload-time = "2026-04-01T14:46:11.734Z" },
1186
  ]
1187
 
1188
+ [[package]]
1189
+ name = "platformdirs"
1190
+ version = "4.9.6"
1191
+ source = { registry = "https://pypi.org/simple" }
1192
+ sdist = { url = "https://files.pythonhosted.org/packages/9f/4a/0883b8e3802965322523f0b200ecf33d31f10991d0401162f4b23c698b42/platformdirs-4.9.6.tar.gz", hash = "sha256:3bfa75b0ad0db84096ae777218481852c0ebc6c727b3168c1b9e0118e458cf0a", size = 29400, upload-time = "2026-04-09T00:04:10.812Z" }
1193
+ wheels = [
1194
+ { url = "https://files.pythonhosted.org/packages/75/a6/a0a304dc33b49145b21f4808d763822111e67d1c3a32b524a1baf947b6e1/platformdirs-4.9.6-py3-none-any.whl", hash = "sha256:e61adb1d5e5cb3441b4b7710bea7e4c12250ca49439228cc1021c00dcfac0917", size = 21348, upload-time = "2026-04-09T00:04:09.463Z" },
1195
+ ]
1196
+
1197
  [[package]]
1198
  name = "pluggy"
1199
  version = "1.6.0"
 
1218
  { name = "pydantic" },
1219
  { name = "requests" },
1220
  { name = "uvicorn" },
1221
+ { name = "wandb" },
1222
  ]
1223
 
1224
  [package.optional-dependencies]
 
1240
  { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0" },
1241
  { name = "requests", specifier = ">=2.25.0" },
1242
  { name = "uvicorn", specifier = ">=0.24.0" },
1243
+ { name = "wandb", specifier = ">=0.16.0" },
1244
  ]
1245
  provides-extras = ["dev"]
1246
 
1247
+ [[package]]
1248
+ name = "protobuf"
1249
+ version = "7.34.1"
1250
+ source = { registry = "https://pypi.org/simple" }
1251
+ sdist = { url = "https://files.pythonhosted.org/packages/6b/6b/a0e95cad1ad7cc3f2c6821fcab91671bd5b78bd42afb357bb4765f29bc41/protobuf-7.34.1.tar.gz", hash = "sha256:9ce42245e704cc5027be797c1db1eb93184d44d1cdd71811fb2d9b25ad541280", size = 454708, upload-time = "2026-03-20T17:34:47.036Z" }
1252
+ wheels = [
1253
+ { url = "https://files.pythonhosted.org/packages/ec/11/3325d41e6ee15bf1125654301211247b042563bcc898784351252549a8ad/protobuf-7.34.1-cp310-abi3-macosx_10_9_universal2.whl", hash = "sha256:d8b2cc79c4d8f62b293ad9b11ec3aebce9af481fa73e64556969f7345ebf9fc7", size = 429247, upload-time = "2026-03-20T17:34:37.024Z" },
1254
+ { url = "https://files.pythonhosted.org/packages/eb/9d/aa69df2724ff63efa6f72307b483ce0827f4347cc6d6df24b59e26659fef/protobuf-7.34.1-cp310-abi3-manylinux2014_aarch64.whl", hash = "sha256:5185e0e948d07abe94bb76ec9b8416b604cfe5da6f871d67aad30cbf24c3110b", size = 325753, upload-time = "2026-03-20T17:34:38.751Z" },
1255
+ { url = "https://files.pythonhosted.org/packages/92/e8/d174c91fd48e50101943f042b09af9029064810b734e4160bbe282fa1caa/protobuf-7.34.1-cp310-abi3-manylinux2014_s390x.whl", hash = "sha256:403b093a6e28a960372b44e5eb081775c9b056e816a8029c61231743d63f881a", size = 340198, upload-time = "2026-03-20T17:34:39.871Z" },
1256
+ { url = "https://files.pythonhosted.org/packages/53/1b/3b431694a4dc6d37b9f653f0c64b0a0d9ec074ee810710c0c3da21d67ba7/protobuf-7.34.1-cp310-abi3-manylinux2014_x86_64.whl", hash = "sha256:8ff40ce8cd688f7265326b38d5a1bed9bfdf5e6723d49961432f83e21d5713e4", size = 324267, upload-time = "2026-03-20T17:34:41.1Z" },
1257
+ { url = "https://files.pythonhosted.org/packages/85/29/64de04a0ac142fb685fd09999bc3d337943fb386f3a0ec57f92fd8203f97/protobuf-7.34.1-cp310-abi3-win32.whl", hash = "sha256:34b84ce27680df7cca9f231043ada0daa55d0c44a2ddfaa58ec1d0d89d8bf60a", size = 426628, upload-time = "2026-03-20T17:34:42.536Z" },
1258
+ { url = "https://files.pythonhosted.org/packages/4d/87/cb5e585192a22b8bd457df5a2c16a75ea0db9674c3a0a39fc9347d84e075/protobuf-7.34.1-cp310-abi3-win_amd64.whl", hash = "sha256:e97b55646e6ce5cbb0954a8c28cd39a5869b59090dfaa7df4598a7fba869468c", size = 437901, upload-time = "2026-03-20T17:34:44.112Z" },
1259
+ { url = "https://files.pythonhosted.org/packages/88/95/608f665226bca68b736b79e457fded9a2a38c4f4379a4a7614303d9db3bc/protobuf-7.34.1-py3-none-any.whl", hash = "sha256:bb3812cd53aefea2b028ef42bd780f5b96407247f20c6ef7c679807e9d188f11", size = 170715, upload-time = "2026-03-20T17:34:45.384Z" },
1260
+ ]
1261
+
1262
  [[package]]
1263
  name = "pydantic"
1264
  version = "2.13.3"
 
1530
  { url = "https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl", hash = "sha256:33bd4ef74232fb73fe9279a257718407f169c09b78a87ad3d296f548e27de0bb", size = 310654, upload-time = "2026-04-12T08:24:02.83Z" },
1531
  ]
1532
 
1533
+ [[package]]
1534
+ name = "sentry-sdk"
1535
+ version = "2.58.0"
1536
+ source = { registry = "https://pypi.org/simple" }
1537
+ dependencies = [
1538
+ { name = "certifi" },
1539
+ { name = "urllib3" },
1540
+ ]
1541
+ sdist = { url = "https://files.pythonhosted.org/packages/26/b3/fb8291170d0e844173164709fc0fa0c221ed75a5da740c8746f2a83b4eb1/sentry_sdk-2.58.0.tar.gz", hash = "sha256:c1144d947352d54e5b7daa63596d9f848adf684989c06c4f5a659f0c85a18f6f", size = 438764, upload-time = "2026-04-13T17:23:26.265Z" }
1542
+ wheels = [
1543
+ { url = "https://files.pythonhosted.org/packages/fa/eb/d875669993b762556ae8b2efd86219943b4c0864d22204d622a9aee3052b/sentry_sdk-2.58.0-py2.py3-none-any.whl", hash = "sha256:688d1c704ddecf382ea3326f21a67453d4caa95592d722b7c780a36a9d23109e", size = 460919, upload-time = "2026-04-13T17:23:24.675Z" },
1544
+ ]
1545
+
1546
  [[package]]
1547
  name = "shellingham"
1548
  version = "1.5.4"
 
1561
  { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" },
1562
  ]
1563
 
1564
+ [[package]]
1565
+ name = "smmap"
1566
+ version = "5.0.3"
1567
+ source = { registry = "https://pypi.org/simple" }
1568
+ sdist = { url = "https://files.pythonhosted.org/packages/1f/ea/49c993d6dfdd7338c9b1000a0f36817ed7ec84577ae2e52f890d1a4ff909/smmap-5.0.3.tar.gz", hash = "sha256:4d9debb8b99007ae47165abc08670bd74cb74b5227dda7f643eccc4e9eb5642c", size = 22506, upload-time = "2026-03-09T03:43:26.1Z" }
1569
+ wheels = [
1570
+ { url = "https://files.pythonhosted.org/packages/c1/d4/59e74daffcb57a07668852eeeb6035af9f32cbfd7a1d2511f17d2fe6a738/smmap-5.0.3-py3-none-any.whl", hash = "sha256:c106e05d5a61449cf6ba9a1e650227ecfb141590d2a98412103ff35d89fc7b2f", size = 24390, upload-time = "2026-03-09T03:43:24.361Z" },
1571
+ ]
1572
+
1573
  [[package]]
1574
  name = "sniffio"
1575
  version = "1.3.1"
 
1716
  wheels = [
1717
  { url = "https://files.pythonhosted.org/packages/31/a3/5b1562db76a5a488274b2332a97199b32d0442aca0ed193697fd47786316/uvicorn-0.46.0-py3-none-any.whl", hash = "sha256:bbebbcbed972d162afca128605223022bedd345b7bc7855ce66deb31487a9048", size = 70926, upload-time = "2026-04-23T07:15:58.355Z" },
1718
  ]
1719
+
1720
+ [[package]]
1721
+ name = "wandb"
1722
+ version = "0.26.1"
1723
+ source = { registry = "https://pypi.org/simple" }
1724
+ dependencies = [
1725
+ { name = "click" },
1726
+ { name = "gitpython" },
1727
+ { name = "packaging" },
1728
+ { name = "platformdirs" },
1729
+ { name = "protobuf" },
1730
+ { name = "pydantic" },
1731
+ { name = "pyyaml" },
1732
+ { name = "requests" },
1733
+ { name = "sentry-sdk" },
1734
+ { name = "typing-extensions" },
1735
+ ]
1736
+ sdist = { url = "https://files.pythonhosted.org/packages/6a/a4/72a6640e1f566e81f184a426e3e45298d4c6672664de41adb7eb6f64370a/wandb-0.26.1.tar.gz", hash = "sha256:eef2dbaea06f0b1c0cdc5d76f544ae4c2b8848fc512442a00bd59f0502fc8aa1", size = 42159814, upload-time = "2026-04-23T16:27:34.033Z" }
1737
+ wheels = [
1738
+ { url = "https://files.pythonhosted.org/packages/8c/09/3296235f3906e904f06f2df29eed4d672fb23c0932c9486e2af64f2f2a66/wandb-0.26.1-py3-none-macosx_12_0_arm64.whl", hash = "sha256:2955fe190c005fb83ee6d73f066c8a33f09f3212a1f2eb53faa6581440e456be", size = 24857204, upload-time = "2026-04-23T16:26:58.576Z" },
1739
+ { url = "https://files.pythonhosted.org/packages/a1/ad/e39ca3086534129e42208ba00ed2c6247ce425f890219eeec33b4f162864/wandb-0.26.1-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:55d91cabde98162d7116a5e19ddd052bd9848556243f1da4cbb9ffb7ad435bfc", size = 26014649, upload-time = "2026-04-23T16:27:02.559Z" },
1740
+ { url = "https://files.pythonhosted.org/packages/56/af/400d84a3bdce0b062b4baa70acb6becd2c8018697f4fbf5af9a9e1e406e5/wandb-0.26.1-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:7c78bc2454cfe1ffa1c3a256060a387356eed8a4488e024d9d2eba8f2b5bd51d", size = 25421317, upload-time = "2026-04-23T16:27:06.411Z" },
1741
+ { url = "https://files.pythonhosted.org/packages/7b/e9/b4bf8f3509dcea1cec52233a38991459654635b5a8e6a494eb912e1b9cfb/wandb-0.26.1-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:a2c8eeec8706dcd2872e69c3b4d20ec523082fdb4440295491556e219ad2aa67", size = 27192831, upload-time = "2026-04-23T16:27:10.308Z" },
1742
+ { url = "https://files.pythonhosted.org/packages/62/cf/4a6dce0c782223ef0eeea7139daee73418a7322befcf083512c31cebaa18/wandb-0.26.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:2fa768ee0636a569afb7541cf996e56309c47070566a38916823f94e02afe586", size = 25593326, upload-time = "2026-04-23T16:27:14.259Z" },
1743
+ { url = "https://files.pythonhosted.org/packages/df/99/58c3d8c36ae8e2b7d70bf6493eb5daa1cca0231a04b025717b4cd1a78f1e/wandb-0.26.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5854928725cfeff1f284d5c043cd353f810e5da02eead2c120ef5056ad026fea", size = 27535542, upload-time = "2026-04-23T16:27:18.473Z" },
1744
+ { url = "https://files.pythonhosted.org/packages/7c/d0/4e846ffc1d0cc435518dfa581ce73ac82cfd0ebbf35f3853c9277f632e5f/wandb-0.26.1-py3-none-win32.whl", hash = "sha256:5c2bd44e575ae9944e2764d1aaa031461178276bf2636d5558399c2816ef5cfe", size = 24968151, upload-time = "2026-04-23T16:27:22.086Z" },
1745
+ { url = "https://files.pythonhosted.org/packages/e3/9b/487413eaccefdb58799a226726e24b486e9192d2671c75a4550c160aba23/wandb-0.26.1-py3-none-win_amd64.whl", hash = "sha256:5817785467d3f1676f1812ec19a89f77f6e56dfe67d9f47080075af95f705d3e", size = 24968155, upload-time = "2026-04-23T16:27:25.731Z" },
1746
+ { url = "https://files.pythonhosted.org/packages/04/dc/5baf3e99b3eeb709d6f75124b5bec8cb73d4b38d2b10df7fdcfde4966200/wandb-0.26.1-py3-none-win_arm64.whl", hash = "sha256:f848b7744f896bc04cabbb28360a2814d1551a91fa2c456243e06435729c8a2e", size = 22912416, upload-time = "2026-04-23T16:27:29.456Z" },
1747
+ ]