Spaces:
Sleeping
Sleeping
Commit ·
26e9b86
1
Parent(s): 7d89faf
feat: add GRPO training pipeline for GridMind-RL environment via Unsloth and TRL
Browse files- scripts/gridmind_grpo_colab.ipynb +261 -44
scripts/gridmind_grpo_colab.ipynb
CHANGED
|
@@ -343,14 +343,16 @@
|
|
| 343 |
"torch.cuda.empty_cache()\n",
|
| 344 |
"\n",
|
| 345 |
"MODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
|
| 346 |
-
"
|
|
|
|
|
|
|
| 347 |
"\n",
|
| 348 |
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n",
|
| 349 |
"if tokenizer.pad_token is None:\n",
|
| 350 |
" tokenizer.pad_token = tokenizer.eos_token\n",
|
| 351 |
"tokenizer.padding_side = \"left\" # required for GRPO\n",
|
| 352 |
"\n",
|
| 353 |
-
"# 4-bit quantization -
|
| 354 |
"bnb_config = BitsAndBytesConfig(\n",
|
| 355 |
" load_in_4bit=True,\n",
|
| 356 |
" bnb_4bit_compute_dtype=torch.float16,\n",
|
|
@@ -366,8 +368,8 @@
|
|
| 366 |
")\n",
|
| 367 |
"\n",
|
| 368 |
"print(f\"Model loaded on: {next(model.parameters()).device}\")\n",
|
| 369 |
-
"print(f\"Memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB /
|
| 370 |
-
"print(f\"Memory reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB /
|
| 371 |
]
|
| 372 |
},
|
| 373 |
{
|
|
@@ -417,8 +419,15 @@
|
|
| 417 |
" reset_payload = {\"task_id\": task_id, \"seed\": batch_seed}\n",
|
| 418 |
" reset_r = _requests.post(f\"{ENV_URL}/reset\", json=reset_payload, timeout=10)\n",
|
| 419 |
" reset_ok = reset_r.status_code == 200\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
" except Exception:\n",
|
| 421 |
" reset_ok = False\n",
|
|
|
|
| 422 |
"\n",
|
| 423 |
" if not reset_ok:\n",
|
| 424 |
" return [-0.1] * len(completions)\n",
|
|
@@ -450,25 +459,33 @@
|
|
| 450 |
" cleaned_action = {}\n",
|
| 451 |
"\n",
|
| 452 |
" try:\n",
|
| 453 |
-
"
|
|
|
|
|
|
|
| 454 |
" valid_fields += 1\n",
|
| 455 |
" except Exception:\n",
|
| 456 |
" cleaned_action[\"hvac_power_level\"] = 0.5\n",
|
| 457 |
"\n",
|
| 458 |
" try:\n",
|
| 459 |
-
"
|
|
|
|
|
|
|
| 460 |
" valid_fields += 1\n",
|
| 461 |
" except Exception:\n",
|
| 462 |
" cleaned_action[\"thermal_charge_rate\"] = 0.0\n",
|
| 463 |
"\n",
|
| 464 |
" try:\n",
|
| 465 |
-
"
|
|
|
|
|
|
|
| 466 |
" valid_fields += 1\n",
|
| 467 |
" except Exception:\n",
|
| 468 |
" cleaned_action[\"batch_job_slot\"] = 0\n",
|
| 469 |
"\n",
|
| 470 |
" try:\n",
|
| 471 |
-
"
|
|
|
|
|
|
|
| 472 |
" valid_fields += 1\n",
|
| 473 |
" except Exception:\n",
|
| 474 |
" cleaned_action[\"load_shed_fraction\"] = 0.0\n",
|
|
@@ -496,18 +513,32 @@
|
|
| 496 |
" grid_r = float(comps.get(\"grid_response\", 0.0))\n",
|
| 497 |
" task_r = float(comps.get(\"task_satisfaction\", 0.0))\n",
|
| 498 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
" if comps:\n",
|
| 500 |
-
"
|
| 501 |
-
" cost_r * 0.40 +\n",
|
| 502 |
-
" comfort_r * 0.25 +\n",
|
| 503 |
-
" grid_r * 0.15 +\n",
|
| 504 |
-
" task_r * 0.20 +\n",
|
| 505 |
-
" completeness_bonus\n",
|
| 506 |
-
" )\n",
|
| 507 |
" else:\n",
|
| 508 |
-
"
|
| 509 |
"\n",
|
| 510 |
-
"
|
|
|
|
|
|
|
| 511 |
"\n",
|
| 512 |
" rewards.append(composite)\n",
|
| 513 |
" batch_raw.append(composite)\n",
|
|
@@ -641,25 +672,37 @@
|
|
| 641 |
" task_type=\"CAUSAL_LM\",\n",
|
| 642 |
")\n",
|
| 643 |
"\n",
|
| 644 |
-
"# GRPOConfig
|
| 645 |
-
"#
|
| 646 |
-
"
|
| 647 |
-
" output_dir
|
| 648 |
-
" num_train_epochs
|
| 649 |
-
" max_steps
|
| 650 |
-
" per_device_train_batch_size
|
| 651 |
-
" gradient_accumulation_steps
|
| 652 |
-
" max_prompt_length
|
| 653 |
-
" max_completion_length
|
| 654 |
-
"
|
| 655 |
-
"
|
| 656 |
-
"
|
| 657 |
-
"
|
| 658 |
-
"
|
| 659 |
-
"
|
| 660 |
-
"
|
| 661 |
-
"
|
| 662 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 663 |
"\n",
|
| 664 |
"# Confirm the installed TRL API before constructing the trainer.\n",
|
| 665 |
"import trl\n",
|
|
@@ -671,8 +714,10 @@
|
|
| 671 |
"print(f\"Uses 'args=': {'args' in params}\")\n",
|
| 672 |
"print(f\"Uses 'config=': {'config' in params}\")\n",
|
| 673 |
"\n",
|
| 674 |
-
"
|
| 675 |
-
"
|
|
|
|
|
|
|
| 676 |
"\n",
|
| 677 |
"# Custom callback to capture loss at every step for graphing.\n",
|
| 678 |
"from transformers import TrainerCallback\n",
|
|
@@ -680,12 +725,16 @@
|
|
| 680 |
"step_losses = []\n",
|
| 681 |
"step_numbers = []\n",
|
| 682 |
"step_reward_means = []\n",
|
|
|
|
| 683 |
"\n",
|
| 684 |
"class LossCaptureCallback(TrainerCallback):\n",
|
| 685 |
" def on_log(self, args, state, control, logs=None, **kwargs):\n",
|
| 686 |
" if not logs:\n",
|
| 687 |
" return\n",
|
| 688 |
" step = state.global_step\n",
|
|
|
|
|
|
|
|
|
|
| 689 |
" loss = logs.get(\"loss\", logs.get(\"train_loss\", None))\n",
|
| 690 |
" if loss is not None:\n",
|
| 691 |
" step_losses.append(float(loss))\n",
|
|
@@ -714,7 +763,7 @@
|
|
| 714 |
"\n",
|
| 715 |
"print(\"\\nStarting GRPO training with QLoRA...\")\n",
|
| 716 |
"print(\"Watch for non-zero loss values. If all zeros, reward variance is still too low.\\n\")\n",
|
| 717 |
-
"print(f\"Steps: {grpo_config
|
| 718 |
"print(\"Estimated time: ~25-35 min on T4\\n\")\n",
|
| 719 |
"\n",
|
| 720 |
"train_result = trainer.train()\n",
|
|
@@ -732,6 +781,38 @@
|
|
| 732 |
"else:\n",
|
| 733 |
" print(f\"\\nTraining produced gradient signal on {len(non_zero_losses)} steps.\")\n",
|
| 734 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 735 |
"print(f\"\\nMemory after training: {torch.cuda.memory_allocated()/1e9:.2f} GB\")\n",
|
| 736 |
"\n",
|
| 737 |
"# Save LoRA adapter (much smaller than full model)\n",
|
|
@@ -863,9 +944,35 @@
|
|
| 863 |
"import matplotlib.pyplot as plt\n",
|
| 864 |
"import matplotlib.gridspec as gridspec\n",
|
| 865 |
"import numpy as np\n",
|
|
|
|
| 866 |
"import os\n",
|
| 867 |
"\n",
|
| 868 |
"os.makedirs(\"results\", exist_ok=True)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 869 |
"\n",
|
| 870 |
"tasks = [1, 2, 3, 4]\n",
|
| 871 |
"task_labels = [\n",
|
|
@@ -963,23 +1070,39 @@
|
|
| 963 |
"# Panel B: reward signal over time.\n",
|
| 964 |
"ax_rew = fig.add_subplot(gs[1, 0])\n",
|
| 965 |
"style_ax(ax_rew, 'GRPO Training: Reward Signal per Step')\n",
|
| 966 |
-
"if
|
| 967 |
-
"
|
| 968 |
-
"
|
| 969 |
-
"
|
|
|
|
|
|
|
|
|
|
| 970 |
" ax_rew.plot(steps_r, smooth(raw, window=6), color=C['reward'], linewidth=2.5, label='Smoothed reward')\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 971 |
" if len(steps_r) > 8:\n",
|
| 972 |
" z = np.polyfit(steps_r, raw, 1)\n",
|
| 973 |
" p = np.poly1d(z)\n",
|
| 974 |
" ax_rew.plot(steps_r, p(steps_r), '--', color='white', alpha=0.35, linewidth=1.5,\n",
|
| 975 |
" label=f'Trend ({z[0]:+.5f}/step)')\n",
|
| 976 |
-
" ax_rew.set_xlabel('
|
| 977 |
" ax_rew.set_ylabel('Reward Value')\n",
|
| 978 |
" ax_rew.legend(fontsize=9, facecolor=C['grid'], labelcolor=C['text'], framealpha=0.9, edgecolor=C['border'])\n",
|
| 979 |
" if np.var(raw) < 0.01:\n",
|
| 980 |
" ax_rew.text(0.5, 0.5, 'Low reward variance detected.\\nThis graph exposes weak learning signal.',\n",
|
| 981 |
" transform=ax_rew.transAxes, ha='center', va='center', color=C['random'], fontsize=10,\n",
|
| 982 |
" bbox=dict(boxstyle='round', facecolor=C['panel'], alpha=0.8))\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 983 |
"else:\n",
|
| 984 |
" ax_rew.text(0.5, 0.5, 'No training rewards captured.\\nRe-run with fixed reward function.',\n",
|
| 985 |
" transform=ax_rew.transAxes, ha='center', va='center', color=C['subtext'], fontsize=11)\n",
|
|
@@ -1027,6 +1150,92 @@
|
|
| 1027 |
"fig.savefig(dashboard_path, dpi=180, facecolor=fig.get_facecolor(), bbox_inches='tight')\n",
|
| 1028 |
"plt.close(fig)\n",
|
| 1029 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1030 |
"# Separate before/after comparison graph for quick judge inspection.\n",
|
| 1031 |
"fig2, ax2 = plt.subplots(figsize=(11, 6))\n",
|
| 1032 |
"fig2.patch.set_facecolor(C['bg'])\n",
|
|
@@ -1048,6 +1257,10 @@
|
|
| 1048 |
"plt.close(fig2)\n",
|
| 1049 |
"\n",
|
| 1050 |
"print(f\"Saved dashboard graph to {dashboard_path}\")\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1051 |
"print(f\"Saved before/after graph to {comparison_path}\")\n",
|
| 1052 |
"\n",
|
| 1053 |
"results = {\n",
|
|
@@ -1066,8 +1279,12 @@
|
|
| 1066 |
" \"training_rewards_log\": training_rewards[-20:] if training_rewards else [],\n",
|
| 1067 |
" \"training_step_logs\": training_steps_log[-20:] if training_steps_log else [],\n",
|
| 1068 |
" \"step_losses\": step_losses if 'step_losses' in globals() else [],\n",
|
|
|
|
| 1069 |
" \"graphs\": {\n",
|
| 1070 |
" \"dashboard\": dashboard_path,\n",
|
|
|
|
|
|
|
|
|
|
| 1071 |
" \"before_after\": comparison_path,\n",
|
| 1072 |
" },\n",
|
| 1073 |
"}\n",
|
|
|
|
| 343 |
"torch.cuda.empty_cache()\n",
|
| 344 |
"\n",
|
| 345 |
"MODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
|
| 346 |
+
"gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"CPU\"\n",
|
| 347 |
+
"gpu_total_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0\n",
|
| 348 |
+
"print(f\"Loading {MODEL_NAME} with 4-bit quantization on {gpu_name} ({gpu_total_gb:.2f} GB VRAM)...\")\n",
|
| 349 |
"\n",
|
| 350 |
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n",
|
| 351 |
"if tokenizer.pad_token is None:\n",
|
| 352 |
" tokenizer.pad_token = tokenizer.eos_token\n",
|
| 353 |
"tokenizer.padding_side = \"left\" # required for GRPO\n",
|
| 354 |
"\n",
|
| 355 |
+
"# 4-bit quantization for memory-efficient QLoRA training\n",
|
| 356 |
"bnb_config = BitsAndBytesConfig(\n",
|
| 357 |
" load_in_4bit=True,\n",
|
| 358 |
" bnb_4bit_compute_dtype=torch.float16,\n",
|
|
|
|
| 368 |
")\n",
|
| 369 |
"\n",
|
| 370 |
"print(f\"Model loaded on: {next(model.parameters()).device}\")\n",
|
| 371 |
+
"print(f\"Memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB / {gpu_total_gb:.2f} GB\")\n",
|
| 372 |
+
"print(f\"Memory reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB / {gpu_total_gb:.2f} GB\")"
|
| 373 |
]
|
| 374 |
},
|
| 375 |
{
|
|
|
|
| 419 |
" reset_payload = {\"task_id\": task_id, \"seed\": batch_seed}\n",
|
| 420 |
" reset_r = _requests.post(f\"{ENV_URL}/reset\", json=reset_payload, timeout=10)\n",
|
| 421 |
" reset_ok = reset_r.status_code == 200\n",
|
| 422 |
+
" reset_data = reset_r.json() if reset_ok else {}\n",
|
| 423 |
+
" reset_obs = reset_data.get(\"observations\", [reset_data.get(\"observation\", {})])\n",
|
| 424 |
+
" if isinstance(reset_obs, list):\n",
|
| 425 |
+
" base_obs = reset_obs[0] if reset_obs else {}\n",
|
| 426 |
+
" else:\n",
|
| 427 |
+
" base_obs = reset_obs or {}\n",
|
| 428 |
" except Exception:\n",
|
| 429 |
" reset_ok = False\n",
|
| 430 |
+
" base_obs = {}\n",
|
| 431 |
"\n",
|
| 432 |
" if not reset_ok:\n",
|
| 433 |
" return [-0.1] * len(completions)\n",
|
|
|
|
| 459 |
" cleaned_action = {}\n",
|
| 460 |
"\n",
|
| 461 |
" try:\n",
|
| 462 |
+
" if \"hvac_power_level\" not in action:\n",
|
| 463 |
+
" raise KeyError(\"hvac_power_level\")\n",
|
| 464 |
+
" cleaned_action[\"hvac_power_level\"] = max(0.0, min(1.0, float(action[\"hvac_power_level\"])))\n",
|
| 465 |
" valid_fields += 1\n",
|
| 466 |
" except Exception:\n",
|
| 467 |
" cleaned_action[\"hvac_power_level\"] = 0.5\n",
|
| 468 |
"\n",
|
| 469 |
" try:\n",
|
| 470 |
+
" if \"thermal_charge_rate\" not in action:\n",
|
| 471 |
+
" raise KeyError(\"thermal_charge_rate\")\n",
|
| 472 |
+
" cleaned_action[\"thermal_charge_rate\"] = max(-1.0, min(1.0, float(action[\"thermal_charge_rate\"])))\n",
|
| 473 |
" valid_fields += 1\n",
|
| 474 |
" except Exception:\n",
|
| 475 |
" cleaned_action[\"thermal_charge_rate\"] = 0.0\n",
|
| 476 |
"\n",
|
| 477 |
" try:\n",
|
| 478 |
+
" if \"batch_job_slot\" not in action:\n",
|
| 479 |
+
" raise KeyError(\"batch_job_slot\")\n",
|
| 480 |
+
" cleaned_action[\"batch_job_slot\"] = max(0, min(4, int(action[\"batch_job_slot\"])))\n",
|
| 481 |
" valid_fields += 1\n",
|
| 482 |
" except Exception:\n",
|
| 483 |
" cleaned_action[\"batch_job_slot\"] = 0\n",
|
| 484 |
"\n",
|
| 485 |
" try:\n",
|
| 486 |
+
" if \"load_shed_fraction\" not in action:\n",
|
| 487 |
+
" raise KeyError(\"load_shed_fraction\")\n",
|
| 488 |
+
" cleaned_action[\"load_shed_fraction\"] = max(0.0, min(0.5, float(action[\"load_shed_fraction\"])))\n",
|
| 489 |
" valid_fields += 1\n",
|
| 490 |
" except Exception:\n",
|
| 491 |
" cleaned_action[\"load_shed_fraction\"] = 0.0\n",
|
|
|
|
| 513 |
" grid_r = float(comps.get(\"grid_response\", 0.0))\n",
|
| 514 |
" task_r = float(comps.get(\"task_satisfaction\", 0.0))\n",
|
| 515 |
"\n",
|
| 516 |
+
" price = float(base_obs.get(\"current_price\", base_obs.get(\"price\", 0.10)))\n",
|
| 517 |
+
" stress = float(base_obs.get(\"grid_stress_signal\", base_obs.get(\"grid_stress\", 0.0)))\n",
|
| 518 |
+
" temp = float(base_obs.get(\"indoor_temperature\", 21.0))\n",
|
| 519 |
+
" charge = cleaned_action[\"thermal_charge_rate\"]\n",
|
| 520 |
+
" hvac = cleaned_action[\"hvac_power_level\"]\n",
|
| 521 |
+
" shed = cleaned_action[\"load_shed_fraction\"]\n",
|
| 522 |
+
"\n",
|
| 523 |
+
" price_signal = 0.0\n",
|
| 524 |
+
" if price < 0.08:\n",
|
| 525 |
+
" price_signal += 0.08 * charge\n",
|
| 526 |
+
" elif price > 0.15:\n",
|
| 527 |
+
" price_signal += 0.08 * (-charge)\n",
|
| 528 |
+
" price_signal -= 0.03 * abs(charge)\n",
|
| 529 |
+
"\n",
|
| 530 |
+
" stress_signal = 0.12 * shed if stress > 0.65 else -0.04 * shed\n",
|
| 531 |
+
" comfort_signal = -0.04 * abs(temp - 21.0) * abs(hvac - 0.5)\n",
|
| 532 |
+
" action_signal = price_signal + stress_signal + comfort_signal\n",
|
| 533 |
+
"\n",
|
| 534 |
" if comps:\n",
|
| 535 |
+
" component_signal = 0.04 * cost_r + 0.03 * comfort_r + 0.03 * grid_r + 0.03 * task_r\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
" else:\n",
|
| 537 |
+
" component_signal = 0.0\n",
|
| 538 |
"\n",
|
| 539 |
+
" # Center raw env reward to avoid saturating all valid JSON at the clip boundary.\n",
|
| 540 |
+
" composite = (env_reward - 0.5) * 0.35 + component_signal + action_signal + completeness_bonus\n",
|
| 541 |
+
" composite = max(-0.45, min(0.45, composite))\n",
|
| 542 |
"\n",
|
| 543 |
" rewards.append(composite)\n",
|
| 544 |
" batch_raw.append(composite)\n",
|
|
|
|
| 672 |
" task_type=\"CAUSAL_LM\",\n",
|
| 673 |
")\n",
|
| 674 |
"\n",
|
| 675 |
+
"# GRPOConfig compatibility shim. HF/Colab images can have TRL builds whose\n",
|
| 676 |
+
"# GRPOConfig fields differ, so only pass arguments accepted by this runtime.\n",
|
| 677 |
+
"grpo_config_requested = {\n",
|
| 678 |
+
" \"output_dir\": \"./gridmind-grpo-output\",\n",
|
| 679 |
+
" \"num_train_epochs\": 1,\n",
|
| 680 |
+
" \"max_steps\": 60,\n",
|
| 681 |
+
" \"per_device_train_batch_size\": 1,\n",
|
| 682 |
+
" \"gradient_accumulation_steps\": 4,\n",
|
| 683 |
+
" \"max_prompt_length\": 400,\n",
|
| 684 |
+
" \"max_completion_length\": 80,\n",
|
| 685 |
+
" \"max_new_tokens\": 80,\n",
|
| 686 |
+
" \"num_generations\": 4,\n",
|
| 687 |
+
" \"learning_rate\": 5e-5,\n",
|
| 688 |
+
" \"fp16\": True,\n",
|
| 689 |
+
" \"logging_steps\": 1,\n",
|
| 690 |
+
" \"save_steps\": 60,\n",
|
| 691 |
+
" \"report_to\": \"none\",\n",
|
| 692 |
+
" \"dataloader_num_workers\": 0,\n",
|
| 693 |
+
" \"remove_unused_columns\": False,\n",
|
| 694 |
+
"}\n",
|
| 695 |
+
"\n",
|
| 696 |
+
"grpo_config_sig = inspect.signature(GRPOConfig.__init__)\n",
|
| 697 |
+
"grpo_config_params = set(grpo_config_sig.parameters.keys()) - {\"self\"}\n",
|
| 698 |
+
"grpo_config_kwargs = {k: v for k, v in grpo_config_requested.items() if k in grpo_config_params}\n",
|
| 699 |
+
"if \"max_completion_length\" in grpo_config_kwargs and \"max_new_tokens\" in grpo_config_kwargs:\n",
|
| 700 |
+
" grpo_config_kwargs.pop(\"max_new_tokens\")\n",
|
| 701 |
+
"skipped_config_keys = [k for k in grpo_config_requested if k not in grpo_config_params]\n",
|
| 702 |
+
"print(f\"GRPOConfig accepted keys: {sorted(grpo_config_kwargs.keys())}\")\n",
|
| 703 |
+
"print(f\"GRPOConfig skipped unsupported keys: {skipped_config_keys}\")\n",
|
| 704 |
+
"\n",
|
| 705 |
+
"grpo_config = GRPOConfig(**grpo_config_kwargs)\n",
|
| 706 |
"\n",
|
| 707 |
"# Confirm the installed TRL API before constructing the trainer.\n",
|
| 708 |
"import trl\n",
|
|
|
|
| 714 |
"print(f\"Uses 'args=': {'args' in params}\")\n",
|
| 715 |
"print(f\"Uses 'config=': {'config' in params}\")\n",
|
| 716 |
"\n",
|
| 717 |
+
"gpu_total_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0\n",
|
| 718 |
+
"gpu_used_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0\n",
|
| 719 |
+
"print(f\"\\nGPU memory: {gpu_used_gb:.2f} GB used / {gpu_total_gb:.2f} GB total\")\n",
|
| 720 |
+
"print(f\"Free: {max(0, gpu_total_gb - gpu_used_gb):.2f} GB\")\n",
|
| 721 |
"\n",
|
| 722 |
"# Custom callback to capture loss at every step for graphing.\n",
|
| 723 |
"from transformers import TrainerCallback\n",
|
|
|
|
| 725 |
"step_losses = []\n",
|
| 726 |
"step_numbers = []\n",
|
| 727 |
"step_reward_means = []\n",
|
| 728 |
+
"training_log_history = []\n",
|
| 729 |
"\n",
|
| 730 |
"class LossCaptureCallback(TrainerCallback):\n",
|
| 731 |
" def on_log(self, args, state, control, logs=None, **kwargs):\n",
|
| 732 |
" if not logs:\n",
|
| 733 |
" return\n",
|
| 734 |
" step = state.global_step\n",
|
| 735 |
+
" row = {\"step\": step}\n",
|
| 736 |
+
" row.update({k: float(v) if isinstance(v, (int, float)) else v for k, v in logs.items()})\n",
|
| 737 |
+
" training_log_history.append(row)\n",
|
| 738 |
" loss = logs.get(\"loss\", logs.get(\"train_loss\", None))\n",
|
| 739 |
" if loss is not None:\n",
|
| 740 |
" step_losses.append(float(loss))\n",
|
|
|
|
| 763 |
"\n",
|
| 764 |
"print(\"\\nStarting GRPO training with QLoRA...\")\n",
|
| 765 |
"print(\"Watch for non-zero loss values. If all zeros, reward variance is still too low.\\n\")\n",
|
| 766 |
+
"print(f\"Steps: {getattr(grpo_config, 'max_steps', 60)} | Batch: {getattr(grpo_config, 'per_device_train_batch_size', 1)} | Generations: {getattr(grpo_config, 'num_generations', 4)}\")\n",
|
| 767 |
"print(\"Estimated time: ~25-35 min on T4\\n\")\n",
|
| 768 |
"\n",
|
| 769 |
"train_result = trainer.train()\n",
|
|
|
|
| 781 |
"else:\n",
|
| 782 |
" print(f\"\\nTraining produced gradient signal on {len(non_zero_losses)} steps.\")\n",
|
| 783 |
"\n",
|
| 784 |
+
"# Preserve the exact tabular statistics that TRL prints during training.\n",
|
| 785 |
+
"try:\n",
|
| 786 |
+
" import pandas as pd\n",
|
| 787 |
+
" trainer_log_rows = [r for r in trainer.state.log_history if \"loss\" in r or \"reward\" in r or \"rewards / reward_func / mean\" in r]\n",
|
| 788 |
+
" if trainer_log_rows:\n",
|
| 789 |
+
" training_metrics_df = pd.DataFrame(trainer_log_rows)\n",
|
| 790 |
+
" if \"step\" not in training_metrics_df.columns:\n",
|
| 791 |
+
" training_metrics_df.insert(0, \"step\", range(1, len(training_metrics_df) + 1))\n",
|
| 792 |
+
" elif training_log_history:\n",
|
| 793 |
+
" training_metrics_df = pd.DataFrame(training_log_history)\n",
|
| 794 |
+
" else:\n",
|
| 795 |
+
" training_metrics_df = pd.DataFrame({\"step\": step_numbers, \"loss\": step_losses, \"reward\": step_reward_means[:len(step_numbers)]})\n",
|
| 796 |
+
"\n",
|
| 797 |
+
" os.makedirs(\"results\", exist_ok=True)\n",
|
| 798 |
+
" training_metrics_path = \"results/gridmind_training_metrics.csv\"\n",
|
| 799 |
+
" training_metrics_df.to_csv(training_metrics_path, index=False)\n",
|
| 800 |
+
" print(f\"\\nSaved TRL training metrics table to {training_metrics_path}\")\n",
|
| 801 |
+
"\n",
|
| 802 |
+
" preferred_cols = [\n",
|
| 803 |
+
" \"step\", \"loss\", \"reward\", \"reward_std\",\n",
|
| 804 |
+
" \"completions / mean_length\", \"completions / min_length\", \"completions / max_length\",\n",
|
| 805 |
+
" \"completions / clipped_ratio\", \"kl\", \"rewards / reward_func / mean\", \"rewards / reward_func / std\",\n",
|
| 806 |
+
" ]\n",
|
| 807 |
+
" display_cols = [c for c in preferred_cols if c in training_metrics_df.columns]\n",
|
| 808 |
+
" if display_cols:\n",
|
| 809 |
+
" print(\"\\nTraining metrics preview:\")\n",
|
| 810 |
+
" display(training_metrics_df[display_cols].tail(10))\n",
|
| 811 |
+
"except Exception as e:\n",
|
| 812 |
+
" training_metrics_df = None\n",
|
| 813 |
+
" training_metrics_path = None\n",
|
| 814 |
+
" print(f\"Could not build training metrics table: {e}\")\n",
|
| 815 |
+
"\n",
|
| 816 |
"print(f\"\\nMemory after training: {torch.cuda.memory_allocated()/1e9:.2f} GB\")\n",
|
| 817 |
"\n",
|
| 818 |
"# Save LoRA adapter (much smaller than full model)\n",
|
|
|
|
| 944 |
"import matplotlib.pyplot as plt\n",
|
| 945 |
"import matplotlib.gridspec as gridspec\n",
|
| 946 |
"import numpy as np\n",
|
| 947 |
+
"import pandas as pd\n",
|
| 948 |
"import os\n",
|
| 949 |
"\n",
|
| 950 |
"os.makedirs(\"results\", exist_ok=True)\n",
|
| 951 |
+
"os.makedirs(\"plots\", exist_ok=True)\n",
|
| 952 |
+
"\n",
|
| 953 |
+
"# Build a TRL-style metrics table from trainer logs. This matches the tabular\n",
|
| 954 |
+
"# output with columns like reward, reward_std, completion lengths, tools, and KL.\n",
|
| 955 |
+
"if 'training_metrics_df' not in globals() or training_metrics_df is None:\n",
|
| 956 |
+
" trainer_log_rows = [r for r in trainer.state.log_history if \"loss\" in r or \"reward\" in r or \"rewards / reward_func / mean\" in r]\n",
|
| 957 |
+
" training_metrics_df = pd.DataFrame(trainer_log_rows if trainer_log_rows else training_log_history)\n",
|
| 958 |
+
" if not training_metrics_df.empty and \"step\" not in training_metrics_df.columns:\n",
|
| 959 |
+
" training_metrics_df.insert(0, \"step\", range(1, len(training_metrics_df) + 1))\n",
|
| 960 |
+
"\n",
|
| 961 |
+
"training_metrics_path = \"results/gridmind_training_metrics.csv\"\n",
|
| 962 |
+
"if not training_metrics_df.empty:\n",
|
| 963 |
+
" training_metrics_df.to_csv(training_metrics_path, index=False)\n",
|
| 964 |
+
" print(f\"Saved TRL metrics table to {training_metrics_path}\")\n",
|
| 965 |
+
" preferred_cols = [\n",
|
| 966 |
+
" \"step\", \"loss\", \"reward\", \"reward_std\",\n",
|
| 967 |
+
" \"completions / mean_length\", \"completions / min_length\", \"completions / max_length\",\n",
|
| 968 |
+
" \"completions / clipped_ratio\", \"completions / mean_terminated_length\",\n",
|
| 969 |
+
" \"completions / min_terminated_length\", \"completions / max_terminated_length\",\n",
|
| 970 |
+
" \"tools / call_frequency\", \"tools / failure_frequency\", \"kl\",\n",
|
| 971 |
+
" \"rewards / reward_func / mean\", \"rewards / reward_func / std\",\n",
|
| 972 |
+
" ]\n",
|
| 973 |
+
" display_cols = [c for c in preferred_cols if c in training_metrics_df.columns]\n",
|
| 974 |
+
" if display_cols:\n",
|
| 975 |
+
" display(training_metrics_df[display_cols].tail(20))\n",
|
| 976 |
"\n",
|
| 977 |
"tasks = [1, 2, 3, 4]\n",
|
| 978 |
"task_labels = [\n",
|
|
|
|
| 1070 |
"# Panel B: reward signal over time.\n",
|
| 1071 |
"ax_rew = fig.add_subplot(gs[1, 0])\n",
|
| 1072 |
"style_ax(ax_rew, 'GRPO Training: Reward Signal per Step')\n",
|
| 1073 |
+
"if not training_metrics_df.empty and (\"reward\" in training_metrics_df.columns or \"rewards / reward_func / mean\" in training_metrics_df.columns):\n",
|
| 1074 |
+
" reward_col = \"reward\" if \"reward\" in training_metrics_df.columns else \"rewards / reward_func / mean\"\n",
|
| 1075 |
+
" std_col = \"reward_std\" if \"reward_std\" in training_metrics_df.columns else \"rewards / reward_func / std\"\n",
|
| 1076 |
+
" reward_df = training_metrics_df[[\"step\", reward_col] + ([std_col] if std_col in training_metrics_df.columns else [])].dropna(subset=[reward_col])\n",
|
| 1077 |
+
" steps_r = reward_df[\"step\"].astype(float).tolist()\n",
|
| 1078 |
+
" raw = reward_df[reward_col].astype(float).tolist()\n",
|
| 1079 |
+
" ax_rew.plot(steps_r, raw, alpha=0.28, color=C['reward'], linewidth=1.2, marker='o', markersize=2, label='Logged reward')\n",
|
| 1080 |
" ax_rew.plot(steps_r, smooth(raw, window=6), color=C['reward'], linewidth=2.5, label='Smoothed reward')\n",
|
| 1081 |
+
" if std_col in reward_df.columns:\n",
|
| 1082 |
+
" std = reward_df[std_col].fillna(0).astype(float).to_numpy()\n",
|
| 1083 |
+
" raw_np = np.array(raw)\n",
|
| 1084 |
+
" steps_np = np.array(steps_r)\n",
|
| 1085 |
+
" ax_rew.fill_between(steps_np, raw_np - std, raw_np + std, color=C['reward'], alpha=0.12, label='Reward std')\n",
|
| 1086 |
" if len(steps_r) > 8:\n",
|
| 1087 |
" z = np.polyfit(steps_r, raw, 1)\n",
|
| 1088 |
" p = np.poly1d(z)\n",
|
| 1089 |
" ax_rew.plot(steps_r, p(steps_r), '--', color='white', alpha=0.35, linewidth=1.5,\n",
|
| 1090 |
" label=f'Trend ({z[0]:+.5f}/step)')\n",
|
| 1091 |
+
" ax_rew.set_xlabel('Training Step')\n",
|
| 1092 |
" ax_rew.set_ylabel('Reward Value')\n",
|
| 1093 |
" ax_rew.legend(fontsize=9, facecolor=C['grid'], labelcolor=C['text'], framealpha=0.9, edgecolor=C['border'])\n",
|
| 1094 |
" if np.var(raw) < 0.01:\n",
|
| 1095 |
" ax_rew.text(0.5, 0.5, 'Low reward variance detected.\\nThis graph exposes weak learning signal.',\n",
|
| 1096 |
" transform=ax_rew.transAxes, ha='center', va='center', color=C['random'], fontsize=10,\n",
|
| 1097 |
" bbox=dict(boxstyle='round', facecolor=C['panel'], alpha=0.8))\n",
|
| 1098 |
+
"elif training_rewards and len(training_rewards) >= 4:\n",
|
| 1099 |
+
" raw = training_rewards\n",
|
| 1100 |
+
" steps_r = list(range(1, len(raw) + 1))\n",
|
| 1101 |
+
" ax_rew.plot(steps_r, raw, alpha=0.20, color=C['reward'], linewidth=1)\n",
|
| 1102 |
+
" ax_rew.plot(steps_r, smooth(raw, window=6), color=C['reward'], linewidth=2.5, label='Smoothed reward')\n",
|
| 1103 |
+
" ax_rew.set_xlabel('Reward Function Call')\n",
|
| 1104 |
+
" ax_rew.set_ylabel('Reward Value')\n",
|
| 1105 |
+
" ax_rew.legend(fontsize=9, facecolor=C['grid'], labelcolor=C['text'], framealpha=0.9, edgecolor=C['border'])\n",
|
| 1106 |
"else:\n",
|
| 1107 |
" ax_rew.text(0.5, 0.5, 'No training rewards captured.\\nRe-run with fixed reward function.',\n",
|
| 1108 |
" transform=ax_rew.transAxes, ha='center', va='center', color=C['subtext'], fontsize=11)\n",
|
|
|
|
| 1150 |
"fig.savefig(dashboard_path, dpi=180, facecolor=fig.get_facecolor(), bbox_inches='tight')\n",
|
| 1151 |
"plt.close(fig)\n",
|
| 1152 |
"\n",
|
| 1153 |
+
"# Standalone training reward curve for reports/slides.\n",
|
| 1154 |
+
"reward_curve_path = 'results/gridmind_training_reward_curve.png'\n",
|
| 1155 |
+
"fig_reward, ax_reward = plt.subplots(figsize=(11, 6))\n",
|
| 1156 |
+
"fig_reward.patch.set_facecolor(C['bg'])\n",
|
| 1157 |
+
"style_ax(ax_reward, 'Training Reward Curve')\n",
|
| 1158 |
+
"if not training_metrics_df.empty and (\"reward\" in training_metrics_df.columns or \"rewards / reward_func / mean\" in training_metrics_df.columns):\n",
|
| 1159 |
+
" reward_col = \"reward\" if \"reward\" in training_metrics_df.columns else \"rewards / reward_func / mean\"\n",
|
| 1160 |
+
" std_col = \"reward_std\" if \"reward_std\" in training_metrics_df.columns else \"rewards / reward_func / std\"\n",
|
| 1161 |
+
" reward_df = training_metrics_df[[\"step\", reward_col] + ([std_col] if std_col in training_metrics_df.columns else [])].dropna(subset=[reward_col])\n",
|
| 1162 |
+
" xs = reward_df[\"step\"].astype(float).to_numpy()\n",
|
| 1163 |
+
" ys = reward_df[reward_col].astype(float).to_numpy()\n",
|
| 1164 |
+
" ax_reward.plot(xs, ys, color=C['reward'], alpha=0.35, linewidth=1.2, marker='o', markersize=2, label='Reward')\n",
|
| 1165 |
+
" ax_reward.plot(xs, smooth(ys.tolist(), window=6), color=C['trained'], linewidth=2.5, label='Smoothed reward')\n",
|
| 1166 |
+
" if std_col in reward_df.columns:\n",
|
| 1167 |
+
" std = reward_df[std_col].fillna(0).astype(float).to_numpy()\n",
|
| 1168 |
+
" ax_reward.fill_between(xs, ys - std, ys + std, color=C['reward'], alpha=0.12, label='Reward std')\n",
|
| 1169 |
+
" if len(xs) > 8:\n",
|
| 1170 |
+
" z = np.polyfit(xs, ys, 1)\n",
|
| 1171 |
+
" p = np.poly1d(z)\n",
|
| 1172 |
+
" ax_reward.plot(xs, p(xs), '--', color=C['text'], alpha=0.45, linewidth=1.5, label=f'Trend ({z[0]:+.5f}/step)')\n",
|
| 1173 |
+
" ax_reward.set_xlabel('Training Step')\n",
|
| 1174 |
+
" ax_reward.set_ylabel('Reward')\n",
|
| 1175 |
+
" ax_reward.legend(facecolor=C['grid'], labelcolor=C['text'], edgecolor=C['border'])\n",
|
| 1176 |
+
"else:\n",
|
| 1177 |
+
" ax_reward.text(0.5, 0.5, 'No logged reward data available.', transform=ax_reward.transAxes,\n",
|
| 1178 |
+
" ha='center', va='center', color=C['subtext'])\n",
|
| 1179 |
+
"fig_reward.savefig(reward_curve_path, dpi=180, facecolor=fig_reward.get_facecolor(), bbox_inches='tight')\n",
|
| 1180 |
+
"plt.close(fig_reward)\n",
|
| 1181 |
+
"\n",
|
| 1182 |
+
"# Reference-style simple plots from trainer.state.log_history.\n",
|
| 1183 |
+
"log_history = trainer.state.log_history\n",
|
| 1184 |
+
"simple_steps = []\n",
|
| 1185 |
+
"simple_rewards = []\n",
|
| 1186 |
+
"simple_losses = []\n",
|
| 1187 |
+
"simple_loss_steps = []\n",
|
| 1188 |
+
"\n",
|
| 1189 |
+
"for entry in log_history:\n",
|
| 1190 |
+
" reward_key = \"reward\" if \"reward\" in entry else (\"rewards / reward_func / mean\" if \"rewards / reward_func / mean\" in entry else None)\n",
|
| 1191 |
+
" if reward_key is not None:\n",
|
| 1192 |
+
" simple_steps.append(entry.get(\"step\", len(simple_steps) + 1))\n",
|
| 1193 |
+
" simple_rewards.append(float(entry[reward_key]))\n",
|
| 1194 |
+
" if \"loss\" in entry:\n",
|
| 1195 |
+
" simple_loss_steps.append(entry.get(\"step\", len(simple_loss_steps) + 1))\n",
|
| 1196 |
+
" simple_losses.append(float(entry[\"loss\"]))\n",
|
| 1197 |
+
"\n",
|
| 1198 |
+
"# Plot 1: Reward over training\n",
|
| 1199 |
+
"simple_reward_curve_path = \"plots/reward_curve.png\"\n",
|
| 1200 |
+
"fig_simple_reward, ax_simple_reward = plt.subplots(1, 1, figsize=(10, 5))\n",
|
| 1201 |
+
"if simple_rewards:\n",
|
| 1202 |
+
" ax_simple_reward.plot(simple_steps[:len(simple_rewards)], simple_rewards, color=\"#4285f4\", linewidth=2, label=\"GRPO Reward\")\n",
|
| 1203 |
+
" if len(simple_rewards) > 5:\n",
|
| 1204 |
+
" window = max(3, len(simple_rewards) // 10)\n",
|
| 1205 |
+
" smoothed = [\n",
|
| 1206 |
+
" sum(simple_rewards[max(0, i-window):i+1]) / len(simple_rewards[max(0, i-window):i+1])\n",
|
| 1207 |
+
" for i in range(len(simple_rewards))\n",
|
| 1208 |
+
" ]\n",
|
| 1209 |
+
" ax_simple_reward.plot(simple_steps[:len(smoothed)], smoothed, color=\"#ea4335\", linewidth=2, linestyle=\"--\", label=f\"Smoothed (window={window})\")\n",
|
| 1210 |
+
"else:\n",
|
| 1211 |
+
" ax_simple_reward.text(0.5, 0.5, \"No reward logs found\", transform=ax_simple_reward.transAxes, ha=\"center\", va=\"center\")\n",
|
| 1212 |
+
"ax_simple_reward.set_xlabel(\"Training Step\", fontsize=12)\n",
|
| 1213 |
+
"ax_simple_reward.set_ylabel(\"Reward\", fontsize=12)\n",
|
| 1214 |
+
"ax_simple_reward.set_title(\"GridMind-RL GRPO Training - Reward Curve\", fontsize=14, fontweight=\"bold\")\n",
|
| 1215 |
+
"ax_simple_reward.legend()\n",
|
| 1216 |
+
"ax_simple_reward.grid(True, alpha=0.3)\n",
|
| 1217 |
+
"fig_simple_reward.tight_layout()\n",
|
| 1218 |
+
"fig_simple_reward.savefig(simple_reward_curve_path, dpi=150)\n",
|
| 1219 |
+
"plt.show()\n",
|
| 1220 |
+
"print(f\"Saved: {simple_reward_curve_path}\")\n",
|
| 1221 |
+
"\n",
|
| 1222 |
+
"# Plot 2: Loss over training\n",
|
| 1223 |
+
"simple_loss_curve_path = \"plots/loss_curve.png\"\n",
|
| 1224 |
+
"if simple_losses:\n",
|
| 1225 |
+
" fig_simple_loss, ax_simple_loss = plt.subplots(1, 1, figsize=(10, 5))\n",
|
| 1226 |
+
" ax_simple_loss.plot(simple_loss_steps[:len(simple_losses)], simple_losses, color=\"#34a853\", linewidth=2)\n",
|
| 1227 |
+
" ax_simple_loss.set_xlabel(\"Training Step\", fontsize=12)\n",
|
| 1228 |
+
" ax_simple_loss.set_ylabel(\"Loss\", fontsize=12)\n",
|
| 1229 |
+
" ax_simple_loss.set_title(\"GridMind-RL GRPO Training - Loss Curve\", fontsize=14, fontweight=\"bold\")\n",
|
| 1230 |
+
" ax_simple_loss.grid(True, alpha=0.3)\n",
|
| 1231 |
+
" fig_simple_loss.tight_layout()\n",
|
| 1232 |
+
" fig_simple_loss.savefig(simple_loss_curve_path, dpi=150)\n",
|
| 1233 |
+
" plt.show()\n",
|
| 1234 |
+
" print(f\"Saved: {simple_loss_curve_path}\")\n",
|
| 1235 |
+
"else:\n",
|
| 1236 |
+
" simple_loss_curve_path = None\n",
|
| 1237 |
+
" print(\"No loss logs found; skipped plots/loss_curve.png\")\n",
|
| 1238 |
+
"\n",
|
| 1239 |
"# Separate before/after comparison graph for quick judge inspection.\n",
|
| 1240 |
"fig2, ax2 = plt.subplots(figsize=(11, 6))\n",
|
| 1241 |
"fig2.patch.set_facecolor(C['bg'])\n",
|
|
|
|
| 1257 |
"plt.close(fig2)\n",
|
| 1258 |
"\n",
|
| 1259 |
"print(f\"Saved dashboard graph to {dashboard_path}\")\n",
|
| 1260 |
+
"print(f\"Saved training reward curve to {reward_curve_path}\")\n",
|
| 1261 |
+
"print(f\"Saved simple reward curve to {simple_reward_curve_path}\")\n",
|
| 1262 |
+
"if simple_loss_curve_path:\n",
|
| 1263 |
+
" print(f\"Saved simple loss curve to {simple_loss_curve_path}\")\n",
|
| 1264 |
"print(f\"Saved before/after graph to {comparison_path}\")\n",
|
| 1265 |
"\n",
|
| 1266 |
"results = {\n",
|
|
|
|
| 1279 |
" \"training_rewards_log\": training_rewards[-20:] if training_rewards else [],\n",
|
| 1280 |
" \"training_step_logs\": training_steps_log[-20:] if training_steps_log else [],\n",
|
| 1281 |
" \"step_losses\": step_losses if 'step_losses' in globals() else [],\n",
|
| 1282 |
+
" \"training_metrics_table\": training_metrics_path,\n",
|
| 1283 |
" \"graphs\": {\n",
|
| 1284 |
" \"dashboard\": dashboard_path,\n",
|
| 1285 |
+
" \"training_reward_curve\": reward_curve_path,\n",
|
| 1286 |
+
" \"simple_reward_curve\": simple_reward_curve_path,\n",
|
| 1287 |
+
" \"simple_loss_curve\": simple_loss_curve_path,\n",
|
| 1288 |
" \"before_after\": comparison_path,\n",
|
| 1289 |
" },\n",
|
| 1290 |
"}\n",
|