adityss commited on
Commit
27d3504
·
1 Parent(s): 19ba2eb

feat: implement Unsloth GRPO training pipeline with environment-backed reward functions and balanced dataset generation

Browse files
scripts/gridmind_grpo_colab.ipynb CHANGED
@@ -272,88 +272,113 @@
272
  "outputs": [],
273
  "source": [
274
  "import json as _json\n",
275
- "import requests as _requests\n",
276
- "import random as _random\n",
277
  "import math as _math\n",
 
 
 
 
278
  "\n",
 
279
  "call_count = [0]\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  "\n",
281
  "def gridmind_reward_fn(completions, **kwargs):\n",
282
  " \"\"\"\n",
283
- " Reward function for GridMind-RL GRPO training.\n",
284
- " - Parses JSON action from LLM output\n",
285
- " - Executes against environment\n",
286
- " - Returns normalized reward signal\n",
287
  " \"\"\"\n",
288
  " rewards = []\n",
289
- " task_id = _random.choice([1, 2, 3, 4])\n",
290
  "\n",
291
- " try:\n",
292
- " _requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": task_id}, timeout=10)\n",
293
- " except:\n",
294
- " return [-0.1] * len(completions)\n",
295
  "\n",
296
- " for completion in completions:\n",
297
  " try:\n",
298
- " text = str(completion[0]) if isinstance(completion, list) and completion else str(completion)\n",
299
- " text = text.strip()\n",
300
- "\n",
301
- " # Extract JSON from completion\n",
302
- " start = text.rfind('{')\n",
303
- " end = text.rfind('}') + 1\n",
304
- " if start < 0 or end <= start:\n",
305
- " rewards.append(-0.3)\n",
306
- " try:\n",
307
- " _requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": task_id}, timeout=6)\n",
308
- " except:\n",
309
- " pass\n",
310
- " continue\n",
 
311
  "\n",
312
- " try:\n",
313
- " action = _json.loads(text[start:end])\n",
314
- " except _json.JSONDecodeError:\n",
315
- " rewards.append(-0.2)\n",
316
- " try:\n",
317
- " _requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": task_id}, timeout=6)\n",
318
- " except:\n",
319
- " pass\n",
320
- " continue\n",
321
- "\n",
322
- " # Validate and clamp action fields\n",
323
- " cleaned = {\n",
324
- " \"hvac_power_level\": max(0.0, min(1.0, float(action.get(\"hvac_power_level\", 0.5)))),\n",
325
- " \"thermal_charge_rate\": max(-1.0, min(1.0, float(action.get(\"thermal_charge_rate\", 0.0)))),\n",
326
- " \"batch_job_slot\": max(0, min(4, int(action.get(\"batch_job_slot\", 0)))),\n",
327
- " \"load_shed_fraction\": max(0.0, min(0.5, float(action.get(\"load_shed_fraction\", 0.0)))),\n",
328
- " \"building_id\": int(action.get(\"building_id\", 0)),\n",
329
- " }\n",
330
  "\n",
331
- " try:\n",
332
- " step_r = _requests.post(f\"{ENV_URL}/step\", json=cleaned, timeout=8)\n",
333
- " data = step_r.json()\n",
 
 
 
 
334
  " if isinstance(data, list):\n",
335
  " data = data[0]\n",
336
- " env_reward = float(data.get(\"reward\", 0.0))\n",
337
- " reward_signal = _math.tanh(env_reward * 1.5) * 0.5\n",
338
- " rewards.append(reward_signal)\n",
339
- " except:\n",
340
- " rewards.append(-0.15)\n",
341
- "\n",
342
- " try:\n",
343
- " _requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": task_id}, timeout=6)\n",
344
- " except:\n",
345
- " pass\n",
346
- "\n",
 
 
 
 
347
  " except Exception:\n",
348
- " rewards.append(-0.15)\n",
349
  "\n",
350
- " call_count[0] += 1\n",
351
- " if call_count[0] % 5 == 0:\n",
352
- " print(f\" Step {call_count[0]}: Avg reward = {sum(rewards)/len(rewards):+.3f}\")\n",
 
353
  "\n",
354
  " return rewards\n",
355
  "\n",
356
- "print(\"Reward function ready\")"
357
  ]
358
  },
359
  {
@@ -373,6 +398,7 @@
373
  "source": [
374
  "from trl import GRPOTrainer, GRPOConfig\n",
375
  "from peft import LoraConfig, prepare_model_for_kbit_training\n",
 
376
  "import inspect\n",
377
  "import os\n",
378
  "\n",
@@ -384,30 +410,106 @@
384
  "peft_config = LoraConfig(\n",
385
  " r=16,\n",
386
  " lora_alpha=32,\n",
387
- " target_modules=[\"q_proj\", \"v_proj\", \"k_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
388
  " lora_dropout=0.05,\n",
389
  " bias=\"none\",\n",
390
  " task_type=\"CAUSAL_LM\",\n",
391
  ")\n",
392
  "\n",
393
- "# Configure GRPO training\n",
394
- "# Note: Disable AMP (fp16/bf16) when using quantized models to avoid gradient scaler issues\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  "grpo_config_dict = {\n",
396
- " \"output_dir\": \"./gridmind-grpo-output\",\n",
397
  " \"num_train_epochs\": 1,\n",
398
  " \"max_steps\": 60,\n",
399
  " \"per_device_train_batch_size\": 1,\n",
400
  " \"gradient_accumulation_steps\": 4,\n",
401
- " \"max_prompt_length\": 400,\n",
402
- " \"max_completion_length\": 80,\n",
403
  " \"num_generations\": 4,\n",
 
 
404
  " \"learning_rate\": 5e-5,\n",
 
 
405
  " \"fp16\": False,\n",
406
  " \"bf16\": False,\n",
407
  " \"max_grad_norm\": 0.0,\n",
408
- " \"logging_steps\": 1,\n",
409
- " \"log_completions\": True,\n",
410
- " \"num_completions_to_print\": 1,\n",
411
  " \"save_steps\": 60,\n",
412
  " \"report_to\": \"none\",\n",
413
  " \"disable_tqdm\": True,\n",
@@ -434,14 +536,16 @@
434
  " train_dataset=dataset,\n",
435
  " reward_funcs=gridmind_reward_fn,\n",
436
  " peft_config=peft_config,\n",
 
437
  ")\n",
 
438
  "\n",
439
  "print(\"\\nStarting GRPO training (estimated 25-35 min on T4)...\\n\")\n",
440
  "train_result = trainer.train()\n",
441
  "\n",
442
- "print(f\"\\n✔ Training complete!\")\n",
443
  "print(f\" Total steps: {train_result.global_step}\")\n",
444
- "print(f\" Final loss: {train_result.training_loss:.6f}\")"
445
  ]
446
  },
447
  {
 
272
  "outputs": [],
273
  "source": [
274
  "import json as _json\n",
 
 
275
  "import math as _math\n",
276
+ "import random as _random\n",
277
+ "import re as _re\n",
278
+ "import requests as _requests\n",
279
+ "import numpy as _np\n",
280
  "\n",
281
+ "training_rewards = []\n",
282
  "call_count = [0]\n",
283
+ "group_count = [0]\n",
284
+ "NUM_GENERATIONS_FOR_REWARD = 4\n",
285
+ "\n",
286
+ "_REQUIRED_ACTION_KEYS = {\"hvac_power_level\", \"thermal_charge_rate\", \"batch_job_slot\", \"load_shed_fraction\", \"building_id\"}\n",
287
+ "\n",
288
+ "def _extract_action(text):\n",
289
+ " match = _re.search(r\"\\{.*?\\}\", text, _re.DOTALL)\n",
290
+ " if not match:\n",
291
+ " raise ValueError(\"completion did not contain a JSON object\")\n",
292
+ " action = _json.loads(match.group())\n",
293
+ " missing = _REQUIRED_ACTION_KEYS - set(action)\n",
294
+ " if missing:\n",
295
+ " raise ValueError(f\"missing action fields: {sorted(missing)}\")\n",
296
+ " return {\n",
297
+ " \"hvac_power_level\": float(max(0, min(1, action.get(\"hvac_power_level\", 0.5)))),\n",
298
+ " \"thermal_charge_rate\": float(max(-1, min(1, action.get(\"thermal_charge_rate\", 0.0)))),\n",
299
+ " \"batch_job_slot\": int(max(0, min(4, action.get(\"batch_job_slot\", 0)))),\n",
300
+ " \"load_shed_fraction\": float(max(0, min(0.5, action.get(\"load_shed_fraction\", 0.0)))),\n",
301
+ " \"building_id\": int(max(0, min(2, action.get(\"building_id\", 0)))),\n",
302
+ " }\n",
303
  "\n",
304
  "def gridmind_reward_fn(completions, **kwargs):\n",
305
  " \"\"\"\n",
306
+ " Environment-backed GRPO reward.\n",
307
+ " Generations from the same prompt are evaluated on the same task/seed, so\n",
308
+ " advantages reflect real action quality instead of random episode noise.\n",
 
309
  " \"\"\"\n",
310
  " rewards = []\n",
311
+ " batch_start = group_count[0]\n",
312
  "\n",
313
+ " for i, completion in enumerate(completions):\n",
314
+ " call_count[0] += 1\n",
315
+ " group_id = batch_start + (i // NUM_GENERATIONS_FOR_REWARD)\n",
316
+ " text = completion[0][\"content\"] if isinstance(completion, list) else completion\n",
317
  "\n",
 
318
  " try:\n",
319
+ " action = _extract_action(text)\n",
320
+ " except _json.JSONDecodeError:\n",
321
+ " reward = -0.8\n",
322
+ " rewards.append(reward)\n",
323
+ " training_rewards.append(reward)\n",
324
+ " continue\n",
325
+ " except ValueError:\n",
326
+ " reward = -1.0\n",
327
+ " rewards.append(reward)\n",
328
+ " training_rewards.append(reward)\n",
329
+ " continue\n",
330
+ "\n",
331
+ " task_id = (group_id % 4) + 1\n",
332
+ " seed = 10_000 + group_id\n",
333
  "\n",
334
+ " try:\n",
335
+ " reset_resp = _requests.post(\n",
336
+ " f\"{ENV_URL}/reset\",\n",
337
+ " json={\"task_id\": task_id, \"seed\": seed, \"num_buildings\": 1},\n",
338
+ " timeout=15,\n",
339
+ " )\n",
340
+ " reset_resp.raise_for_status()\n",
341
+ " except Exception:\n",
342
+ " reward = -0.5\n",
343
+ " rewards.append(reward)\n",
344
+ " training_rewards.append(reward)\n",
345
+ " continue\n",
 
 
 
 
 
 
346
  "\n",
347
+ " total_env_reward = 0.0\n",
348
+ " completed_steps = 0\n",
349
+ " try:\n",
350
+ " for _ in range(8):\n",
351
+ " step_resp = _requests.post(f\"{ENV_URL}/step\", json=action, timeout=15)\n",
352
+ " step_resp.raise_for_status()\n",
353
+ " data = step_resp.json()\n",
354
  " if isinstance(data, list):\n",
355
  " data = data[0]\n",
356
+ " if \"data\" in data and isinstance(data[\"data\"], dict):\n",
357
+ " data = data[\"data\"]\n",
358
+ " total_env_reward += float(data.get(\"reward\", 0.0) or 0.0)\n",
359
+ " completed_steps += 1\n",
360
+ " if data.get(\"done\", False):\n",
361
+ " break\n",
362
+ "\n",
363
+ " avg_step_reward = total_env_reward / max(completed_steps, 1)\n",
364
+ " normalized_step_reward = max(-1.0, min(1.0, avg_step_reward / 10.0))\n",
365
+ " grade_resp = _requests.get(f\"{ENV_URL}/grade\", timeout=15)\n",
366
+ " if grade_resp.status_code == 200:\n",
367
+ " normalized_grade = max(0.0, min(1.0, float(grade_resp.json().get(\"score\", 0.0))))\n",
368
+ " reward = 0.7 * normalized_grade + 0.3 * normalized_step_reward\n",
369
+ " else:\n",
370
+ " reward = normalized_step_reward\n",
371
  " except Exception:\n",
372
+ " reward = -0.5\n",
373
  "\n",
374
+ " rewards.append(reward)\n",
375
+ " training_rewards.append(reward)\n",
376
+ "\n",
377
+ " group_count[0] += _math.ceil(len(completions) / NUM_GENERATIONS_FOR_REWARD)\n",
378
  "\n",
379
  " return rewards\n",
380
  "\n",
381
+ "print(\"Environment-backed reward function ready\")\n"
382
  ]
383
  },
384
  {
 
398
  "source": [
399
  "from trl import GRPOTrainer, GRPOConfig\n",
400
  "from peft import LoraConfig, prepare_model_for_kbit_training\n",
401
+ "from transformers import PrinterCallback, TrainerCallback\n",
402
  "import inspect\n",
403
  "import os\n",
404
  "\n",
 
410
  "peft_config = LoraConfig(\n",
411
  " r=16,\n",
412
  " lora_alpha=32,\n",
413
+ " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
414
  " lora_dropout=0.05,\n",
415
  " bias=\"none\",\n",
416
  " task_type=\"CAUSAL_LM\",\n",
417
  ")\n",
418
  "\n",
419
+ "class MetricsTableCallback(TrainerCallback):\n",
420
+ " columns = [\n",
421
+ " (\"step\", \"Step\", 6),\n",
422
+ " (\"loss\", \"Loss\", 10),\n",
423
+ " (\"reward\", \"Reward\", 10),\n",
424
+ " (\"reward_std\", \"RewardStd\", 10),\n",
425
+ " (\"entropy\", \"Entropy\", 10),\n",
426
+ " (\"learning_rate\", \"LR\", 11),\n",
427
+ " (\"num_tokens\", \"Tokens\", 8),\n",
428
+ " (\"step_time\", \"StepTime\", 10),\n",
429
+ " ]\n",
430
+ "\n",
431
+ " def __init__(self):\n",
432
+ " self.header_printed = False\n",
433
+ " self.rewards = []\n",
434
+ "\n",
435
+ " def _format_value(self, key, value):\n",
436
+ " if value is None:\n",
437
+ " return \"-\"\n",
438
+ " try:\n",
439
+ " if key in {\"step\", \"num_tokens\"}:\n",
440
+ " return str(int(float(value)))\n",
441
+ " if key == \"learning_rate\":\n",
442
+ " return f\"{float(value):.2e}\"\n",
443
+ " return f\"{float(value):.4f}\"\n",
444
+ " except (TypeError, ValueError):\n",
445
+ " return str(value)\n",
446
+ "\n",
447
+ " def _print_header(self):\n",
448
+ " separator = \"+\" + \"+\".join(\"-\" * (width + 2) for _, _, width in self.columns) + \"+\"\n",
449
+ " header = \"|\" + \"|\".join(f\" {title:<{width}} \" for _, title, width in self.columns) + \"|\"\n",
450
+ " print(separator)\n",
451
+ " print(header)\n",
452
+ " print(separator)\n",
453
+ " self.header_printed = True\n",
454
+ "\n",
455
+ " def on_log(self, args, state, control, logs=None, **kwargs):\n",
456
+ " if not logs or (\"loss\" not in logs and \"reward\" not in logs):\n",
457
+ " return\n",
458
+ " if not self.header_printed:\n",
459
+ " self._print_header()\n",
460
+ " row_values = []\n",
461
+ " for key, _, width in self.columns:\n",
462
+ " value = state.global_step if key == \"step\" else logs.get(key)\n",
463
+ " row_values.append(f\" {self._format_value(key, value):>{width}} \")\n",
464
+ " print(\"|\" + \"|\".join(row_values) + \"|\")\n",
465
+ "\n",
466
+ " if \"reward\" in logs:\n",
467
+ " try:\n",
468
+ " self.rewards.append(float(logs[\"reward\"]))\n",
469
+ " except (TypeError, ValueError):\n",
470
+ " pass\n",
471
+ "\n",
472
+ " def on_train_end(self, args, state, control, **kwargs):\n",
473
+ " if not self.rewards:\n",
474
+ " return\n",
475
+ " first_window = self.rewards[: min(5, len(self.rewards))]\n",
476
+ " last_window = self.rewards[-min(5, len(self.rewards)) :]\n",
477
+ " first_avg = float(_np.mean(first_window))\n",
478
+ " last_avg = float(_np.mean(last_window))\n",
479
+ " overall_avg = float(_np.mean(self.rewards))\n",
480
+ " best_reward = float(_np.max(self.rewards))\n",
481
+ " print(\"+----------------------+------------+\")\n",
482
+ " print(\"| Reward Summary | Value |\")\n",
483
+ " print(\"+----------------------+------------+\")\n",
484
+ " print(f\"| Logged rows | {len(self.rewards):>10} |\")\n",
485
+ " print(f\"| First rows avg | {first_avg:>+10.4f} |\")\n",
486
+ " print(f\"| Last rows avg | {last_avg:>+10.4f} |\")\n",
487
+ " print(f\"| Improvement | {last_avg - first_avg:>+10.4f} |\")\n",
488
+ " print(f\"| Overall avg | {overall_avg:>+10.4f} |\")\n",
489
+ " print(f\"| Best row reward | {best_reward:>+10.4f} |\")\n",
490
+ " print(\"+----------------------+------------+\")\n",
491
+ "\n",
492
+ "# GRPO config - stable for T4 / Colab\n",
493
+ "output_dir = \"gridmind-grpo-trained\"\n",
494
+ "os.makedirs(output_dir, exist_ok=True)\n",
495
+ "\n",
496
  "grpo_config_dict = {\n",
497
+ " \"output_dir\": output_dir,\n",
498
  " \"num_train_epochs\": 1,\n",
499
  " \"max_steps\": 60,\n",
500
  " \"per_device_train_batch_size\": 1,\n",
501
  " \"gradient_accumulation_steps\": 4,\n",
 
 
502
  " \"num_generations\": 4,\n",
503
+ " \"max_prompt_length\": 512,\n",
504
+ " \"max_completion_length\": 80,\n",
505
  " \"learning_rate\": 5e-5,\n",
506
+ " \"lr_scheduler_type\": \"cosine\",\n",
507
+ " \"warmup_ratio\": 0.1,\n",
508
  " \"fp16\": False,\n",
509
  " \"bf16\": False,\n",
510
  " \"max_grad_norm\": 0.0,\n",
511
+ " \"logging_steps\": 5,\n",
512
+ " \"log_completions\": False,\n",
 
513
  " \"save_steps\": 60,\n",
514
  " \"report_to\": \"none\",\n",
515
  " \"disable_tqdm\": True,\n",
 
536
  " train_dataset=dataset,\n",
537
  " reward_funcs=gridmind_reward_fn,\n",
538
  " peft_config=peft_config,\n",
539
+ " callbacks=[MetricsTableCallback()],\n",
540
  ")\n",
541
+ "trainer.remove_callback(PrinterCallback)\n",
542
  "\n",
543
  "print(\"\\nStarting GRPO training (estimated 25-35 min on T4)...\\n\")\n",
544
  "train_result = trainer.train()\n",
545
  "\n",
546
+ "print(f\"\\nTraining complete!\")\n",
547
  "print(f\" Total steps: {train_result.global_step}\")\n",
548
+ "print(f\" Final loss: {train_result.training_loss:.6f}\")\n"
549
  ]
550
  },
551
  {
scripts/train_unsloth.py CHANGED
@@ -33,7 +33,7 @@ import matplotlib.gridspec as gridspec
33
  from datasets import Dataset
34
  from trl import GRPOTrainer, GRPOConfig
35
  from unsloth import FastLanguageModel
36
- from transformers import TrainerCallback
37
 
38
  os.makedirs("results", exist_ok=True)
39
 
@@ -65,69 +65,39 @@ def make_prompt(i):
65
  }]
66
 
67
 
68
- def reward_valid_json(completions, **kwargs):
69
- rewards = []
70
- for completion in completions:
71
- text = completion[0]["content"] if isinstance(completion, list) else completion
72
- try:
73
- match = re.search(r'\{.*?\}', text, re.DOTALL)
74
- if match:
75
- json.loads(match.group())
76
- rewards.append(0.3)
77
- else:
78
- rewards.append(0.0)
79
- except Exception:
80
- rewards.append(0.0)
81
- return rewards
82
-
83
-
84
- def reward_has_required_keys(completions, **kwargs):
85
- required = {"hvac_power_level", "thermal_charge_rate", "batch_job_slot", "load_shed_fraction"}
86
- rewards = []
87
- for completion in completions:
88
- text = completion[0]["content"] if isinstance(completion, list) else completion
89
- try:
90
- match = re.search(r'\{.*?\}', text, re.DOTALL)
91
- if match:
92
- action = json.loads(match.group())
93
- if required.issubset(action.keys()):
94
- rewards.append(0.3)
95
- else:
96
- rewards.append(0.1)
97
- else:
98
- rewards.append(0.0)
99
- except Exception:
100
- rewards.append(0.0)
101
- return rewards
102
-
103
-
104
  ENV_URL = "https://prajwal782007-gridmind.hf.space"
105
 
106
 
107
  class GridMindRewardFn:
108
- """Fixed reward function with environment reset per completion call."""
109
 
110
- def __init__(self, env_url, num_steps=8):
111
  self.env_url = env_url
112
  self.num_steps = num_steps
 
113
  self.call_count = [0]
114
  self.reward_variance_log = []
115
  self.training_rewards = []
 
116
 
117
  def __call__(self, completions, **kwargs):
118
  rewards = []
119
  batch_rewards = []
120
 
 
121
  for i, completion in enumerate(completions):
122
  self.call_count[0] += 1
 
123
 
124
  text = completion[0]["content"] if isinstance(completion, list) else completion
125
 
126
  try:
127
  match = re.search(r'\{.*?\}', text, re.DOTALL)
128
  if not match:
129
- rewards.append(-1.0)
130
- batch_rewards.append(-1.0)
 
 
131
  continue
132
 
133
  action = json.loads(match.group())
@@ -140,8 +110,10 @@ class GridMindRewardFn:
140
  "building_id": 0
141
  }
142
 
143
- seed = 1000 + self.call_count[0]
144
- task_id = (self.call_count[0] % 3) + 1
 
 
145
 
146
  reset_resp = requests.post(
147
  f"{self.env_url}/reset",
@@ -149,11 +121,14 @@ class GridMindRewardFn:
149
  timeout=30
150
  )
151
  if reset_resp.status_code != 200:
152
- rewards.append(-0.5)
153
- batch_rewards.append(-0.5)
 
 
154
  continue
155
 
156
  total_reward = 0.0
 
157
  for _ in range(self.num_steps):
158
  step_resp = requests.post(
159
  f"{self.env_url}/step",
@@ -166,36 +141,39 @@ class GridMindRewardFn:
166
  if isinstance(step_data, list):
167
  step_data = step_data[0]
168
  total_reward += float(step_data.get("reward", 0))
 
169
 
170
- avg_reward = total_reward / self.num_steps if self.num_steps > 0 else 0
 
171
 
172
  grade_resp = requests.get(f"{self.env_url}/grade", timeout=30)
173
  if grade_resp.status_code == 200:
174
  episode_score = float(grade_resp.json().get("score", 0.5))
175
- normalized = max(0.0, min(1.0, (episode_score - 0.4) / 0.32))
176
- final_reward = normalized
177
  else:
178
- final_reward = max(-1.0, min(1.0, avg_reward / 10.0))
179
 
180
  rewards.append(final_reward)
181
  batch_rewards.append(final_reward)
182
  self.training_rewards.append(final_reward)
183
 
184
  except json.JSONDecodeError:
185
- rewards.append(-0.8)
186
- batch_rewards.append(-0.8)
 
 
187
  except Exception as e:
188
  print(f"Reward error: {e}", file=sys.stderr)
189
- rewards.append(-0.5)
190
- batch_rewards.append(-0.5)
 
 
191
 
192
- if len(batch_rewards) > 1 and self.call_count[0] % 10 == 0:
193
- try:
194
- variance = np.var(batch_rewards)
195
- print(f" [Step {self.call_count[0]}] Reward variance: {variance:.4f} | Avg: {np.mean(batch_rewards):.3f}")
196
- self.reward_variance_log.append(variance)
197
- except:
198
- pass
199
 
200
  return rewards
201
 
@@ -627,6 +605,85 @@ class CSVLogCallback(TrainerCallback):
627
  pd.DataFrame(self.log_history).to_csv(self.output_path, index=False)
628
 
629
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
630
  def main():
631
  parser = argparse.ArgumentParser(description="Train GridMind-RL agent with Unsloth GRPO")
632
  parser.add_argument("--env-url", type=str, default="http://localhost:7860", help="OpenEnv server URL")
@@ -683,9 +740,8 @@ def main():
683
  "learning_rate": 5e-6, # FIXED: was 5e-5, too high
684
  "lr_scheduler_type": "cosine",
685
  "warmup_ratio": 0.1,
686
- "logging_steps": 1, # Log every step to produce dense table
687
- "log_completions": True, # Enable completion metrics in table
688
- "num_completions_to_print": 1, # Print 1 completion per step
689
  "save_steps": 100,
690
  "fp16": False, # Disable AMP with quantized models (avoid grad scaler issues)
691
  "bf16": False,
@@ -708,20 +764,21 @@ def main():
708
  print(f"Skipping unsupported GRPOConfig args: {skipped_training_args}")
709
  training_args = GRPOConfig(**training_arg_kwargs)
710
 
711
- reward_fn = GridMindRewardFn(args.env_url, num_steps=8)
 
 
 
 
712
 
713
  trainer = GRPOTrainer(
714
  model=model,
715
  processing_class=tokenizer,
716
  args=training_args,
717
  train_dataset=dataset,
718
- reward_funcs=[
719
- reward_valid_json,
720
- reward_has_required_keys,
721
- reward_fn,
722
- ],
723
- callbacks=[CSVLogCallback(args.output_csv)]
724
  )
 
725
 
726
  print("🚀 Starting GRPO training...")
727
  trainer.train()
 
33
  from datasets import Dataset
34
  from trl import GRPOTrainer, GRPOConfig
35
  from unsloth import FastLanguageModel
36
+ from transformers import PrinterCallback, TrainerCallback
37
 
38
  os.makedirs("results", exist_ok=True)
39
 
 
65
  }]
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  ENV_URL = "https://prajwal782007-gridmind.hf.space"
69
 
70
 
71
  class GridMindRewardFn:
72
+ """Environment-backed reward function with comparable rollouts per GRPO group."""
73
 
74
+ def __init__(self, env_url, num_steps=8, num_generations=4):
75
  self.env_url = env_url
76
  self.num_steps = num_steps
77
+ self.num_generations = max(1, num_generations)
78
  self.call_count = [0]
79
  self.reward_variance_log = []
80
  self.training_rewards = []
81
+ self.group_count = 0
82
 
83
  def __call__(self, completions, **kwargs):
84
  rewards = []
85
  batch_rewards = []
86
 
87
+ batch_start = self.group_count
88
  for i, completion in enumerate(completions):
89
  self.call_count[0] += 1
90
+ group_id = batch_start + (i // self.num_generations)
91
 
92
  text = completion[0]["content"] if isinstance(completion, list) else completion
93
 
94
  try:
95
  match = re.search(r'\{.*?\}', text, re.DOTALL)
96
  if not match:
97
+ final_reward = -1.0
98
+ rewards.append(final_reward)
99
+ batch_rewards.append(final_reward)
100
+ self.training_rewards.append(final_reward)
101
  continue
102
 
103
  action = json.loads(match.group())
 
110
  "building_id": 0
111
  }
112
 
113
+ # Evaluate all generations for the same prompt on the same scenario.
114
+ # This keeps GRPO advantages tied to action quality instead of seed noise.
115
+ seed = 10_000 + group_id
116
+ task_id = (group_id % 4) + 1
117
 
118
  reset_resp = requests.post(
119
  f"{self.env_url}/reset",
 
121
  timeout=30
122
  )
123
  if reset_resp.status_code != 200:
124
+ final_reward = -0.5
125
+ rewards.append(final_reward)
126
+ batch_rewards.append(final_reward)
127
+ self.training_rewards.append(final_reward)
128
  continue
129
 
130
  total_reward = 0.0
131
+ completed_steps = 0
132
  for _ in range(self.num_steps):
133
  step_resp = requests.post(
134
  f"{self.env_url}/step",
 
141
  if isinstance(step_data, list):
142
  step_data = step_data[0]
143
  total_reward += float(step_data.get("reward", 0))
144
+ completed_steps += 1
145
 
146
+ avg_step_reward = total_reward / max(completed_steps, 1)
147
+ normalized_step_reward = max(-1.0, min(1.0, avg_step_reward / 10.0))
148
 
149
  grade_resp = requests.get(f"{self.env_url}/grade", timeout=30)
150
  if grade_resp.status_code == 200:
151
  episode_score = float(grade_resp.json().get("score", 0.5))
152
+ normalized_grade = max(0.0, min(1.0, episode_score))
153
+ final_reward = 0.7 * normalized_grade + 0.3 * normalized_step_reward
154
  else:
155
+ final_reward = normalized_step_reward
156
 
157
  rewards.append(final_reward)
158
  batch_rewards.append(final_reward)
159
  self.training_rewards.append(final_reward)
160
 
161
  except json.JSONDecodeError:
162
+ final_reward = -0.8
163
+ rewards.append(final_reward)
164
+ batch_rewards.append(final_reward)
165
+ self.training_rewards.append(final_reward)
166
  except Exception as e:
167
  print(f"Reward error: {e}", file=sys.stderr)
168
+ final_reward = -0.5
169
+ rewards.append(final_reward)
170
+ batch_rewards.append(final_reward)
171
+ self.training_rewards.append(final_reward)
172
 
173
+ self.group_count += math.ceil(len(completions) / self.num_generations)
174
+
175
+ if len(batch_rewards) > 1:
176
+ self.reward_variance_log.append(float(np.var(batch_rewards)))
 
 
 
177
 
178
  return rewards
179
 
 
605
  pd.DataFrame(self.log_history).to_csv(self.output_path, index=False)
606
 
607
 
608
+ class MetricsTableCallback(TrainerCallback):
609
+ """Print compact GRPO metrics without dumping prompts or completions."""
610
+
611
+ columns = [
612
+ ("step", "Step", 6),
613
+ ("loss", "Loss", 10),
614
+ ("reward", "Reward", 10),
615
+ ("reward_std", "RewardStd", 10),
616
+ ("entropy", "Entropy", 10),
617
+ ("learning_rate", "LR", 11),
618
+ ("num_tokens", "Tokens", 8),
619
+ ("step_time", "StepTime", 10),
620
+ ]
621
+
622
+ def __init__(self):
623
+ self.header_printed = False
624
+ self.rewards = []
625
+
626
+ def _format_value(self, key, value):
627
+ if value is None:
628
+ return "-"
629
+ try:
630
+ if key in {"step", "num_tokens"}:
631
+ return str(int(float(value)))
632
+ if key == "learning_rate":
633
+ return f"{float(value):.2e}"
634
+ return f"{float(value):.4f}"
635
+ except (TypeError, ValueError):
636
+ return str(value)
637
+
638
+ def _print_header(self):
639
+ separator = "+" + "+".join("-" * (width + 2) for _, _, width in self.columns) + "+"
640
+ header = "|" + "|".join(f" {title:<{width}} " for _, title, width in self.columns) + "|"
641
+ print(separator)
642
+ print(header)
643
+ print(separator)
644
+ self.header_printed = True
645
+
646
+ def on_log(self, args, state, control, logs=None, **kwargs):
647
+ if not logs or ("loss" not in logs and "reward" not in logs):
648
+ return
649
+ if not self.header_printed:
650
+ self._print_header()
651
+
652
+ row_values = []
653
+ for key, _, width in self.columns:
654
+ value = state.global_step if key == "step" else logs.get(key)
655
+ row_values.append(f" {self._format_value(key, value):>{width}} ")
656
+ print("|" + "|".join(row_values) + "|")
657
+
658
+ if "reward" in logs:
659
+ try:
660
+ self.rewards.append(float(logs["reward"]))
661
+ except (TypeError, ValueError):
662
+ pass
663
+
664
+ def on_train_end(self, args, state, control, **kwargs):
665
+ if not self.rewards:
666
+ return
667
+
668
+ first_window = self.rewards[: min(5, len(self.rewards))]
669
+ last_window = self.rewards[-min(5, len(self.rewards)) :]
670
+ first_avg = float(np.mean(first_window))
671
+ last_avg = float(np.mean(last_window))
672
+ overall_avg = float(np.mean(self.rewards))
673
+ best_reward = float(np.max(self.rewards))
674
+
675
+ print("+----------------------+------------+")
676
+ print("| Reward Summary | Value |")
677
+ print("+----------------------+------------+")
678
+ print(f"| Logged rows | {len(self.rewards):>10} |")
679
+ print(f"| First rows avg | {first_avg:>+10.4f} |")
680
+ print(f"| Last rows avg | {last_avg:>+10.4f} |")
681
+ print(f"| Improvement | {last_avg - first_avg:>+10.4f} |")
682
+ print(f"| Overall avg | {overall_avg:>+10.4f} |")
683
+ print(f"| Best row reward | {best_reward:>+10.4f} |")
684
+ print("+----------------------+------------+")
685
+
686
+
687
  def main():
688
  parser = argparse.ArgumentParser(description="Train GridMind-RL agent with Unsloth GRPO")
689
  parser.add_argument("--env-url", type=str, default="http://localhost:7860", help="OpenEnv server URL")
 
740
  "learning_rate": 5e-6, # FIXED: was 5e-5, too high
741
  "lr_scheduler_type": "cosine",
742
  "warmup_ratio": 0.1,
743
+ "logging_steps": 5, # Keep 60-step output clean: rows at 5, 10, ..., 60
744
+ "log_completions": False,
 
745
  "save_steps": 100,
746
  "fp16": False, # Disable AMP with quantized models (avoid grad scaler issues)
747
  "bf16": False,
 
764
  print(f"Skipping unsupported GRPOConfig args: {skipped_training_args}")
765
  training_args = GRPOConfig(**training_arg_kwargs)
766
 
767
+ reward_fn = GridMindRewardFn(
768
+ args.env_url,
769
+ num_steps=8,
770
+ num_generations=requested_training_args["num_generations"],
771
+ )
772
 
773
  trainer = GRPOTrainer(
774
  model=model,
775
  processing_class=tokenizer,
776
  args=training_args,
777
  train_dataset=dataset,
778
+ reward_funcs=[reward_fn],
779
+ callbacks=[CSVLogCallback(args.output_csv), MetricsTableCallback()]
 
 
 
 
780
  )
781
+ trainer.remove_callback(PrinterCallback)
782
 
783
  print("🚀 Starting GRPO training...")
784
  trainer.train()