Prajwal782007 commited on
Commit
3d49e8a
·
1 Parent(s): 08731ee

feat: add GridMind GRPO training environment and Unsloth training script

Browse files
scripts/gridmind_grpo_colab.ipynb CHANGED
@@ -332,24 +332,42 @@
332
  "outputs": [],
333
  "source": [
334
  "import torch\n",
335
- "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
 
 
 
 
 
 
 
 
336
  "\n",
337
  "MODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
338
- "print(f\"Loading {MODEL_NAME}...\")\n",
339
  "\n",
340
- "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
341
  "if tokenizer.pad_token is None:\n",
342
  " tokenizer.pad_token = tokenizer.eos_token\n",
 
 
 
 
 
 
 
 
 
343
  "\n",
344
  "model = AutoModelForCausalLM.from_pretrained(\n",
345
  " MODEL_NAME,\n",
346
- " torch_dtype=torch.float16,\n",
347
- " device_map=\"cuda\"\n",
 
348
  ")\n",
349
  "\n",
350
- "total_params = sum(p.numel() for p in model.parameters())\n",
351
- "print(f\"Model loaded. Parameters: {total_params/1e6:.0f}M\")\n",
352
- "print(f\"Device: {next(model.parameters()).device}\")"
353
  ]
354
  },
355
  {
@@ -368,53 +386,103 @@
368
  "outputs": [],
369
  "source": [
370
  "import json as _json\n",
 
 
 
371
  "\n",
372
  "training_rewards = []\n",
373
- "\n",
374
- "def gridmind_reward_fn(completions, **kwargs):\n",
375
- " \"\"\"Reward function that calls the real environment.\"\"\"\n",
 
 
 
 
 
 
376
  " rewards = []\n",
377
- " \n",
 
378
  " for completion in completions:\n",
 
 
379
  " try:\n",
380
- " # Extract JSON action from completion\n",
381
- " text = str(completion).strip()\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  " start = text.rfind('{')\n",
383
  " end = text.rfind('}') + 1\n",
384
  " if start < 0 or end <= start:\n",
385
  " rewards.append(-1.0)\n",
 
386
  " continue\n",
387
- " \n",
388
- " action_str = text[start:end]\n",
389
- " action = _json.loads(action_str)\n",
390
- " \n",
391
- " # Clamp action to valid ranges\n",
392
- " action[\"hvac_power_level\"] = max(0.0, min(1.0, float(action.get(\"hvac_power_level\", 0.5))))\n",
393
- " action[\"thermal_charge_rate\"] = max(-1.0, min(1.0, float(action.get(\"thermal_charge_rate\", 0.0))))\n",
394
- " action[\"batch_job_slot\"] = max(0, min(4, int(action.get(\"batch_job_slot\", 0))))\n",
395
- " action[\"load_shed_fraction\"] = max(0.0, min(0.5, float(action.get(\"load_shed_fraction\", 0.0))))\n",
396
- " action[\"building_id\"] = int(action.get(\"building_id\", 0))\n",
397
- " \n",
398
- " # Call environment\n",
399
- " r = requests.post(f\"{ENV_URL}/step\", json=action, timeout=8)\n",
400
- " if r.status_code != 200:\n",
401
  " rewards.append(-0.5)\n",
 
402
  " continue\n",
403
- " \n",
404
- " step_data = r.json()\n",
405
- " if isinstance(step_data, list):\n",
406
- " step_data = step_data[0]\n",
407
- " \n",
408
- " reward = float(step_data.get(\"reward\", 0))\n",
409
- " rewards.append(max(-1.0, min(1.0, reward))) # Clamp to [-1, 1]\n",
410
- " training_rewards.append(reward)\n",
411
- " \n",
412
- " except Exception as e:\n",
413
- " rewards.append(-1.0)\n",
414
- " \n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  " return rewards\n",
416
  "\n",
417
- "print(\"Reward function defined.\")"
418
  ]
419
  },
420
  {
@@ -433,49 +501,133 @@
433
  "outputs": [],
434
  "source": [
435
  "from trl import GRPOTrainer, GRPOConfig\n",
 
436
  "from datasets import Dataset\n",
 
 
 
 
 
437
  "\n",
438
  "# Prepare dataset\n",
439
  "train_data = [{\"prompt\": d[\"prompt\"]} for d in dataset]\n",
440
  "train_ds = Dataset.from_list(train_data)\n",
441
- "\n",
442
  "print(f\"Training dataset: {len(train_ds)} prompts\")\n",
443
- "print(f\"Sample prompt:\\n{train_data[0]['prompt'][:200]}...\\n\")\n",
444
  "\n",
445
- "# GRPO config for free T4 GPU\n",
446
- "config = GRPOConfig(\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
  " output_dir=\"./gridmind-grpo-output\",\n",
448
  " num_train_epochs=1,\n",
449
- " max_steps=60, # Complete in ~30-40 min on T4\n",
450
- " per_device_train_batch_size=2,\n",
451
- " gradient_accumulation_steps=2,\n",
452
- " max_prompt_length=512,\n",
453
- " learning_rate=5e-6,\n",
454
- " logging_steps=5,\n",
455
- " save_steps=60,\n",
456
  " fp16=True,\n",
457
- " dataloader_num_workers=0,\n",
 
458
  " report_to=\"none\",\n",
459
- " num_generations=2, # 2 generations per prompt for speed\n",
 
460
  ")\n",
461
  "\n",
462
- "print(\"\\nStarting GRPO training...\")\n",
463
- "print(f\"Estimated time: 30-40 minutes on Colab T4 GPU\")\n",
464
- "print(f\"Steps: {config.max_steps}, Batch size: {config.per_device_train_batch_size * config.gradient_accumulation_steps}\\n\")\n",
465
- "\n",
466
- "# Initialize trainer\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
  "trainer = GRPOTrainer(\n",
468
  " model=model,\n",
 
469
  " processing_class=tokenizer,\n",
470
- " config=config,\n",
471
  " train_dataset=train_ds,\n",
472
  " reward_funcs=gridmind_reward_fn,\n",
473
- " generation_kwargs={\"max_new_tokens\": 100},\n",
474
  ")\n",
475
  "\n",
476
- "# Train\n",
477
- "trainer.train()\n",
478
- "print(\"\\n\u00e2\u0153\u201c Training complete!\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
  ]
480
  },
481
  {
@@ -598,7 +750,7 @@
598
  " },\n",
599
  " \"improvement_percent\": overall_improvement,\n",
600
  " \"model\": MODEL_NAME,\n",
601
- " \"training_steps\": config.max_steps,\n",
602
  " \"themes_covered\": [\"multi_agent\", \"instruction_following\", \"world_modeling\", \"curriculum\"],\n",
603
  " \"training_rewards_log\": training_rewards[-20:] if training_rewards else [],\n",
604
  "}\n",
@@ -624,4 +776,4 @@
624
  },
625
  "nbformat": 4,
626
  "nbformat_minor": 5
627
- }
 
332
  "outputs": [],
333
  "source": [
334
  "import torch\n",
335
+ "import gc\n",
336
+ "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
337
+ "\n",
338
+ "# Clear any previous model from memory\n",
339
+ "for var in ['model', 'trainer']:\n",
340
+ " if var in dir():\n",
341
+ " del var\n",
342
+ "gc.collect()\n",
343
+ "torch.cuda.empty_cache()\n",
344
  "\n",
345
  "MODEL_NAME = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
346
+ "print(f\"Loading {MODEL_NAME} with 4-bit quantization for T4 16GB...\")\n",
347
  "\n",
348
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n",
349
  "if tokenizer.pad_token is None:\n",
350
  " tokenizer.pad_token = tokenizer.eos_token\n",
351
+ "tokenizer.padding_side = \"left\" # required for GRPO\n",
352
+ "\n",
353
+ "# 4-bit quantization - fits safely on T4 16GB\n",
354
+ "bnb_config = BitsAndBytesConfig(\n",
355
+ " load_in_4bit=True,\n",
356
+ " bnb_4bit_compute_dtype=torch.float16,\n",
357
+ " bnb_4bit_quant_type=\"nf4\",\n",
358
+ " bnb_4bit_use_double_quant=True,\n",
359
+ ")\n",
360
  "\n",
361
  "model = AutoModelForCausalLM.from_pretrained(\n",
362
  " MODEL_NAME,\n",
363
+ " quantization_config=bnb_config,\n",
364
+ " device_map=\"auto\",\n",
365
+ " trust_remote_code=True,\n",
366
  ")\n",
367
  "\n",
368
+ "print(f\"Model loaded on: {next(model.parameters()).device}\")\n",
369
+ "print(f\"Memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB / 16 GB\")\n",
370
+ "print(f\"Memory reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB / 16 GB\")"
371
  ]
372
  },
373
  {
 
386
  "outputs": [],
387
  "source": [
388
  "import json as _json\n",
389
+ "import requests as _requests\n",
390
+ "import random as _random\n",
391
+ "import statistics as _statistics\n",
392
  "\n",
393
  "training_rewards = []\n",
394
+ "_reward_variance_log = []\n",
395
+ "_call_count = [0]\n",
396
+ "\n",
397
+ "def gridmind_reward_fn(completions, prompts=None, **kwargs):\n",
398
+ " \"\"\"\n",
399
+ " Reward function compatible with trl 0.23.0.\n",
400
+ " Called with positional completions list.\n",
401
+ " Must return list of floats same length as completions.\n",
402
+ " \"\"\"\n",
403
  " rewards = []\n",
404
+ " batch_raw = []\n",
405
+ "\n",
406
  " for completion in completions:\n",
407
+ " _call_count[0] += 1\n",
408
+ "\n",
409
  " try:\n",
410
+ " # Handle both string and list completion formats\n",
411
+ " if isinstance(completion, list):\n",
412
+ " text = str(completion[0]) if completion else \"\"\n",
413
+ " else:\n",
414
+ " text = str(completion)\n",
415
+ " text = text.strip()\n",
416
+ "\n",
417
+ " # Reset env before each reward call for variance\n",
418
+ " task_id = _random.choice([1, 2, 3, 4])\n",
419
+ " reset_r = _requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": task_id}, timeout=8)\n",
420
+ " if reset_r.status_code != 200:\n",
421
+ " rewards.append(-0.5)\n",
422
+ " batch_raw.append(-0.5)\n",
423
+ " continue\n",
424
+ "\n",
425
+ " # Extract JSON from completion\n",
426
  " start = text.rfind('{')\n",
427
  " end = text.rfind('}') + 1\n",
428
  " if start < 0 or end <= start:\n",
429
  " rewards.append(-1.0)\n",
430
+ " batch_raw.append(-1.0)\n",
431
  " continue\n",
432
+ "\n",
433
+ " action = _json.loads(text[start:end])\n",
434
+ " action = {\n",
435
+ " \"hvac_power_level\": max(0.0, min(1.0, float(action.get(\"hvac_power_level\", 0.5)))),\n",
436
+ " \"thermal_charge_rate\": max(-1.0, min(1.0, float(action.get(\"thermal_charge_rate\", 0.0)))),\n",
437
+ " \"batch_job_slot\": max(0, min(4, int(action.get(\"batch_job_slot\", 0)))),\n",
438
+ " \"load_shed_fraction\": max(0.0, min(0.5, float(action.get(\"load_shed_fraction\", 0.0)))),\n",
439
+ " \"building_id\": int(action.get(\"building_id\", 0)),\n",
440
+ " }\n",
441
+ "\n",
442
+ " step_r = _requests.post(f\"{ENV_URL}/step\", json=action, timeout=8)\n",
443
+ " if step_r.status_code != 200:\n",
 
 
444
  " rewards.append(-0.5)\n",
445
+ " batch_raw.append(-0.5)\n",
446
  " continue\n",
447
+ "\n",
448
+ " data = step_r.json()\n",
449
+ " if isinstance(data, list):\n",
450
+ " data = data[0]\n",
451
+ "\n",
452
+ " base = float(data.get(\"reward\", 0.0))\n",
453
+ " comps = data.get(\"rewards\", {})\n",
454
+ " bonus = (\n",
455
+ " float(comps.get(\"cost_savings\", 0)) * 0.3 +\n",
456
+ " float(comps.get(\"task_satisfaction\", 0)) * 0.2 +\n",
457
+ " float(comps.get(\"efficiency_bonus\", 0)) * 0.1 +\n",
458
+ " float(comps.get(\"temperature_constraint\", 0)) * 0.15\n",
459
+ " )\n",
460
+ " final = max(-1.0, min(1.0, base + bonus))\n",
461
+ " rewards.append(final)\n",
462
+ " batch_raw.append(final)\n",
463
+ " training_rewards.append(final)\n",
464
+ "\n",
465
+ " except _json.JSONDecodeError:\n",
466
+ " rewards.append(-0.8)\n",
467
+ " batch_raw.append(-0.8)\n",
468
+ " except Exception:\n",
469
+ " rewards.append(-0.5)\n",
470
+ " batch_raw.append(-0.5)\n",
471
+ "\n",
472
+ " # Log variance every 10 calls\n",
473
+ " if len(batch_raw) > 1 and _call_count[0] % 10 == 0:\n",
474
+ " try:\n",
475
+ " var = _statistics.variance(batch_raw)\n",
476
+ " _reward_variance_log.append(var)\n",
477
+ " print(f\" [Call {_call_count[0]}] Rewards: {[f'{r:.3f}' for r in batch_raw]} | Variance: {var:.4f}\")\n",
478
+ " if var < 0.001:\n",
479
+ " print(\" Zero variance - no learning signal!\")\n",
480
+ " except Exception:\n",
481
+ " pass\n",
482
+ "\n",
483
  " return rewards\n",
484
  "\n",
485
+ "print(\"Reward function defined (trl 0.23.0 compatible)\")"
486
  ]
487
  },
488
  {
 
501
  "outputs": [],
502
  "source": [
503
  "from trl import GRPOTrainer, GRPOConfig\n",
504
+ "from peft import LoraConfig, prepare_model_for_kbit_training\n",
505
  "from datasets import Dataset\n",
506
+ "import inspect\n",
507
+ "import os\n",
508
+ "import requests as _requests\n",
509
+ "import statistics\n",
510
+ "import torch, gc\n",
511
  "\n",
512
  "# Prepare dataset\n",
513
  "train_data = [{\"prompt\": d[\"prompt\"]} for d in dataset]\n",
514
  "train_ds = Dataset.from_list(train_data)\n",
 
515
  "print(f\"Training dataset: {len(train_ds)} prompts\")\n",
 
516
  "\n",
517
+ "theme_dist = {}\n",
518
+ "for d in dataset:\n",
519
+ " t = d.get(\"theme\", \"unknown\")\n",
520
+ " theme_dist[t] = theme_dist.get(t, 0) + 1\n",
521
+ "print(f\"Theme distribution: {theme_dist}\")\n",
522
+ "print(f\"Sample prompt preview:\\n{train_data[0]['prompt'][:200]}...\\n\")\n",
523
+ "\n",
524
+ "# Prepare model for QLoRA training\n",
525
+ "model.config.use_cache = False\n",
526
+ "model.gradient_checkpointing_enable()\n",
527
+ "model = prepare_model_for_kbit_training(model)\n",
528
+ "\n",
529
+ "peft_config = LoraConfig(\n",
530
+ " r=16,\n",
531
+ " lora_alpha=32,\n",
532
+ " target_modules=[\"q_proj\", \"v_proj\", \"k_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
533
+ " lora_dropout=0.05,\n",
534
+ " bias=\"none\",\n",
535
+ " task_type=\"CAUSAL_LM\",\n",
536
+ ")\n",
537
+ "\n",
538
+ "# GRPOConfig - trl==0.23.0 compatible. Pass this as args=, not config=.\n",
539
+ "# generation_kwargs is not a GRPOTrainer init parameter in trl 0.23.0.\n",
540
+ "grpo_config = GRPOConfig(\n",
541
  " output_dir=\"./gridmind-grpo-output\",\n",
542
  " num_train_epochs=1,\n",
543
+ " max_steps=60,\n",
544
+ " per_device_train_batch_size=1,\n",
545
+ " gradient_accumulation_steps=4,\n",
546
+ " max_prompt_length=400,\n",
547
+ " max_completion_length=80,\n",
548
+ " num_generations=4,\n",
549
+ " learning_rate=5e-5,\n",
550
  " fp16=True,\n",
551
+ " logging_steps=1,\n",
552
+ " save_steps=60,\n",
553
  " report_to=\"none\",\n",
554
+ " dataloader_num_workers=0,\n",
555
+ " remove_unused_columns=False,\n",
556
  ")\n",
557
  "\n",
558
+ "print(\"=== PRE-TRAINING DIAGNOSTIC ===\\n\")\n",
559
+ "import trl\n",
560
+ "print(f\"TRL version: {trl.__version__}\")\n",
561
+ "sig = inspect.signature(GRPOTrainer.__init__)\n",
562
+ "params = list(sig.parameters.keys())\n",
563
+ "print(f\"GRPOTrainer params: {params[:8]}\")\n",
564
+ "print(f\"Uses 'args=': {'args' in params}\")\n",
565
+ "print(f\"Uses 'config=': {'config' in params}\")\n",
566
+ "\n",
567
+ "print(\"\\nTesting reward function...\")\n",
568
+ "test_completions = [\n",
569
+ " '{\"hvac_power_level\": 0.2, \"thermal_charge_rate\": 0.8, \"batch_job_slot\": 2, \"load_shed_fraction\": 0.0, \"building_id\": 0}',\n",
570
+ " '{\"hvac_power_level\": 1.0, \"thermal_charge_rate\": -1.0, \"batch_job_slot\": 0, \"load_shed_fraction\": 0.5, \"building_id\": 0}',\n",
571
+ " '{\"hvac_power_level\": 0.5, \"thermal_charge_rate\": 0.0, \"batch_job_slot\": 0, \"load_shed_fraction\": 0.0, \"building_id\": 0}',\n",
572
+ " 'not valid json at all',\n",
573
+ "]\n",
574
+ "test_rewards = gridmind_reward_fn(test_completions)\n",
575
+ "print(f\"Test rewards: {[f'{r:.3f}' for r in test_rewards]}\")\n",
576
+ "reward_var = statistics.variance(test_rewards) if len(set(test_rewards)) > 1 else 0.0\n",
577
+ "if reward_var <= 0.001:\n",
578
+ " print(\"CRITICAL: Reward variance is too low - fix reward function before training\")\n",
579
+ "else:\n",
580
+ " print(f\"Reward variance: {reward_var:.4f} - sufficient for GRPO\")\n",
581
+ "\n",
582
+ "print(f\"\\nGPU memory: {torch.cuda.memory_allocated()/1e9:.2f} GB used / 16 GB total\")\n",
583
+ "print(f\"Free: {(16 - torch.cuda.memory_allocated()/1e9):.2f} GB\")\n",
584
+ "print(\"\\n=== READY TO TRAIN ===\" if reward_var > 0.001 else \"\\n=== FIX REWARD FUNCTION FIRST ===\")\n",
585
+ "\n",
586
+ "# Reset environment before training\n",
587
+ "_requests.post(f\"{ENV_URL}/reset\", json={\"task_id\": 1}, timeout=10)\n",
588
+ "print(\"Environment reset before training.\")\n",
589
+ "\n",
590
+ "# Initialize GRPOTrainer - trl 0.23.0 API\n",
591
  "trainer = GRPOTrainer(\n",
592
  " model=model,\n",
593
+ " args=grpo_config,\n",
594
  " processing_class=tokenizer,\n",
 
595
  " train_dataset=train_ds,\n",
596
  " reward_funcs=gridmind_reward_fn,\n",
597
+ " peft_config=peft_config,\n",
598
  ")\n",
599
  "\n",
600
+ "print(\"\\nStarting GRPO training with QLoRA...\")\n",
601
+ "print(f\"Steps: {grpo_config.max_steps} | Batch: {grpo_config.per_device_train_batch_size} | Generations: {grpo_config.num_generations}\")\n",
602
+ "print(\"Estimated time: ~25-35 min on T4\\n\")\n",
603
+ "\n",
604
+ "train_result = trainer.train()\n",
605
+ "\n",
606
+ "print(\"\\nTraining complete!\")\n",
607
+ "print(f\" Total steps: {train_result.global_step}\")\n",
608
+ "print(f\" Training loss: {train_result.training_loss:.6f}\")\n",
609
+ "\n",
610
+ "if train_result.training_loss == 0.0:\n",
611
+ " print(\"\\nWARNING: Loss is 0.0 - reward function may have zero variance.\")\n",
612
+ " print(\"Check reward diagnostic output above. This means the model saw no learning signal.\")\n",
613
+ "else:\n",
614
+ " print(\"\\nNon-zero loss confirmed - model received learning signal.\")\n",
615
+ "\n",
616
+ "print(f\"\\nMemory after training: {torch.cuda.memory_allocated()/1e9:.2f} GB\")\n",
617
+ "\n",
618
+ "# Save LoRA adapter (much smaller than full model)\n",
619
+ "adapter_path = \"./gridmind-lora-adapter\"\n",
620
+ "trainer.model.save_pretrained(adapter_path)\n",
621
+ "tokenizer.save_pretrained(adapter_path)\n",
622
+ "print(f\"LoRA adapter saved to {adapter_path}\")\n",
623
+ "\n",
624
+ "total_size = sum(\n",
625
+ " os.path.getsize(os.path.join(adapter_path, f))\n",
626
+ " for f in os.listdir(adapter_path)\n",
627
+ " if os.path.isfile(os.path.join(adapter_path, f))\n",
628
+ ")\n",
629
+ "print(f\"Adapter size: {total_size/1e6:.1f} MB\")\n",
630
+ "print(\"Full model would be ~3 GB - adapter is the diff only\")"
631
  ]
632
  },
633
  {
 
750
  " },\n",
751
  " \"improvement_percent\": overall_improvement,\n",
752
  " \"model\": MODEL_NAME,\n",
753
+ " \"training_steps\": grpo_config.max_steps,\n",
754
  " \"themes_covered\": [\"multi_agent\", \"instruction_following\", \"world_modeling\", \"curriculum\"],\n",
755
  " \"training_rewards_log\": training_rewards[-20:] if training_rewards else [],\n",
756
  "}\n",
 
776
  },
777
  "nbformat": 4,
778
  "nbformat_minor": 5
779
+ }
scripts/train_unsloth.py CHANGED
@@ -690,7 +690,7 @@ def main():
690
 
691
  trainer = GRPOTrainer(
692
  model=model,
693
- tokenizer=tokenizer,
694
  args=training_args,
695
  train_dataset=dataset,
696
  reward_funcs=[
@@ -746,4 +746,4 @@ def main():
746
 
747
 
748
  if __name__ == "__main__":
749
- main()
 
690
 
691
  trainer = GRPOTrainer(
692
  model=model,
693
+ processing_class=tokenizer,
694
  args=training_args,
695
  train_dataset=dataset,
696
  reward_funcs=[
 
746
 
747
 
748
  if __name__ == "__main__":
749
+ main()