Spaces:
Running
Running
Commit ·
29b9cd0
1
Parent(s): 9d42d14
feat: add GridMind GRPO training notebook for Meta PyTorch OpenEnv hackathon
Browse files- scripts/gridmind_grpo_colab.ipynb +114 -195
scripts/gridmind_grpo_colab.ipynb
CHANGED
|
@@ -695,6 +695,7 @@
|
|
| 695 |
" \"logging_steps\": 1,\n",
|
| 696 |
" \"save_steps\": 60,\n",
|
| 697 |
" \"report_to\": \"none\",\n",
|
|
|
|
| 698 |
" \"dataloader_num_workers\": 0,\n",
|
| 699 |
" \"remove_unused_columns\": False,\n",
|
| 700 |
"}\n",
|
|
@@ -732,6 +733,53 @@
|
|
| 732 |
"step_numbers = []\n",
|
| 733 |
"step_reward_means = []\n",
|
| 734 |
"training_log_history = []\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 735 |
"\n",
|
| 736 |
"class LossCaptureCallback(TrainerCallback):\n",
|
| 737 |
" def on_log(self, args, state, control, logs=None, **kwargs):\n",
|
|
@@ -740,7 +788,25 @@
|
|
| 740 |
" step = state.global_step\n",
|
| 741 |
" row = {\"step\": step}\n",
|
| 742 |
" row.update({k: float(v) if isinstance(v, (int, float)) else v for k, v in logs.items()})\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 743 |
" training_log_history.append(row)\n",
|
|
|
|
|
|
|
| 744 |
" loss = logs.get(\"loss\", logs.get(\"train_loss\", None))\n",
|
| 745 |
" if loss is not None:\n",
|
| 746 |
" step_losses.append(float(loss))\n",
|
|
@@ -767,6 +833,17 @@
|
|
| 767 |
" callbacks=[LossCaptureCallback()],\n",
|
| 768 |
")\n",
|
| 769 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 770 |
"print(\"\\nStarting GRPO training with QLoRA...\")\n",
|
| 771 |
"print(\"Watch for non-zero loss values. If all zeros, reward variance is still too low.\\n\")\n",
|
| 772 |
"print(f\"Steps: {getattr(grpo_config, 'max_steps', 60)} | Batch: {getattr(grpo_config, 'per_device_train_batch_size', 1)} | Generations: {getattr(grpo_config, 'num_generations', 4)}\")\n",
|
|
@@ -792,7 +869,9 @@
|
|
| 792 |
" import pandas as pd\n",
|
| 793 |
" import numpy as np\n",
|
| 794 |
" trainer_log_rows = [r for r in trainer.state.log_history if \"loss\" in r or \"reward\" in r or \"rewards / reward_func / mean\" in r]\n",
|
| 795 |
-
" if
|
|
|
|
|
|
|
| 796 |
" training_metrics_df = pd.DataFrame(trainer_log_rows)\n",
|
| 797 |
" if \"step\" not in training_metrics_df.columns:\n",
|
| 798 |
" training_metrics_df.insert(0, \"step\", range(1, len(training_metrics_df) + 1))\n",
|
|
@@ -972,7 +1051,6 @@
|
|
| 972 |
"import matplotlib\n",
|
| 973 |
"matplotlib.use('Agg')\n",
|
| 974 |
"import matplotlib.pyplot as plt\n",
|
| 975 |
-
"import matplotlib.gridspec as gridspec\n",
|
| 976 |
"import numpy as np\n",
|
| 977 |
"import pandas as pd\n",
|
| 978 |
"import os\n",
|
|
@@ -980,11 +1058,13 @@
|
|
| 980 |
"os.makedirs(\"results\", exist_ok=True)\n",
|
| 981 |
"os.makedirs(\"plots\", exist_ok=True)\n",
|
| 982 |
"\n",
|
| 983 |
-
"#
|
| 984 |
-
"# output with columns like reward, reward_std, completion lengths, tools, and KL.\n",
|
| 985 |
"if 'training_metrics_df' not in globals() or training_metrics_df is None:\n",
|
| 986 |
" trainer_log_rows = [r for r in trainer.state.log_history if \"loss\" in r or \"reward\" in r or \"rewards / reward_func / mean\" in r]\n",
|
| 987 |
-
"
|
|
|
|
|
|
|
|
|
|
| 988 |
" if not training_metrics_df.empty and \"step\" not in training_metrics_df.columns:\n",
|
| 989 |
" training_metrics_df.insert(0, \"step\", range(1, len(training_metrics_df) + 1))\n",
|
| 990 |
"\n",
|
|
@@ -992,17 +1072,7 @@
|
|
| 992 |
"if not training_metrics_df.empty:\n",
|
| 993 |
" training_metrics_df.to_csv(training_metrics_path, index=False)\n",
|
| 994 |
" print(f\"Saved TRL metrics table to {training_metrics_path}\")\n",
|
| 995 |
-
"
|
| 996 |
-
" \"step\", \"loss\", \"reward\", \"reward_std\",\n",
|
| 997 |
-
" \"completions / mean_length\", \"completions / min_length\", \"completions / max_length\",\n",
|
| 998 |
-
" \"completions / clipped_ratio\", \"completions / mean_terminated_length\",\n",
|
| 999 |
-
" \"completions / min_terminated_length\", \"completions / max_terminated_length\",\n",
|
| 1000 |
-
" \"tools / call_frequency\", \"tools / failure_frequency\", \"kl\",\n",
|
| 1001 |
-
" \"rewards / reward_func / mean\", \"rewards / reward_func / std\",\n",
|
| 1002 |
-
" ]\n",
|
| 1003 |
-
" display_cols = [c for c in preferred_cols if c in training_metrics_df.columns]\n",
|
| 1004 |
-
" if display_cols:\n",
|
| 1005 |
-
" print(\"Step 6 already displayed the full TRL-style metrics table; Step 8 only reuses it for plots.\")\n",
|
| 1006 |
"\n",
|
| 1007 |
"tasks = [1, 2, 3, 4]\n",
|
| 1008 |
"task_labels = [\n",
|
|
@@ -1034,179 +1104,31 @@
|
|
| 1034 |
" out.append(sum(w) / len(w))\n",
|
| 1035 |
" return out\n",
|
| 1036 |
"\n",
|
| 1037 |
-
"C = {\n",
|
| 1038 |
-
" 'bg': '#0d1117', 'panel': '#161b22', 'grid': '#21262d',\n",
|
| 1039 |
-
" 'text': '#e6edf3', 'subtext': '#8b949e', 'random': '#f85149',\n",
|
| 1040 |
-
" 'heuristic': '#58a6ff', 'trained': '#3fb950', 'reward': '#d29922',\n",
|
| 1041 |
-
" 'loss': '#bc8cff', 'border': '#30363d',\n",
|
| 1042 |
-
"}\n",
|
| 1043 |
-
"\n",
|
| 1044 |
-
"def style_ax(ax, title):\n",
|
| 1045 |
-
" ax.set_facecolor(C['panel'])\n",
|
| 1046 |
-
" ax.set_title(title, color=C['text'], fontsize=12, fontweight='bold', pad=10)\n",
|
| 1047 |
-
" ax.tick_params(colors=C['subtext'], labelsize=9)\n",
|
| 1048 |
-
" ax.grid(alpha=0.15, color=C['grid'], linewidth=0.8)\n",
|
| 1049 |
-
" for spine in ax.spines.values():\n",
|
| 1050 |
-
" spine.set_edgecolor(C['border'])\n",
|
| 1051 |
-
" ax.xaxis.label.set_color(C['subtext'])\n",
|
| 1052 |
-
" ax.yaxis.label.set_color(C['subtext'])\n",
|
| 1053 |
-
"\n",
|
| 1054 |
-
"fig = plt.figure(figsize=(12, 8))\n",
|
| 1055 |
-
"fig.patch.set_facecolor(C['bg'])\n",
|
| 1056 |
-
"gs = gridspec.GridSpec(2, 2, figure=fig, hspace=0.50, wspace=0.38,\n",
|
| 1057 |
-
" left=0.07, right=0.97, top=0.91, bottom=0.07)\n",
|
| 1058 |
-
"\n",
|
| 1059 |
-
"# Panel A: policy comparison across all tasks.\n",
|
| 1060 |
-
"ax_bar = fig.add_subplot(gs[0, :])\n",
|
| 1061 |
-
"ax_bar.set_facecolor(C['panel'])\n",
|
| 1062 |
-
"x = np.arange(len(tasks))\n",
|
| 1063 |
-
"w = 0.24\n",
|
| 1064 |
-
"br = ax_bar.bar(x - w, random_vals, w, label='Random Policy', color=C['random'], alpha=0.85, zorder=3, edgecolor=C['bg'], linewidth=0.5)\n",
|
| 1065 |
-
"bh = ax_bar.bar(x, heuristic_vals, w, label='Heuristic Baseline', color=C['heuristic'], alpha=0.85, zorder=3, edgecolor=C['bg'], linewidth=0.5)\n",
|
| 1066 |
-
"bt = ax_bar.bar(x + w, trained_vals, w, label='Trained LLM (GRPO)', color=C['trained'], alpha=0.85, zorder=3, edgecolor=C['bg'], linewidth=0.5)\n",
|
| 1067 |
-
"\n",
|
| 1068 |
-
"for bars, col in [(br, C['random']), (bh, C['heuristic']), (bt, C['trained'])]:\n",
|
| 1069 |
-
" for bar in bars:\n",
|
| 1070 |
-
" h = bar.get_height()\n",
|
| 1071 |
-
" ax_bar.text(bar.get_x() + bar.get_width()/2, h + 0.012, f'{h:.3f}',\n",
|
| 1072 |
-
" ha='center', va='bottom', fontsize=8.5, color=col, fontweight='bold', zorder=4)\n",
|
| 1073 |
-
"\n",
|
| 1074 |
-
"for i in range(len(tasks)):\n",
|
| 1075 |
-
" h_val = heuristic_vals[i]\n",
|
| 1076 |
-
" t_val = trained_vals[i]\n",
|
| 1077 |
-
" pct = ((t_val - h_val) / h_val * 100) if h_val > 0 else 0\n",
|
| 1078 |
-
" color = C['trained'] if pct >= 0 else C['random']\n",
|
| 1079 |
-
" sign = '+' if pct >= 0 else '-'\n",
|
| 1080 |
-
" ax_bar.text(x[i] + w, max(h_val, t_val) + 0.06, f'{sign}{abs(pct):.1f}%',\n",
|
| 1081 |
-
" ha='center', fontsize=10, color=color, fontweight='bold', zorder=4)\n",
|
| 1082 |
-
"\n",
|
| 1083 |
-
"ax_bar.axhline(baseline_avg, color=C['heuristic'], linestyle=':', linewidth=1.5, alpha=0.6,\n",
|
| 1084 |
-
" label=f'Heuristic avg ({baseline_avg:.3f})', zorder=2)\n",
|
| 1085 |
-
"ax_bar.axhline(trained_avg, color=C['trained'], linestyle=':', linewidth=1.5, alpha=0.6,\n",
|
| 1086 |
-
" label=f'Trained avg ({trained_avg:.3f})', zorder=2)\n",
|
| 1087 |
-
"ax_bar.set_xticks(x)\n",
|
| 1088 |
-
"ax_bar.set_xticklabels(task_labels, color=C['text'], fontsize=10)\n",
|
| 1089 |
-
"ax_bar.set_ylabel('Grade Score (0.0 to 1.0, higher is better)', fontsize=11, color=C['subtext'])\n",
|
| 1090 |
-
"ax_bar.set_ylim(0, 1.15)\n",
|
| 1091 |
-
"ax_bar.set_title('GridMind-RL Policy Performance Across All 4 Hackathon Themes\\nRandom vs Heuristic Baseline vs GRPO Fine-Tuned LLM',\n",
|
| 1092 |
-
" color=C['text'], fontsize=13, fontweight='bold', pad=12)\n",
|
| 1093 |
-
"ax_bar.legend(fontsize=10, facecolor=C['grid'], labelcolor=C['text'], framealpha=0.9,\n",
|
| 1094 |
-
" edgecolor=C['border'], ncol=3, loc='upper right')\n",
|
| 1095 |
-
"ax_bar.grid(axis='y', alpha=0.15, color=C['grid'], zorder=1)\n",
|
| 1096 |
-
"for spine in ax_bar.spines.values():\n",
|
| 1097 |
-
" spine.set_edgecolor(C['border'])\n",
|
| 1098 |
-
"ax_bar.tick_params(colors=C['subtext'])\n",
|
| 1099 |
-
"\n",
|
| 1100 |
-
"# Panel B: reward signal over time.\n",
|
| 1101 |
-
"ax_rew = fig.add_subplot(gs[1, 0])\n",
|
| 1102 |
-
"style_ax(ax_rew, 'GRPO Training: Reward Signal per Step')\n",
|
| 1103 |
-
"if not training_metrics_df.empty and (\"reward\" in training_metrics_df.columns or \"rewards / reward_func / mean\" in training_metrics_df.columns):\n",
|
| 1104 |
-
" reward_col = \"reward\" if \"reward\" in training_metrics_df.columns else \"rewards / reward_func / mean\"\n",
|
| 1105 |
-
" std_col = \"reward_std\" if \"reward_std\" in training_metrics_df.columns else \"rewards / reward_func / std\"\n",
|
| 1106 |
-
" reward_df = training_metrics_df[[\"step\", reward_col] + ([std_col] if std_col in training_metrics_df.columns else [])].dropna(subset=[reward_col])\n",
|
| 1107 |
-
" steps_r = reward_df[\"step\"].astype(float).tolist()\n",
|
| 1108 |
-
" raw = reward_df[reward_col].astype(float).tolist()\n",
|
| 1109 |
-
" ax_rew.plot(steps_r, raw, alpha=0.28, color=C['reward'], linewidth=1.2, marker='o', markersize=2, label='Logged reward')\n",
|
| 1110 |
-
" ax_rew.plot(steps_r, smooth(raw, window=6), color=C['reward'], linewidth=2.5, label='Smoothed reward')\n",
|
| 1111 |
-
" if std_col in reward_df.columns:\n",
|
| 1112 |
-
" std = reward_df[std_col].fillna(0).astype(float).to_numpy()\n",
|
| 1113 |
-
" raw_np = np.array(raw)\n",
|
| 1114 |
-
" steps_np = np.array(steps_r)\n",
|
| 1115 |
-
" ax_rew.fill_between(steps_np, raw_np - std, raw_np + std, color=C['reward'], alpha=0.12, label='Reward std')\n",
|
| 1116 |
-
" if len(steps_r) > 8:\n",
|
| 1117 |
-
" z = np.polyfit(steps_r, raw, 1)\n",
|
| 1118 |
-
" p = np.poly1d(z)\n",
|
| 1119 |
-
" ax_rew.plot(steps_r, p(steps_r), '--', color='white', alpha=0.35, linewidth=1.5,\n",
|
| 1120 |
-
" label=f'Trend ({z[0]:+.5f}/step)')\n",
|
| 1121 |
-
" ax_rew.set_xlabel('Training Step')\n",
|
| 1122 |
-
" ax_rew.set_ylabel('Reward Value')\n",
|
| 1123 |
-
" ax_rew.legend(fontsize=9, facecolor=C['grid'], labelcolor=C['text'], framealpha=0.9, edgecolor=C['border'])\n",
|
| 1124 |
-
" if np.var(raw) < 0.01:\n",
|
| 1125 |
-
" ax_rew.text(0.5, 0.5, 'Low reward variance detected.\\nThis graph exposes weak learning signal.',\n",
|
| 1126 |
-
" transform=ax_rew.transAxes, ha='center', va='center', color=C['random'], fontsize=10,\n",
|
| 1127 |
-
" bbox=dict(boxstyle='round', facecolor=C['panel'], alpha=0.8))\n",
|
| 1128 |
-
"elif training_rewards and len(training_rewards) >= 4:\n",
|
| 1129 |
-
" raw = training_rewards\n",
|
| 1130 |
-
" steps_r = list(range(1, len(raw) + 1))\n",
|
| 1131 |
-
" ax_rew.plot(steps_r, raw, alpha=0.20, color=C['reward'], linewidth=1)\n",
|
| 1132 |
-
" ax_rew.plot(steps_r, smooth(raw, window=6), color=C['reward'], linewidth=2.5, label='Smoothed reward')\n",
|
| 1133 |
-
" ax_rew.set_xlabel('Reward Function Call')\n",
|
| 1134 |
-
" ax_rew.set_ylabel('Reward Value')\n",
|
| 1135 |
-
" ax_rew.legend(fontsize=9, facecolor=C['grid'], labelcolor=C['text'], framealpha=0.9, edgecolor=C['border'])\n",
|
| 1136 |
-
"else:\n",
|
| 1137 |
-
" ax_rew.text(0.5, 0.5, 'No training rewards captured.\\nRe-run with fixed reward function.',\n",
|
| 1138 |
-
" transform=ax_rew.transAxes, ha='center', va='center', color=C['subtext'], fontsize=11)\n",
|
| 1139 |
-
"\n",
|
| 1140 |
-
"# Panel C: training loss, with reward variance fallback.\n",
|
| 1141 |
-
"ax_loss = fig.add_subplot(gs[1, 1])\n",
|
| 1142 |
-
"style_ax(ax_loss, 'GRPO Training Loss per Step')\n",
|
| 1143 |
-
"if step_losses and len(step_losses) >= 2:\n",
|
| 1144 |
-
" ax_loss.plot(step_numbers, step_losses, alpha=0.25, color=C['loss'], linewidth=1)\n",
|
| 1145 |
-
" ax_loss.plot(step_numbers, smooth(step_losses, window=4), color=C['loss'], linewidth=2.5, label='Smoothed loss')\n",
|
| 1146 |
-
" non_zero = [l for l in step_losses if abs(l) > 1e-7]\n",
|
| 1147 |
-
" pct_nz = len(non_zero) / len(step_losses) * 100\n",
|
| 1148 |
-
" note_color = C['trained'] if pct_nz > 50 else C['random']\n",
|
| 1149 |
-
" ax_loss.text(0.04, 0.96, f'Non-zero steps: {len(non_zero)}/{len(step_losses)} ({pct_nz:.0f}%)',\n",
|
| 1150 |
-
" transform=ax_loss.transAxes, va='top', color=note_color, fontsize=9,\n",
|
| 1151 |
-
" bbox=dict(boxstyle='round', facecolor=C['panel'], alpha=0.8))\n",
|
| 1152 |
-
" ax_loss.set_xlabel('Training Step')\n",
|
| 1153 |
-
" ax_loss.set_ylabel('Loss')\n",
|
| 1154 |
-
" ax_loss.legend(fontsize=9, facecolor=C['grid'], labelcolor=C['text'], framealpha=0.9, edgecolor=C['border'])\n",
|
| 1155 |
-
"else:\n",
|
| 1156 |
-
" proxy_loss = []\n",
|
| 1157 |
-
" for i in range(0, len(training_rewards), 4):\n",
|
| 1158 |
-
" chunk = training_rewards[i:i+4]\n",
|
| 1159 |
-
" if len(chunk) > 1:\n",
|
| 1160 |
-
" proxy_loss.append(float(np.var(chunk)))\n",
|
| 1161 |
-
" if proxy_loss:\n",
|
| 1162 |
-
" ax_loss.plot(range(1, len(proxy_loss) + 1), proxy_loss, color=C['loss'], linewidth=2,\n",
|
| 1163 |
-
" label='Reward variance proxy')\n",
|
| 1164 |
-
" ax_loss.set_xlabel('Training Batch')\n",
|
| 1165 |
-
" ax_loss.set_ylabel('Reward Variance')\n",
|
| 1166 |
-
" ax_loss.legend(fontsize=9, facecolor=C['grid'], labelcolor=C['text'], framealpha=0.9, edgecolor=C['border'])\n",
|
| 1167 |
-
" ax_loss.text(0.5, 0.92, 'Loss not captured - showing reward variance proxy',\n",
|
| 1168 |
-
" transform=ax_loss.transAxes, ha='center', color=C['subtext'], fontsize=8)\n",
|
| 1169 |
-
" else:\n",
|
| 1170 |
-
" ax_loss.text(0.5, 0.5, 'No loss data available.', transform=ax_loss.transAxes,\n",
|
| 1171 |
-
" ha='center', va='center', color=C['subtext'], fontsize=11)\n",
|
| 1172 |
-
"\n",
|
| 1173 |
-
"fig.suptitle(\n",
|
| 1174 |
-
" 'GridMind-RL - Meta OpenEnv Hackathon - Multi-Agent Industrial Energy Management\\n'\n",
|
| 1175 |
-
" f'Model: Qwen2.5-1.5B + QLoRA + GRPO | Overall improvement vs heuristic: {overall_improvement:+.1f}%',\n",
|
| 1176 |
-
" color=C['text'], fontsize=14, fontweight='bold', y=0.97\n",
|
| 1177 |
-
")\n",
|
| 1178 |
-
"\n",
|
| 1179 |
-
"dashboard_path = 'results/gridmind_training_dashboard.png'\n",
|
| 1180 |
-
"fig.savefig(dashboard_path, dpi=100, facecolor=fig.get_facecolor(), bbox_inches='tight')\n",
|
| 1181 |
-
"plt.close(fig)\n",
|
| 1182 |
-
"\n",
|
| 1183 |
-
"# Standalone training reward curve for reports/slides.\n",
|
| 1184 |
"reward_curve_path = 'results/gridmind_training_reward_curve.png'\n",
|
| 1185 |
-
"fig_reward, ax_reward = plt.subplots(figsize=(
|
| 1186 |
-
"fig_reward.patch.set_facecolor(C['bg'])\n",
|
| 1187 |
-
"style_ax(ax_reward, 'Training Reward Curve')\n",
|
| 1188 |
"if not training_metrics_df.empty and (\"reward\" in training_metrics_df.columns or \"rewards / reward_func / mean\" in training_metrics_df.columns):\n",
|
| 1189 |
" reward_col = \"reward\" if \"reward\" in training_metrics_df.columns else \"rewards / reward_func / mean\"\n",
|
| 1190 |
" std_col = \"reward_std\" if \"reward_std\" in training_metrics_df.columns else \"rewards / reward_func / std\"\n",
|
| 1191 |
" reward_df = training_metrics_df[[\"step\", reward_col] + ([std_col] if std_col in training_metrics_df.columns else [])].dropna(subset=[reward_col])\n",
|
| 1192 |
" xs = reward_df[\"step\"].astype(float).to_numpy()\n",
|
| 1193 |
" ys = reward_df[reward_col].astype(float).to_numpy()\n",
|
| 1194 |
-
" ax_reward.plot(xs, ys, color=
|
| 1195 |
-
"
|
|
|
|
|
|
|
|
|
|
| 1196 |
" if std_col in reward_df.columns:\n",
|
| 1197 |
" std = reward_df[std_col].fillna(0).astype(float).to_numpy()\n",
|
| 1198 |
-
" ax_reward.fill_between(xs, ys - std, ys + std, color=
|
| 1199 |
-
" if len(xs) > 8:\n",
|
| 1200 |
-
" z = np.polyfit(xs, ys, 1)\n",
|
| 1201 |
-
" p = np.poly1d(z)\n",
|
| 1202 |
-
" ax_reward.plot(xs, p(xs), '--', color=C['text'], alpha=0.45, linewidth=1.5, label=f'Trend ({z[0]:+.5f}/step)')\n",
|
| 1203 |
-
" ax_reward.set_xlabel('Training Step')\n",
|
| 1204 |
-
" ax_reward.set_ylabel('Reward')\n",
|
| 1205 |
-
" ax_reward.legend(facecolor=C['grid'], labelcolor=C['text'], edgecolor=C['border'])\n",
|
| 1206 |
"else:\n",
|
| 1207 |
-
" ax_reward.text(0.5, 0.5, 'No logged reward data available.', transform=ax_reward.transAxes,\n",
|
| 1208 |
-
"
|
| 1209 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1210 |
"plt.close(fig_reward)\n",
|
| 1211 |
"\n",
|
| 1212 |
"# Reference-style simple plots from trainer.state.log_history.\n",
|
|
@@ -1267,26 +1189,23 @@
|
|
| 1267 |
" print(\"No loss logs found; skipped plots/loss_curve.png\")\n",
|
| 1268 |
"\n",
|
| 1269 |
"# Separate before/after comparison graph for quick judge inspection.\n",
|
| 1270 |
-
"fig2, ax2 = plt.subplots(figsize=(
|
| 1271 |
-
"
|
| 1272 |
-
"
|
| 1273 |
-
"ax2.bar(x - w/2, heuristic_vals, w, label='Heuristic Baseline', color=
|
| 1274 |
-
"ax2.bar(x + w/2, trained_vals, w, label='Trained LLM (GRPO)', color=
|
| 1275 |
"ax2.set_xticks(x)\n",
|
| 1276 |
-
"ax2.set_xticklabels(task_labels
|
| 1277 |
"ax2.set_ylim(0, 1.05)\n",
|
| 1278 |
-
"ax2.set_ylabel('Grade Score'
|
| 1279 |
-
"ax2.set_title('Before/After Policy Score Comparison',
|
| 1280 |
-
"ax2.legend(
|
| 1281 |
-
"ax2.grid(axis='y', alpha=0.
|
| 1282 |
-
"
|
| 1283 |
-
"for spine in ax2.spines.values():\n",
|
| 1284 |
-
" spine.set_edgecolor(C['border'])\n",
|
| 1285 |
"comparison_path = 'results/gridmind_before_after_comparison.png'\n",
|
| 1286 |
-
"fig2.savefig(comparison_path, dpi=100
|
| 1287 |
"plt.close(fig2)\n",
|
| 1288 |
"\n",
|
| 1289 |
-
"print(f\"Saved dashboard graph to {dashboard_path}\")\n",
|
| 1290 |
"print(f\"Saved training reward curve to {reward_curve_path}\")\n",
|
| 1291 |
"print(f\"Saved simple reward curve to {simple_reward_curve_path}\")\n",
|
| 1292 |
"if simple_loss_curve_path:\n",
|
|
@@ -1312,7 +1231,7 @@
|
|
| 1312 |
" \"training_metrics_table\": training_metrics_path,\n",
|
| 1313 |
" \"training_metrics_display_table\": training_metrics_display_path if 'training_metrics_display_path' in globals() else None,\n",
|
| 1314 |
" \"graphs\": {\n",
|
| 1315 |
-
" \"dashboard\":
|
| 1316 |
" \"training_reward_curve\": reward_curve_path,\n",
|
| 1317 |
" \"simple_reward_curve\": simple_reward_curve_path,\n",
|
| 1318 |
" \"simple_loss_curve\": simple_loss_curve_path,\n",
|
|
|
|
| 695 |
" \"logging_steps\": 1,\n",
|
| 696 |
" \"save_steps\": 60,\n",
|
| 697 |
" \"report_to\": \"none\",\n",
|
| 698 |
+
" \"disable_tqdm\": True,\n",
|
| 699 |
" \"dataloader_num_workers\": 0,\n",
|
| 700 |
" \"remove_unused_columns\": False,\n",
|
| 701 |
"}\n",
|
|
|
|
| 733 |
"step_numbers = []\n",
|
| 734 |
"step_reward_means = []\n",
|
| 735 |
"training_log_history = []\n",
|
| 736 |
+
"training_table_rows = []\n",
|
| 737 |
+
"_training_table_header_printed = [False]\n",
|
| 738 |
+
"\n",
|
| 739 |
+
"TRAINING_TABLE_COLUMNS = [\n",
|
| 740 |
+
" (\"Step\", \"step\"),\n",
|
| 741 |
+
" (\"Training Loss\", \"loss\"),\n",
|
| 742 |
+
" (\"reward\", \"reward\"),\n",
|
| 743 |
+
" (\"reward_std\", \"reward_std\"),\n",
|
| 744 |
+
" (\"completions / mean_length\", \"completions / mean_length\"),\n",
|
| 745 |
+
" (\"completions / min_length\", \"completions / min_length\"),\n",
|
| 746 |
+
" (\"completions / max_length\", \"completions / max_length\"),\n",
|
| 747 |
+
" (\"completions / clipped_ratio\", \"completions / clipped_ratio\"),\n",
|
| 748 |
+
" (\"completions / mean_terminated_length\", \"completions / mean_terminated_length\"),\n",
|
| 749 |
+
" (\"completions / min_terminated_length\", \"completions / min_terminated_length\"),\n",
|
| 750 |
+
" (\"completions / max_terminated_length\", \"completions / max_terminated_length\"),\n",
|
| 751 |
+
" (\"tools / call_frequency\", \"tools / call_frequency\"),\n",
|
| 752 |
+
" (\"tools / failure_frequency\", \"tools / failure_frequency\"),\n",
|
| 753 |
+
" (\"kl\", \"kl\"),\n",
|
| 754 |
+
" (\"rewards / reward_func / mean\", \"rewards / reward_func / mean\"),\n",
|
| 755 |
+
" (\"rewards / reward_func / std\", \"rewards / reward_func / std\"),\n",
|
| 756 |
+
"]\n",
|
| 757 |
+
"\n",
|
| 758 |
+
"def _metric_value(logs, *keys, default=float(\"nan\")):\n",
|
| 759 |
+
" for key in keys:\n",
|
| 760 |
+
" if key in logs and logs[key] is not None:\n",
|
| 761 |
+
" return logs[key]\n",
|
| 762 |
+
" return default\n",
|
| 763 |
+
"\n",
|
| 764 |
+
"def _fmt_metric(value):\n",
|
| 765 |
+
" try:\n",
|
| 766 |
+
" if value is None or (isinstance(value, float) and value != value):\n",
|
| 767 |
+
" return \"\"\n",
|
| 768 |
+
" if isinstance(value, int):\n",
|
| 769 |
+
" return str(value)\n",
|
| 770 |
+
" return f\"{float(value):.6f}\"\n",
|
| 771 |
+
" except Exception:\n",
|
| 772 |
+
" return str(value)\n",
|
| 773 |
+
"\n",
|
| 774 |
+
"def _print_training_table_row(row):\n",
|
| 775 |
+
" widths = [6, 14, 10, 10, 26, 25, 25, 29, 38, 37, 37, 24, 27, 10, 28, 27]\n",
|
| 776 |
+
" if not _training_table_header_printed[0]:\n",
|
| 777 |
+
" header = \" \".join(label.ljust(widths[i]) for i, (label, _) in enumerate(TRAINING_TABLE_COLUMNS))\n",
|
| 778 |
+
" print(\"\\n\" + header)\n",
|
| 779 |
+
" print(\"-\" * len(header))\n",
|
| 780 |
+
" _training_table_header_printed[0] = True\n",
|
| 781 |
+
" values = [_fmt_metric(row.get(source, float(\"nan\"))).ljust(widths[i]) for i, (_, source) in enumerate(TRAINING_TABLE_COLUMNS)]\n",
|
| 782 |
+
" print(\" \".join(values))\n",
|
| 783 |
"\n",
|
| 784 |
"class LossCaptureCallback(TrainerCallback):\n",
|
| 785 |
" def on_log(self, args, state, control, logs=None, **kwargs):\n",
|
|
|
|
| 788 |
" step = state.global_step\n",
|
| 789 |
" row = {\"step\": step}\n",
|
| 790 |
" row.update({k: float(v) if isinstance(v, (int, float)) else v for k, v in logs.items()})\n",
|
| 791 |
+
" if \"loss\" not in row and \"train_loss\" in row:\n",
|
| 792 |
+
" row[\"loss\"] = row[\"train_loss\"]\n",
|
| 793 |
+
" recent_rewards = training_rewards[max(0, len(training_rewards)-4):]\n",
|
| 794 |
+
" if recent_rewards:\n",
|
| 795 |
+
" if \"reward\" not in row and \"rewards / reward_func / mean\" not in row:\n",
|
| 796 |
+
" row[\"reward\"] = sum(recent_rewards) / len(recent_rewards)\n",
|
| 797 |
+
" if \"reward_std\" not in row and \"rewards / reward_func / std\" not in row and len(recent_rewards) > 1:\n",
|
| 798 |
+
" row[\"reward_std\"] = statistics.pstdev(recent_rewards)\n",
|
| 799 |
+
" if \"rewards / reward_func / mean\" not in row and \"reward\" in row:\n",
|
| 800 |
+
" row[\"rewards / reward_func / mean\"] = row[\"reward\"]\n",
|
| 801 |
+
" if \"rewards / reward_func / std\" not in row and \"reward_std\" in row:\n",
|
| 802 |
+
" row[\"rewards / reward_func / std\"] = row[\"reward_std\"]\n",
|
| 803 |
+
" if \"tools / call_frequency\" not in row:\n",
|
| 804 |
+
" row[\"tools / call_frequency\"] = float(\"nan\")\n",
|
| 805 |
+
" if \"tools / failure_frequency\" not in row:\n",
|
| 806 |
+
" row[\"tools / failure_frequency\"] = 0.0\n",
|
| 807 |
" training_log_history.append(row)\n",
|
| 808 |
+
" training_table_rows.append(row)\n",
|
| 809 |
+
" _print_training_table_row(row)\n",
|
| 810 |
" loss = logs.get(\"loss\", logs.get(\"train_loss\", None))\n",
|
| 811 |
" if loss is not None:\n",
|
| 812 |
" step_losses.append(float(loss))\n",
|
|
|
|
| 833 |
" callbacks=[LossCaptureCallback()],\n",
|
| 834 |
")\n",
|
| 835 |
"\n",
|
| 836 |
+
"# Remove the default Trainer progress/notebook callbacks so only the custom\n",
|
| 837 |
+
"# TRL-style table appears during training.\n",
|
| 838 |
+
"from transformers.trainer_callback import ProgressCallback, PrinterCallback\n",
|
| 839 |
+
"trainer.remove_callback(ProgressCallback)\n",
|
| 840 |
+
"trainer.remove_callback(PrinterCallback)\n",
|
| 841 |
+
"try:\n",
|
| 842 |
+
" from transformers.utils.notebook import NotebookProgressCallback\n",
|
| 843 |
+
" trainer.remove_callback(NotebookProgressCallback)\n",
|
| 844 |
+
"except Exception:\n",
|
| 845 |
+
" pass\n",
|
| 846 |
+
"\n",
|
| 847 |
"print(\"\\nStarting GRPO training with QLoRA...\")\n",
|
| 848 |
"print(\"Watch for non-zero loss values. If all zeros, reward variance is still too low.\\n\")\n",
|
| 849 |
"print(f\"Steps: {getattr(grpo_config, 'max_steps', 60)} | Batch: {getattr(grpo_config, 'per_device_train_batch_size', 1)} | Generations: {getattr(grpo_config, 'num_generations', 4)}\")\n",
|
|
|
|
| 869 |
" import pandas as pd\n",
|
| 870 |
" import numpy as np\n",
|
| 871 |
" trainer_log_rows = [r for r in trainer.state.log_history if \"loss\" in r or \"reward\" in r or \"rewards / reward_func / mean\" in r]\n",
|
| 872 |
+
" if training_table_rows:\n",
|
| 873 |
+
" training_metrics_df = pd.DataFrame(training_table_rows)\n",
|
| 874 |
+
" elif trainer_log_rows:\n",
|
| 875 |
" training_metrics_df = pd.DataFrame(trainer_log_rows)\n",
|
| 876 |
" if \"step\" not in training_metrics_df.columns:\n",
|
| 877 |
" training_metrics_df.insert(0, \"step\", range(1, len(training_metrics_df) + 1))\n",
|
|
|
|
| 1051 |
"import matplotlib\n",
|
| 1052 |
"matplotlib.use('Agg')\n",
|
| 1053 |
"import matplotlib.pyplot as plt\n",
|
|
|
|
| 1054 |
"import numpy as np\n",
|
| 1055 |
"import pandas as pd\n",
|
| 1056 |
"import os\n",
|
|
|
|
| 1058 |
"os.makedirs(\"results\", exist_ok=True)\n",
|
| 1059 |
"os.makedirs(\"plots\", exist_ok=True)\n",
|
| 1060 |
"\n",
|
| 1061 |
+
"# Reuse the Step 6 metrics table and only do lightweight exports here.\n",
|
|
|
|
| 1062 |
"if 'training_metrics_df' not in globals() or training_metrics_df is None:\n",
|
| 1063 |
" trainer_log_rows = [r for r in trainer.state.log_history if \"loss\" in r or \"reward\" in r or \"rewards / reward_func / mean\" in r]\n",
|
| 1064 |
+
" if 'training_table_rows' in globals() and training_table_rows:\n",
|
| 1065 |
+
" training_metrics_df = pd.DataFrame(training_table_rows)\n",
|
| 1066 |
+
" else:\n",
|
| 1067 |
+
" training_metrics_df = pd.DataFrame(trainer_log_rows if trainer_log_rows else training_log_history)\n",
|
| 1068 |
" if not training_metrics_df.empty and \"step\" not in training_metrics_df.columns:\n",
|
| 1069 |
" training_metrics_df.insert(0, \"step\", range(1, len(training_metrics_df) + 1))\n",
|
| 1070 |
"\n",
|
|
|
|
| 1072 |
"if not training_metrics_df.empty:\n",
|
| 1073 |
" training_metrics_df.to_csv(training_metrics_path, index=False)\n",
|
| 1074 |
" print(f\"Saved TRL metrics table to {training_metrics_path}\")\n",
|
| 1075 |
+
" print(\"Step 8 reuses the Step 6 table and only saves files.\")\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1076 |
"\n",
|
| 1077 |
"tasks = [1, 2, 3, 4]\n",
|
| 1078 |
"task_labels = [\n",
|
|
|
|
| 1104 |
" out.append(sum(w) / len(w))\n",
|
| 1105 |
" return out\n",
|
| 1106 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1107 |
"reward_curve_path = 'results/gridmind_training_reward_curve.png'\n",
|
| 1108 |
+
"fig_reward, ax_reward = plt.subplots(figsize=(10, 5))\n",
|
|
|
|
|
|
|
| 1109 |
"if not training_metrics_df.empty and (\"reward\" in training_metrics_df.columns or \"rewards / reward_func / mean\" in training_metrics_df.columns):\n",
|
| 1110 |
" reward_col = \"reward\" if \"reward\" in training_metrics_df.columns else \"rewards / reward_func / mean\"\n",
|
| 1111 |
" std_col = \"reward_std\" if \"reward_std\" in training_metrics_df.columns else \"rewards / reward_func / std\"\n",
|
| 1112 |
" reward_df = training_metrics_df[[\"step\", reward_col] + ([std_col] if std_col in training_metrics_df.columns else [])].dropna(subset=[reward_col])\n",
|
| 1113 |
" xs = reward_df[\"step\"].astype(float).to_numpy()\n",
|
| 1114 |
" ys = reward_df[reward_col].astype(float).to_numpy()\n",
|
| 1115 |
+
" ax_reward.plot(xs, ys, color=\"#4285f4\", linewidth=2, label=\"GRPO Reward\")\n",
|
| 1116 |
+
" if len(ys) > 5:\n",
|
| 1117 |
+
" window = max(3, len(ys) // 10)\n",
|
| 1118 |
+
" smoothed = [sum(ys[max(0, i-window):i+1]) / len(ys[max(0, i-window):i+1]) for i in range(len(ys))]\n",
|
| 1119 |
+
" ax_reward.plot(xs[:len(smoothed)], smoothed, color=\"#ea4335\", linewidth=2, linestyle=\"--\", label=f\"Smoothed (window={window})\")\n",
|
| 1120 |
" if std_col in reward_df.columns:\n",
|
| 1121 |
" std = reward_df[std_col].fillna(0).astype(float).to_numpy()\n",
|
| 1122 |
+
" ax_reward.fill_between(xs, ys - std, ys + std, color=\"#4285f4\", alpha=0.12)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1123 |
"else:\n",
|
| 1124 |
+
" ax_reward.text(0.5, 0.5, 'No logged reward data available.', transform=ax_reward.transAxes, ha='center', va='center')\n",
|
| 1125 |
+
"ax_reward.set_xlabel('Training Step', fontsize=12)\n",
|
| 1126 |
+
"ax_reward.set_ylabel('Reward', fontsize=12)\n",
|
| 1127 |
+
"ax_reward.set_title('GridMind-RL GRPO Training - Reward Curve', fontsize=14, fontweight='bold')\n",
|
| 1128 |
+
"ax_reward.legend()\n",
|
| 1129 |
+
"ax_reward.grid(True, alpha=0.3)\n",
|
| 1130 |
+
"fig_reward.tight_layout()\n",
|
| 1131 |
+
"fig_reward.savefig(reward_curve_path, dpi=100)\n",
|
| 1132 |
"plt.close(fig_reward)\n",
|
| 1133 |
"\n",
|
| 1134 |
"# Reference-style simple plots from trainer.state.log_history.\n",
|
|
|
|
| 1189 |
" print(\"No loss logs found; skipped plots/loss_curve.png\")\n",
|
| 1190 |
"\n",
|
| 1191 |
"# Separate before/after comparison graph for quick judge inspection.\n",
|
| 1192 |
+
"fig2, ax2 = plt.subplots(figsize=(10, 5))\n",
|
| 1193 |
+
"x = np.arange(len(tasks))\n",
|
| 1194 |
+
"w = 0.35\n",
|
| 1195 |
+
"ax2.bar(x - w/2, heuristic_vals, w, label='Heuristic Baseline', color=\"#58a6ff\", alpha=0.9)\n",
|
| 1196 |
+
"ax2.bar(x + w/2, trained_vals, w, label='Trained LLM (GRPO)', color=\"#3fb950\", alpha=0.9)\n",
|
| 1197 |
"ax2.set_xticks(x)\n",
|
| 1198 |
+
"ax2.set_xticklabels(task_labels)\n",
|
| 1199 |
"ax2.set_ylim(0, 1.05)\n",
|
| 1200 |
+
"ax2.set_ylabel('Grade Score')\n",
|
| 1201 |
+
"ax2.set_title('Before/After Policy Score Comparison', fontweight='bold')\n",
|
| 1202 |
+
"ax2.legend()\n",
|
| 1203 |
+
"ax2.grid(axis='y', alpha=0.3)\n",
|
| 1204 |
+
"fig2.tight_layout()\n",
|
|
|
|
|
|
|
| 1205 |
"comparison_path = 'results/gridmind_before_after_comparison.png'\n",
|
| 1206 |
+
"fig2.savefig(comparison_path, dpi=100)\n",
|
| 1207 |
"plt.close(fig2)\n",
|
| 1208 |
"\n",
|
|
|
|
| 1209 |
"print(f\"Saved training reward curve to {reward_curve_path}\")\n",
|
| 1210 |
"print(f\"Saved simple reward curve to {simple_reward_curve_path}\")\n",
|
| 1211 |
"if simple_loss_curve_path:\n",
|
|
|
|
| 1231 |
" \"training_metrics_table\": training_metrics_path,\n",
|
| 1232 |
" \"training_metrics_display_table\": training_metrics_display_path if 'training_metrics_display_path' in globals() else None,\n",
|
| 1233 |
" \"graphs\": {\n",
|
| 1234 |
+
" \"dashboard\": None,\n",
|
| 1235 |
" \"training_reward_curve\": reward_curve_path,\n",
|
| 1236 |
" \"simple_reward_curve\": simple_reward_curve_path,\n",
|
| 1237 |
" \"simple_loss_curve\": simple_loss_curve_path,\n",
|