Spaces:
Sleeping
Sleeping
Commit ·
3d49e8a
1
Parent(s): 08731ee
feat: add GridMind GRPO training environment and Unsloth training script
Browse files- scripts/gridmind_grpo_colab.ipynb +218 -66
- scripts/train_unsloth.py +2 -2
scripts/gridmind_grpo_colab.ipynb
CHANGED
|
@@ -332,24 +332,42 @@
|
|
| 332 |
"outputs": [],
|
| 333 |
"source": [
|
| 334 |
"import torch\n",
|
| 335 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
"\n",
|
| 337 |
"MODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
|
| 338 |
-
"print(f\"Loading {MODEL_NAME}...\")\n",
|
| 339 |
"\n",
|
| 340 |
-
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 341 |
"if tokenizer.pad_token is None:\n",
|
| 342 |
" tokenizer.pad_token = tokenizer.eos_token\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
"\n",
|
| 344 |
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 345 |
" MODEL_NAME,\n",
|
| 346 |
-
"
|
| 347 |
-
" device_map=\"
|
|
|
|
| 348 |
")\n",
|
| 349 |
"\n",
|
| 350 |
-
"
|
| 351 |
-
"print(f\"
|
| 352 |
-
"print(f\"
|
| 353 |
]
|
| 354 |
},
|
| 355 |
{
|
|
@@ -368,53 +386,103 @@
|
|
| 368 |
"outputs": [],
|
| 369 |
"source": [
|
| 370 |
"import json as _json\n",
|
|
|
|
|
|
|
|
|
|
| 371 |
"\n",
|
| 372 |
"training_rewards = []\n",
|
| 373 |
-
"\n",
|
| 374 |
-
"
|
| 375 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
" rewards = []\n",
|
| 377 |
-
" \n",
|
|
|
|
| 378 |
" for completion in completions:\n",
|
|
|
|
|
|
|
| 379 |
" try:\n",
|
| 380 |
-
" #
|
| 381 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
" start = text.rfind('{')\n",
|
| 383 |
" end = text.rfind('}') + 1\n",
|
| 384 |
" if start < 0 or end <= start:\n",
|
| 385 |
" rewards.append(-1.0)\n",
|
|
|
|
| 386 |
" continue\n",
|
| 387 |
-
"
|
| 388 |
-
"
|
| 389 |
-
" action =
|
| 390 |
-
"
|
| 391 |
-
"
|
| 392 |
-
"
|
| 393 |
-
"
|
| 394 |
-
"
|
| 395 |
-
"
|
| 396 |
-
"
|
| 397 |
-
" \n",
|
| 398 |
-
"
|
| 399 |
-
" r = requests.post(f\"{ENV_URL}/step\", json=action, timeout=8)\n",
|
| 400 |
-
" if r.status_code != 200:\n",
|
| 401 |
" rewards.append(-0.5)\n",
|
|
|
|
| 402 |
" continue\n",
|
| 403 |
-
"
|
| 404 |
-
"
|
| 405 |
-
" if isinstance(
|
| 406 |
-
"
|
| 407 |
-
"
|
| 408 |
-
"
|
| 409 |
-
"
|
| 410 |
-
"
|
| 411 |
-
"
|
| 412 |
-
"
|
| 413 |
-
"
|
| 414 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
" return rewards\n",
|
| 416 |
"\n",
|
| 417 |
-
"print(\"Reward function defined.\")"
|
| 418 |
]
|
| 419 |
},
|
| 420 |
{
|
|
@@ -433,49 +501,133 @@
|
|
| 433 |
"outputs": [],
|
| 434 |
"source": [
|
| 435 |
"from trl import GRPOTrainer, GRPOConfig\n",
|
|
|
|
| 436 |
"from datasets import Dataset\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
"\n",
|
| 438 |
"# Prepare dataset\n",
|
| 439 |
"train_data = [{\"prompt\": d[\"prompt\"]} for d in dataset]\n",
|
| 440 |
"train_ds = Dataset.from_list(train_data)\n",
|
| 441 |
-
"\n",
|
| 442 |
"print(f\"Training dataset: {len(train_ds)} prompts\")\n",
|
| 443 |
-
"print(f\"Sample prompt:\\n{train_data[0]['prompt'][:200]}...\\n\")\n",
|
| 444 |
"\n",
|
| 445 |
-
"
|
| 446 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
" output_dir=\"./gridmind-grpo-output\",\n",
|
| 448 |
" num_train_epochs=1,\n",
|
| 449 |
-
" max_steps=60,
|
| 450 |
-
" per_device_train_batch_size=
|
| 451 |
-
" gradient_accumulation_steps=
|
| 452 |
-
" max_prompt_length=
|
| 453 |
-
"
|
| 454 |
-
"
|
| 455 |
-
"
|
| 456 |
" fp16=True,\n",
|
| 457 |
-
"
|
|
|
|
| 458 |
" report_to=\"none\",\n",
|
| 459 |
-
"
|
|
|
|
| 460 |
")\n",
|
| 461 |
"\n",
|
| 462 |
-
"print(\"
|
| 463 |
-
"
|
| 464 |
-
"print(f\"
|
| 465 |
-
"\n",
|
| 466 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 467 |
"trainer = GRPOTrainer(\n",
|
| 468 |
" model=model,\n",
|
|
|
|
| 469 |
" processing_class=tokenizer,\n",
|
| 470 |
-
" config=config,\n",
|
| 471 |
" train_dataset=train_ds,\n",
|
| 472 |
" reward_funcs=gridmind_reward_fn,\n",
|
| 473 |
-
"
|
| 474 |
")\n",
|
| 475 |
"\n",
|
| 476 |
-
"
|
| 477 |
-
"
|
| 478 |
-
"print(\"\\n\
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 479 |
]
|
| 480 |
},
|
| 481 |
{
|
|
@@ -598,7 +750,7 @@
|
|
| 598 |
" },\n",
|
| 599 |
" \"improvement_percent\": overall_improvement,\n",
|
| 600 |
" \"model\": MODEL_NAME,\n",
|
| 601 |
-
" \"training_steps\":
|
| 602 |
" \"themes_covered\": [\"multi_agent\", \"instruction_following\", \"world_modeling\", \"curriculum\"],\n",
|
| 603 |
" \"training_rewards_log\": training_rewards[-20:] if training_rewards else [],\n",
|
| 604 |
"}\n",
|
|
@@ -624,4 +776,4 @@
|
|
| 624 |
},
|
| 625 |
"nbformat": 4,
|
| 626 |
"nbformat_minor": 5
|
| 627 |
-
}
|
|
|
|
| 332 |
"outputs": [],
|
| 333 |
"source": [
|
| 334 |
"import torch\n",
|
| 335 |
+
"import gc\n",
|
| 336 |
+
"from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
|
| 337 |
+
"\n",
|
| 338 |
+
"# Clear any previous model from memory\n",
|
| 339 |
+
"for var in ['model', 'trainer']:\n",
|
| 340 |
+
" if var in dir():\n",
|
| 341 |
+
" del var\n",
|
| 342 |
+
"gc.collect()\n",
|
| 343 |
+
"torch.cuda.empty_cache()\n",
|
| 344 |
"\n",
|
| 345 |
"MODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
|
| 346 |
+
"print(f\"Loading {MODEL_NAME} with 4-bit quantization for T4 16GB...\")\n",
|
| 347 |
"\n",
|
| 348 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n",
|
| 349 |
"if tokenizer.pad_token is None:\n",
|
| 350 |
" tokenizer.pad_token = tokenizer.eos_token\n",
|
| 351 |
+
"tokenizer.padding_side = \"left\" # required for GRPO\n",
|
| 352 |
+
"\n",
|
| 353 |
+
"# 4-bit quantization - fits safely on T4 16GB\n",
|
| 354 |
+
"bnb_config = BitsAndBytesConfig(\n",
|
| 355 |
+
" load_in_4bit=True,\n",
|
| 356 |
+
" bnb_4bit_compute_dtype=torch.float16,\n",
|
| 357 |
+
" bnb_4bit_quant_type=\"nf4\",\n",
|
| 358 |
+
" bnb_4bit_use_double_quant=True,\n",
|
| 359 |
+
")\n",
|
| 360 |
"\n",
|
| 361 |
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 362 |
" MODEL_NAME,\n",
|
| 363 |
+
" quantization_config=bnb_config,\n",
|
| 364 |
+
" device_map=\"auto\",\n",
|
| 365 |
+
" trust_remote_code=True,\n",
|
| 366 |
")\n",
|
| 367 |
"\n",
|
| 368 |
+
"print(f\"Model loaded on: {next(model.parameters()).device}\")\n",
|
| 369 |
+
"print(f\"Memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB / 16 GB\")\n",
|
| 370 |
+
"print(f\"Memory reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB / 16 GB\")"
|
| 371 |
]
|
| 372 |
},
|
| 373 |
{
|
|
|
|
| 386 |
"outputs": [],
|
| 387 |
"source": [
|
| 388 |
"import json as _json\n",
|
| 389 |
+
"import requests as _requests\n",
|
| 390 |
+
"import random as _random\n",
|
| 391 |
+
"import statistics as _statistics\n",
|
| 392 |
"\n",
|
| 393 |
"training_rewards = []\n",
|
| 394 |
+
"_reward_variance_log = []\n",
|
| 395 |
+
"_call_count = [0]\n",
|
| 396 |
+
"\n",
|
| 397 |
+
"def gridmind_reward_fn(completions, prompts=None, **kwargs):\n",
|
| 398 |
+
" \"\"\"\n",
|
| 399 |
+
" Reward function compatible with trl 0.23.0.\n",
|
| 400 |
+
" Called with positional completions list.\n",
|
| 401 |
+
" Must return list of floats same length as completions.\n",
|
| 402 |
+
" \"\"\"\n",
|
| 403 |
" rewards = []\n",
|
| 404 |
+
" batch_raw = []\n",
|
| 405 |
+
"\n",
|
| 406 |
" for completion in completions:\n",
|
| 407 |
+
" _call_count[0] += 1\n",
|
| 408 |
+
"\n",
|
| 409 |
" try:\n",
|
| 410 |
+
" # Handle both string and list completion formats\n",
|
| 411 |
+
" if isinstance(completion, list):\n",
|
| 412 |
+
" text = str(completion[0]) if completion else \"\"\n",
|
| 413 |
+
" else:\n",
|
| 414 |
+
" text = str(completion)\n",
|
| 415 |
+
" text = text.strip()\n",
|
| 416 |
+
"\n",
|
| 417 |
+
" # Reset env before each reward call for variance\n",
|
| 418 |
+
" task_id = _random.choice([1, 2, 3, 4])\n",
|
| 419 |
+
" reset_r = _requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": task_id}, timeout=8)\n",
|
| 420 |
+
" if reset_r.status_code != 200:\n",
|
| 421 |
+
" rewards.append(-0.5)\n",
|
| 422 |
+
" batch_raw.append(-0.5)\n",
|
| 423 |
+
" continue\n",
|
| 424 |
+
"\n",
|
| 425 |
+
" # Extract JSON from completion\n",
|
| 426 |
" start = text.rfind('{')\n",
|
| 427 |
" end = text.rfind('}') + 1\n",
|
| 428 |
" if start < 0 or end <= start:\n",
|
| 429 |
" rewards.append(-1.0)\n",
|
| 430 |
+
" batch_raw.append(-1.0)\n",
|
| 431 |
" continue\n",
|
| 432 |
+
"\n",
|
| 433 |
+
" action = _json.loads(text[start:end])\n",
|
| 434 |
+
" action = {\n",
|
| 435 |
+
" \"hvac_power_level\": max(0.0, min(1.0, float(action.get(\"hvac_power_level\", 0.5)))),\n",
|
| 436 |
+
" \"thermal_charge_rate\": max(-1.0, min(1.0, float(action.get(\"thermal_charge_rate\", 0.0)))),\n",
|
| 437 |
+
" \"batch_job_slot\": max(0, min(4, int(action.get(\"batch_job_slot\", 0)))),\n",
|
| 438 |
+
" \"load_shed_fraction\": max(0.0, min(0.5, float(action.get(\"load_shed_fraction\", 0.0)))),\n",
|
| 439 |
+
" \"building_id\": int(action.get(\"building_id\", 0)),\n",
|
| 440 |
+
" }\n",
|
| 441 |
+
"\n",
|
| 442 |
+
" step_r = _requests.post(f\"{ENV_URL}/step\", json=action, timeout=8)\n",
|
| 443 |
+
" if step_r.status_code != 200:\n",
|
|
|
|
|
|
|
| 444 |
" rewards.append(-0.5)\n",
|
| 445 |
+
" batch_raw.append(-0.5)\n",
|
| 446 |
" continue\n",
|
| 447 |
+
"\n",
|
| 448 |
+
" data = step_r.json()\n",
|
| 449 |
+
" if isinstance(data, list):\n",
|
| 450 |
+
" data = data[0]\n",
|
| 451 |
+
"\n",
|
| 452 |
+
" base = float(data.get(\"reward\", 0.0))\n",
|
| 453 |
+
" comps = data.get(\"rewards\", {})\n",
|
| 454 |
+
" bonus = (\n",
|
| 455 |
+
" float(comps.get(\"cost_savings\", 0)) * 0.3 +\n",
|
| 456 |
+
" float(comps.get(\"task_satisfaction\", 0)) * 0.2 +\n",
|
| 457 |
+
" float(comps.get(\"efficiency_bonus\", 0)) * 0.1 +\n",
|
| 458 |
+
" float(comps.get(\"temperature_constraint\", 0)) * 0.15\n",
|
| 459 |
+
" )\n",
|
| 460 |
+
" final = max(-1.0, min(1.0, base + bonus))\n",
|
| 461 |
+
" rewards.append(final)\n",
|
| 462 |
+
" batch_raw.append(final)\n",
|
| 463 |
+
" training_rewards.append(final)\n",
|
| 464 |
+
"\n",
|
| 465 |
+
" except _json.JSONDecodeError:\n",
|
| 466 |
+
" rewards.append(-0.8)\n",
|
| 467 |
+
" batch_raw.append(-0.8)\n",
|
| 468 |
+
" except Exception:\n",
|
| 469 |
+
" rewards.append(-0.5)\n",
|
| 470 |
+
" batch_raw.append(-0.5)\n",
|
| 471 |
+
"\n",
|
| 472 |
+
" # Log variance every 10 calls\n",
|
| 473 |
+
" if len(batch_raw) > 1 and _call_count[0] % 10 == 0:\n",
|
| 474 |
+
" try:\n",
|
| 475 |
+
" var = _statistics.variance(batch_raw)\n",
|
| 476 |
+
" _reward_variance_log.append(var)\n",
|
| 477 |
+
" print(f\" [Call {_call_count[0]}] Rewards: {[f'{r:.3f}' for r in batch_raw]} | Variance: {var:.4f}\")\n",
|
| 478 |
+
" if var < 0.001:\n",
|
| 479 |
+
" print(\" Zero variance - no learning signal!\")\n",
|
| 480 |
+
" except Exception:\n",
|
| 481 |
+
" pass\n",
|
| 482 |
+
"\n",
|
| 483 |
" return rewards\n",
|
| 484 |
"\n",
|
| 485 |
+
"print(\"Reward function defined (trl 0.23.0 compatible)\")"
|
| 486 |
]
|
| 487 |
},
|
| 488 |
{
|
|
|
|
| 501 |
"outputs": [],
|
| 502 |
"source": [
|
| 503 |
"from trl import GRPOTrainer, GRPOConfig\n",
|
| 504 |
+
"from peft import LoraConfig, prepare_model_for_kbit_training\n",
|
| 505 |
"from datasets import Dataset\n",
|
| 506 |
+
"import inspect\n",
|
| 507 |
+
"import os\n",
|
| 508 |
+
"import requests as _requests\n",
|
| 509 |
+
"import statistics\n",
|
| 510 |
+
"import torch, gc\n",
|
| 511 |
"\n",
|
| 512 |
"# Prepare dataset\n",
|
| 513 |
"train_data = [{\"prompt\": d[\"prompt\"]} for d in dataset]\n",
|
| 514 |
"train_ds = Dataset.from_list(train_data)\n",
|
|
|
|
| 515 |
"print(f\"Training dataset: {len(train_ds)} prompts\")\n",
|
|
|
|
| 516 |
"\n",
|
| 517 |
+
"theme_dist = {}\n",
|
| 518 |
+
"for d in dataset:\n",
|
| 519 |
+
" t = d.get(\"theme\", \"unknown\")\n",
|
| 520 |
+
" theme_dist[t] = theme_dist.get(t, 0) + 1\n",
|
| 521 |
+
"print(f\"Theme distribution: {theme_dist}\")\n",
|
| 522 |
+
"print(f\"Sample prompt preview:\\n{train_data[0]['prompt'][:200]}...\\n\")\n",
|
| 523 |
+
"\n",
|
| 524 |
+
"# Prepare model for QLoRA training\n",
|
| 525 |
+
"model.config.use_cache = False\n",
|
| 526 |
+
"model.gradient_checkpointing_enable()\n",
|
| 527 |
+
"model = prepare_model_for_kbit_training(model)\n",
|
| 528 |
+
"\n",
|
| 529 |
+
"peft_config = LoraConfig(\n",
|
| 530 |
+
" r=16,\n",
|
| 531 |
+
" lora_alpha=32,\n",
|
| 532 |
+
" target_modules=[\"q_proj\", \"v_proj\", \"k_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
|
| 533 |
+
" lora_dropout=0.05,\n",
|
| 534 |
+
" bias=\"none\",\n",
|
| 535 |
+
" task_type=\"CAUSAL_LM\",\n",
|
| 536 |
+
")\n",
|
| 537 |
+
"\n",
|
| 538 |
+
"# GRPOConfig - trl==0.23.0 compatible. Pass this as args=, not config=.\n",
|
| 539 |
+
"# generation_kwargs is not a GRPOTrainer init parameter in trl 0.23.0.\n",
|
| 540 |
+
"grpo_config = GRPOConfig(\n",
|
| 541 |
" output_dir=\"./gridmind-grpo-output\",\n",
|
| 542 |
" num_train_epochs=1,\n",
|
| 543 |
+
" max_steps=60,\n",
|
| 544 |
+
" per_device_train_batch_size=1,\n",
|
| 545 |
+
" gradient_accumulation_steps=4,\n",
|
| 546 |
+
" max_prompt_length=400,\n",
|
| 547 |
+
" max_completion_length=80,\n",
|
| 548 |
+
" num_generations=4,\n",
|
| 549 |
+
" learning_rate=5e-5,\n",
|
| 550 |
" fp16=True,\n",
|
| 551 |
+
" logging_steps=1,\n",
|
| 552 |
+
" save_steps=60,\n",
|
| 553 |
" report_to=\"none\",\n",
|
| 554 |
+
" dataloader_num_workers=0,\n",
|
| 555 |
+
" remove_unused_columns=False,\n",
|
| 556 |
")\n",
|
| 557 |
"\n",
|
| 558 |
+
"print(\"=== PRE-TRAINING DIAGNOSTIC ===\\n\")\n",
|
| 559 |
+
"import trl\n",
|
| 560 |
+
"print(f\"TRL version: {trl.__version__}\")\n",
|
| 561 |
+
"sig = inspect.signature(GRPOTrainer.__init__)\n",
|
| 562 |
+
"params = list(sig.parameters.keys())\n",
|
| 563 |
+
"print(f\"GRPOTrainer params: {params[:8]}\")\n",
|
| 564 |
+
"print(f\"Uses 'args=': {'args' in params}\")\n",
|
| 565 |
+
"print(f\"Uses 'config=': {'config' in params}\")\n",
|
| 566 |
+
"\n",
|
| 567 |
+
"print(\"\\nTesting reward function...\")\n",
|
| 568 |
+
"test_completions = [\n",
|
| 569 |
+
" '{\"hvac_power_level\": 0.2, \"thermal_charge_rate\": 0.8, \"batch_job_slot\": 2, \"load_shed_fraction\": 0.0, \"building_id\": 0}',\n",
|
| 570 |
+
" '{\"hvac_power_level\": 1.0, \"thermal_charge_rate\": -1.0, \"batch_job_slot\": 0, \"load_shed_fraction\": 0.5, \"building_id\": 0}',\n",
|
| 571 |
+
" '{\"hvac_power_level\": 0.5, \"thermal_charge_rate\": 0.0, \"batch_job_slot\": 0, \"load_shed_fraction\": 0.0, \"building_id\": 0}',\n",
|
| 572 |
+
" 'not valid json at all',\n",
|
| 573 |
+
"]\n",
|
| 574 |
+
"test_rewards = gridmind_reward_fn(test_completions)\n",
|
| 575 |
+
"print(f\"Test rewards: {[f'{r:.3f}' for r in test_rewards]}\")\n",
|
| 576 |
+
"reward_var = statistics.variance(test_rewards) if len(set(test_rewards)) > 1 else 0.0\n",
|
| 577 |
+
"if reward_var <= 0.001:\n",
|
| 578 |
+
" print(\"CRITICAL: Reward variance is too low - fix reward function before training\")\n",
|
| 579 |
+
"else:\n",
|
| 580 |
+
" print(f\"Reward variance: {reward_var:.4f} - sufficient for GRPO\")\n",
|
| 581 |
+
"\n",
|
| 582 |
+
"print(f\"\\nGPU memory: {torch.cuda.memory_allocated()/1e9:.2f} GB used / 16 GB total\")\n",
|
| 583 |
+
"print(f\"Free: {(16 - torch.cuda.memory_allocated()/1e9):.2f} GB\")\n",
|
| 584 |
+
"print(\"\\n=== READY TO TRAIN ===\" if reward_var > 0.001 else \"\\n=== FIX REWARD FUNCTION FIRST ===\")\n",
|
| 585 |
+
"\n",
|
| 586 |
+
"# Reset environment before training\n",
|
| 587 |
+
"_requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": 1}, timeout=10)\n",
|
| 588 |
+
"print(\"Environment reset before training.\")\n",
|
| 589 |
+
"\n",
|
| 590 |
+
"# Initialize GRPOTrainer - trl 0.23.0 API\n",
|
| 591 |
"trainer = GRPOTrainer(\n",
|
| 592 |
" model=model,\n",
|
| 593 |
+
" args=grpo_config,\n",
|
| 594 |
" processing_class=tokenizer,\n",
|
|
|
|
| 595 |
" train_dataset=train_ds,\n",
|
| 596 |
" reward_funcs=gridmind_reward_fn,\n",
|
| 597 |
+
" peft_config=peft_config,\n",
|
| 598 |
")\n",
|
| 599 |
"\n",
|
| 600 |
+
"print(\"\\nStarting GRPO training with QLoRA...\")\n",
|
| 601 |
+
"print(f\"Steps: {grpo_config.max_steps} | Batch: {grpo_config.per_device_train_batch_size} | Generations: {grpo_config.num_generations}\")\n",
|
| 602 |
+
"print(\"Estimated time: ~25-35 min on T4\\n\")\n",
|
| 603 |
+
"\n",
|
| 604 |
+
"train_result = trainer.train()\n",
|
| 605 |
+
"\n",
|
| 606 |
+
"print(\"\\nTraining complete!\")\n",
|
| 607 |
+
"print(f\" Total steps: {train_result.global_step}\")\n",
|
| 608 |
+
"print(f\" Training loss: {train_result.training_loss:.6f}\")\n",
|
| 609 |
+
"\n",
|
| 610 |
+
"if train_result.training_loss == 0.0:\n",
|
| 611 |
+
" print(\"\\nWARNING: Loss is 0.0 - reward function may have zero variance.\")\n",
|
| 612 |
+
" print(\"Check reward diagnostic output above. This means the model saw no learning signal.\")\n",
|
| 613 |
+
"else:\n",
|
| 614 |
+
" print(\"\\nNon-zero loss confirmed - model received learning signal.\")\n",
|
| 615 |
+
"\n",
|
| 616 |
+
"print(f\"\\nMemory after training: {torch.cuda.memory_allocated()/1e9:.2f} GB\")\n",
|
| 617 |
+
"\n",
|
| 618 |
+
"# Save LoRA adapter (much smaller than full model)\n",
|
| 619 |
+
"adapter_path = \"./gridmind-lora-adapter\"\n",
|
| 620 |
+
"trainer.model.save_pretrained(adapter_path)\n",
|
| 621 |
+
"tokenizer.save_pretrained(adapter_path)\n",
|
| 622 |
+
"print(f\"LoRA adapter saved to {adapter_path}\")\n",
|
| 623 |
+
"\n",
|
| 624 |
+
"total_size = sum(\n",
|
| 625 |
+
" os.path.getsize(os.path.join(adapter_path, f))\n",
|
| 626 |
+
" for f in os.listdir(adapter_path)\n",
|
| 627 |
+
" if os.path.isfile(os.path.join(adapter_path, f))\n",
|
| 628 |
+
")\n",
|
| 629 |
+
"print(f\"Adapter size: {total_size/1e6:.1f} MB\")\n",
|
| 630 |
+
"print(\"Full model would be ~3 GB - adapter is the diff only\")"
|
| 631 |
]
|
| 632 |
},
|
| 633 |
{
|
|
|
|
| 750 |
" },\n",
|
| 751 |
" \"improvement_percent\": overall_improvement,\n",
|
| 752 |
" \"model\": MODEL_NAME,\n",
|
| 753 |
+
" \"training_steps\": grpo_config.max_steps,\n",
|
| 754 |
" \"themes_covered\": [\"multi_agent\", \"instruction_following\", \"world_modeling\", \"curriculum\"],\n",
|
| 755 |
" \"training_rewards_log\": training_rewards[-20:] if training_rewards else [],\n",
|
| 756 |
"}\n",
|
|
|
|
| 776 |
},
|
| 777 |
"nbformat": 4,
|
| 778 |
"nbformat_minor": 5
|
| 779 |
+
}
|
scripts/train_unsloth.py
CHANGED
|
@@ -690,7 +690,7 @@ def main():
|
|
| 690 |
|
| 691 |
trainer = GRPOTrainer(
|
| 692 |
model=model,
|
| 693 |
-
|
| 694 |
args=training_args,
|
| 695 |
train_dataset=dataset,
|
| 696 |
reward_funcs=[
|
|
@@ -746,4 +746,4 @@ def main():
|
|
| 746 |
|
| 747 |
|
| 748 |
if __name__ == "__main__":
|
| 749 |
-
main()
|
|
|
|
| 690 |
|
| 691 |
trainer = GRPOTrainer(
|
| 692 |
model=model,
|
| 693 |
+
processing_class=tokenizer,
|
| 694 |
args=training_args,
|
| 695 |
train_dataset=dataset,
|
| 696 |
reward_funcs=[
|
|
|
|
| 746 |
|
| 747 |
|
| 748 |
if __name__ == "__main__":
|
| 749 |
+
main()
|