Prajwal782007 commited on
Commit
26e9b86
·
1 Parent(s): 7d89faf

feat: add GRPO training pipeline for GridMind-RL environment via Unsloth and TRL

Browse files
Files changed (1) hide show
  1. scripts/gridmind_grpo_colab.ipynb +261 -44
scripts/gridmind_grpo_colab.ipynb CHANGED
@@ -343,14 +343,16 @@
343
  "torch.cuda.empty_cache()\n",
344
  "\n",
345
  "MODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
346
- "print(f\"Loading {MODEL_NAME} with 4-bit quantization for T4 16GB...\")\n",
 
 
347
  "\n",
348
  "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n",
349
  "if tokenizer.pad_token is None:\n",
350
  " tokenizer.pad_token = tokenizer.eos_token\n",
351
  "tokenizer.padding_side = \"left\" # required for GRPO\n",
352
  "\n",
353
- "# 4-bit quantization - fits safely on T4 16GB\n",
354
  "bnb_config = BitsAndBytesConfig(\n",
355
  " load_in_4bit=True,\n",
356
  " bnb_4bit_compute_dtype=torch.float16,\n",
@@ -366,8 +368,8 @@
366
  ")\n",
367
  "\n",
368
  "print(f\"Model loaded on: {next(model.parameters()).device}\")\n",
369
- "print(f\"Memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB / 16 GB\")\n",
370
- "print(f\"Memory reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB / 16 GB\")"
371
  ]
372
  },
373
  {
@@ -417,8 +419,15 @@
417
  " reset_payload = {\"task_id\": task_id, \"seed\": batch_seed}\n",
418
  " reset_r = _requests.post(f\"{ENV_URL}/reset\", json=reset_payload, timeout=10)\n",
419
  " reset_ok = reset_r.status_code == 200\n",
 
 
 
 
 
 
420
  " except Exception:\n",
421
  " reset_ok = False\n",
 
422
  "\n",
423
  " if not reset_ok:\n",
424
  " return [-0.1] * len(completions)\n",
@@ -450,25 +459,33 @@
450
  " cleaned_action = {}\n",
451
  "\n",
452
  " try:\n",
453
- " cleaned_action[\"hvac_power_level\"] = max(0.0, min(1.0, float(action.get(\"hvac_power_level\", 0.5))))\n",
 
 
454
  " valid_fields += 1\n",
455
  " except Exception:\n",
456
  " cleaned_action[\"hvac_power_level\"] = 0.5\n",
457
  "\n",
458
  " try:\n",
459
- " cleaned_action[\"thermal_charge_rate\"] = max(-1.0, min(1.0, float(action.get(\"thermal_charge_rate\", 0.0))))\n",
 
 
460
  " valid_fields += 1\n",
461
  " except Exception:\n",
462
  " cleaned_action[\"thermal_charge_rate\"] = 0.0\n",
463
  "\n",
464
  " try:\n",
465
- " cleaned_action[\"batch_job_slot\"] = max(0, min(4, int(action.get(\"batch_job_slot\", 0))))\n",
 
 
466
  " valid_fields += 1\n",
467
  " except Exception:\n",
468
  " cleaned_action[\"batch_job_slot\"] = 0\n",
469
  "\n",
470
  " try:\n",
471
- " cleaned_action[\"load_shed_fraction\"] = max(0.0, min(0.5, float(action.get(\"load_shed_fraction\", 0.0))))\n",
 
 
472
  " valid_fields += 1\n",
473
  " except Exception:\n",
474
  " cleaned_action[\"load_shed_fraction\"] = 0.0\n",
@@ -496,18 +513,32 @@
496
  " grid_r = float(comps.get(\"grid_response\", 0.0))\n",
497
  " task_r = float(comps.get(\"task_satisfaction\", 0.0))\n",
498
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
  " if comps:\n",
500
- " composite = (\n",
501
- " cost_r * 0.40 +\n",
502
- " comfort_r * 0.25 +\n",
503
- " grid_r * 0.15 +\n",
504
- " task_r * 0.20 +\n",
505
- " completeness_bonus\n",
506
- " )\n",
507
  " else:\n",
508
- " composite = env_reward * 0.5 + completeness_bonus\n",
509
  "\n",
510
- " composite = max(-0.6, min(0.6, composite))\n",
 
 
511
  "\n",
512
  " rewards.append(composite)\n",
513
  " batch_raw.append(composite)\n",
@@ -641,25 +672,37 @@
641
  " task_type=\"CAUSAL_LM\",\n",
642
  ")\n",
643
  "\n",
644
- "# GRPOConfig - trl==0.23.0 compatible. Pass this as args=, not config=.\n",
645
- "# generation_kwargs is not a GRPOTrainer init parameter in trl 0.23.0.\n",
646
- "grpo_config = GRPOConfig(\n",
647
- " output_dir=\"./gridmind-grpo-output\",\n",
648
- " num_train_epochs=1,\n",
649
- " max_steps=60,\n",
650
- " per_device_train_batch_size=1,\n",
651
- " gradient_accumulation_steps=4,\n",
652
- " max_prompt_length=400,\n",
653
- " max_completion_length=80,\n",
654
- " num_generations=4,\n",
655
- " learning_rate=5e-5,\n",
656
- " fp16=True,\n",
657
- " logging_steps=1,\n",
658
- " save_steps=60,\n",
659
- " report_to=\"none\",\n",
660
- " dataloader_num_workers=0,\n",
661
- " remove_unused_columns=False,\n",
662
- ")\n",
 
 
 
 
 
 
 
 
 
 
 
 
663
  "\n",
664
  "# Confirm the installed TRL API before constructing the trainer.\n",
665
  "import trl\n",
@@ -671,8 +714,10 @@
671
  "print(f\"Uses 'args=': {'args' in params}\")\n",
672
  "print(f\"Uses 'config=': {'config' in params}\")\n",
673
  "\n",
674
- "print(f\"\\nGPU memory: {torch.cuda.memory_allocated()/1e9:.2f} GB used / 16 GB total\")\n",
675
- "print(f\"Free: {(16 - torch.cuda.memory_allocated()/1e9):.2f} GB\")\n",
 
 
676
  "\n",
677
  "# Custom callback to capture loss at every step for graphing.\n",
678
  "from transformers import TrainerCallback\n",
@@ -680,12 +725,16 @@
680
  "step_losses = []\n",
681
  "step_numbers = []\n",
682
  "step_reward_means = []\n",
 
683
  "\n",
684
  "class LossCaptureCallback(TrainerCallback):\n",
685
  " def on_log(self, args, state, control, logs=None, **kwargs):\n",
686
  " if not logs:\n",
687
  " return\n",
688
  " step = state.global_step\n",
 
 
 
689
  " loss = logs.get(\"loss\", logs.get(\"train_loss\", None))\n",
690
  " if loss is not None:\n",
691
  " step_losses.append(float(loss))\n",
@@ -714,7 +763,7 @@
714
  "\n",
715
  "print(\"\\nStarting GRPO training with QLoRA...\")\n",
716
  "print(\"Watch for non-zero loss values. If all zeros, reward variance is still too low.\\n\")\n",
717
- "print(f\"Steps: {grpo_config.max_steps} | Batch: {grpo_config.per_device_train_batch_size} | Generations: {grpo_config.num_generations}\")\n",
718
  "print(\"Estimated time: ~25-35 min on T4\\n\")\n",
719
  "\n",
720
  "train_result = trainer.train()\n",
@@ -732,6 +781,38 @@
732
  "else:\n",
733
  " print(f\"\\nTraining produced gradient signal on {len(non_zero_losses)} steps.\")\n",
734
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
735
  "print(f\"\\nMemory after training: {torch.cuda.memory_allocated()/1e9:.2f} GB\")\n",
736
  "\n",
737
  "# Save LoRA adapter (much smaller than full model)\n",
@@ -863,9 +944,35 @@
863
  "import matplotlib.pyplot as plt\n",
864
  "import matplotlib.gridspec as gridspec\n",
865
  "import numpy as np\n",
 
866
  "import os\n",
867
  "\n",
868
  "os.makedirs(\"results\", exist_ok=True)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
869
  "\n",
870
  "tasks = [1, 2, 3, 4]\n",
871
  "task_labels = [\n",
@@ -963,23 +1070,39 @@
963
  "# Panel B: reward signal over time.\n",
964
  "ax_rew = fig.add_subplot(gs[1, 0])\n",
965
  "style_ax(ax_rew, 'GRPO Training: Reward Signal per Step')\n",
966
- "if training_rewards and len(training_rewards) >= 4:\n",
967
- " raw = training_rewards\n",
968
- " steps_r = list(range(1, len(raw) + 1))\n",
969
- " ax_rew.plot(steps_r, raw, alpha=0.20, color=C['reward'], linewidth=1)\n",
 
 
 
970
  " ax_rew.plot(steps_r, smooth(raw, window=6), color=C['reward'], linewidth=2.5, label='Smoothed reward')\n",
 
 
 
 
 
971
  " if len(steps_r) > 8:\n",
972
  " z = np.polyfit(steps_r, raw, 1)\n",
973
  " p = np.poly1d(z)\n",
974
  " ax_rew.plot(steps_r, p(steps_r), '--', color='white', alpha=0.35, linewidth=1.5,\n",
975
  " label=f'Trend ({z[0]:+.5f}/step)')\n",
976
- " ax_rew.set_xlabel('Reward Function Call')\n",
977
  " ax_rew.set_ylabel('Reward Value')\n",
978
  " ax_rew.legend(fontsize=9, facecolor=C['grid'], labelcolor=C['text'], framealpha=0.9, edgecolor=C['border'])\n",
979
  " if np.var(raw) < 0.01:\n",
980
  " ax_rew.text(0.5, 0.5, 'Low reward variance detected.\\nThis graph exposes weak learning signal.',\n",
981
  " transform=ax_rew.transAxes, ha='center', va='center', color=C['random'], fontsize=10,\n",
982
  " bbox=dict(boxstyle='round', facecolor=C['panel'], alpha=0.8))\n",
 
 
 
 
 
 
 
 
983
  "else:\n",
984
  " ax_rew.text(0.5, 0.5, 'No training rewards captured.\\nRe-run with fixed reward function.',\n",
985
  " transform=ax_rew.transAxes, ha='center', va='center', color=C['subtext'], fontsize=11)\n",
@@ -1027,6 +1150,92 @@
1027
  "fig.savefig(dashboard_path, dpi=180, facecolor=fig.get_facecolor(), bbox_inches='tight')\n",
1028
  "plt.close(fig)\n",
1029
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1030
  "# Separate before/after comparison graph for quick judge inspection.\n",
1031
  "fig2, ax2 = plt.subplots(figsize=(11, 6))\n",
1032
  "fig2.patch.set_facecolor(C['bg'])\n",
@@ -1048,6 +1257,10 @@
1048
  "plt.close(fig2)\n",
1049
  "\n",
1050
  "print(f\"Saved dashboard graph to {dashboard_path}\")\n",
 
 
 
 
1051
  "print(f\"Saved before/after graph to {comparison_path}\")\n",
1052
  "\n",
1053
  "results = {\n",
@@ -1066,8 +1279,12 @@
1066
  " \"training_rewards_log\": training_rewards[-20:] if training_rewards else [],\n",
1067
  " \"training_step_logs\": training_steps_log[-20:] if training_steps_log else [],\n",
1068
  " \"step_losses\": step_losses if 'step_losses' in globals() else [],\n",
 
1069
  " \"graphs\": {\n",
1070
  " \"dashboard\": dashboard_path,\n",
 
 
 
1071
  " \"before_after\": comparison_path,\n",
1072
  " },\n",
1073
  "}\n",
 
343
  "torch.cuda.empty_cache()\n",
344
  "\n",
345
  "MODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
346
+ "gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"CPU\"\n",
347
+ "gpu_total_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0\n",
348
+ "print(f\"Loading {MODEL_NAME} with 4-bit quantization on {gpu_name} ({gpu_total_gb:.2f} GB VRAM)...\")\n",
349
  "\n",
350
  "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n",
351
  "if tokenizer.pad_token is None:\n",
352
  " tokenizer.pad_token = tokenizer.eos_token\n",
353
  "tokenizer.padding_side = \"left\" # required for GRPO\n",
354
  "\n",
355
+ "# 4-bit quantization for memory-efficient QLoRA training\n",
356
  "bnb_config = BitsAndBytesConfig(\n",
357
  " load_in_4bit=True,\n",
358
  " bnb_4bit_compute_dtype=torch.float16,\n",
 
368
  ")\n",
369
  "\n",
370
  "print(f\"Model loaded on: {next(model.parameters()).device}\")\n",
371
+ "print(f\"Memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB / {gpu_total_gb:.2f} GB\")\n",
372
+ "print(f\"Memory reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB / {gpu_total_gb:.2f} GB\")"
373
  ]
374
  },
375
  {
 
419
  " reset_payload = {\"task_id\": task_id, \"seed\": batch_seed}\n",
420
  " reset_r = _requests.post(f\"{ENV_URL}/reset\", json=reset_payload, timeout=10)\n",
421
  " reset_ok = reset_r.status_code == 200\n",
422
+ " reset_data = reset_r.json() if reset_ok else {}\n",
423
+ " reset_obs = reset_data.get(\"observations\", [reset_data.get(\"observation\", {})])\n",
424
+ " if isinstance(reset_obs, list):\n",
425
+ " base_obs = reset_obs[0] if reset_obs else {}\n",
426
+ " else:\n",
427
+ " base_obs = reset_obs or {}\n",
428
  " except Exception:\n",
429
  " reset_ok = False\n",
430
+ " base_obs = {}\n",
431
  "\n",
432
  " if not reset_ok:\n",
433
  " return [-0.1] * len(completions)\n",
 
459
  " cleaned_action = {}\n",
460
  "\n",
461
  " try:\n",
462
+ " if \"hvac_power_level\" not in action:\n",
463
+ " raise KeyError(\"hvac_power_level\")\n",
464
+ " cleaned_action[\"hvac_power_level\"] = max(0.0, min(1.0, float(action[\"hvac_power_level\"])))\n",
465
  " valid_fields += 1\n",
466
  " except Exception:\n",
467
  " cleaned_action[\"hvac_power_level\"] = 0.5\n",
468
  "\n",
469
  " try:\n",
470
+ " if \"thermal_charge_rate\" not in action:\n",
471
+ " raise KeyError(\"thermal_charge_rate\")\n",
472
+ " cleaned_action[\"thermal_charge_rate\"] = max(-1.0, min(1.0, float(action[\"thermal_charge_rate\"])))\n",
473
  " valid_fields += 1\n",
474
  " except Exception:\n",
475
  " cleaned_action[\"thermal_charge_rate\"] = 0.0\n",
476
  "\n",
477
  " try:\n",
478
+ " if \"batch_job_slot\" not in action:\n",
479
+ " raise KeyError(\"batch_job_slot\")\n",
480
+ " cleaned_action[\"batch_job_slot\"] = max(0, min(4, int(action[\"batch_job_slot\"])))\n",
481
  " valid_fields += 1\n",
482
  " except Exception:\n",
483
  " cleaned_action[\"batch_job_slot\"] = 0\n",
484
  "\n",
485
  " try:\n",
486
+ " if \"load_shed_fraction\" not in action:\n",
487
+ " raise KeyError(\"load_shed_fraction\")\n",
488
+ " cleaned_action[\"load_shed_fraction\"] = max(0.0, min(0.5, float(action[\"load_shed_fraction\"])))\n",
489
  " valid_fields += 1\n",
490
  " except Exception:\n",
491
  " cleaned_action[\"load_shed_fraction\"] = 0.0\n",
 
513
  " grid_r = float(comps.get(\"grid_response\", 0.0))\n",
514
  " task_r = float(comps.get(\"task_satisfaction\", 0.0))\n",
515
  "\n",
516
+ " price = float(base_obs.get(\"current_price\", base_obs.get(\"price\", 0.10)))\n",
517
+ " stress = float(base_obs.get(\"grid_stress_signal\", base_obs.get(\"grid_stress\", 0.0)))\n",
518
+ " temp = float(base_obs.get(\"indoor_temperature\", 21.0))\n",
519
+ " charge = cleaned_action[\"thermal_charge_rate\"]\n",
520
+ " hvac = cleaned_action[\"hvac_power_level\"]\n",
521
+ " shed = cleaned_action[\"load_shed_fraction\"]\n",
522
+ "\n",
523
+ " price_signal = 0.0\n",
524
+ " if price < 0.08:\n",
525
+ " price_signal += 0.08 * charge\n",
526
+ " elif price > 0.15:\n",
527
+ " price_signal += 0.08 * (-charge)\n",
528
+ " price_signal -= 0.03 * abs(charge)\n",
529
+ "\n",
530
+ " stress_signal = 0.12 * shed if stress > 0.65 else -0.04 * shed\n",
531
+ " comfort_signal = -0.04 * abs(temp - 21.0) * abs(hvac - 0.5)\n",
532
+ " action_signal = price_signal + stress_signal + comfort_signal\n",
533
+ "\n",
534
  " if comps:\n",
535
+ " component_signal = 0.04 * cost_r + 0.03 * comfort_r + 0.03 * grid_r + 0.03 * task_r\n",
 
 
 
 
 
 
536
  " else:\n",
537
+ " component_signal = 0.0\n",
538
  "\n",
539
+ " # Center raw env reward to avoid saturating all valid JSON at the clip boundary.\n",
540
+ " composite = (env_reward - 0.5) * 0.35 + component_signal + action_signal + completeness_bonus\n",
541
+ " composite = max(-0.45, min(0.45, composite))\n",
542
  "\n",
543
  " rewards.append(composite)\n",
544
  " batch_raw.append(composite)\n",
 
672
  " task_type=\"CAUSAL_LM\",\n",
673
  ")\n",
674
  "\n",
675
+ "# GRPOConfig compatibility shim. HF/Colab images can have TRL builds whose\n",
676
+ "# GRPOConfig fields differ, so only pass arguments accepted by this runtime.\n",
677
+ "grpo_config_requested = {\n",
678
+ " \"output_dir\": \"./gridmind-grpo-output\",\n",
679
+ " \"num_train_epochs\": 1,\n",
680
+ " \"max_steps\": 60,\n",
681
+ " \"per_device_train_batch_size\": 1,\n",
682
+ " \"gradient_accumulation_steps\": 4,\n",
683
+ " \"max_prompt_length\": 400,\n",
684
+ " \"max_completion_length\": 80,\n",
685
+ " \"max_new_tokens\": 80,\n",
686
+ " \"num_generations\": 4,\n",
687
+ " \"learning_rate\": 5e-5,\n",
688
+ " \"fp16\": True,\n",
689
+ " \"logging_steps\": 1,\n",
690
+ " \"save_steps\": 60,\n",
691
+ " \"report_to\": \"none\",\n",
692
+ " \"dataloader_num_workers\": 0,\n",
693
+ " \"remove_unused_columns\": False,\n",
694
+ "}\n",
695
+ "\n",
696
+ "grpo_config_sig = inspect.signature(GRPOConfig.__init__)\n",
697
+ "grpo_config_params = set(grpo_config_sig.parameters.keys()) - {\"self\"}\n",
698
+ "grpo_config_kwargs = {k: v for k, v in grpo_config_requested.items() if k in grpo_config_params}\n",
699
+ "if \"max_completion_length\" in grpo_config_kwargs and \"max_new_tokens\" in grpo_config_kwargs:\n",
700
+ " grpo_config_kwargs.pop(\"max_new_tokens\")\n",
701
+ "skipped_config_keys = [k for k in grpo_config_requested if k not in grpo_config_params]\n",
702
+ "print(f\"GRPOConfig accepted keys: {sorted(grpo_config_kwargs.keys())}\")\n",
703
+ "print(f\"GRPOConfig skipped unsupported keys: {skipped_config_keys}\")\n",
704
+ "\n",
705
+ "grpo_config = GRPOConfig(**grpo_config_kwargs)\n",
706
  "\n",
707
  "# Confirm the installed TRL API before constructing the trainer.\n",
708
  "import trl\n",
 
714
  "print(f\"Uses 'args=': {'args' in params}\")\n",
715
  "print(f\"Uses 'config=': {'config' in params}\")\n",
716
  "\n",
717
+ "gpu_total_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0\n",
718
+ "gpu_used_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0\n",
719
+ "print(f\"\\nGPU memory: {gpu_used_gb:.2f} GB used / {gpu_total_gb:.2f} GB total\")\n",
720
+ "print(f\"Free: {max(0, gpu_total_gb - gpu_used_gb):.2f} GB\")\n",
721
  "\n",
722
  "# Custom callback to capture loss at every step for graphing.\n",
723
  "from transformers import TrainerCallback\n",
 
725
  "step_losses = []\n",
726
  "step_numbers = []\n",
727
  "step_reward_means = []\n",
728
+ "training_log_history = []\n",
729
  "\n",
730
  "class LossCaptureCallback(TrainerCallback):\n",
731
  " def on_log(self, args, state, control, logs=None, **kwargs):\n",
732
  " if not logs:\n",
733
  " return\n",
734
  " step = state.global_step\n",
735
+ " row = {\"step\": step}\n",
736
+ " row.update({k: float(v) if isinstance(v, (int, float)) else v for k, v in logs.items()})\n",
737
+ " training_log_history.append(row)\n",
738
  " loss = logs.get(\"loss\", logs.get(\"train_loss\", None))\n",
739
  " if loss is not None:\n",
740
  " step_losses.append(float(loss))\n",
 
763
  "\n",
764
  "print(\"\\nStarting GRPO training with QLoRA...\")\n",
765
  "print(\"Watch for non-zero loss values. If all zeros, reward variance is still too low.\\n\")\n",
766
+ "print(f\"Steps: {getattr(grpo_config, 'max_steps', 60)} | Batch: {getattr(grpo_config, 'per_device_train_batch_size', 1)} | Generations: {getattr(grpo_config, 'num_generations', 4)}\")\n",
767
  "print(\"Estimated time: ~25-35 min on T4\\n\")\n",
768
  "\n",
769
  "train_result = trainer.train()\n",
 
781
  "else:\n",
782
  " print(f\"\\nTraining produced gradient signal on {len(non_zero_losses)} steps.\")\n",
783
  "\n",
784
+ "# Preserve the exact tabular statistics that TRL prints during training.\n",
785
+ "try:\n",
786
+ " import pandas as pd\n",
787
+ " trainer_log_rows = [r for r in trainer.state.log_history if \"loss\" in r or \"reward\" in r or \"rewards / reward_func / mean\" in r]\n",
788
+ " if trainer_log_rows:\n",
789
+ " training_metrics_df = pd.DataFrame(trainer_log_rows)\n",
790
+ " if \"step\" not in training_metrics_df.columns:\n",
791
+ " training_metrics_df.insert(0, \"step\", range(1, len(training_metrics_df) + 1))\n",
792
+ " elif training_log_history:\n",
793
+ " training_metrics_df = pd.DataFrame(training_log_history)\n",
794
+ " else:\n",
795
+ " training_metrics_df = pd.DataFrame({\"step\": step_numbers, \"loss\": step_losses, \"reward\": step_reward_means[:len(step_numbers)]})\n",
796
+ "\n",
797
+ " os.makedirs(\"results\", exist_ok=True)\n",
798
+ " training_metrics_path = \"results/gridmind_training_metrics.csv\"\n",
799
+ " training_metrics_df.to_csv(training_metrics_path, index=False)\n",
800
+ " print(f\"\\nSaved TRL training metrics table to {training_metrics_path}\")\n",
801
+ "\n",
802
+ " preferred_cols = [\n",
803
+ " \"step\", \"loss\", \"reward\", \"reward_std\",\n",
804
+ " \"completions / mean_length\", \"completions / min_length\", \"completions / max_length\",\n",
805
+ " \"completions / clipped_ratio\", \"kl\", \"rewards / reward_func / mean\", \"rewards / reward_func / std\",\n",
806
+ " ]\n",
807
+ " display_cols = [c for c in preferred_cols if c in training_metrics_df.columns]\n",
808
+ " if display_cols:\n",
809
+ " print(\"\\nTraining metrics preview:\")\n",
810
+ " display(training_metrics_df[display_cols].tail(10))\n",
811
+ "except Exception as e:\n",
812
+ " training_metrics_df = None\n",
813
+ " training_metrics_path = None\n",
814
+ " print(f\"Could not build training metrics table: {e}\")\n",
815
+ "\n",
816
  "print(f\"\\nMemory after training: {torch.cuda.memory_allocated()/1e9:.2f} GB\")\n",
817
  "\n",
818
  "# Save LoRA adapter (much smaller than full model)\n",
 
944
  "import matplotlib.pyplot as plt\n",
945
  "import matplotlib.gridspec as gridspec\n",
946
  "import numpy as np\n",
947
+ "import pandas as pd\n",
948
  "import os\n",
949
  "\n",
950
  "os.makedirs(\"results\", exist_ok=True)\n",
951
+ "os.makedirs(\"plots\", exist_ok=True)\n",
952
+ "\n",
953
+ "# Build a TRL-style metrics table from trainer logs. This matches the tabular\n",
954
+ "# output with columns like reward, reward_std, completion lengths, tools, and KL.\n",
955
+ "if 'training_metrics_df' not in globals() or training_metrics_df is None:\n",
956
+ " trainer_log_rows = [r for r in trainer.state.log_history if \"loss\" in r or \"reward\" in r or \"rewards / reward_func / mean\" in r]\n",
957
+ " training_metrics_df = pd.DataFrame(trainer_log_rows if trainer_log_rows else training_log_history)\n",
958
+ " if not training_metrics_df.empty and \"step\" not in training_metrics_df.columns:\n",
959
+ " training_metrics_df.insert(0, \"step\", range(1, len(training_metrics_df) + 1))\n",
960
+ "\n",
961
+ "training_metrics_path = \"results/gridmind_training_metrics.csv\"\n",
962
+ "if not training_metrics_df.empty:\n",
963
+ " training_metrics_df.to_csv(training_metrics_path, index=False)\n",
964
+ " print(f\"Saved TRL metrics table to {training_metrics_path}\")\n",
965
+ " preferred_cols = [\n",
966
+ " \"step\", \"loss\", \"reward\", \"reward_std\",\n",
967
+ " \"completions / mean_length\", \"completions / min_length\", \"completions / max_length\",\n",
968
+ " \"completions / clipped_ratio\", \"completions / mean_terminated_length\",\n",
969
+ " \"completions / min_terminated_length\", \"completions / max_terminated_length\",\n",
970
+ " \"tools / call_frequency\", \"tools / failure_frequency\", \"kl\",\n",
971
+ " \"rewards / reward_func / mean\", \"rewards / reward_func / std\",\n",
972
+ " ]\n",
973
+ " display_cols = [c for c in preferred_cols if c in training_metrics_df.columns]\n",
974
+ " if display_cols:\n",
975
+ " display(training_metrics_df[display_cols].tail(20))\n",
976
  "\n",
977
  "tasks = [1, 2, 3, 4]\n",
978
  "task_labels = [\n",
 
1070
  "# Panel B: reward signal over time.\n",
1071
  "ax_rew = fig.add_subplot(gs[1, 0])\n",
1072
  "style_ax(ax_rew, 'GRPO Training: Reward Signal per Step')\n",
1073
+ "if not training_metrics_df.empty and (\"reward\" in training_metrics_df.columns or \"rewards / reward_func / mean\" in training_metrics_df.columns):\n",
1074
+ " reward_col = \"reward\" if \"reward\" in training_metrics_df.columns else \"rewards / reward_func / mean\"\n",
1075
+ " std_col = \"reward_std\" if \"reward_std\" in training_metrics_df.columns else \"rewards / reward_func / std\"\n",
1076
+ " reward_df = training_metrics_df[[\"step\", reward_col] + ([std_col] if std_col in training_metrics_df.columns else [])].dropna(subset=[reward_col])\n",
1077
+ " steps_r = reward_df[\"step\"].astype(float).tolist()\n",
1078
+ " raw = reward_df[reward_col].astype(float).tolist()\n",
1079
+ " ax_rew.plot(steps_r, raw, alpha=0.28, color=C['reward'], linewidth=1.2, marker='o', markersize=2, label='Logged reward')\n",
1080
  " ax_rew.plot(steps_r, smooth(raw, window=6), color=C['reward'], linewidth=2.5, label='Smoothed reward')\n",
1081
+ " if std_col in reward_df.columns:\n",
1082
+ " std = reward_df[std_col].fillna(0).astype(float).to_numpy()\n",
1083
+ " raw_np = np.array(raw)\n",
1084
+ " steps_np = np.array(steps_r)\n",
1085
+ " ax_rew.fill_between(steps_np, raw_np - std, raw_np + std, color=C['reward'], alpha=0.12, label='Reward std')\n",
1086
  " if len(steps_r) > 8:\n",
1087
  " z = np.polyfit(steps_r, raw, 1)\n",
1088
  " p = np.poly1d(z)\n",
1089
  " ax_rew.plot(steps_r, p(steps_r), '--', color='white', alpha=0.35, linewidth=1.5,\n",
1090
  " label=f'Trend ({z[0]:+.5f}/step)')\n",
1091
+ " ax_rew.set_xlabel('Training Step')\n",
1092
  " ax_rew.set_ylabel('Reward Value')\n",
1093
  " ax_rew.legend(fontsize=9, facecolor=C['grid'], labelcolor=C['text'], framealpha=0.9, edgecolor=C['border'])\n",
1094
  " if np.var(raw) < 0.01:\n",
1095
  " ax_rew.text(0.5, 0.5, 'Low reward variance detected.\\nThis graph exposes weak learning signal.',\n",
1096
  " transform=ax_rew.transAxes, ha='center', va='center', color=C['random'], fontsize=10,\n",
1097
  " bbox=dict(boxstyle='round', facecolor=C['panel'], alpha=0.8))\n",
1098
+ "elif training_rewards and len(training_rewards) >= 4:\n",
1099
+ " raw = training_rewards\n",
1100
+ " steps_r = list(range(1, len(raw) + 1))\n",
1101
+ " ax_rew.plot(steps_r, raw, alpha=0.20, color=C['reward'], linewidth=1)\n",
1102
+ " ax_rew.plot(steps_r, smooth(raw, window=6), color=C['reward'], linewidth=2.5, label='Smoothed reward')\n",
1103
+ " ax_rew.set_xlabel('Reward Function Call')\n",
1104
+ " ax_rew.set_ylabel('Reward Value')\n",
1105
+ " ax_rew.legend(fontsize=9, facecolor=C['grid'], labelcolor=C['text'], framealpha=0.9, edgecolor=C['border'])\n",
1106
  "else:\n",
1107
  " ax_rew.text(0.5, 0.5, 'No training rewards captured.\\nRe-run with fixed reward function.',\n",
1108
  " transform=ax_rew.transAxes, ha='center', va='center', color=C['subtext'], fontsize=11)\n",
 
1150
  "fig.savefig(dashboard_path, dpi=180, facecolor=fig.get_facecolor(), bbox_inches='tight')\n",
1151
  "plt.close(fig)\n",
1152
  "\n",
1153
+ "# Standalone training reward curve for reports/slides.\n",
1154
+ "reward_curve_path = 'results/gridmind_training_reward_curve.png'\n",
1155
+ "fig_reward, ax_reward = plt.subplots(figsize=(11, 6))\n",
1156
+ "fig_reward.patch.set_facecolor(C['bg'])\n",
1157
+ "style_ax(ax_reward, 'Training Reward Curve')\n",
1158
+ "if not training_metrics_df.empty and (\"reward\" in training_metrics_df.columns or \"rewards / reward_func / mean\" in training_metrics_df.columns):\n",
1159
+ " reward_col = \"reward\" if \"reward\" in training_metrics_df.columns else \"rewards / reward_func / mean\"\n",
1160
+ " std_col = \"reward_std\" if \"reward_std\" in training_metrics_df.columns else \"rewards / reward_func / std\"\n",
1161
+ " reward_df = training_metrics_df[[\"step\", reward_col] + ([std_col] if std_col in training_metrics_df.columns else [])].dropna(subset=[reward_col])\n",
1162
+ " xs = reward_df[\"step\"].astype(float).to_numpy()\n",
1163
+ " ys = reward_df[reward_col].astype(float).to_numpy()\n",
1164
+ " ax_reward.plot(xs, ys, color=C['reward'], alpha=0.35, linewidth=1.2, marker='o', markersize=2, label='Reward')\n",
1165
+ " ax_reward.plot(xs, smooth(ys.tolist(), window=6), color=C['trained'], linewidth=2.5, label='Smoothed reward')\n",
1166
+ " if std_col in reward_df.columns:\n",
1167
+ " std = reward_df[std_col].fillna(0).astype(float).to_numpy()\n",
1168
+ " ax_reward.fill_between(xs, ys - std, ys + std, color=C['reward'], alpha=0.12, label='Reward std')\n",
1169
+ " if len(xs) > 8:\n",
1170
+ " z = np.polyfit(xs, ys, 1)\n",
1171
+ " p = np.poly1d(z)\n",
1172
+ " ax_reward.plot(xs, p(xs), '--', color=C['text'], alpha=0.45, linewidth=1.5, label=f'Trend ({z[0]:+.5f}/step)')\n",
1173
+ " ax_reward.set_xlabel('Training Step')\n",
1174
+ " ax_reward.set_ylabel('Reward')\n",
1175
+ " ax_reward.legend(facecolor=C['grid'], labelcolor=C['text'], edgecolor=C['border'])\n",
1176
+ "else:\n",
1177
+ " ax_reward.text(0.5, 0.5, 'No logged reward data available.', transform=ax_reward.transAxes,\n",
1178
+ " ha='center', va='center', color=C['subtext'])\n",
1179
+ "fig_reward.savefig(reward_curve_path, dpi=180, facecolor=fig_reward.get_facecolor(), bbox_inches='tight')\n",
1180
+ "plt.close(fig_reward)\n",
1181
+ "\n",
1182
+ "# Reference-style simple plots from trainer.state.log_history.\n",
1183
+ "log_history = trainer.state.log_history\n",
1184
+ "simple_steps = []\n",
1185
+ "simple_rewards = []\n",
1186
+ "simple_losses = []\n",
1187
+ "simple_loss_steps = []\n",
1188
+ "\n",
1189
+ "for entry in log_history:\n",
1190
+ " reward_key = \"reward\" if \"reward\" in entry else (\"rewards / reward_func / mean\" if \"rewards / reward_func / mean\" in entry else None)\n",
1191
+ " if reward_key is not None:\n",
1192
+ " simple_steps.append(entry.get(\"step\", len(simple_steps) + 1))\n",
1193
+ " simple_rewards.append(float(entry[reward_key]))\n",
1194
+ " if \"loss\" in entry:\n",
1195
+ " simple_loss_steps.append(entry.get(\"step\", len(simple_loss_steps) + 1))\n",
1196
+ " simple_losses.append(float(entry[\"loss\"]))\n",
1197
+ "\n",
1198
+ "# Plot 1: Reward over training\n",
1199
+ "simple_reward_curve_path = \"plots/reward_curve.png\"\n",
1200
+ "fig_simple_reward, ax_simple_reward = plt.subplots(1, 1, figsize=(10, 5))\n",
1201
+ "if simple_rewards:\n",
1202
+ " ax_simple_reward.plot(simple_steps[:len(simple_rewards)], simple_rewards, color=\"#4285f4\", linewidth=2, label=\"GRPO Reward\")\n",
1203
+ " if len(simple_rewards) > 5:\n",
1204
+ " window = max(3, len(simple_rewards) // 10)\n",
1205
+ " smoothed = [\n",
1206
+ " sum(simple_rewards[max(0, i-window):i+1]) / len(simple_rewards[max(0, i-window):i+1])\n",
1207
+ " for i in range(len(simple_rewards))\n",
1208
+ " ]\n",
1209
+ " ax_simple_reward.plot(simple_steps[:len(smoothed)], smoothed, color=\"#ea4335\", linewidth=2, linestyle=\"--\", label=f\"Smoothed (window={window})\")\n",
1210
+ "else:\n",
1211
+ " ax_simple_reward.text(0.5, 0.5, \"No reward logs found\", transform=ax_simple_reward.transAxes, ha=\"center\", va=\"center\")\n",
1212
+ "ax_simple_reward.set_xlabel(\"Training Step\", fontsize=12)\n",
1213
+ "ax_simple_reward.set_ylabel(\"Reward\", fontsize=12)\n",
1214
+ "ax_simple_reward.set_title(\"GridMind-RL GRPO Training - Reward Curve\", fontsize=14, fontweight=\"bold\")\n",
1215
+ "ax_simple_reward.legend()\n",
1216
+ "ax_simple_reward.grid(True, alpha=0.3)\n",
1217
+ "fig_simple_reward.tight_layout()\n",
1218
+ "fig_simple_reward.savefig(simple_reward_curve_path, dpi=150)\n",
1219
+ "plt.show()\n",
1220
+ "print(f\"Saved: {simple_reward_curve_path}\")\n",
1221
+ "\n",
1222
+ "# Plot 2: Loss over training\n",
1223
+ "simple_loss_curve_path = \"plots/loss_curve.png\"\n",
1224
+ "if simple_losses:\n",
1225
+ " fig_simple_loss, ax_simple_loss = plt.subplots(1, 1, figsize=(10, 5))\n",
1226
+ " ax_simple_loss.plot(simple_loss_steps[:len(simple_losses)], simple_losses, color=\"#34a853\", linewidth=2)\n",
1227
+ " ax_simple_loss.set_xlabel(\"Training Step\", fontsize=12)\n",
1228
+ " ax_simple_loss.set_ylabel(\"Loss\", fontsize=12)\n",
1229
+ " ax_simple_loss.set_title(\"GridMind-RL GRPO Training - Loss Curve\", fontsize=14, fontweight=\"bold\")\n",
1230
+ " ax_simple_loss.grid(True, alpha=0.3)\n",
1231
+ " fig_simple_loss.tight_layout()\n",
1232
+ " fig_simple_loss.savefig(simple_loss_curve_path, dpi=150)\n",
1233
+ " plt.show()\n",
1234
+ " print(f\"Saved: {simple_loss_curve_path}\")\n",
1235
+ "else:\n",
1236
+ " simple_loss_curve_path = None\n",
1237
+ " print(\"No loss logs found; skipped plots/loss_curve.png\")\n",
1238
+ "\n",
1239
  "# Separate before/after comparison graph for quick judge inspection.\n",
1240
  "fig2, ax2 = plt.subplots(figsize=(11, 6))\n",
1241
  "fig2.patch.set_facecolor(C['bg'])\n",
 
1257
  "plt.close(fig2)\n",
1258
  "\n",
1259
  "print(f\"Saved dashboard graph to {dashboard_path}\")\n",
1260
+ "print(f\"Saved training reward curve to {reward_curve_path}\")\n",
1261
+ "print(f\"Saved simple reward curve to {simple_reward_curve_path}\")\n",
1262
+ "if simple_loss_curve_path:\n",
1263
+ " print(f\"Saved simple loss curve to {simple_loss_curve_path}\")\n",
1264
  "print(f\"Saved before/after graph to {comparison_path}\")\n",
1265
  "\n",
1266
  "results = {\n",
 
1279
  " \"training_rewards_log\": training_rewards[-20:] if training_rewards else [],\n",
1280
  " \"training_step_logs\": training_steps_log[-20:] if training_steps_log else [],\n",
1281
  " \"step_losses\": step_losses if 'step_losses' in globals() else [],\n",
1282
+ " \"training_metrics_table\": training_metrics_path,\n",
1283
  " \"graphs\": {\n",
1284
  " \"dashboard\": dashboard_path,\n",
1285
+ " \"training_reward_curve\": reward_curve_path,\n",
1286
+ " \"simple_reward_curve\": simple_reward_curve_path,\n",
1287
+ " \"simple_loss_curve\": simple_loss_curve_path,\n",
1288
  " \"before_after\": comparison_path,\n",
1289
  " },\n",
1290
  "}\n",