Spaces:
Running
Running
Commit ·
505323f
1
Parent(s): 29b9cd0
feat: add GRPO training notebook for GridMind-RL environment
Browse files
scripts/gridmind_grpo_colab.ipynb
CHANGED
|
@@ -1007,30 +1007,39 @@
|
|
| 1007 |
" except:\n",
|
| 1008 |
" return 0.0\n",
|
| 1009 |
"\n",
|
| 1010 |
-
"
|
| 1011 |
-
"
|
| 1012 |
-
"for task_id in [1, 2, 3, 4]:\n",
|
| 1013 |
-
" scores = []\n",
|
| 1014 |
-
" for ep in range(2):\n",
|
| 1015 |
-
" score = run_llm_episode(task_id=task_id)\n",
|
| 1016 |
-
" scores.append(score)\n",
|
| 1017 |
-
" print(f\" Task {task_id} Episode {ep+1}: {score:.3f}\")\n",
|
| 1018 |
-
" trained_scores[task_id] = sum(scores) / len(scores)\n",
|
| 1019 |
-
"\n",
|
| 1020 |
-
"print(f\"\\nTrained Model Scores:\")\n",
|
| 1021 |
-
"for task_id, avg in trained_scores.items():\n",
|
| 1022 |
-
" baseline = baseline_scores[task_id]\n",
|
| 1023 |
-
" improvement = ((avg - baseline) / baseline * 100) if baseline > 0 else 0\n",
|
| 1024 |
-
" print(f\" Task {task_id}: {avg:.3f} (baseline: {baseline:.3f}, {improvement:+.1f}%)\")\n",
|
| 1025 |
"\n",
|
| 1026 |
-
"
|
| 1027 |
"baseline_avg = sum(baseline_scores.values()) / len(baseline_scores)\n",
|
| 1028 |
-
"
|
| 1029 |
-
"\n",
|
| 1030 |
-
"
|
| 1031 |
-
"
|
| 1032 |
-
"print(f\"
|
| 1033 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1034 |
]
|
| 1035 |
},
|
| 1036 |
{
|
|
@@ -1084,16 +1093,17 @@
|
|
| 1084 |
"\n",
|
| 1085 |
"random_by_task = {1: 0.35, 2: 0.28, 3: 0.21, 4: 0.25}\n",
|
| 1086 |
"heuristic_by_task = baseline_scores\n",
|
| 1087 |
-
"trained_by_task = trained_scores\n",
|
| 1088 |
"\n",
|
| 1089 |
"random_vals = [random_by_task.get(t, 0.3) for t in tasks]\n",
|
| 1090 |
"heuristic_vals = [heuristic_by_task.get(t, 0.5) for t in tasks]\n",
|
| 1091 |
-
"trained_vals = [trained_by_task.get(t,
|
| 1092 |
"\n",
|
| 1093 |
"baseline_avg = sum(heuristic_vals) / len(heuristic_vals)\n",
|
| 1094 |
-
"
|
|
|
|
| 1095 |
"random_avg = sum(random_vals) / len(random_vals)\n",
|
| 1096 |
-
"overall_improvement = ((trained_avg - baseline_avg) / baseline_avg * 100) if baseline_avg > 0 else
|
| 1097 |
"\n",
|
| 1098 |
"def smooth(values, window=5):\n",
|
| 1099 |
" if not values or len(values) < 2:\n",
|
|
@@ -1193,7 +1203,9 @@
|
|
| 1193 |
"x = np.arange(len(tasks))\n",
|
| 1194 |
"w = 0.35\n",
|
| 1195 |
"ax2.bar(x - w/2, heuristic_vals, w, label='Heuristic Baseline', color=\"#58a6ff\", alpha=0.9)\n",
|
| 1196 |
-
"
|
|
|
|
|
|
|
| 1197 |
"ax2.set_xticks(x)\n",
|
| 1198 |
"ax2.set_xticklabels(task_labels)\n",
|
| 1199 |
"ax2.set_ylim(0, 1.05)\n",
|
|
@@ -1203,14 +1215,20 @@
|
|
| 1203 |
"ax2.grid(axis='y', alpha=0.3)\n",
|
| 1204 |
"fig2.tight_layout()\n",
|
| 1205 |
"comparison_path = 'results/gridmind_before_after_comparison.png'\n",
|
| 1206 |
-
"
|
|
|
|
|
|
|
|
|
|
| 1207 |
"plt.close(fig2)\n",
|
| 1208 |
"\n",
|
| 1209 |
"print(f\"Saved training reward curve to {reward_curve_path}\")\n",
|
| 1210 |
"print(f\"Saved simple reward curve to {simple_reward_curve_path}\")\n",
|
| 1211 |
"if simple_loss_curve_path:\n",
|
| 1212 |
" print(f\"Saved simple loss curve to {simple_loss_curve_path}\")\n",
|
| 1213 |
-
"
|
|
|
|
|
|
|
|
|
|
| 1214 |
"\n",
|
| 1215 |
"results = {\n",
|
| 1216 |
" \"heuristic_baseline\": {\n",
|
|
@@ -1218,7 +1236,7 @@
|
|
| 1218 |
" \"average\": baseline_avg\n",
|
| 1219 |
" },\n",
|
| 1220 |
" \"trained_llm\": {\n",
|
| 1221 |
-
" \"scores_by_task\": {str(k): v for k, v in trained_scores.items()},\n",
|
| 1222 |
" \"average\": trained_avg\n",
|
| 1223 |
" },\n",
|
| 1224 |
" \"improvement_percent\": overall_improvement,\n",
|
|
@@ -1248,8 +1266,12 @@
|
|
| 1248 |
"print(f\" Model: {MODEL_NAME}\")\n",
|
| 1249 |
"print(f\" Themes: {results['themes_covered']}\")\n",
|
| 1250 |
"print(f\" Heuristic baseline: {baseline_avg:.3f}\")\n",
|
| 1251 |
-
"
|
| 1252 |
-
"print(f\"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1253 |
]
|
| 1254 |
}
|
| 1255 |
],
|
|
|
|
| 1007 |
" except:\n",
|
| 1008 |
" return 0.0\n",
|
| 1009 |
"\n",
|
| 1010 |
+
"RUN_EVALUATION = False\n",
|
| 1011 |
+
"EVAL_EPISODES_PER_TASK = 1\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1012 |
"\n",
|
| 1013 |
+
"trained_scores = {}\n",
|
| 1014 |
"baseline_avg = sum(baseline_scores.values()) / len(baseline_scores)\n",
|
| 1015 |
+
"trained_avg = None\n",
|
| 1016 |
+
"overall_improvement = None\n",
|
| 1017 |
+
"\n",
|
| 1018 |
+
"if RUN_EVALUATION:\n",
|
| 1019 |
+
" print(f\"Evaluating trained model ({EVAL_EPISODES_PER_TASK} episode(s) per task)...\")\n",
|
| 1020 |
+
" for task_id in [1, 2, 3, 4]:\n",
|
| 1021 |
+
" scores = []\n",
|
| 1022 |
+
" for ep in range(EVAL_EPISODES_PER_TASK):\n",
|
| 1023 |
+
" score = run_llm_episode(task_id=task_id)\n",
|
| 1024 |
+
" scores.append(score)\n",
|
| 1025 |
+
" print(f\" Task {task_id} Episode {ep+1}: {score:.3f}\")\n",
|
| 1026 |
+
" trained_scores[task_id] = sum(scores) / len(scores)\n",
|
| 1027 |
+
"\n",
|
| 1028 |
+
" print(f\"\\nTrained Model Scores:\")\n",
|
| 1029 |
+
" for task_id, avg in trained_scores.items():\n",
|
| 1030 |
+
" baseline = baseline_scores[task_id]\n",
|
| 1031 |
+
" improvement = ((avg - baseline) / baseline * 100) if baseline > 0 else 0\n",
|
| 1032 |
+
" print(f\" Task {task_id}: {avg:.3f} (baseline: {baseline:.3f}, {improvement:+.1f}%)\")\n",
|
| 1033 |
+
"\n",
|
| 1034 |
+
" trained_avg = sum(trained_scores.values()) / len(trained_scores)\n",
|
| 1035 |
+
" overall_improvement = ((trained_avg - baseline_avg) / baseline_avg * 100) if baseline_avg > 0 else 0\n",
|
| 1036 |
+
"\n",
|
| 1037 |
+
" print(f\"\\nOverall Scores:\")\n",
|
| 1038 |
+
" print(f\" Heuristic baseline: {baseline_avg:.3f}\")\n",
|
| 1039 |
+
" print(f\" Trained LLM: {trained_avg:.3f}\")\n",
|
| 1040 |
+
" print(f\" Improvement: {overall_improvement:+.1f}%\")\n",
|
| 1041 |
+
"else:\n",
|
| 1042 |
+
" print(\"Skipping trained-model evaluation. Set RUN_EVALUATION = True to generate trained_scores and improvement metrics.\")"
|
| 1043 |
]
|
| 1044 |
},
|
| 1045 |
{
|
|
|
|
| 1093 |
"\n",
|
| 1094 |
"random_by_task = {1: 0.35, 2: 0.28, 3: 0.21, 4: 0.25}\n",
|
| 1095 |
"heuristic_by_task = baseline_scores\n",
|
| 1096 |
+
"trained_by_task = trained_scores if trained_scores else {}\n",
|
| 1097 |
"\n",
|
| 1098 |
"random_vals = [random_by_task.get(t, 0.3) for t in tasks]\n",
|
| 1099 |
"heuristic_vals = [heuristic_by_task.get(t, 0.5) for t in tasks]\n",
|
| 1100 |
+
"trained_vals = [trained_by_task.get(t, np.nan) for t in tasks]\n",
|
| 1101 |
"\n",
|
| 1102 |
"baseline_avg = sum(heuristic_vals) / len(heuristic_vals)\n",
|
| 1103 |
+
"valid_trained_vals = [v for v in trained_vals if not np.isnan(v)]\n",
|
| 1104 |
+
"trained_avg = (sum(valid_trained_vals) / len(valid_trained_vals)) if valid_trained_vals else None\n",
|
| 1105 |
"random_avg = sum(random_vals) / len(random_vals)\n",
|
| 1106 |
+
"overall_improvement = ((trained_avg - baseline_avg) / baseline_avg * 100) if (trained_avg is not None and baseline_avg > 0) else None\n",
|
| 1107 |
"\n",
|
| 1108 |
"def smooth(values, window=5):\n",
|
| 1109 |
" if not values or len(values) < 2:\n",
|
|
|
|
| 1203 |
"x = np.arange(len(tasks))\n",
|
| 1204 |
"w = 0.35\n",
|
| 1205 |
"ax2.bar(x - w/2, heuristic_vals, w, label='Heuristic Baseline', color=\"#58a6ff\", alpha=0.9)\n",
|
| 1206 |
+
"if valid_trained_vals:\n",
|
| 1207 |
+
" trained_plot_vals = [0.0 if np.isnan(v) else v for v in trained_vals]\n",
|
| 1208 |
+
" ax2.bar(x + w/2, trained_plot_vals, w, label='Trained LLM (GRPO)', color=\"#3fb950\", alpha=0.9)\n",
|
| 1209 |
"ax2.set_xticks(x)\n",
|
| 1210 |
"ax2.set_xticklabels(task_labels)\n",
|
| 1211 |
"ax2.set_ylim(0, 1.05)\n",
|
|
|
|
| 1215 |
"ax2.grid(axis='y', alpha=0.3)\n",
|
| 1216 |
"fig2.tight_layout()\n",
|
| 1217 |
"comparison_path = 'results/gridmind_before_after_comparison.png'\n",
|
| 1218 |
+
"if valid_trained_vals:\n",
|
| 1219 |
+
" fig2.savefig(comparison_path, dpi=100)\n",
|
| 1220 |
+
"else:\n",
|
| 1221 |
+
" comparison_path = None\n",
|
| 1222 |
"plt.close(fig2)\n",
|
| 1223 |
"\n",
|
| 1224 |
"print(f\"Saved training reward curve to {reward_curve_path}\")\n",
|
| 1225 |
"print(f\"Saved simple reward curve to {simple_reward_curve_path}\")\n",
|
| 1226 |
"if simple_loss_curve_path:\n",
|
| 1227 |
" print(f\"Saved simple loss curve to {simple_loss_curve_path}\")\n",
|
| 1228 |
+
"if comparison_path:\n",
|
| 1229 |
+
" print(f\"Saved before/after graph to {comparison_path}\")\n",
|
| 1230 |
+
"else:\n",
|
| 1231 |
+
" print(\"Skipped before/after graph because RUN_EVALUATION is False.\")\n",
|
| 1232 |
"\n",
|
| 1233 |
"results = {\n",
|
| 1234 |
" \"heuristic_baseline\": {\n",
|
|
|
|
| 1236 |
" \"average\": baseline_avg\n",
|
| 1237 |
" },\n",
|
| 1238 |
" \"trained_llm\": {\n",
|
| 1239 |
+
" \"scores_by_task\": {str(k): v for k, v in trained_scores.items()} if trained_scores else {},\n",
|
| 1240 |
" \"average\": trained_avg\n",
|
| 1241 |
" },\n",
|
| 1242 |
" \"improvement_percent\": overall_improvement,\n",
|
|
|
|
| 1266 |
"print(f\" Model: {MODEL_NAME}\")\n",
|
| 1267 |
"print(f\" Themes: {results['themes_covered']}\")\n",
|
| 1268 |
"print(f\" Heuristic baseline: {baseline_avg:.3f}\")\n",
|
| 1269 |
+
"if trained_avg is not None:\n",
|
| 1270 |
+
" print(f\" Trained LLM: {trained_avg:.3f}\")\n",
|
| 1271 |
+
"if overall_improvement is not None:\n",
|
| 1272 |
+
" print(f\" Improvement: {overall_improvement:+.1f}%\")\n",
|
| 1273 |
+
"else:\n",
|
| 1274 |
+
" print(\" Improvement: evaluation skipped\")"
|
| 1275 |
]
|
| 1276 |
}
|
| 1277 |
],
|