Prajwal782007 commited on
Commit
505323f
·
1 Parent(s): 29b9cd0

feat: add GRPO training notebook for GridMind-RL environment

Browse files
Files changed (1) hide show
  1. scripts/gridmind_grpo_colab.ipynb +54 -32
scripts/gridmind_grpo_colab.ipynb CHANGED
@@ -1007,30 +1007,39 @@
1007
  " except:\n",
1008
  " return 0.0\n",
1009
  "\n",
1010
- "print(\"Evaluating trained model (2 episodes per task)...\")\n",
1011
- "trained_scores = {}\n",
1012
- "for task_id in [1, 2, 3, 4]:\n",
1013
- " scores = []\n",
1014
- " for ep in range(2):\n",
1015
- " score = run_llm_episode(task_id=task_id)\n",
1016
- " scores.append(score)\n",
1017
- " print(f\" Task {task_id} Episode {ep+1}: {score:.3f}\")\n",
1018
- " trained_scores[task_id] = sum(scores) / len(scores)\n",
1019
- "\n",
1020
- "print(f\"\\nTrained Model Scores:\")\n",
1021
- "for task_id, avg in trained_scores.items():\n",
1022
- " baseline = baseline_scores[task_id]\n",
1023
- " improvement = ((avg - baseline) / baseline * 100) if baseline > 0 else 0\n",
1024
- " print(f\" Task {task_id}: {avg:.3f} (baseline: {baseline:.3f}, {improvement:+.1f}%)\")\n",
1025
  "\n",
1026
- "trained_avg = sum(trained_scores.values()) / len(trained_scores)\n",
1027
  "baseline_avg = sum(baseline_scores.values()) / len(baseline_scores)\n",
1028
- "overall_improvement = ((trained_avg - baseline_avg) / baseline_avg * 100) if baseline_avg > 0 else 0\n",
1029
- "\n",
1030
- "print(f\"\\nOverall Scores:\")\n",
1031
- "print(f\" Heuristic baseline: {baseline_avg:.3f}\")\n",
1032
- "print(f\" Trained LLM: {trained_avg:.3f}\")\n",
1033
- "print(f\" Improvement: {overall_improvement:+.1f}%\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1034
  ]
1035
  },
1036
  {
@@ -1084,16 +1093,17 @@
1084
  "\n",
1085
  "random_by_task = {1: 0.35, 2: 0.28, 3: 0.21, 4: 0.25}\n",
1086
  "heuristic_by_task = baseline_scores\n",
1087
- "trained_by_task = trained_scores\n",
1088
  "\n",
1089
  "random_vals = [random_by_task.get(t, 0.3) for t in tasks]\n",
1090
  "heuristic_vals = [heuristic_by_task.get(t, 0.5) for t in tasks]\n",
1091
- "trained_vals = [trained_by_task.get(t, 0.5) for t in tasks]\n",
1092
  "\n",
1093
  "baseline_avg = sum(heuristic_vals) / len(heuristic_vals)\n",
1094
- "trained_avg = sum(trained_vals) / len(trained_vals)\n",
 
1095
  "random_avg = sum(random_vals) / len(random_vals)\n",
1096
- "overall_improvement = ((trained_avg - baseline_avg) / baseline_avg * 100) if baseline_avg > 0 else 0\n",
1097
  "\n",
1098
  "def smooth(values, window=5):\n",
1099
  " if not values or len(values) < 2:\n",
@@ -1193,7 +1203,9 @@
1193
  "x = np.arange(len(tasks))\n",
1194
  "w = 0.35\n",
1195
  "ax2.bar(x - w/2, heuristic_vals, w, label='Heuristic Baseline', color=\"#58a6ff\", alpha=0.9)\n",
1196
- "ax2.bar(x + w/2, trained_vals, w, label='Trained LLM (GRPO)', color=\"#3fb950\", alpha=0.9)\n",
 
 
1197
  "ax2.set_xticks(x)\n",
1198
  "ax2.set_xticklabels(task_labels)\n",
1199
  "ax2.set_ylim(0, 1.05)\n",
@@ -1203,14 +1215,20 @@
1203
  "ax2.grid(axis='y', alpha=0.3)\n",
1204
  "fig2.tight_layout()\n",
1205
  "comparison_path = 'results/gridmind_before_after_comparison.png'\n",
1206
- "fig2.savefig(comparison_path, dpi=100)\n",
 
 
 
1207
  "plt.close(fig2)\n",
1208
  "\n",
1209
  "print(f\"Saved training reward curve to {reward_curve_path}\")\n",
1210
  "print(f\"Saved simple reward curve to {simple_reward_curve_path}\")\n",
1211
  "if simple_loss_curve_path:\n",
1212
  " print(f\"Saved simple loss curve to {simple_loss_curve_path}\")\n",
1213
- "print(f\"Saved before/after graph to {comparison_path}\")\n",
 
 
 
1214
  "\n",
1215
  "results = {\n",
1216
  " \"heuristic_baseline\": {\n",
@@ -1218,7 +1236,7 @@
1218
  " \"average\": baseline_avg\n",
1219
  " },\n",
1220
  " \"trained_llm\": {\n",
1221
- " \"scores_by_task\": {str(k): v for k, v in trained_scores.items()},\n",
1222
  " \"average\": trained_avg\n",
1223
  " },\n",
1224
  " \"improvement_percent\": overall_improvement,\n",
@@ -1248,8 +1266,12 @@
1248
  "print(f\" Model: {MODEL_NAME}\")\n",
1249
  "print(f\" Themes: {results['themes_covered']}\")\n",
1250
  "print(f\" Heuristic baseline: {baseline_avg:.3f}\")\n",
1251
- "print(f\" Trained LLM: {trained_avg:.3f}\")\n",
1252
- "print(f\" Improvement: {overall_improvement:+.1f}%\")"
 
 
 
 
1253
  ]
1254
  }
1255
  ],
 
1007
  " except:\n",
1008
  " return 0.0\n",
1009
  "\n",
1010
+ "RUN_EVALUATION = False\n",
1011
+ "EVAL_EPISODES_PER_TASK = 1\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
1012
  "\n",
1013
+ "trained_scores = {}\n",
1014
  "baseline_avg = sum(baseline_scores.values()) / len(baseline_scores)\n",
1015
+ "trained_avg = None\n",
1016
+ "overall_improvement = None\n",
1017
+ "\n",
1018
+ "if RUN_EVALUATION:\n",
1019
+ " print(f\"Evaluating trained model ({EVAL_EPISODES_PER_TASK} episode(s) per task)...\")\n",
1020
+ " for task_id in [1, 2, 3, 4]:\n",
1021
+ " scores = []\n",
1022
+ " for ep in range(EVAL_EPISODES_PER_TASK):\n",
1023
+ " score = run_llm_episode(task_id=task_id)\n",
1024
+ " scores.append(score)\n",
1025
+ " print(f\" Task {task_id} Episode {ep+1}: {score:.3f}\")\n",
1026
+ " trained_scores[task_id] = sum(scores) / len(scores)\n",
1027
+ "\n",
1028
+ " print(f\"\\nTrained Model Scores:\")\n",
1029
+ " for task_id, avg in trained_scores.items():\n",
1030
+ " baseline = baseline_scores[task_id]\n",
1031
+ " improvement = ((avg - baseline) / baseline * 100) if baseline > 0 else 0\n",
1032
+ " print(f\" Task {task_id}: {avg:.3f} (baseline: {baseline:.3f}, {improvement:+.1f}%)\")\n",
1033
+ "\n",
1034
+ " trained_avg = sum(trained_scores.values()) / len(trained_scores)\n",
1035
+ " overall_improvement = ((trained_avg - baseline_avg) / baseline_avg * 100) if baseline_avg > 0 else 0\n",
1036
+ "\n",
1037
+ " print(f\"\\nOverall Scores:\")\n",
1038
+ " print(f\" Heuristic baseline: {baseline_avg:.3f}\")\n",
1039
+ " print(f\" Trained LLM: {trained_avg:.3f}\")\n",
1040
+ " print(f\" Improvement: {overall_improvement:+.1f}%\")\n",
1041
+ "else:\n",
1042
+ " print(\"Skipping trained-model evaluation. Set RUN_EVALUATION = True to generate trained_scores and improvement metrics.\")"
1043
  ]
1044
  },
1045
  {
 
1093
  "\n",
1094
  "random_by_task = {1: 0.35, 2: 0.28, 3: 0.21, 4: 0.25}\n",
1095
  "heuristic_by_task = baseline_scores\n",
1096
+ "trained_by_task = trained_scores if trained_scores else {}\n",
1097
  "\n",
1098
  "random_vals = [random_by_task.get(t, 0.3) for t in tasks]\n",
1099
  "heuristic_vals = [heuristic_by_task.get(t, 0.5) for t in tasks]\n",
1100
+ "trained_vals = [trained_by_task.get(t, np.nan) for t in tasks]\n",
1101
  "\n",
1102
  "baseline_avg = sum(heuristic_vals) / len(heuristic_vals)\n",
1103
+ "valid_trained_vals = [v for v in trained_vals if not np.isnan(v)]\n",
1104
+ "trained_avg = (sum(valid_trained_vals) / len(valid_trained_vals)) if valid_trained_vals else None\n",
1105
  "random_avg = sum(random_vals) / len(random_vals)\n",
1106
+ "overall_improvement = ((trained_avg - baseline_avg) / baseline_avg * 100) if (trained_avg is not None and baseline_avg > 0) else None\n",
1107
  "\n",
1108
  "def smooth(values, window=5):\n",
1109
  " if not values or len(values) < 2:\n",
 
1203
  "x = np.arange(len(tasks))\n",
1204
  "w = 0.35\n",
1205
  "ax2.bar(x - w/2, heuristic_vals, w, label='Heuristic Baseline', color=\"#58a6ff\", alpha=0.9)\n",
1206
+ "if valid_trained_vals:\n",
1207
+ " trained_plot_vals = [0.0 if np.isnan(v) else v for v in trained_vals]\n",
1208
+ " ax2.bar(x + w/2, trained_plot_vals, w, label='Trained LLM (GRPO)', color=\"#3fb950\", alpha=0.9)\n",
1209
  "ax2.set_xticks(x)\n",
1210
  "ax2.set_xticklabels(task_labels)\n",
1211
  "ax2.set_ylim(0, 1.05)\n",
 
1215
  "ax2.grid(axis='y', alpha=0.3)\n",
1216
  "fig2.tight_layout()\n",
1217
  "comparison_path = 'results/gridmind_before_after_comparison.png'\n",
1218
+ "if valid_trained_vals:\n",
1219
+ " fig2.savefig(comparison_path, dpi=100)\n",
1220
+ "else:\n",
1221
+ " comparison_path = None\n",
1222
  "plt.close(fig2)\n",
1223
  "\n",
1224
  "print(f\"Saved training reward curve to {reward_curve_path}\")\n",
1225
  "print(f\"Saved simple reward curve to {simple_reward_curve_path}\")\n",
1226
  "if simple_loss_curve_path:\n",
1227
  " print(f\"Saved simple loss curve to {simple_loss_curve_path}\")\n",
1228
+ "if comparison_path:\n",
1229
+ " print(f\"Saved before/after graph to {comparison_path}\")\n",
1230
+ "else:\n",
1231
+ " print(\"Skipped before/after graph because RUN_EVALUATION is False.\")\n",
1232
  "\n",
1233
  "results = {\n",
1234
  " \"heuristic_baseline\": {\n",
 
1236
  " \"average\": baseline_avg\n",
1237
  " },\n",
1238
  " \"trained_llm\": {\n",
1239
+ " \"scores_by_task\": {str(k): v for k, v in trained_scores.items()} if trained_scores else {},\n",
1240
  " \"average\": trained_avg\n",
1241
  " },\n",
1242
  " \"improvement_percent\": overall_improvement,\n",
 
1266
  "print(f\" Model: {MODEL_NAME}\")\n",
1267
  "print(f\" Themes: {results['themes_covered']}\")\n",
1268
  "print(f\" Heuristic baseline: {baseline_avg:.3f}\")\n",
1269
+ "if trained_avg is not None:\n",
1270
+ " print(f\" Trained LLM: {trained_avg:.3f}\")\n",
1271
+ "if overall_improvement is not None:\n",
1272
+ " print(f\" Improvement: {overall_improvement:+.1f}%\")\n",
1273
+ "else:\n",
1274
+ " print(\" Improvement: evaluation skipped\")"
1275
  ]
1276
  }
1277
  ],