adityss commited on
Commit
e890cbb
·
1 Parent(s): 3b977fc

feat: add GRPO training notebook for GridMind-RL environment

Browse files
Files changed (1) hide show
  1. scripts/gridmind_grpo_colab.ipynb +624 -624
scripts/gridmind_grpo_colab.ipynb CHANGED
@@ -1,626 +1,626 @@
1
  {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "193da661",
6
- "metadata": {},
7
- "source": [
8
- "# GridMind-RL: GRPO Training for Industrial Energy Management\n",
9
- "\n",
10
- "**Meta PyTorch OpenEnv Hackathon GridMind-RL Team**\n",
11
- "\n",
12
- "This notebook trains a small LLM (Qwen2.5-1.5B) using TRL GRPO on the GridMind-RL environment.\n",
13
- "The environment covers all 4 hackathon themes:\n",
14
- "\n",
15
- "1. **Theme 1: Multi-Agent** 3 buildings share a grid feeder; each agent makes independent decisions\n",
16
- "2. **Theme 2: Instruction Following** Task 4 provides natural language objectives that must be satisfied\n",
17
- "3. **Theme 3: World Modeling** `/simulate` endpoint predicts outcomes before committing actions\n",
18
- "4. **Theme 4: Self-Improvement** Curriculum automatically advances difficulty as agent performance improves\n",
19
- "\n",
20
- "| | |\n",
21
- "|---|---|\n",
22
- "| **Environment** | https://lo-kyu-gridmind.hf.space |\n",
23
- "| **Method** | GRPO (Group Relative Policy Optimization) |\n",
24
- "| **Model** | Qwen2.5-1.5B-Instruct |\n",
25
- "| **Training Time** | ~30-40 minutes on free Colab T4 GPU |\n",
26
- "| **Expected Improvement** | 20-40% score gain over heuristic baseline |"
27
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  },
29
- {
30
- "cell_type": "code",
31
- "execution_count": null,
32
- "id": "f28e2f2c",
33
- "metadata": {},
34
- "outputs": [],
35
- "source": [
36
- "# Install dependencies\n",
37
- "!pip install trl==0.8.6 transformers==4.40.0 torch accelerate datasets requests -q\n",
38
- "\n",
39
- "import torch\n",
40
- "import sys\n",
41
- "\n",
42
- "print(f\"PyTorch: {torch.__version__}\")\n",
43
- "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
44
- "if torch.cuda.is_available():\n",
45
- " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
46
- " print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")"
47
- ]
48
- },
49
- {
50
- "cell_type": "markdown",
51
- "id": "5021a299",
52
- "metadata": {},
53
- "source": [
54
- "## Step 1: Connect to Environment and Verify Connectivity"
55
- ]
56
- },
57
- {
58
- "cell_type": "code",
59
- "execution_count": null,
60
- "id": "4cdf0f35",
61
- "metadata": {},
62
- "outputs": [],
63
- "source": [
64
- "import requests\n",
65
- "import json\n",
66
- "import time\n",
67
- "\n",
68
- "ENV_URL = \"https://lo-kyu-gridmind.hf.space\"\n",
69
- "\n",
70
- "# Test connectivity\n",
71
- "print(\"Testing environment connectivity...\")\n",
72
- "try:\n",
73
- " health = requests.get(f\"{ENV_URL}/health\", timeout=10).json()\n",
74
- " print(f\"✓ Health check: {health}\")\n",
75
- "except Exception as e:\n",
76
- " print(f\"✗ Health check failed: {e}\")\n",
77
- " sys.exit(1)\n",
78
- "\n",
79
- "# Test each task reset\n",
80
- "print(\"\\nTesting all 4 tasks...\")\n",
81
- "for task_id in [1, 2, 3, 4]:\n",
82
- " try:\n",
83
- " r = requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": task_id}, timeout=10)\n",
84
- " obs = r.json()\n",
85
- " has_card = \"instruction_card\" in obs or \"observations\" in obs and obs[\"observations\"][0].get(\"instruction_card\")\n",
86
- " print(f\"✓ Task {task_id}: status={r.status_code}, has_instruction_card={has_card}\")\n",
87
- " except Exception as e:\n",
88
- " print(f\"✗ Task {task_id} failed: {e}\")\n",
89
- "\n",
90
- "# Test coordinator (multi-agent)\n",
91
- "print(\"\\nTesting multi-agent coordinator...\")\n",
92
- "try:\n",
93
- " r = requests.post(f\"{ENV_URL}/coordinator/reset\", json={}, timeout=10)\n",
94
- " obs = r.json()\n",
95
- " n_buildings = len(obs.get(\"observations\", []))\n",
96
- " print(f\"✓ Coordinator reset: {n_buildings} buildings\")\n",
97
- "except Exception as e:\n",
98
- " print(f\"✗ Coordinator failed: {e}\")\n",
99
- "\n",
100
- "# Test world modeling\n",
101
- "print(\"\\nTesting world modeling (/simulate)...\")\n",
102
- "try:\n",
103
- " r = requests.post(f\"{ENV_URL}/simulate\", \n",
104
- " json=[{\"hvac_power_level\": 0.5, \"thermal_charge_rate\": 0.0, \n",
105
- " \"batch_job_slot\": 0, \"load_shed_fraction\": 0.0, \"building_id\": 0}],\n",
106
- " timeout=10)\n",
107
- " sim = r.json()\n",
108
- " has_results = \"results\" in sim\n",
109
- " print(f\"✓ Simulate: has_results={has_results}\")\n",
110
- "except Exception as e:\n",
111
- " print(f\"✗ Simulate failed: {e}\")\n",
112
- "\n",
113
- "print(\"\\n✓ All connectivity checks passed!\")"
114
- ]
115
- },
116
- {
117
- "cell_type": "markdown",
118
- "id": "4a5b58c2",
119
- "metadata": {},
120
- "source": [
121
- "## Step 2: Measure Baseline Performance (Before Training)"
122
- ]
123
- },
124
- {
125
- "cell_type": "code",
126
- "execution_count": null,
127
- "id": "42cecadb",
128
- "metadata": {},
129
- "outputs": [],
130
- "source": [
131
- "import random\n",
132
- "\n",
133
- "def run_heuristic_episode(task_id=1, max_steps=96):\n",
134
- " \"\"\"Run an episode using a rule-based heuristic policy.\"\"\"\n",
135
- " try:\n",
136
- " r = requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": task_id}, timeout=10)\n",
137
- " obs_data = r.json()\n",
138
- " obs = obs_data[\"observations\"][0] if \"observations\" in obs_data else obs_data\n",
139
- " except:\n",
140
- " return 0.0\n",
141
- " \n",
142
- " for step in range(max_steps):\n",
143
- " # Simple heuristic: charge off-peak, discharge peak\n",
144
- " hour = step // 4\n",
145
- " hvac = 0.7 if 8 <= hour <= 18 else 0.3\n",
146
- " charge = 0.6 if hour < 6 else (-0.4 if 14 <= hour <= 18 else 0.0)\n",
147
- " shed = 0.3 if 14 <= hour <= 17 else 0.0\n",
148
- " \n",
149
- " action = {\n",
150
- " \"hvac_power_level\": hvac,\n",
151
- " \"thermal_charge_rate\": charge,\n",
152
- " \"batch_job_slot\": 1 if 22 <= hour or hour <= 5 else 0,\n",
153
- " \"load_shed_fraction\": shed,\n",
154
- " \"building_id\": 0\n",
155
- " }\n",
156
- " \n",
157
- " try:\n",
158
- " r = requests.post(f\"{ENV_URL}/step\", json=action, timeout=8)\n",
159
- " step_data = r.json()\n",
160
- " if isinstance(step_data, list):\n",
161
- " step_data = step_data[0]\n",
162
- " obs = step_data.get(\"observation\", obs)\n",
163
- " if step_data.get(\"done\", False):\n",
164
- " break\n",
165
- " except:\n",
166
- " break\n",
167
- " \n",
168
- " # Get final grade\n",
169
- " try:\n",
170
- " grade = requests.get(f\"{ENV_URL}/grade\", timeout=10).json()\n",
171
- " return float(grade.get(\"score\", 0))\n",
172
- " except:\n",
173
- " return 0.0\n",
174
- "\n",
175
- "print(\"Measuring heuristic baseline (2 episodes per task)...\")\n",
176
- "baseline_scores = {}\n",
177
- "for task_id in [1, 2, 3, 4]:\n",
178
- " scores = []\n",
179
- " for ep in range(2):\n",
180
- " score = run_heuristic_episode(task_id=task_id)\n",
181
- " scores.append(score)\n",
182
- " print(f\" Task {task_id} Episode {ep+1}: {score:.3f}\")\n",
183
- " baseline_scores[task_id] = sum(scores) / len(scores)\n",
184
- "\n",
185
- "print(f\"\\nHeuristic Baseline Averages:\")\n",
186
- "for task_id, avg in baseline_scores.items():\n",
187
- " print(f\" Task {task_id}: {avg:.3f}\")\n",
188
- "print(f\" Overall: {sum(baseline_scores.values()) / len(baseline_scores):.3f}\")"
189
- ]
190
- },
191
- {
192
- "cell_type": "markdown",
193
- "id": "7abdd330",
194
- "metadata": {},
195
- "source": [
196
- "## Step 3: Build Multi-Theme Training Dataset"
197
- ]
198
- },
199
- {
200
- "cell_type": "code",
201
- "execution_count": null,
202
- "id": "1c496af9",
203
- "metadata": {},
204
- "outputs": [],
205
- "source": [
206
- "# Build a dataset that covers all 4 themes\n",
207
- "dataset = []\n",
208
- "\n",
209
- "# Theme 1: Multi-Agent (3 buildings cooperating)\n",
210
- "print(\"Building multi-agent theme examples...\")\n",
211
- "for i in range(20):\n",
212
- " try:\n",
213
- " resp = requests.post(f\"{ENV_URL}/coordinator/reset\", json={}, timeout=10).json()\n",
214
- " if \"observations\" in resp:\n",
215
- " for b_idx, b_obs in enumerate(resp[\"observations\"]):\n",
216
- " prompt = f\"\"\"You control Building {b_idx} in a 3-building facility.\n",
217
- "All buildings share one grid connection (feeder limit: 250 kW).\n",
218
- "Your current state: temp={b_obs.get('indoor_temperature', 21):.1f}°C, \n",
219
- "storage={b_obs.get('thermal_storage_level', 0.5):.2f}, \n",
220
- "price=${b_obs.get('current_price', 0.1):.3f}/kWh\n",
221
- "Grid stress signal: {b_obs.get('grid_stress_signal', 0):.2f}\n",
222
- "\n",
223
- "You must coordinate with other buildings to keep total feeder load under 250 kW.\n",
224
- "Each building decides independently. Respond with your JSON action:\n",
225
- "{{\"hvac_power_level\": <0-1>, \"thermal_charge_rate\": <-1 to 1>, \"batch_job_slot\": <0-4>, \n",
226
- "\"load_shed_fraction\": <0-0.5>, \"building_id\": {b_idx}}}\"\"\"\n",
227
- " dataset.append({\"prompt\": prompt, \"theme\": \"multi_agent\"})\n",
228
- " except:\n",
229
- " pass\n",
230
- "\n",
231
- "print(f\"Multi-agent examples: {len([d for d in dataset if d.get('theme')=='multi_agent'])}\")\n",
232
- "\n",
233
- "# Theme 2: Instruction Following (Task 4 with explicit objectives)\n",
234
- "print(\"Building instruction-following theme examples...\")\n",
235
- "for i in range(20):\n",
236
- " try:\n",
237
- " resp = requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": 4}, timeout=10).json()\n",
238
- " if \"observations\" in resp:\n",
239
- " obs = resp[\"observations\"][0]\n",
240
- " instruction = resp.get(\"instruction_card\", obs.get(\"instruction_card\", {}))\n",
241
- " instruction_text = instruction.get(\"text\", \"Minimize cost\") if isinstance(instruction, dict) else str(instruction)\n",
242
- " prompt = f\"\"\"INSTRUCTION CARD: {instruction_text}\n",
243
- "\n",
244
- "Current state: temp={obs.get('indoor_temperature', 21):.1f}°C, \n",
245
- "storage={obs.get('thermal_storage_level', 0.5):.2f}, \n",
246
- "cost_so_far=${obs.get('cumulative_cost', 0):.2f}, \n",
247
- "step={obs.get('step', 0)}/96\n",
248
- "\n",
249
- "You MUST satisfy the instruction. Output JSON action:\n",
250
- "{{\"hvac_power_level\": <0-1>, \"thermal_charge_rate\": <-1 to 1>, \"batch_job_slot\": <0-4>, \n",
251
- "\"load_shed_fraction\": <0-0.5>, \"building_id\": 0}}\"\"\"\n",
252
- " dataset.append({\"prompt\": prompt, \"theme\": \"instruction_following\"})\n",
253
- " except:\n",
254
- " pass\n",
255
- "\n",
256
- "print(f\"Instruction-following examples: {len([d for d in dataset if d.get('theme')=='instruction_following'])}\")\n",
257
- "\n",
258
- "# Theme 3: World Modeling (use /simulate)\n",
259
- "print(\"Building world-modeling theme examples...\")\n",
260
- "for task_id in [1, 2]:\n",
261
- " for i in range(10):\n",
262
- " try:\n",
263
- " resp = requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": task_id}, timeout=10).json()\n",
264
- " if \"observations\" in resp:\n",
265
- " obs = resp[\"observations\"][0]\n",
266
- " # Simulate 2 candidate actions\n",
267
- " try:\n",
268
- " sim_a = requests.post(f\"{ENV_URL}/simulate\",\n",
269
- " json=[{\"hvac_power_level\": 0.8, \"thermal_charge_rate\": 0.3,\n",
270
- " \"batch_job_slot\": 0, \"load_shed_fraction\": 0.0, \"building_id\": 0}],\n",
271
- " timeout=10).json()\n",
272
- " sim_b = requests.post(f\"{ENV_URL}/simulate\",\n",
273
- " json=[{\"hvac_power_level\": 0.3, \"thermal_charge_rate\": -0.2,\n",
274
- " \"batch_job_slot\": 0, \"load_shed_fraction\": 0.2, \"building_id\": 0}],\n",
275
- " timeout=10).json()\n",
276
- " sim_context = \"\\nPredicted outcomes:\\nOption A (high HVAC): efficient\\nOption B (low HVAC): economical\"\n",
277
- " except:\n",
278
- " sim_context = \"\"\n",
279
- " \n",
280
- " prompt = f\"\"\"Plan your actions using simulation of future outcomes.\n",
281
- "State: temp={obs.get('indoor_temperature', 21):.1f}°C, storage={obs.get('thermal_storage_level', 0.5):.2f}{sim_context}\n",
282
- "\n",
283
- "Output your best JSON action:\n",
284
- "{{\"hvac_power_level\": <0-1>, \"thermal_charge_rate\": <-1 to 1>, \"batch_job_slot\": <0-4>, \n",
285
- "\"load_shed_fraction\": <0-0.5>, \"building_id\": 0}}\"\"\"\n",
286
- " dataset.append({\"prompt\": prompt, \"theme\": \"world_modeling\"})\n",
287
- " except:\n",
288
- " pass\n",
289
- "\n",
290
- "print(f\"World-modeling examples: {len([d for d in dataset if d.get('theme')=='world_modeling'])}\")\n",
291
- "\n",
292
- "# Theme 4: Self-Improvement (curriculum across difficulties)\n",
293
- "print(\"Building self-improvement theme examples...\")\n",
294
- "for difficulty in [1, 1, 2, 2, 3, 3]:\n",
295
- " try:\n",
296
- " resp = requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": difficulty}, timeout=10).json()\n",
297
- " if \"observations\" in resp:\n",
298
- " obs = resp[\"observations\"][0]\n",
299
- " prompt = f\"\"\"Difficulty Level {difficulty}/3 - Control building energy system.\n",
300
- "State: temp={obs.get('indoor_temperature', 21):.1f}°C, storage={obs.get('thermal_storage_level', 0.5):.2f},\n",
301
- "price=${obs.get('current_price', 0.1):.3f}/kWh\n",
302
- "\n",
303
- "Output JSON action:\n",
304
- "{{\"hvac_power_level\": <0-1>, \"thermal_charge_rate\": <-1 to 1>, \"batch_job_slot\": <0-4>, \n",
305
- "\"load_shed_fraction\": <0-0.5>, \"building_id\": 0}}\"\"\"\n",
306
- " dataset.append({\"prompt\": prompt, \"theme\": \"curriculum\", \"difficulty\": difficulty})\n",
307
- " except:\n",
308
- " pass\n",
309
- "\n",
310
- "print(f\"Self-improvement examples: {len([d for d in dataset if d.get('theme')=='curriculum'])}\")\n",
311
- "\n",
312
- "print(f\"\\nTotal dataset: {len(dataset)} prompts\")\n",
313
- "theme_counts = {}\n",
314
- "for d in dataset:\n",
315
- " theme = d.get(\"theme\", \"unknown\")\n",
316
- " theme_counts[theme] = theme_counts.get(theme, 0) + 1\n",
317
- "print(f\"Theme distribution: {theme_counts}\")"
318
- ]
319
- },
320
- {
321
- "cell_type": "markdown",
322
- "id": "2ed46c06",
323
- "metadata": {},
324
- "source": [
325
- "## Step 4: Load Model and Tokenizer"
326
- ]
327
- },
328
- {
329
- "cell_type": "code",
330
- "execution_count": null,
331
- "id": "5e5826e4",
332
- "metadata": {},
333
- "outputs": [],
334
- "source": [
335
- "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
336
- "\n",
337
- "MODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
338
- "print(f\"Loading {MODEL_NAME}...\")\n",
339
- "\n",
340
- "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
341
- "if tokenizer.pad_token is None:\n",
342
- " tokenizer.pad_token = tokenizer.eos_token\n",
343
- "\n",
344
- "model = AutoModelForCausalLM.from_pretrained(\n",
345
- " MODEL_NAME,\n",
346
- " torch_dtype=torch.float16,\n",
347
- " device_map=\"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
348
- ")\n",
349
- "\n",
350
- "total_params = sum(p.numel() for p in model.parameters())\n",
351
- "print(f\"Model loaded. Parameters: {total_params/1e6:.0f}M\")\n",
352
- "print(f\"Device: {next(model.parameters()).device}\")"
353
- ]
354
- },
355
- {
356
- "cell_type": "markdown",
357
- "id": "ba6645a6",
358
- "metadata": {},
359
- "source": [
360
- "## Step 5: Define Reward Function"
361
- ]
362
- },
363
- {
364
- "cell_type": "code",
365
- "execution_count": null,
366
- "id": "02686008",
367
- "metadata": {},
368
- "outputs": [],
369
- "source": [
370
- "import json as _json\n",
371
- "\n",
372
- "training_rewards = []\n",
373
- "\n",
374
- "def gridmind_reward_fn(completions, **kwargs):\n",
375
- " \"\"\"Reward function that calls the real environment.\"\"\"\n",
376
- " rewards = []\n",
377
- " \n",
378
- " for completion in completions:\n",
379
- " try:\n",
380
- " # Extract JSON action from completion\n",
381
- " text = str(completion).strip()\n",
382
- " start = text.rfind('{')\n",
383
- " end = text.rfind('}') + 1\n",
384
- " if start < 0 or end <= start:\n",
385
- " rewards.append(-1.0)\n",
386
- " continue\n",
387
- " \n",
388
- " action_str = text[start:end]\n",
389
- " action = _json.loads(action_str)\n",
390
- " \n",
391
- " # Clamp action to valid ranges\n",
392
- " action[\"hvac_power_level\"] = max(0.0, min(1.0, float(action.get(\"hvac_power_level\", 0.5))))\n",
393
- " action[\"thermal_charge_rate\"] = max(-1.0, min(1.0, float(action.get(\"thermal_charge_rate\", 0.0))))\n",
394
- " action[\"batch_job_slot\"] = max(0, min(4, int(action.get(\"batch_job_slot\", 0))))\n",
395
- " action[\"load_shed_fraction\"] = max(0.0, min(0.5, float(action.get(\"load_shed_fraction\", 0.0))))\n",
396
- " action[\"building_id\"] = int(action.get(\"building_id\", 0))\n",
397
- " \n",
398
- " # Call environment\n",
399
- " r = requests.post(f\"{ENV_URL}/step\", json=action, timeout=8)\n",
400
- " if r.status_code != 200:\n",
401
- " rewards.append(-0.5)\n",
402
- " continue\n",
403
- " \n",
404
- " step_data = r.json()\n",
405
- " if isinstance(step_data, list):\n",
406
- " step_data = step_data[0]\n",
407
- " \n",
408
- " reward = float(step_data.get(\"reward\", 0))\n",
409
- " rewards.append(max(-1.0, min(1.0, reward))) # Clamp to [-1, 1]\n",
410
- " training_rewards.append(reward)\n",
411
- " \n",
412
- " except Exception as e:\n",
413
- " rewards.append(-1.0)\n",
414
- " \n",
415
- " return rewards\n",
416
- "\n",
417
- "print(\"Reward function defined.\")"
418
- ]
419
- },
420
- {
421
- "cell_type": "markdown",
422
- "id": "adae3837",
423
- "metadata": {},
424
- "source": [
425
- "## Step 6: Configure and Run GRPO Training"
426
- ]
427
- },
428
- {
429
- "cell_type": "code",
430
- "execution_count": null,
431
- "id": "ceac8c9d",
432
- "metadata": {},
433
- "outputs": [],
434
- "source": [
435
- "from trl import GRPOTrainer, GRPOConfig\n",
436
- "from datasets import Dataset\n",
437
- "\n",
438
- "# Prepare dataset\n",
439
- "train_data = [{\"prompt\": d[\"prompt\"]} for d in dataset]\n",
440
- "train_ds = Dataset.from_list(train_data)\n",
441
- "\n",
442
- "print(f\"Training dataset: {len(train_ds)} prompts\")\n",
443
- "print(f\"Sample prompt:\\n{train_data[0]['prompt'][:200]}...\\n\")\n",
444
- "\n",
445
- "# GRPO config for free T4 GPU\n",
446
- "config = GRPOConfig(\n",
447
- " output_dir=\"./gridmind-grpo-output\",\n",
448
- " num_train_epochs=1,\n",
449
- " max_steps=60, # Complete in ~30-40 min on T4\n",
450
- " per_device_train_batch_size=2,\n",
451
- " gradient_accumulation_steps=2,\n",
452
- " max_new_tokens=100,\n",
453
- " max_prompt_length=512,\n",
454
- " learning_rate=5e-6,\n",
455
- " logging_steps=5,\n",
456
- " save_steps=60,\n",
457
- " fp16=True,\n",
458
- " dataloader_num_workers=0,\n",
459
- " report_to=\"none\",\n",
460
- " num_generations=2, # 2 generations per prompt for speed\n",
461
- ")\n",
462
- "\n",
463
- "print(\"\\nStarting GRPO training...\")\n",
464
- "print(f\"Estimated time: 30-40 minutes on Colab T4 GPU\")\n",
465
- "print(f\"Steps: {config.max_steps}, Batch size: {config.per_device_train_batch_size * config.gradient_accumulation_steps}\\n\")\n",
466
- "\n",
467
- "# Initialize trainer\n",
468
- "trainer = GRPOTrainer(\n",
469
- " model=model,\n",
470
- " tokenizer=tokenizer,\n",
471
- " config=config,\n",
472
- " train_dataset=train_ds,\n",
473
- " reward_funcs=gridmind_reward_fn,\n",
474
- ")\n",
475
- "\n",
476
- "# Train\n",
477
- "trainer.train()\n",
478
- "print(\"\\n✓ Training complete!\")"
479
- ]
480
- },
481
- {
482
- "cell_type": "markdown",
483
- "id": "c145c8c6",
484
- "metadata": {},
485
- "source": [
486
- "## Step 7: Evaluate Trained Model"
487
- ]
488
- },
489
- {
490
- "cell_type": "code",
491
- "execution_count": null,
492
- "id": "dac005cc",
493
- "metadata": {},
494
- "outputs": [],
495
- "source": [
496
- "def run_llm_episode(task_id=1, max_steps=96):\n",
497
- " \"\"\"Run an episode using the trained LLM.\"\"\"\n",
498
- " try:\n",
499
- " r = requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": task_id}, timeout=10)\n",
500
- " obs_data = r.json()\n",
501
- " obs = obs_data[\"observations\"][0] if \"observations\" in obs_data else obs_data\n",
502
- " except:\n",
503
- " return 0.0\n",
504
- " \n",
505
- " model.eval()\n",
506
- " \n",
507
- " for step in range(max_steps):\n",
508
- " prompt = f\"\"\"Control industrial building energy system.\n",
509
- "State: temp={obs.get('indoor_temperature', 21):.1f}°C, storage={obs.get('thermal_storage_level', 0.5):.2f}\n",
510
- "Output JSON action (hvac_power_level 0-1, thermal_charge_rate -1 to 1, batch_job_slot 0-4,\n",
511
- "load_shed_fraction 0-0.5, building_id 0):\"\"\"\n",
512
- " \n",
513
- " try:\n",
514
- " inputs = tokenizer(prompt, return_tensors=\"pt\", truncation=True, max_length=400).to(model.device)\n",
515
- " with torch.no_grad():\n",
516
- " outputs = model.generate(**inputs, max_new_tokens=80, do_sample=False, pad_token_id=tokenizer.eos_token_id)\n",
517
- " generated = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)\n",
518
- " \n",
519
- " start = generated.rfind('{')\n",
520
- " end = generated.rfind('}') + 1\n",
521
- " if start >= 0 and end > start:\n",
522
- " action = _json.loads(generated[start:end])\n",
523
- " action[\"hvac_power_level\"] = max(0.0, min(1.0, float(action.get(\"hvac_power_level\", 0.5))))\n",
524
- " action[\"thermal_charge_rate\"] = max(-1.0, min(1.0, float(action.get(\"thermal_charge_rate\", 0.0))))\n",
525
- " action[\"batch_job_slot\"] = max(0, min(4, int(action.get(\"batch_job_slot\", 0))))\n",
526
- " action[\"load_shed_fraction\"] = max(0.0, min(0.5, float(action.get(\"load_shed_fraction\", 0.0))))\n",
527
- " action[\"building_id\"] = 0\n",
528
- " else:\n",
529
- " action = {\"hvac_power_level\": 0.5, \"thermal_charge_rate\": 0.0, \"batch_job_slot\": 0,\n",
530
- " \"load_shed_fraction\": 0.0, \"building_id\": 0}\n",
531
- " \n",
532
- " r = requests.post(f\"{ENV_URL}/step\", json=action, timeout=8)\n",
533
- " step_data = r.json()\n",
534
- " if isinstance(step_data, list):\n",
535
- " step_data = step_data[0]\n",
536
- " obs = step_data.get(\"observation\", obs)\n",
537
- " if step_data.get(\"done\", False):\n",
538
- " break\n",
539
- " except:\n",
540
- " break\n",
541
- " \n",
542
- " try:\n",
543
- " grade = requests.get(f\"{ENV_URL}/grade\", timeout=10).json()\n",
544
- " return float(grade.get(\"score\", 0))\n",
545
- " except:\n",
546
- " return 0.0\n",
547
- "\n",
548
- "print(\"Evaluating trained model (2 episodes per task)...\")\n",
549
- "trained_scores = {}\n",
550
- "for task_id in [1, 2, 3, 4]:\n",
551
- " scores = []\n",
552
- " for ep in range(2):\n",
553
- " score = run_llm_episode(task_id=task_id)\n",
554
- " scores.append(score)\n",
555
- " print(f\" Task {task_id} Episode {ep+1}: {score:.3f}\")\n",
556
- " trained_scores[task_id] = sum(scores) / len(scores)\n",
557
- "\n",
558
- "print(f\"\\nTrained Model Scores:\")\n",
559
- "for task_id, avg in trained_scores.items():\n",
560
- " baseline = baseline_scores[task_id]\n",
561
- " improvement = ((avg - baseline) / baseline * 100) if baseline > 0 else 0\n",
562
- " print(f\" Task {task_id}: {avg:.3f} (baseline: {baseline:.3f}, {improvement:+.1f}%)\")\n",
563
- "\n",
564
- "trained_avg = sum(trained_scores.values()) / len(trained_scores)\n",
565
- "baseline_avg = sum(baseline_scores.values()) / len(baseline_scores)\n",
566
- "overall_improvement = ((trained_avg - baseline_avg) / baseline_avg * 100) if baseline_avg > 0 else 0\n",
567
- "\n",
568
- "print(f\"\\nOverall Scores:\")\n",
569
- "print(f\" Heuristic baseline: {baseline_avg:.3f}\")\n",
570
- "print(f\" Trained LLM: {trained_avg:.3f}\")\n",
571
- "print(f\" Improvement: {overall_improvement:+.1f}%\")"
572
- ]
573
- },
574
- {
575
- "cell_type": "markdown",
576
- "id": "0f955e71",
577
- "metadata": {},
578
- "source": [
579
- "## Step 8: Save Results"
580
- ]
581
- },
582
- {
583
- "cell_type": "code",
584
- "execution_count": null,
585
- "id": "00844cb1",
586
- "metadata": {},
587
- "outputs": [],
588
- "source": [
589
- "results = {\n",
590
- " \"heuristic_baseline\": {\n",
591
- " \"scores_by_task\": {str(k): v for k, v in baseline_scores.items()},\n",
592
- " \"average\": baseline_avg\n",
593
- " },\n",
594
- " \"trained_llm\": {\n",
595
- " \"scores_by_task\": {str(k): v for k, v in trained_scores.items()},\n",
596
- " \"average\": trained_avg\n",
597
- " },\n",
598
- " \"improvement_percent\": overall_improvement,\n",
599
- " \"model\": MODEL_NAME,\n",
600
- " \"training_steps\": config.max_steps,\n",
601
- " \"themes_covered\": [\"multi_agent\", \"instruction_following\", \"world_modeling\", \"curriculum\"],\n",
602
- " \"training_rewards_log\": training_rewards[-20:] if training_rewards else [],\n",
603
- "}\n",
604
- "\n",
605
- "print(\"Saving results...\")\n",
606
- "with open(\"gridmind_training_results.json\", \"w\") as f:\n",
607
- " _json.dump(results, f, indent=2)\n",
608
- "\n",
609
- "print(\"✓ Results saved to gridmind_training_results.json\")\n",
610
- "print(f\"\\nSummary:\")\n",
611
- "print(f\" Model: {MODEL_NAME}\")\n",
612
- "print(f\" Themes: {results['themes_covered']}\")\n",
613
- "print(f\" Heuristic baseline: {baseline_avg:.3f}\")\n",
614
- "print(f\" Trained LLM: {trained_avg:.3f}\")\n",
615
- "print(f\" Improvement: {overall_improvement:+.1f}%\")"
616
- ]
617
- }
618
- ],
619
- "metadata": {
620
- "language_info": {
621
- "name": "python"
622
- }
623
- },
624
- "nbformat": 4,
625
- "nbformat_minor": 5
626
- }
 
1
  {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "193da661",
6
+ "metadata": {},
7
+ "source": [
8
+ "# GridMind-RL: GRPO Training for Industrial Energy Management\n",
9
+ "\n",
10
+ "**Meta PyTorch OpenEnv Hackathon \u00e2\u20ac\u201d GridMind-RL Team**\n",
11
+ "\n",
12
+ "This notebook trains a small LLM (Qwen2.5-1.5B) using TRL GRPO on the GridMind-RL environment.\n",
13
+ "The environment covers all 4 hackathon themes:\n",
14
+ "\n",
15
+ "1. **Theme 1: Multi-Agent** \u00e2\u20ac\u201d 3 buildings share a grid feeder; each agent makes independent decisions\n",
16
+ "2. **Theme 2: Instruction Following** \u00e2\u20ac\u201d Task 4 provides natural language objectives that must be satisfied\n",
17
+ "3. **Theme 3: World Modeling** \u00e2\u20ac\u201d `/simulate` endpoint predicts outcomes before committing actions\n",
18
+ "4. **Theme 4: Self-Improvement** \u00e2\u20ac\u201d Curriculum automatically advances difficulty as agent performance improves\n",
19
+ "\n",
20
+ "| | |\n",
21
+ "|---|---|\n",
22
+ "| **Environment** | https://lo-kyu-gridmind.hf.space |\n",
23
+ "| **Method** | GRPO (Group Relative Policy Optimization) |\n",
24
+ "| **Model** | Qwen2.5-1.5B-Instruct |\n",
25
+ "| **Training Time** | ~30-40 minutes on free Colab T4 GPU |\n",
26
+ "| **Expected Improvement** | 20-40% score gain over heuristic baseline |"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "id": "f28e2f2c",
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "# Install dependencies\n",
37
+ "!pip install trl==0.8.6 transformers>=4.41.0 torch accelerate datasets requests -q\n",
38
+ "\n",
39
+ "import torch\n",
40
+ "import sys\n",
41
+ "\n",
42
+ "print(f\"PyTorch: {torch.__version__}\")\n",
43
+ "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
44
+ "if torch.cuda.is_available():\n",
45
+ " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
46
+ " print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "markdown",
51
+ "id": "5021a299",
52
+ "metadata": {},
53
+ "source": [
54
+ "## Step 1: Connect to Environment and Verify Connectivity"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "id": "4cdf0f35",
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "import requests\n",
65
+ "import json\n",
66
+ "import time\n",
67
+ "\n",
68
+ "ENV_URL = \"https://lo-kyu-gridmind.hf.space\"\n",
69
+ "\n",
70
+ "# Test connectivity\n",
71
+ "print(\"Testing environment connectivity...\")\n",
72
+ "try:\n",
73
+ " health = requests.get(f\"{ENV_URL}/health\", timeout=10).json()\n",
74
+ " print(f\"\u00e2\u0153\u201c Health check: {health}\")\n",
75
+ "except Exception as e:\n",
76
+ " print(f\"\u00e2\u0153\u2014 Health check failed: {e}\")\n",
77
+ " sys.exit(1)\n",
78
+ "\n",
79
+ "# Test each task reset\n",
80
+ "print(\"\\nTesting all 4 tasks...\")\n",
81
+ "for task_id in [1, 2, 3, 4]:\n",
82
+ " try:\n",
83
+ " r = requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": task_id}, timeout=10)\n",
84
+ " obs = r.json()\n",
85
+ " has_card = \"instruction_card\" in obs or \"observations\" in obs and obs[\"observations\"][0].get(\"instruction_card\")\n",
86
+ " print(f\"\u00e2\u0153\u201c Task {task_id}: status={r.status_code}, has_instruction_card={has_card}\")\n",
87
+ " except Exception as e:\n",
88
+ " print(f\"\u00e2\u0153\u2014 Task {task_id} failed: {e}\")\n",
89
+ "\n",
90
+ "# Test coordinator (multi-agent)\n",
91
+ "print(\"\\nTesting multi-agent coordinator...\")\n",
92
+ "try:\n",
93
+ " r = requests.post(f\"{ENV_URL}/coordinator/reset\", json={}, timeout=10)\n",
94
+ " obs = r.json()\n",
95
+ " n_buildings = len(obs.get(\"observations\", []))\n",
96
+ " print(f\"\u00e2\u0153\u201c Coordinator reset: {n_buildings} buildings\")\n",
97
+ "except Exception as e:\n",
98
+ " print(f\"\u00e2\u0153\u2014 Coordinator failed: {e}\")\n",
99
+ "\n",
100
+ "# Test world modeling\n",
101
+ "print(\"\\nTesting world modeling (/simulate)...\")\n",
102
+ "try:\n",
103
+ " r = requests.post(f\"{ENV_URL}/simulate\", \n",
104
+ " json=[{\"hvac_power_level\": 0.5, \"thermal_charge_rate\": 0.0, \n",
105
+ " \"batch_job_slot\": 0, \"load_shed_fraction\": 0.0, \"building_id\": 0}],\n",
106
+ " timeout=10)\n",
107
+ " sim = r.json()\n",
108
+ " has_results = \"results\" in sim\n",
109
+ " print(f\"\u00e2\u0153\u201c Simulate: has_results={has_results}\")\n",
110
+ "except Exception as e:\n",
111
+ " print(f\"\u00e2\u0153\u2014 Simulate failed: {e}\")\n",
112
+ "\n",
113
+ "print(\"\\n\u00e2\u0153\u201c All connectivity checks passed!\")"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "markdown",
118
+ "id": "4a5b58c2",
119
+ "metadata": {},
120
+ "source": [
121
+ "## Step 2: Measure Baseline Performance (Before Training)"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": null,
127
+ "id": "42cecadb",
128
+ "metadata": {},
129
+ "outputs": [],
130
+ "source": [
131
+ "import random\n",
132
+ "\n",
133
+ "def run_heuristic_episode(task_id=1, max_steps=96):\n",
134
+ " \"\"\"Run an episode using a rule-based heuristic policy.\"\"\"\n",
135
+ " try:\n",
136
+ " r = requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": task_id}, timeout=10)\n",
137
+ " obs_data = r.json()\n",
138
+ " obs = obs_data[\"observations\"][0] if \"observations\" in obs_data else obs_data\n",
139
+ " except:\n",
140
+ " return 0.0\n",
141
+ " \n",
142
+ " for step in range(max_steps):\n",
143
+ " # Simple heuristic: charge off-peak, discharge peak\n",
144
+ " hour = step // 4\n",
145
+ " hvac = 0.7 if 8 <= hour <= 18 else 0.3\n",
146
+ " charge = 0.6 if hour < 6 else (-0.4 if 14 <= hour <= 18 else 0.0)\n",
147
+ " shed = 0.3 if 14 <= hour <= 17 else 0.0\n",
148
+ " \n",
149
+ " action = {\n",
150
+ " \"hvac_power_level\": hvac,\n",
151
+ " \"thermal_charge_rate\": charge,\n",
152
+ " \"batch_job_slot\": 1 if 22 <= hour or hour <= 5 else 0,\n",
153
+ " \"load_shed_fraction\": shed,\n",
154
+ " \"building_id\": 0\n",
155
+ " }\n",
156
+ " \n",
157
+ " try:\n",
158
+ " r = requests.post(f\"{ENV_URL}/step\", json=action, timeout=8)\n",
159
+ " step_data = r.json()\n",
160
+ " if isinstance(step_data, list):\n",
161
+ " step_data = step_data[0]\n",
162
+ " obs = step_data.get(\"observation\", obs)\n",
163
+ " if step_data.get(\"done\", False):\n",
164
+ " break\n",
165
+ " except:\n",
166
+ " break\n",
167
+ " \n",
168
+ " # Get final grade\n",
169
+ " try:\n",
170
+ " grade = requests.get(f\"{ENV_URL}/grade\", timeout=10).json()\n",
171
+ " return float(grade.get(\"score\", 0))\n",
172
+ " except:\n",
173
+ " return 0.0\n",
174
+ "\n",
175
+ "print(\"Measuring heuristic baseline (2 episodes per task)...\")\n",
176
+ "baseline_scores = {}\n",
177
+ "for task_id in [1, 2, 3, 4]:\n",
178
+ " scores = []\n",
179
+ " for ep in range(2):\n",
180
+ " score = run_heuristic_episode(task_id=task_id)\n",
181
+ " scores.append(score)\n",
182
+ " print(f\" Task {task_id} Episode {ep+1}: {score:.3f}\")\n",
183
+ " baseline_scores[task_id] = sum(scores) / len(scores)\n",
184
+ "\n",
185
+ "print(f\"\\nHeuristic Baseline Averages:\")\n",
186
+ "for task_id, avg in baseline_scores.items():\n",
187
+ " print(f\" Task {task_id}: {avg:.3f}\")\n",
188
+ "print(f\" Overall: {sum(baseline_scores.values()) / len(baseline_scores):.3f}\")"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "markdown",
193
+ "id": "7abdd330",
194
+ "metadata": {},
195
+ "source": [
196
+ "## Step 3: Build Multi-Theme Training Dataset"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "code",
201
+ "execution_count": null,
202
+ "id": "1c496af9",
203
+ "metadata": {},
204
+ "outputs": [],
205
+ "source": [
206
+ "# Build a dataset that covers all 4 themes\n",
207
+ "dataset = []\n",
208
+ "\n",
209
+ "# Theme 1: Multi-Agent (3 buildings cooperating)\n",
210
+ "print(\"Building multi-agent theme examples...\")\n",
211
+ "for i in range(20):\n",
212
+ " try:\n",
213
+ " resp = requests.post(f\"{ENV_URL}/coordinator/reset\", json={}, timeout=10).json()\n",
214
+ " if \"observations\" in resp:\n",
215
+ " for b_idx, b_obs in enumerate(resp[\"observations\"]):\n",
216
+ " prompt = f\"\"\"You control Building {b_idx} in a 3-building facility.\n",
217
+ "All buildings share one grid connection (feeder limit: 250 kW).\n",
218
+ "Your current state: temp={b_obs.get('indoor_temperature', 21):.1f}\u00c2\u00b0C, \n",
219
+ "storage={b_obs.get('thermal_storage_level', 0.5):.2f}, \n",
220
+ "price=${b_obs.get('current_price', 0.1):.3f}/kWh\n",
221
+ "Grid stress signal: {b_obs.get('grid_stress_signal', 0):.2f}\n",
222
+ "\n",
223
+ "You must coordinate with other buildings to keep total feeder load under 250 kW.\n",
224
+ "Each building decides independently. Respond with your JSON action:\n",
225
+ "{{\"hvac_power_level\": <0-1>, \"thermal_charge_rate\": <-1 to 1>, \"batch_job_slot\": <0-4>, \n",
226
+ "\"load_shed_fraction\": <0-0.5>, \"building_id\": {b_idx}}}\"\"\"\n",
227
+ " dataset.append({\"prompt\": prompt, \"theme\": \"multi_agent\"})\n",
228
+ " except:\n",
229
+ " pass\n",
230
+ "\n",
231
+ "print(f\"Multi-agent examples: {len([d for d in dataset if d.get('theme')=='multi_agent'])}\")\n",
232
+ "\n",
233
+ "# Theme 2: Instruction Following (Task 4 with explicit objectives)\n",
234
+ "print(\"Building instruction-following theme examples...\")\n",
235
+ "for i in range(20):\n",
236
+ " try:\n",
237
+ " resp = requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": 4}, timeout=10).json()\n",
238
+ " if \"observations\" in resp:\n",
239
+ " obs = resp[\"observations\"][0]\n",
240
+ " instruction = resp.get(\"instruction_card\", obs.get(\"instruction_card\", {}))\n",
241
+ " instruction_text = instruction.get(\"text\", \"Minimize cost\") if isinstance(instruction, dict) else str(instruction)\n",
242
+ " prompt = f\"\"\"INSTRUCTION CARD: {instruction_text}\n",
243
+ "\n",
244
+ "Current state: temp={obs.get('indoor_temperature', 21):.1f}\u00c2\u00b0C, \n",
245
+ "storage={obs.get('thermal_storage_level', 0.5):.2f}, \n",
246
+ "cost_so_far=${obs.get('cumulative_cost', 0):.2f}, \n",
247
+ "step={obs.get('step', 0)}/96\n",
248
+ "\n",
249
+ "You MUST satisfy the instruction. Output JSON action:\n",
250
+ "{{\"hvac_power_level\": <0-1>, \"thermal_charge_rate\": <-1 to 1>, \"batch_job_slot\": <0-4>, \n",
251
+ "\"load_shed_fraction\": <0-0.5>, \"building_id\": 0}}\"\"\"\n",
252
+ " dataset.append({\"prompt\": prompt, \"theme\": \"instruction_following\"})\n",
253
+ " except:\n",
254
+ " pass\n",
255
+ "\n",
256
+ "print(f\"Instruction-following examples: {len([d for d in dataset if d.get('theme')=='instruction_following'])}\")\n",
257
+ "\n",
258
+ "# Theme 3: World Modeling (use /simulate)\n",
259
+ "print(\"Building world-modeling theme examples...\")\n",
260
+ "for task_id in [1, 2]:\n",
261
+ " for i in range(10):\n",
262
+ " try:\n",
263
+ " resp = requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": task_id}, timeout=10).json()\n",
264
+ " if \"observations\" in resp:\n",
265
+ " obs = resp[\"observations\"][0]\n",
266
+ " # Simulate 2 candidate actions\n",
267
+ " try:\n",
268
+ " sim_a = requests.post(f\"{ENV_URL}/simulate\",\n",
269
+ " json=[{\"hvac_power_level\": 0.8, \"thermal_charge_rate\": 0.3,\n",
270
+ " \"batch_job_slot\": 0, \"load_shed_fraction\": 0.0, \"building_id\": 0}],\n",
271
+ " timeout=10).json()\n",
272
+ " sim_b = requests.post(f\"{ENV_URL}/simulate\",\n",
273
+ " json=[{\"hvac_power_level\": 0.3, \"thermal_charge_rate\": -0.2,\n",
274
+ " \"batch_job_slot\": 0, \"load_shed_fraction\": 0.2, \"building_id\": 0}],\n",
275
+ " timeout=10).json()\n",
276
+ " sim_context = \"\\nPredicted outcomes:\\nOption A (high HVAC): efficient\\nOption B (low HVAC): economical\"\n",
277
+ " except:\n",
278
+ " sim_context = \"\"\n",
279
+ " \n",
280
+ " prompt = f\"\"\"Plan your actions using simulation of future outcomes.\n",
281
+ "State: temp={obs.get('indoor_temperature', 21):.1f}\u00c2\u00b0C, storage={obs.get('thermal_storage_level', 0.5):.2f}{sim_context}\n",
282
+ "\n",
283
+ "Output your best JSON action:\n",
284
+ "{{\"hvac_power_level\": <0-1>, \"thermal_charge_rate\": <-1 to 1>, \"batch_job_slot\": <0-4>, \n",
285
+ "\"load_shed_fraction\": <0-0.5>, \"building_id\": 0}}\"\"\"\n",
286
+ " dataset.append({\"prompt\": prompt, \"theme\": \"world_modeling\"})\n",
287
+ " except:\n",
288
+ " pass\n",
289
+ "\n",
290
+ "print(f\"World-modeling examples: {len([d for d in dataset if d.get('theme')=='world_modeling'])}\")\n",
291
+ "\n",
292
+ "# Theme 4: Self-Improvement (curriculum across difficulties)\n",
293
+ "print(\"Building self-improvement theme examples...\")\n",
294
+ "for difficulty in [1, 1, 2, 2, 3, 3]:\n",
295
+ " try:\n",
296
+ " resp = requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": difficulty}, timeout=10).json()\n",
297
+ " if \"observations\" in resp:\n",
298
+ " obs = resp[\"observations\"][0]\n",
299
+ " prompt = f\"\"\"Difficulty Level {difficulty}/3 - Control building energy system.\n",
300
+ "State: temp={obs.get('indoor_temperature', 21):.1f}\u00c2\u00b0C, storage={obs.get('thermal_storage_level', 0.5):.2f},\n",
301
+ "price=${obs.get('current_price', 0.1):.3f}/kWh\n",
302
+ "\n",
303
+ "Output JSON action:\n",
304
+ "{{\"hvac_power_level\": <0-1>, \"thermal_charge_rate\": <-1 to 1>, \"batch_job_slot\": <0-4>, \n",
305
+ "\"load_shed_fraction\": <0-0.5>, \"building_id\": 0}}\"\"\"\n",
306
+ " dataset.append({\"prompt\": prompt, \"theme\": \"curriculum\", \"difficulty\": difficulty})\n",
307
+ " except:\n",
308
+ " pass\n",
309
+ "\n",
310
+ "print(f\"Self-improvement examples: {len([d for d in dataset if d.get('theme')=='curriculum'])}\")\n",
311
+ "\n",
312
+ "print(f\"\\nTotal dataset: {len(dataset)} prompts\")\n",
313
+ "theme_counts = {}\n",
314
+ "for d in dataset:\n",
315
+ " theme = d.get(\"theme\", \"unknown\")\n",
316
+ " theme_counts[theme] = theme_counts.get(theme, 0) + 1\n",
317
+ "print(f\"Theme distribution: {theme_counts}\")"
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "markdown",
322
+ "id": "2ed46c06",
323
+ "metadata": {},
324
+ "source": [
325
+ "## Step 4: Load Model and Tokenizer"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": null,
331
+ "id": "5e5826e4",
332
+ "metadata": {},
333
+ "outputs": [],
334
+ "source": [
335
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
336
+ "\n",
337
+ "MODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
338
+ "print(f\"Loading {MODEL_NAME}...\")\n",
339
+ "\n",
340
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
341
+ "if tokenizer.pad_token is None:\n",
342
+ " tokenizer.pad_token = tokenizer.eos_token\n",
343
+ "\n",
344
+ "model = AutoModelForCausalLM.from_pretrained(\n",
345
+ " MODEL_NAME,\n",
346
+ " torch_dtype=torch.float16,\n",
347
+ " device_map=\"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
348
+ ")\n",
349
+ "\n",
350
+ "total_params = sum(p.numel() for p in model.parameters())\n",
351
+ "print(f\"Model loaded. Parameters: {total_params/1e6:.0f}M\")\n",
352
+ "print(f\"Device: {next(model.parameters()).device}\")"
353
+ ]
354
+ },
355
+ {
356
+ "cell_type": "markdown",
357
+ "id": "ba6645a6",
358
+ "metadata": {},
359
+ "source": [
360
+ "## Step 5: Define Reward Function"
361
+ ]
362
+ },
363
+ {
364
+ "cell_type": "code",
365
+ "execution_count": null,
366
+ "id": "02686008",
367
+ "metadata": {},
368
+ "outputs": [],
369
+ "source": [
370
+ "import json as _json\n",
371
+ "\n",
372
+ "training_rewards = []\n",
373
+ "\n",
374
+ "def gridmind_reward_fn(completions, **kwargs):\n",
375
+ " \"\"\"Reward function that calls the real environment.\"\"\"\n",
376
+ " rewards = []\n",
377
+ " \n",
378
+ " for completion in completions:\n",
379
+ " try:\n",
380
+ " # Extract JSON action from completion\n",
381
+ " text = str(completion).strip()\n",
382
+ " start = text.rfind('{')\n",
383
+ " end = text.rfind('}') + 1\n",
384
+ " if start < 0 or end <= start:\n",
385
+ " rewards.append(-1.0)\n",
386
+ " continue\n",
387
+ " \n",
388
+ " action_str = text[start:end]\n",
389
+ " action = _json.loads(action_str)\n",
390
+ " \n",
391
+ " # Clamp action to valid ranges\n",
392
+ " action[\"hvac_power_level\"] = max(0.0, min(1.0, float(action.get(\"hvac_power_level\", 0.5))))\n",
393
+ " action[\"thermal_charge_rate\"] = max(-1.0, min(1.0, float(action.get(\"thermal_charge_rate\", 0.0))))\n",
394
+ " action[\"batch_job_slot\"] = max(0, min(4, int(action.get(\"batch_job_slot\", 0))))\n",
395
+ " action[\"load_shed_fraction\"] = max(0.0, min(0.5, float(action.get(\"load_shed_fraction\", 0.0))))\n",
396
+ " action[\"building_id\"] = int(action.get(\"building_id\", 0))\n",
397
+ " \n",
398
+ " # Call environment\n",
399
+ " r = requests.post(f\"{ENV_URL}/step\", json=action, timeout=8)\n",
400
+ " if r.status_code != 200:\n",
401
+ " rewards.append(-0.5)\n",
402
+ " continue\n",
403
+ " \n",
404
+ " step_data = r.json()\n",
405
+ " if isinstance(step_data, list):\n",
406
+ " step_data = step_data[0]\n",
407
+ " \n",
408
+ " reward = float(step_data.get(\"reward\", 0))\n",
409
+ " rewards.append(max(-1.0, min(1.0, reward))) # Clamp to [-1, 1]\n",
410
+ " training_rewards.append(reward)\n",
411
+ " \n",
412
+ " except Exception as e:\n",
413
+ " rewards.append(-1.0)\n",
414
+ " \n",
415
+ " return rewards\n",
416
+ "\n",
417
+ "print(\"Reward function defined.\")"
418
+ ]
419
+ },
420
+ {
421
+ "cell_type": "markdown",
422
+ "id": "adae3837",
423
+ "metadata": {},
424
+ "source": [
425
+ "## Step 6: Configure and Run GRPO Training"
426
+ ]
427
+ },
428
+ {
429
+ "cell_type": "code",
430
+ "execution_count": null,
431
+ "id": "ceac8c9d",
432
+ "metadata": {},
433
+ "outputs": [],
434
+ "source": [
435
+ "from trl import GRPOTrainer, GRPOConfig\n",
436
+ "from datasets import Dataset\n",
437
+ "\n",
438
+ "# Prepare dataset\n",
439
+ "train_data = [{\"prompt\": d[\"prompt\"]} for d in dataset]\n",
440
+ "train_ds = Dataset.from_list(train_data)\n",
441
+ "\n",
442
+ "print(f\"Training dataset: {len(train_ds)} prompts\")\n",
443
+ "print(f\"Sample prompt:\\n{train_data[0]['prompt'][:200]}...\\n\")\n",
444
+ "\n",
445
+ "# GRPO config for free T4 GPU\n",
446
+ "config = GRPOConfig(\n",
447
+ " output_dir=\"./gridmind-grpo-output\",\n",
448
+ " num_train_epochs=1,\n",
449
+ " max_steps=60, # Complete in ~30-40 min on T4\n",
450
+ " per_device_train_batch_size=2,\n",
451
+ " gradient_accumulation_steps=2,\n",
452
+ " max_new_tokens=100,\n",
453
+ " max_prompt_length=512,\n",
454
+ " learning_rate=5e-6,\n",
455
+ " logging_steps=5,\n",
456
+ " save_steps=60,\n",
457
+ " fp16=True,\n",
458
+ " dataloader_num_workers=0,\n",
459
+ " report_to=\"none\",\n",
460
+ " num_generations=2, # 2 generations per prompt for speed\n",
461
+ ")\n",
462
+ "\n",
463
+ "print(\"\\nStarting GRPO training...\")\n",
464
+ "print(f\"Estimated time: 30-40 minutes on Colab T4 GPU\")\n",
465
+ "print(f\"Steps: {config.max_steps}, Batch size: {config.per_device_train_batch_size * config.gradient_accumulation_steps}\\n\")\n",
466
+ "\n",
467
+ "# Initialize trainer\n",
468
+ "trainer = GRPOTrainer(\n",
469
+ " model=model,\n",
470
+ " tokenizer=tokenizer,\n",
471
+ " config=config,\n",
472
+ " train_dataset=train_ds,\n",
473
+ " reward_funcs=gridmind_reward_fn,\n",
474
+ ")\n",
475
+ "\n",
476
+ "# Train\n",
477
+ "trainer.train()\n",
478
+ "print(\"\\n\u00e2\u0153\u201c Training complete!\")"
479
+ ]
480
+ },
481
+ {
482
+ "cell_type": "markdown",
483
+ "id": "c145c8c6",
484
+ "metadata": {},
485
+ "source": [
486
+ "## Step 7: Evaluate Trained Model"
487
+ ]
488
+ },
489
+ {
490
+ "cell_type": "code",
491
+ "execution_count": null,
492
+ "id": "dac005cc",
493
+ "metadata": {},
494
+ "outputs": [],
495
+ "source": [
496
+ "def run_llm_episode(task_id=1, max_steps=96):\n",
497
+ " \"\"\"Run an episode using the trained LLM.\"\"\"\n",
498
+ " try:\n",
499
+ " r = requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": task_id}, timeout=10)\n",
500
+ " obs_data = r.json()\n",
501
+ " obs = obs_data[\"observations\"][0] if \"observations\" in obs_data else obs_data\n",
502
+ " except:\n",
503
+ " return 0.0\n",
504
+ " \n",
505
+ " model.eval()\n",
506
+ " \n",
507
+ " for step in range(max_steps):\n",
508
+ " prompt = f\"\"\"Control industrial building energy system.\n",
509
+ "State: temp={obs.get('indoor_temperature', 21):.1f}\u00c2\u00b0C, storage={obs.get('thermal_storage_level', 0.5):.2f}\n",
510
+ "Output JSON action (hvac_power_level 0-1, thermal_charge_rate -1 to 1, batch_job_slot 0-4,\n",
511
+ "load_shed_fraction 0-0.5, building_id 0):\"\"\"\n",
512
+ " \n",
513
+ " try:\n",
514
+ " inputs = tokenizer(prompt, return_tensors=\"pt\", truncation=True, max_length=400).to(model.device)\n",
515
+ " with torch.no_grad():\n",
516
+ " outputs = model.generate(**inputs, max_new_tokens=80, do_sample=False, pad_token_id=tokenizer.eos_token_id)\n",
517
+ " generated = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)\n",
518
+ " \n",
519
+ " start = generated.rfind('{')\n",
520
+ " end = generated.rfind('}') + 1\n",
521
+ " if start >= 0 and end > start:\n",
522
+ " action = _json.loads(generated[start:end])\n",
523
+ " action[\"hvac_power_level\"] = max(0.0, min(1.0, float(action.get(\"hvac_power_level\", 0.5))))\n",
524
+ " action[\"thermal_charge_rate\"] = max(-1.0, min(1.0, float(action.get(\"thermal_charge_rate\", 0.0))))\n",
525
+ " action[\"batch_job_slot\"] = max(0, min(4, int(action.get(\"batch_job_slot\", 0))))\n",
526
+ " action[\"load_shed_fraction\"] = max(0.0, min(0.5, float(action.get(\"load_shed_fraction\", 0.0))))\n",
527
+ " action[\"building_id\"] = 0\n",
528
+ " else:\n",
529
+ " action = {\"hvac_power_level\": 0.5, \"thermal_charge_rate\": 0.0, \"batch_job_slot\": 0,\n",
530
+ " \"load_shed_fraction\": 0.0, \"building_id\": 0}\n",
531
+ " \n",
532
+ " r = requests.post(f\"{ENV_URL}/step\", json=action, timeout=8)\n",
533
+ " step_data = r.json()\n",
534
+ " if isinstance(step_data, list):\n",
535
+ " step_data = step_data[0]\n",
536
+ " obs = step_data.get(\"observation\", obs)\n",
537
+ " if step_data.get(\"done\", False):\n",
538
+ " break\n",
539
+ " except:\n",
540
+ " break\n",
541
+ " \n",
542
+ " try:\n",
543
+ " grade = requests.get(f\"{ENV_URL}/grade\", timeout=10).json()\n",
544
+ " return float(grade.get(\"score\", 0))\n",
545
+ " except:\n",
546
+ " return 0.0\n",
547
+ "\n",
548
+ "print(\"Evaluating trained model (2 episodes per task)...\")\n",
549
+ "trained_scores = {}\n",
550
+ "for task_id in [1, 2, 3, 4]:\n",
551
+ " scores = []\n",
552
+ " for ep in range(2):\n",
553
+ " score = run_llm_episode(task_id=task_id)\n",
554
+ " scores.append(score)\n",
555
+ " print(f\" Task {task_id} Episode {ep+1}: {score:.3f}\")\n",
556
+ " trained_scores[task_id] = sum(scores) / len(scores)\n",
557
+ "\n",
558
+ "print(f\"\\nTrained Model Scores:\")\n",
559
+ "for task_id, avg in trained_scores.items():\n",
560
+ " baseline = baseline_scores[task_id]\n",
561
+ " improvement = ((avg - baseline) / baseline * 100) if baseline > 0 else 0\n",
562
+ " print(f\" Task {task_id}: {avg:.3f} (baseline: {baseline:.3f}, {improvement:+.1f}%)\")\n",
563
+ "\n",
564
+ "trained_avg = sum(trained_scores.values()) / len(trained_scores)\n",
565
+ "baseline_avg = sum(baseline_scores.values()) / len(baseline_scores)\n",
566
+ "overall_improvement = ((trained_avg - baseline_avg) / baseline_avg * 100) if baseline_avg > 0 else 0\n",
567
+ "\n",
568
+ "print(f\"\\nOverall Scores:\")\n",
569
+ "print(f\" Heuristic baseline: {baseline_avg:.3f}\")\n",
570
+ "print(f\" Trained LLM: {trained_avg:.3f}\")\n",
571
+ "print(f\" Improvement: {overall_improvement:+.1f}%\")"
572
+ ]
573
+ },
574
+ {
575
+ "cell_type": "markdown",
576
+ "id": "0f955e71",
577
+ "metadata": {},
578
+ "source": [
579
+ "## Step 8: Save Results"
580
+ ]
581
+ },
582
+ {
583
+ "cell_type": "code",
584
+ "execution_count": null,
585
+ "id": "00844cb1",
586
+ "metadata": {},
587
+ "outputs": [],
588
+ "source": [
589
+ "results = {\n",
590
+ " \"heuristic_baseline\": {\n",
591
+ " \"scores_by_task\": {str(k): v for k, v in baseline_scores.items()},\n",
592
+ " \"average\": baseline_avg\n",
593
+ " },\n",
594
+ " \"trained_llm\": {\n",
595
+ " \"scores_by_task\": {str(k): v for k, v in trained_scores.items()},\n",
596
+ " \"average\": trained_avg\n",
597
+ " },\n",
598
+ " \"improvement_percent\": overall_improvement,\n",
599
+ " \"model\": MODEL_NAME,\n",
600
+ " \"training_steps\": config.max_steps,\n",
601
+ " \"themes_covered\": [\"multi_agent\", \"instruction_following\", \"world_modeling\", \"curriculum\"],\n",
602
+ " \"training_rewards_log\": training_rewards[-20:] if training_rewards else [],\n",
603
+ "}\n",
604
+ "\n",
605
+ "print(\"Saving results...\")\n",
606
+ "with open(\"gridmind_training_results.json\", \"w\") as f:\n",
607
+ " _json.dump(results, f, indent=2)\n",
608
+ "\n",
609
+ "print(\"\u00e2\u0153\u201c Results saved to gridmind_training_results.json\")\n",
610
+ "print(f\"\\nSummary:\")\n",
611
+ "print(f\" Model: {MODEL_NAME}\")\n",
612
+ "print(f\" Themes: {results['themes_covered']}\")\n",
613
+ "print(f\" Heuristic baseline: {baseline_avg:.3f}\")\n",
614
+ "print(f\" Trained LLM: {trained_avg:.3f}\")\n",
615
+ "print(f\" Improvement: {overall_improvement:+.1f}%\")"
616
+ ]
617
+ }
618
+ ],
619
+ "metadata": {
620
+ "language_info": {
621
+ "name": "python"
622
+ }
623
  },
624
+ "nbformat": 4,
625
+ "nbformat_minor": 5
626
+ }