nihalaninihal commited on
Commit
ee8c2d4
Β·
1 Parent(s): 803c93e

Update Colab notebook: 1.5B model, scaled rewards, tuned hyperparameters

Browse files

- Cell 9: Change default model from Qwen2.5-0.5B to Qwen2.5-1.5B-Instruct
- Cell 11: max_prompt_length 512->768, num_generations 4->8, max_completion_length=1280
- Cell 10/8: Update markdown to document reward scaling weights and 1.5B rationale
- Cell 11: Add print for max_prompt/completion_length for training visibility

Files changed (1) hide show
  1. training/colab_training.ipynb +4 -4
training/colab_training.ipynb CHANGED
@@ -81,7 +81,7 @@
81
  },
82
  {
83
  "cell_type": "markdown",
84
- "source": "## 4. Load Model with Unsloth (BF16 + vLLM)\n\nFollowing the Advanced Llama 3.2 GRPO LoRA reference pattern:\n- `load_in_4bit=False` β€” BF16 precision on H100\n- `fast_inference=True` β€” vLLM for fast GRPO generation\n- `lora_rank=64`, `lora_alpha=lora_rank` β€” official LoRA configuration\n- `gpu_memory_utilization=0.9` β€” maximize GPU usage\n- `random_state=3407` β€” reproducibility",
85
  "metadata": {
86
  "id": "train-header"
87
  }
@@ -93,11 +93,11 @@
93
  "id": "train"
94
  },
95
  "outputs": [],
96
- "source": "from unsloth import FastLanguageModel\nimport torch\n\nmodel_name = \"unsloth/Qwen2.5-0.5B-Instruct\"\nmax_seq_length = 2048\nlora_rank = 64\n\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n model_name=model_name,\n max_seq_length=max_seq_length,\n load_in_4bit=False, # BF16 for H100 (reference pattern)\n fast_inference=True, # vLLM fast inference\n max_lora_rank=lora_rank,\n gpu_memory_utilization=0.9,\n)\n\nmodel = FastLanguageModel.get_peft_model(\n model,\n r=lora_rank,\n target_modules=[\n \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\",\n ],\n lora_alpha=lora_rank, # Reference: lora_alpha = lora_rank\n use_gradient_checkpointing=\"unsloth\",\n random_state=3407,\n)\nprint(f\"Model loaded: BF16 + vLLM + LoRA (r={lora_rank}, alpha={lora_rank})\")"
97
  },
98
  {
99
  "cell_type": "markdown",
100
- "source": "## 5. GRPO Training with 4 Reward Functions\n\nFollowing the Advanced Llama 3.2 GRPO LoRA reference pattern with **4 separate reward functions**:\n1. `match_json_format_exactly` β€” strict JSON format validation (+3.0)\n2. `match_json_format_approximately` β€” partial format credit\n3. `check_action` β€” role-specific action correctness\n4. `check_env` β€” **environment-executing reward** (OpenEnv pattern)\n\nPlus reference hyperparameters: `adamw_8bit`, cosine scheduler, `weight_decay=0.1`, `warmup_ratio=0.1`.",
101
  "metadata": {
102
  "id": "save-header"
103
  }
@@ -109,7 +109,7 @@
109
  "id": "save"
110
  },
111
  "outputs": [],
112
- "source": "from trl import GRPOConfig, GRPOTrainer\nfrom train import make_reward_functions\n\n# 4 separate reward functions (reference pattern)\nreward_fns = make_reward_functions(TARGET_AGENT)\nprint(f\"Reward functions: {len(reward_fns)}\")\nfor i, fn in enumerate(reward_fns):\n print(f\" [{i}] {fn.__name__ if hasattr(fn, '__name__') else type(fn).__name__}\")\n\nmax_prompt_length = 512\ngrpo_config = GRPOConfig(\n output_dir=f\"./sentinelops-grpo-{TARGET_AGENT}\",\n max_steps=500, # Reference: 500\n per_device_train_batch_size=1,\n gradient_accumulation_steps=4,\n num_generations=4, # Reference: 4\n max_prompt_length=max_prompt_length,\n max_completion_length=max_seq_length - max_prompt_length,\n learning_rate=5e-6, # Reference: 5e-6\n weight_decay=0.1, # Reference: 0.1\n warmup_ratio=0.1, # Reference: 0.1\n lr_scheduler_type=\"cosine\", # Reference: cosine\n optim=\"adamw_8bit\", # Reference: adamw_8bit\n max_grad_norm=1.0, # Reference: 1.0\n logging_steps=1,\n save_steps=250, # Reference: 250\n report_to=\"none\",\n)\n\ntrainer = GRPOTrainer(\n model=model,\n tokenizer=tokenizer, # Reference uses tokenizer= not processing_class=\n reward_funcs=reward_fns, # 4 reward functions (reference pattern)\n args=grpo_config,\n train_dataset=train_dataset,\n)\n\nprint(f\"\\nStarting GRPO training for {TARGET_AGENT}...\")\nprint(f\" max_steps={grpo_config.max_steps}, lr={grpo_config.learning_rate}\")\nprint(f\" num_generations={grpo_config.num_generations}, optim={grpo_config.optim}\")\ntrainer.train()"
113
  },
114
  {
115
  "cell_type": "markdown",
 
81
  },
82
  {
83
  "cell_type": "markdown",
84
+ "source": "## 4. Load Model with Unsloth (BF16 + vLLM)\n\nFollowing the Advanced Llama 3.2 GRPO LoRA reference pattern:\n- `load_in_4bit=False` β€” BF16 precision on H100\n- `fast_inference=True` β€” vLLM for fast GRPO generation\n- `lora_rank=64`, `lora_alpha=lora_rank` β€” official LoRA configuration\n- `gpu_memory_utilization=0.9` β€” maximize GPU usage\n- `random_state=3407` β€” reproducibility\n\nDefault model: **Qwen2.5-1.5B-Instruct** (minimum recommended for GRPO β€” 0.5B lacks capacity for multi-step reasoning)",
85
  "metadata": {
86
  "id": "train-header"
87
  }
 
93
  "id": "train"
94
  },
95
  "outputs": [],
96
+ "source": "from unsloth import FastLanguageModel\nimport torch\n\nmodel_name = \"unsloth/Qwen2.5-1.5B-Instruct\"\nmax_seq_length = 2048\nlora_rank = 64\n\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n model_name=model_name,\n max_seq_length=max_seq_length,\n load_in_4bit=False, # BF16 for H100 (reference pattern)\n fast_inference=True, # vLLM fast inference\n max_lora_rank=lora_rank,\n gpu_memory_utilization=0.9,\n)\n\nmodel = FastLanguageModel.get_peft_model(\n model,\n r=lora_rank,\n target_modules=[\n \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\",\n ],\n lora_alpha=lora_rank, # Reference: lora_alpha = lora_rank\n use_gradient_checkpointing=\"unsloth\",\n random_state=3407,\n)\nprint(f\"Model loaded: BF16 + vLLM + LoRA (r={lora_rank}, alpha={lora_rank})\")"
97
  },
98
  {
99
  "cell_type": "markdown",
100
+ "source": "## 5. GRPO Training with 4 Scaled Reward Functions\n\nFollowing the Advanced Llama 3.2 GRPO LoRA reference pattern with **4 separate reward functions** and scaling to prevent R1 domination:\n1. `match_json_format_exactly` β€” strict JSON format validation (weight=0.3)\n2. `match_json_format_approximately` β€” partial format credit (weight=0.2)\n3. `check_action` β€” role-specific action correctness (weight=0.5)\n4. `check_env` β€” **environment-executing reward** (weight=1.0, full impact)\n\nUpdated hyperparameters: `max_prompt_length=768` (room for system prompt + observations), `num_generations=8` (stable advantage estimation), `adamw_8bit`, cosine scheduler, `weight_decay=0.1`, `warmup_ratio=0.1`.",
101
  "metadata": {
102
  "id": "save-header"
103
  }
 
109
  "id": "save"
110
  },
111
  "outputs": [],
112
+ "source": "from trl import GRPOConfig, GRPOTrainer\nfrom train import make_reward_functions\n\n# 4 separate reward functions with scaling (reference pattern)\nreward_fns = make_reward_functions(TARGET_AGENT)\nprint(f\"Reward functions: {len(reward_fns)}\")\nfor i, fn in enumerate(reward_fns):\n print(f\" [{i}] {fn.__name__ if hasattr(fn, '__name__') else type(fn).__name__}\")\n\nmax_prompt_length = 768 # System prompt ~350 tokens + observation needs room\ngrpo_config = GRPOConfig(\n output_dir=f\"./sentinelops-grpo-{TARGET_AGENT}\",\n max_steps=500, # Reference: 500\n per_device_train_batch_size=1,\n gradient_accumulation_steps=4,\n num_generations=8, # 8 generations for stable advantage estimation\n max_prompt_length=max_prompt_length,\n max_completion_length=max_seq_length - max_prompt_length, # 2048 - 768 = 1280\n learning_rate=5e-6, # Reference: 5e-6\n weight_decay=0.1, # Reference: 0.1\n warmup_ratio=0.1, # Reference: 0.1\n lr_scheduler_type=\"cosine\", # Reference: cosine\n optim=\"adamw_8bit\", # Reference: adamw_8bit\n max_grad_norm=1.0, # Reference: 1.0\n logging_steps=1,\n save_steps=250, # Reference: 250\n report_to=\"none\",\n)\n\ntrainer = GRPOTrainer(\n model=model,\n tokenizer=tokenizer, # Reference uses tokenizer= not processing_class=\n reward_funcs=reward_fns, # 4 scaled reward functions (reference pattern)\n args=grpo_config,\n train_dataset=train_dataset,\n)\n\nprint(f\"\\nStarting GRPO training for {TARGET_AGENT}...\")\nprint(f\" max_steps={grpo_config.max_steps}, lr={grpo_config.learning_rate}\")\nprint(f\" num_generations={grpo_config.num_generations}, optim={grpo_config.optim}\")\nprint(f\" max_prompt_length={max_prompt_length}, max_completion_length={max_seq_length - max_prompt_length}\")\ntrainer.train()"
113
  },
114
  {
115
  "cell_type": "markdown",