Prajwal782007 commited on
Commit
29b9cd0
·
1 Parent(s): 9d42d14

feat: add GridMind GRPO training notebook for Meta PyTorch OpenEnv hackathon

Browse files
Files changed (1) hide show
  1. scripts/gridmind_grpo_colab.ipynb +114 -195
scripts/gridmind_grpo_colab.ipynb CHANGED
@@ -695,6 +695,7 @@
695
  " \"logging_steps\": 1,\n",
696
  " \"save_steps\": 60,\n",
697
  " \"report_to\": \"none\",\n",
 
698
  " \"dataloader_num_workers\": 0,\n",
699
  " \"remove_unused_columns\": False,\n",
700
  "}\n",
@@ -732,6 +733,53 @@
732
  "step_numbers = []\n",
733
  "step_reward_means = []\n",
734
  "training_log_history = []\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
735
  "\n",
736
  "class LossCaptureCallback(TrainerCallback):\n",
737
  " def on_log(self, args, state, control, logs=None, **kwargs):\n",
@@ -740,7 +788,25 @@
740
  " step = state.global_step\n",
741
  " row = {\"step\": step}\n",
742
  " row.update({k: float(v) if isinstance(v, (int, float)) else v for k, v in logs.items()})\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
743
  " training_log_history.append(row)\n",
 
 
744
  " loss = logs.get(\"loss\", logs.get(\"train_loss\", None))\n",
745
  " if loss is not None:\n",
746
  " step_losses.append(float(loss))\n",
@@ -767,6 +833,17 @@
767
  " callbacks=[LossCaptureCallback()],\n",
768
  ")\n",
769
  "\n",
 
 
 
 
 
 
 
 
 
 
 
770
  "print(\"\\nStarting GRPO training with QLoRA...\")\n",
771
  "print(\"Watch for non-zero loss values. If all zeros, reward variance is still too low.\\n\")\n",
772
  "print(f\"Steps: {getattr(grpo_config, 'max_steps', 60)} | Batch: {getattr(grpo_config, 'per_device_train_batch_size', 1)} | Generations: {getattr(grpo_config, 'num_generations', 4)}\")\n",
@@ -792,7 +869,9 @@
792
  " import pandas as pd\n",
793
  " import numpy as np\n",
794
  " trainer_log_rows = [r for r in trainer.state.log_history if \"loss\" in r or \"reward\" in r or \"rewards / reward_func / mean\" in r]\n",
795
- " if trainer_log_rows:\n",
 
 
796
  " training_metrics_df = pd.DataFrame(trainer_log_rows)\n",
797
  " if \"step\" not in training_metrics_df.columns:\n",
798
  " training_metrics_df.insert(0, \"step\", range(1, len(training_metrics_df) + 1))\n",
@@ -972,7 +1051,6 @@
972
  "import matplotlib\n",
973
  "matplotlib.use('Agg')\n",
974
  "import matplotlib.pyplot as plt\n",
975
- "import matplotlib.gridspec as gridspec\n",
976
  "import numpy as np\n",
977
  "import pandas as pd\n",
978
  "import os\n",
@@ -980,11 +1058,13 @@
980
  "os.makedirs(\"results\", exist_ok=True)\n",
981
  "os.makedirs(\"plots\", exist_ok=True)\n",
982
  "\n",
983
- "# Build a TRL-style metrics table from trainer logs. This matches the tabular\n",
984
- "# output with columns like reward, reward_std, completion lengths, tools, and KL.\n",
985
  "if 'training_metrics_df' not in globals() or training_metrics_df is None:\n",
986
  " trainer_log_rows = [r for r in trainer.state.log_history if \"loss\" in r or \"reward\" in r or \"rewards / reward_func / mean\" in r]\n",
987
- " training_metrics_df = pd.DataFrame(trainer_log_rows if trainer_log_rows else training_log_history)\n",
 
 
 
988
  " if not training_metrics_df.empty and \"step\" not in training_metrics_df.columns:\n",
989
  " training_metrics_df.insert(0, \"step\", range(1, len(training_metrics_df) + 1))\n",
990
  "\n",
@@ -992,17 +1072,7 @@
992
  "if not training_metrics_df.empty:\n",
993
  " training_metrics_df.to_csv(training_metrics_path, index=False)\n",
994
  " print(f\"Saved TRL metrics table to {training_metrics_path}\")\n",
995
- " preferred_cols = [\n",
996
- " \"step\", \"loss\", \"reward\", \"reward_std\",\n",
997
- " \"completions / mean_length\", \"completions / min_length\", \"completions / max_length\",\n",
998
- " \"completions / clipped_ratio\", \"completions / mean_terminated_length\",\n",
999
- " \"completions / min_terminated_length\", \"completions / max_terminated_length\",\n",
1000
- " \"tools / call_frequency\", \"tools / failure_frequency\", \"kl\",\n",
1001
- " \"rewards / reward_func / mean\", \"rewards / reward_func / std\",\n",
1002
- " ]\n",
1003
- " display_cols = [c for c in preferred_cols if c in training_metrics_df.columns]\n",
1004
- " if display_cols:\n",
1005
- " print(\"Step 6 already displayed the full TRL-style metrics table; Step 8 only reuses it for plots.\")\n",
1006
  "\n",
1007
  "tasks = [1, 2, 3, 4]\n",
1008
  "task_labels = [\n",
@@ -1034,179 +1104,31 @@
1034
  " out.append(sum(w) / len(w))\n",
1035
  " return out\n",
1036
  "\n",
1037
- "C = {\n",
1038
- " 'bg': '#0d1117', 'panel': '#161b22', 'grid': '#21262d',\n",
1039
- " 'text': '#e6edf3', 'subtext': '#8b949e', 'random': '#f85149',\n",
1040
- " 'heuristic': '#58a6ff', 'trained': '#3fb950', 'reward': '#d29922',\n",
1041
- " 'loss': '#bc8cff', 'border': '#30363d',\n",
1042
- "}\n",
1043
- "\n",
1044
- "def style_ax(ax, title):\n",
1045
- " ax.set_facecolor(C['panel'])\n",
1046
- " ax.set_title(title, color=C['text'], fontsize=12, fontweight='bold', pad=10)\n",
1047
- " ax.tick_params(colors=C['subtext'], labelsize=9)\n",
1048
- " ax.grid(alpha=0.15, color=C['grid'], linewidth=0.8)\n",
1049
- " for spine in ax.spines.values():\n",
1050
- " spine.set_edgecolor(C['border'])\n",
1051
- " ax.xaxis.label.set_color(C['subtext'])\n",
1052
- " ax.yaxis.label.set_color(C['subtext'])\n",
1053
- "\n",
1054
- "fig = plt.figure(figsize=(12, 8))\n",
1055
- "fig.patch.set_facecolor(C['bg'])\n",
1056
- "gs = gridspec.GridSpec(2, 2, figure=fig, hspace=0.50, wspace=0.38,\n",
1057
- " left=0.07, right=0.97, top=0.91, bottom=0.07)\n",
1058
- "\n",
1059
- "# Panel A: policy comparison across all tasks.\n",
1060
- "ax_bar = fig.add_subplot(gs[0, :])\n",
1061
- "ax_bar.set_facecolor(C['panel'])\n",
1062
- "x = np.arange(len(tasks))\n",
1063
- "w = 0.24\n",
1064
- "br = ax_bar.bar(x - w, random_vals, w, label='Random Policy', color=C['random'], alpha=0.85, zorder=3, edgecolor=C['bg'], linewidth=0.5)\n",
1065
- "bh = ax_bar.bar(x, heuristic_vals, w, label='Heuristic Baseline', color=C['heuristic'], alpha=0.85, zorder=3, edgecolor=C['bg'], linewidth=0.5)\n",
1066
- "bt = ax_bar.bar(x + w, trained_vals, w, label='Trained LLM (GRPO)', color=C['trained'], alpha=0.85, zorder=3, edgecolor=C['bg'], linewidth=0.5)\n",
1067
- "\n",
1068
- "for bars, col in [(br, C['random']), (bh, C['heuristic']), (bt, C['trained'])]:\n",
1069
- " for bar in bars:\n",
1070
- " h = bar.get_height()\n",
1071
- " ax_bar.text(bar.get_x() + bar.get_width()/2, h + 0.012, f'{h:.3f}',\n",
1072
- " ha='center', va='bottom', fontsize=8.5, color=col, fontweight='bold', zorder=4)\n",
1073
- "\n",
1074
- "for i in range(len(tasks)):\n",
1075
- " h_val = heuristic_vals[i]\n",
1076
- " t_val = trained_vals[i]\n",
1077
- " pct = ((t_val - h_val) / h_val * 100) if h_val > 0 else 0\n",
1078
- " color = C['trained'] if pct >= 0 else C['random']\n",
1079
- " sign = '+' if pct >= 0 else '-'\n",
1080
- " ax_bar.text(x[i] + w, max(h_val, t_val) + 0.06, f'{sign}{abs(pct):.1f}%',\n",
1081
- " ha='center', fontsize=10, color=color, fontweight='bold', zorder=4)\n",
1082
- "\n",
1083
- "ax_bar.axhline(baseline_avg, color=C['heuristic'], linestyle=':', linewidth=1.5, alpha=0.6,\n",
1084
- " label=f'Heuristic avg ({baseline_avg:.3f})', zorder=2)\n",
1085
- "ax_bar.axhline(trained_avg, color=C['trained'], linestyle=':', linewidth=1.5, alpha=0.6,\n",
1086
- " label=f'Trained avg ({trained_avg:.3f})', zorder=2)\n",
1087
- "ax_bar.set_xticks(x)\n",
1088
- "ax_bar.set_xticklabels(task_labels, color=C['text'], fontsize=10)\n",
1089
- "ax_bar.set_ylabel('Grade Score (0.0 to 1.0, higher is better)', fontsize=11, color=C['subtext'])\n",
1090
- "ax_bar.set_ylim(0, 1.15)\n",
1091
- "ax_bar.set_title('GridMind-RL Policy Performance Across All 4 Hackathon Themes\\nRandom vs Heuristic Baseline vs GRPO Fine-Tuned LLM',\n",
1092
- " color=C['text'], fontsize=13, fontweight='bold', pad=12)\n",
1093
- "ax_bar.legend(fontsize=10, facecolor=C['grid'], labelcolor=C['text'], framealpha=0.9,\n",
1094
- " edgecolor=C['border'], ncol=3, loc='upper right')\n",
1095
- "ax_bar.grid(axis='y', alpha=0.15, color=C['grid'], zorder=1)\n",
1096
- "for spine in ax_bar.spines.values():\n",
1097
- " spine.set_edgecolor(C['border'])\n",
1098
- "ax_bar.tick_params(colors=C['subtext'])\n",
1099
- "\n",
1100
- "# Panel B: reward signal over time.\n",
1101
- "ax_rew = fig.add_subplot(gs[1, 0])\n",
1102
- "style_ax(ax_rew, 'GRPO Training: Reward Signal per Step')\n",
1103
- "if not training_metrics_df.empty and (\"reward\" in training_metrics_df.columns or \"rewards / reward_func / mean\" in training_metrics_df.columns):\n",
1104
- " reward_col = \"reward\" if \"reward\" in training_metrics_df.columns else \"rewards / reward_func / mean\"\n",
1105
- " std_col = \"reward_std\" if \"reward_std\" in training_metrics_df.columns else \"rewards / reward_func / std\"\n",
1106
- " reward_df = training_metrics_df[[\"step\", reward_col] + ([std_col] if std_col in training_metrics_df.columns else [])].dropna(subset=[reward_col])\n",
1107
- " steps_r = reward_df[\"step\"].astype(float).tolist()\n",
1108
- " raw = reward_df[reward_col].astype(float).tolist()\n",
1109
- " ax_rew.plot(steps_r, raw, alpha=0.28, color=C['reward'], linewidth=1.2, marker='o', markersize=2, label='Logged reward')\n",
1110
- " ax_rew.plot(steps_r, smooth(raw, window=6), color=C['reward'], linewidth=2.5, label='Smoothed reward')\n",
1111
- " if std_col in reward_df.columns:\n",
1112
- " std = reward_df[std_col].fillna(0).astype(float).to_numpy()\n",
1113
- " raw_np = np.array(raw)\n",
1114
- " steps_np = np.array(steps_r)\n",
1115
- " ax_rew.fill_between(steps_np, raw_np - std, raw_np + std, color=C['reward'], alpha=0.12, label='Reward std')\n",
1116
- " if len(steps_r) > 8:\n",
1117
- " z = np.polyfit(steps_r, raw, 1)\n",
1118
- " p = np.poly1d(z)\n",
1119
- " ax_rew.plot(steps_r, p(steps_r), '--', color='white', alpha=0.35, linewidth=1.5,\n",
1120
- " label=f'Trend ({z[0]:+.5f}/step)')\n",
1121
- " ax_rew.set_xlabel('Training Step')\n",
1122
- " ax_rew.set_ylabel('Reward Value')\n",
1123
- " ax_rew.legend(fontsize=9, facecolor=C['grid'], labelcolor=C['text'], framealpha=0.9, edgecolor=C['border'])\n",
1124
- " if np.var(raw) < 0.01:\n",
1125
- " ax_rew.text(0.5, 0.5, 'Low reward variance detected.\\nThis graph exposes weak learning signal.',\n",
1126
- " transform=ax_rew.transAxes, ha='center', va='center', color=C['random'], fontsize=10,\n",
1127
- " bbox=dict(boxstyle='round', facecolor=C['panel'], alpha=0.8))\n",
1128
- "elif training_rewards and len(training_rewards) >= 4:\n",
1129
- " raw = training_rewards\n",
1130
- " steps_r = list(range(1, len(raw) + 1))\n",
1131
- " ax_rew.plot(steps_r, raw, alpha=0.20, color=C['reward'], linewidth=1)\n",
1132
- " ax_rew.plot(steps_r, smooth(raw, window=6), color=C['reward'], linewidth=2.5, label='Smoothed reward')\n",
1133
- " ax_rew.set_xlabel('Reward Function Call')\n",
1134
- " ax_rew.set_ylabel('Reward Value')\n",
1135
- " ax_rew.legend(fontsize=9, facecolor=C['grid'], labelcolor=C['text'], framealpha=0.9, edgecolor=C['border'])\n",
1136
- "else:\n",
1137
- " ax_rew.text(0.5, 0.5, 'No training rewards captured.\\nRe-run with fixed reward function.',\n",
1138
- " transform=ax_rew.transAxes, ha='center', va='center', color=C['subtext'], fontsize=11)\n",
1139
- "\n",
1140
- "# Panel C: training loss, with reward variance fallback.\n",
1141
- "ax_loss = fig.add_subplot(gs[1, 1])\n",
1142
- "style_ax(ax_loss, 'GRPO Training Loss per Step')\n",
1143
- "if step_losses and len(step_losses) >= 2:\n",
1144
- " ax_loss.plot(step_numbers, step_losses, alpha=0.25, color=C['loss'], linewidth=1)\n",
1145
- " ax_loss.plot(step_numbers, smooth(step_losses, window=4), color=C['loss'], linewidth=2.5, label='Smoothed loss')\n",
1146
- " non_zero = [l for l in step_losses if abs(l) > 1e-7]\n",
1147
- " pct_nz = len(non_zero) / len(step_losses) * 100\n",
1148
- " note_color = C['trained'] if pct_nz > 50 else C['random']\n",
1149
- " ax_loss.text(0.04, 0.96, f'Non-zero steps: {len(non_zero)}/{len(step_losses)} ({pct_nz:.0f}%)',\n",
1150
- " transform=ax_loss.transAxes, va='top', color=note_color, fontsize=9,\n",
1151
- " bbox=dict(boxstyle='round', facecolor=C['panel'], alpha=0.8))\n",
1152
- " ax_loss.set_xlabel('Training Step')\n",
1153
- " ax_loss.set_ylabel('Loss')\n",
1154
- " ax_loss.legend(fontsize=9, facecolor=C['grid'], labelcolor=C['text'], framealpha=0.9, edgecolor=C['border'])\n",
1155
- "else:\n",
1156
- " proxy_loss = []\n",
1157
- " for i in range(0, len(training_rewards), 4):\n",
1158
- " chunk = training_rewards[i:i+4]\n",
1159
- " if len(chunk) > 1:\n",
1160
- " proxy_loss.append(float(np.var(chunk)))\n",
1161
- " if proxy_loss:\n",
1162
- " ax_loss.plot(range(1, len(proxy_loss) + 1), proxy_loss, color=C['loss'], linewidth=2,\n",
1163
- " label='Reward variance proxy')\n",
1164
- " ax_loss.set_xlabel('Training Batch')\n",
1165
- " ax_loss.set_ylabel('Reward Variance')\n",
1166
- " ax_loss.legend(fontsize=9, facecolor=C['grid'], labelcolor=C['text'], framealpha=0.9, edgecolor=C['border'])\n",
1167
- " ax_loss.text(0.5, 0.92, 'Loss not captured - showing reward variance proxy',\n",
1168
- " transform=ax_loss.transAxes, ha='center', color=C['subtext'], fontsize=8)\n",
1169
- " else:\n",
1170
- " ax_loss.text(0.5, 0.5, 'No loss data available.', transform=ax_loss.transAxes,\n",
1171
- " ha='center', va='center', color=C['subtext'], fontsize=11)\n",
1172
- "\n",
1173
- "fig.suptitle(\n",
1174
- " 'GridMind-RL - Meta OpenEnv Hackathon - Multi-Agent Industrial Energy Management\\n'\n",
1175
- " f'Model: Qwen2.5-1.5B + QLoRA + GRPO | Overall improvement vs heuristic: {overall_improvement:+.1f}%',\n",
1176
- " color=C['text'], fontsize=14, fontweight='bold', y=0.97\n",
1177
- ")\n",
1178
- "\n",
1179
- "dashboard_path = 'results/gridmind_training_dashboard.png'\n",
1180
- "fig.savefig(dashboard_path, dpi=100, facecolor=fig.get_facecolor(), bbox_inches='tight')\n",
1181
- "plt.close(fig)\n",
1182
- "\n",
1183
- "# Standalone training reward curve for reports/slides.\n",
1184
  "reward_curve_path = 'results/gridmind_training_reward_curve.png'\n",
1185
- "fig_reward, ax_reward = plt.subplots(figsize=(11, 6))\n",
1186
- "fig_reward.patch.set_facecolor(C['bg'])\n",
1187
- "style_ax(ax_reward, 'Training Reward Curve')\n",
1188
  "if not training_metrics_df.empty and (\"reward\" in training_metrics_df.columns or \"rewards / reward_func / mean\" in training_metrics_df.columns):\n",
1189
  " reward_col = \"reward\" if \"reward\" in training_metrics_df.columns else \"rewards / reward_func / mean\"\n",
1190
  " std_col = \"reward_std\" if \"reward_std\" in training_metrics_df.columns else \"rewards / reward_func / std\"\n",
1191
  " reward_df = training_metrics_df[[\"step\", reward_col] + ([std_col] if std_col in training_metrics_df.columns else [])].dropna(subset=[reward_col])\n",
1192
  " xs = reward_df[\"step\"].astype(float).to_numpy()\n",
1193
  " ys = reward_df[reward_col].astype(float).to_numpy()\n",
1194
- " ax_reward.plot(xs, ys, color=C['reward'], alpha=0.35, linewidth=1.2, marker='o', markersize=2, label='Reward')\n",
1195
- " ax_reward.plot(xs, smooth(ys.tolist(), window=6), color=C['trained'], linewidth=2.5, label='Smoothed reward')\n",
 
 
 
1196
  " if std_col in reward_df.columns:\n",
1197
  " std = reward_df[std_col].fillna(0).astype(float).to_numpy()\n",
1198
- " ax_reward.fill_between(xs, ys - std, ys + std, color=C['reward'], alpha=0.12, label='Reward std')\n",
1199
- " if len(xs) > 8:\n",
1200
- " z = np.polyfit(xs, ys, 1)\n",
1201
- " p = np.poly1d(z)\n",
1202
- " ax_reward.plot(xs, p(xs), '--', color=C['text'], alpha=0.45, linewidth=1.5, label=f'Trend ({z[0]:+.5f}/step)')\n",
1203
- " ax_reward.set_xlabel('Training Step')\n",
1204
- " ax_reward.set_ylabel('Reward')\n",
1205
- " ax_reward.legend(facecolor=C['grid'], labelcolor=C['text'], edgecolor=C['border'])\n",
1206
  "else:\n",
1207
- " ax_reward.text(0.5, 0.5, 'No logged reward data available.', transform=ax_reward.transAxes,\n",
1208
- " ha='center', va='center', color=C['subtext'])\n",
1209
- "fig_reward.savefig(reward_curve_path, dpi=100, facecolor=fig_reward.get_facecolor(), bbox_inches='tight')\n",
 
 
 
 
 
1210
  "plt.close(fig_reward)\n",
1211
  "\n",
1212
  "# Reference-style simple plots from trainer.state.log_history.\n",
@@ -1267,26 +1189,23 @@
1267
  " print(\"No loss logs found; skipped plots/loss_curve.png\")\n",
1268
  "\n",
1269
  "# Separate before/after comparison graph for quick judge inspection.\n",
1270
- "fig2, ax2 = plt.subplots(figsize=(11, 6))\n",
1271
- "fig2.patch.set_facecolor(C['bg'])\n",
1272
- "ax2.set_facecolor(C['panel'])\n",
1273
- "ax2.bar(x - w/2, heuristic_vals, w, label='Heuristic Baseline', color=C['heuristic'], alpha=0.9)\n",
1274
- "ax2.bar(x + w/2, trained_vals, w, label='Trained LLM (GRPO)', color=C['trained'], alpha=0.9)\n",
1275
  "ax2.set_xticks(x)\n",
1276
- "ax2.set_xticklabels(task_labels, color=C['text'])\n",
1277
  "ax2.set_ylim(0, 1.05)\n",
1278
- "ax2.set_ylabel('Grade Score', color=C['subtext'])\n",
1279
- "ax2.set_title('Before/After Policy Score Comparison', color=C['text'], fontweight='bold')\n",
1280
- "ax2.legend(facecolor=C['grid'], labelcolor=C['text'], edgecolor=C['border'])\n",
1281
- "ax2.grid(axis='y', alpha=0.15, color=C['grid'])\n",
1282
- "ax2.tick_params(colors=C['subtext'])\n",
1283
- "for spine in ax2.spines.values():\n",
1284
- " spine.set_edgecolor(C['border'])\n",
1285
  "comparison_path = 'results/gridmind_before_after_comparison.png'\n",
1286
- "fig2.savefig(comparison_path, dpi=100, facecolor=fig2.get_facecolor(), bbox_inches='tight')\n",
1287
  "plt.close(fig2)\n",
1288
  "\n",
1289
- "print(f\"Saved dashboard graph to {dashboard_path}\")\n",
1290
  "print(f\"Saved training reward curve to {reward_curve_path}\")\n",
1291
  "print(f\"Saved simple reward curve to {simple_reward_curve_path}\")\n",
1292
  "if simple_loss_curve_path:\n",
@@ -1312,7 +1231,7 @@
1312
  " \"training_metrics_table\": training_metrics_path,\n",
1313
  " \"training_metrics_display_table\": training_metrics_display_path if 'training_metrics_display_path' in globals() else None,\n",
1314
  " \"graphs\": {\n",
1315
- " \"dashboard\": dashboard_path,\n",
1316
  " \"training_reward_curve\": reward_curve_path,\n",
1317
  " \"simple_reward_curve\": simple_reward_curve_path,\n",
1318
  " \"simple_loss_curve\": simple_loss_curve_path,\n",
 
695
  " \"logging_steps\": 1,\n",
696
  " \"save_steps\": 60,\n",
697
  " \"report_to\": \"none\",\n",
698
+ " \"disable_tqdm\": True,\n",
699
  " \"dataloader_num_workers\": 0,\n",
700
  " \"remove_unused_columns\": False,\n",
701
  "}\n",
 
733
  "step_numbers = []\n",
734
  "step_reward_means = []\n",
735
  "training_log_history = []\n",
736
+ "training_table_rows = []\n",
737
+ "_training_table_header_printed = [False]\n",
738
+ "\n",
739
+ "TRAINING_TABLE_COLUMNS = [\n",
740
+ " (\"Step\", \"step\"),\n",
741
+ " (\"Training Loss\", \"loss\"),\n",
742
+ " (\"reward\", \"reward\"),\n",
743
+ " (\"reward_std\", \"reward_std\"),\n",
744
+ " (\"completions / mean_length\", \"completions / mean_length\"),\n",
745
+ " (\"completions / min_length\", \"completions / min_length\"),\n",
746
+ " (\"completions / max_length\", \"completions / max_length\"),\n",
747
+ " (\"completions / clipped_ratio\", \"completions / clipped_ratio\"),\n",
748
+ " (\"completions / mean_terminated_length\", \"completions / mean_terminated_length\"),\n",
749
+ " (\"completions / min_terminated_length\", \"completions / min_terminated_length\"),\n",
750
+ " (\"completions / max_terminated_length\", \"completions / max_terminated_length\"),\n",
751
+ " (\"tools / call_frequency\", \"tools / call_frequency\"),\n",
752
+ " (\"tools / failure_frequency\", \"tools / failure_frequency\"),\n",
753
+ " (\"kl\", \"kl\"),\n",
754
+ " (\"rewards / reward_func / mean\", \"rewards / reward_func / mean\"),\n",
755
+ " (\"rewards / reward_func / std\", \"rewards / reward_func / std\"),\n",
756
+ "]\n",
757
+ "\n",
758
+ "def _metric_value(logs, *keys, default=float(\"nan\")):\n",
759
+ " for key in keys:\n",
760
+ " if key in logs and logs[key] is not None:\n",
761
+ " return logs[key]\n",
762
+ " return default\n",
763
+ "\n",
764
+ "def _fmt_metric(value):\n",
765
+ " try:\n",
766
+ " if value is None or (isinstance(value, float) and value != value):\n",
767
+ " return \"\"\n",
768
+ " if isinstance(value, int):\n",
769
+ " return str(value)\n",
770
+ " return f\"{float(value):.6f}\"\n",
771
+ " except Exception:\n",
772
+ " return str(value)\n",
773
+ "\n",
774
+ "def _print_training_table_row(row):\n",
775
+ " widths = [6, 14, 10, 10, 26, 25, 25, 29, 38, 37, 37, 24, 27, 10, 28, 27]\n",
776
+ " if not _training_table_header_printed[0]:\n",
777
+ " header = \" \".join(label.ljust(widths[i]) for i, (label, _) in enumerate(TRAINING_TABLE_COLUMNS))\n",
778
+ " print(\"\\n\" + header)\n",
779
+ " print(\"-\" * len(header))\n",
780
+ " _training_table_header_printed[0] = True\n",
781
+ " values = [_fmt_metric(row.get(source, float(\"nan\"))).ljust(widths[i]) for i, (_, source) in enumerate(TRAINING_TABLE_COLUMNS)]\n",
782
+ " print(\" \".join(values))\n",
783
  "\n",
784
  "class LossCaptureCallback(TrainerCallback):\n",
785
  " def on_log(self, args, state, control, logs=None, **kwargs):\n",
 
788
  " step = state.global_step\n",
789
  " row = {\"step\": step}\n",
790
  " row.update({k: float(v) if isinstance(v, (int, float)) else v for k, v in logs.items()})\n",
791
+ " if \"loss\" not in row and \"train_loss\" in row:\n",
792
+ " row[\"loss\"] = row[\"train_loss\"]\n",
793
+ " recent_rewards = training_rewards[max(0, len(training_rewards)-4):]\n",
794
+ " if recent_rewards:\n",
795
+ " if \"reward\" not in row and \"rewards / reward_func / mean\" not in row:\n",
796
+ " row[\"reward\"] = sum(recent_rewards) / len(recent_rewards)\n",
797
+ " if \"reward_std\" not in row and \"rewards / reward_func / std\" not in row and len(recent_rewards) > 1:\n",
798
+ " row[\"reward_std\"] = statistics.pstdev(recent_rewards)\n",
799
+ " if \"rewards / reward_func / mean\" not in row and \"reward\" in row:\n",
800
+ " row[\"rewards / reward_func / mean\"] = row[\"reward\"]\n",
801
+ " if \"rewards / reward_func / std\" not in row and \"reward_std\" in row:\n",
802
+ " row[\"rewards / reward_func / std\"] = row[\"reward_std\"]\n",
803
+ " if \"tools / call_frequency\" not in row:\n",
804
+ " row[\"tools / call_frequency\"] = float(\"nan\")\n",
805
+ " if \"tools / failure_frequency\" not in row:\n",
806
+ " row[\"tools / failure_frequency\"] = 0.0\n",
807
  " training_log_history.append(row)\n",
808
+ " training_table_rows.append(row)\n",
809
+ " _print_training_table_row(row)\n",
810
  " loss = logs.get(\"loss\", logs.get(\"train_loss\", None))\n",
811
  " if loss is not None:\n",
812
  " step_losses.append(float(loss))\n",
 
833
  " callbacks=[LossCaptureCallback()],\n",
834
  ")\n",
835
  "\n",
836
+ "# Remove the default Trainer progress/notebook callbacks so only the custom\n",
837
+ "# TRL-style table appears during training.\n",
838
+ "from transformers.trainer_callback import ProgressCallback, PrinterCallback\n",
839
+ "trainer.remove_callback(ProgressCallback)\n",
840
+ "trainer.remove_callback(PrinterCallback)\n",
841
+ "try:\n",
842
+ " from transformers.utils.notebook import NotebookProgressCallback\n",
843
+ " trainer.remove_callback(NotebookProgressCallback)\n",
844
+ "except Exception:\n",
845
+ " pass\n",
846
+ "\n",
847
  "print(\"\\nStarting GRPO training with QLoRA...\")\n",
848
  "print(\"Watch for non-zero loss values. If all zeros, reward variance is still too low.\\n\")\n",
849
  "print(f\"Steps: {getattr(grpo_config, 'max_steps', 60)} | Batch: {getattr(grpo_config, 'per_device_train_batch_size', 1)} | Generations: {getattr(grpo_config, 'num_generations', 4)}\")\n",
 
869
  " import pandas as pd\n",
870
  " import numpy as np\n",
871
  " trainer_log_rows = [r for r in trainer.state.log_history if \"loss\" in r or \"reward\" in r or \"rewards / reward_func / mean\" in r]\n",
872
+ " if training_table_rows:\n",
873
+ " training_metrics_df = pd.DataFrame(training_table_rows)\n",
874
+ " elif trainer_log_rows:\n",
875
  " training_metrics_df = pd.DataFrame(trainer_log_rows)\n",
876
  " if \"step\" not in training_metrics_df.columns:\n",
877
  " training_metrics_df.insert(0, \"step\", range(1, len(training_metrics_df) + 1))\n",
 
1051
  "import matplotlib\n",
1052
  "matplotlib.use('Agg')\n",
1053
  "import matplotlib.pyplot as plt\n",
 
1054
  "import numpy as np\n",
1055
  "import pandas as pd\n",
1056
  "import os\n",
 
1058
  "os.makedirs(\"results\", exist_ok=True)\n",
1059
  "os.makedirs(\"plots\", exist_ok=True)\n",
1060
  "\n",
1061
+ "# Reuse the Step 6 metrics table and only do lightweight exports here.\n",
 
1062
  "if 'training_metrics_df' not in globals() or training_metrics_df is None:\n",
1063
  " trainer_log_rows = [r for r in trainer.state.log_history if \"loss\" in r or \"reward\" in r or \"rewards / reward_func / mean\" in r]\n",
1064
+ " if 'training_table_rows' in globals() and training_table_rows:\n",
1065
+ " training_metrics_df = pd.DataFrame(training_table_rows)\n",
1066
+ " else:\n",
1067
+ " training_metrics_df = pd.DataFrame(trainer_log_rows if trainer_log_rows else training_log_history)\n",
1068
  " if not training_metrics_df.empty and \"step\" not in training_metrics_df.columns:\n",
1069
  " training_metrics_df.insert(0, \"step\", range(1, len(training_metrics_df) + 1))\n",
1070
  "\n",
 
1072
  "if not training_metrics_df.empty:\n",
1073
  " training_metrics_df.to_csv(training_metrics_path, index=False)\n",
1074
  " print(f\"Saved TRL metrics table to {training_metrics_path}\")\n",
1075
+ " print(\"Step 8 reuses the Step 6 table and only saves files.\")\n",
 
 
 
 
 
 
 
 
 
 
1076
  "\n",
1077
  "tasks = [1, 2, 3, 4]\n",
1078
  "task_labels = [\n",
 
1104
  " out.append(sum(w) / len(w))\n",
1105
  " return out\n",
1106
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1107
  "reward_curve_path = 'results/gridmind_training_reward_curve.png'\n",
1108
+ "fig_reward, ax_reward = plt.subplots(figsize=(10, 5))\n",
 
 
1109
  "if not training_metrics_df.empty and (\"reward\" in training_metrics_df.columns or \"rewards / reward_func / mean\" in training_metrics_df.columns):\n",
1110
  " reward_col = \"reward\" if \"reward\" in training_metrics_df.columns else \"rewards / reward_func / mean\"\n",
1111
  " std_col = \"reward_std\" if \"reward_std\" in training_metrics_df.columns else \"rewards / reward_func / std\"\n",
1112
  " reward_df = training_metrics_df[[\"step\", reward_col] + ([std_col] if std_col in training_metrics_df.columns else [])].dropna(subset=[reward_col])\n",
1113
  " xs = reward_df[\"step\"].astype(float).to_numpy()\n",
1114
  " ys = reward_df[reward_col].astype(float).to_numpy()\n",
1115
+ " ax_reward.plot(xs, ys, color=\"#4285f4\", linewidth=2, label=\"GRPO Reward\")\n",
1116
+ " if len(ys) > 5:\n",
1117
+ " window = max(3, len(ys) // 10)\n",
1118
+ " smoothed = [sum(ys[max(0, i-window):i+1]) / len(ys[max(0, i-window):i+1]) for i in range(len(ys))]\n",
1119
+ " ax_reward.plot(xs[:len(smoothed)], smoothed, color=\"#ea4335\", linewidth=2, linestyle=\"--\", label=f\"Smoothed (window={window})\")\n",
1120
  " if std_col in reward_df.columns:\n",
1121
  " std = reward_df[std_col].fillna(0).astype(float).to_numpy()\n",
1122
+ " ax_reward.fill_between(xs, ys - std, ys + std, color=\"#4285f4\", alpha=0.12)\n",
 
 
 
 
 
 
 
1123
  "else:\n",
1124
+ " ax_reward.text(0.5, 0.5, 'No logged reward data available.', transform=ax_reward.transAxes, ha='center', va='center')\n",
1125
+ "ax_reward.set_xlabel('Training Step', fontsize=12)\n",
1126
+ "ax_reward.set_ylabel('Reward', fontsize=12)\n",
1127
+ "ax_reward.set_title('GridMind-RL GRPO Training - Reward Curve', fontsize=14, fontweight='bold')\n",
1128
+ "ax_reward.legend()\n",
1129
+ "ax_reward.grid(True, alpha=0.3)\n",
1130
+ "fig_reward.tight_layout()\n",
1131
+ "fig_reward.savefig(reward_curve_path, dpi=100)\n",
1132
  "plt.close(fig_reward)\n",
1133
  "\n",
1134
  "# Reference-style simple plots from trainer.state.log_history.\n",
 
1189
  " print(\"No loss logs found; skipped plots/loss_curve.png\")\n",
1190
  "\n",
1191
  "# Separate before/after comparison graph for quick judge inspection.\n",
1192
+ "fig2, ax2 = plt.subplots(figsize=(10, 5))\n",
1193
+ "x = np.arange(len(tasks))\n",
1194
+ "w = 0.35\n",
1195
+ "ax2.bar(x - w/2, heuristic_vals, w, label='Heuristic Baseline', color=\"#58a6ff\", alpha=0.9)\n",
1196
+ "ax2.bar(x + w/2, trained_vals, w, label='Trained LLM (GRPO)', color=\"#3fb950\", alpha=0.9)\n",
1197
  "ax2.set_xticks(x)\n",
1198
+ "ax2.set_xticklabels(task_labels)\n",
1199
  "ax2.set_ylim(0, 1.05)\n",
1200
+ "ax2.set_ylabel('Grade Score')\n",
1201
+ "ax2.set_title('Before/After Policy Score Comparison', fontweight='bold')\n",
1202
+ "ax2.legend()\n",
1203
+ "ax2.grid(axis='y', alpha=0.3)\n",
1204
+ "fig2.tight_layout()\n",
 
 
1205
  "comparison_path = 'results/gridmind_before_after_comparison.png'\n",
1206
+ "fig2.savefig(comparison_path, dpi=100)\n",
1207
  "plt.close(fig2)\n",
1208
  "\n",
 
1209
  "print(f\"Saved training reward curve to {reward_curve_path}\")\n",
1210
  "print(f\"Saved simple reward curve to {simple_reward_curve_path}\")\n",
1211
  "if simple_loss_curve_path:\n",
 
1231
  " \"training_metrics_table\": training_metrics_path,\n",
1232
  " \"training_metrics_display_table\": training_metrics_display_path if 'training_metrics_display_path' in globals() else None,\n",
1233
  " \"graphs\": {\n",
1234
+ " \"dashboard\": None,\n",
1235
  " \"training_reward_curve\": reward_curve_path,\n",
1236
  " \"simple_reward_curve\": simple_reward_curve_path,\n",
1237
  " \"simple_loss_curve\": simple_loss_curve_path,\n",