helshahaby commited on
Commit
8032ad5
·
verified ·
1 Parent(s): 7ec7ae6

Upload train_grpo_colab.ipynb

Browse files
Files changed (1) hide show
  1. training/train_grpo_colab.ipynb +497 -336
training/train_grpo_colab.ipynb CHANGED
@@ -1,339 +1,500 @@
1
  {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {},
6
- "source": [
7
- "# 🚗 Autonomous Driving Multi-Agent RL — GRPO Training\n",
8
- "**OpenEnv v0.2.1 · Unsloth · vLLM · W&B**\n",
9
- "\n",
10
- "Runtime: **H100 GPU (BF16)** | Reduce `max_steps` to `50` for a quick test run.\n",
11
- "\n",
12
- "Pipeline:\n",
13
- "```\n",
14
- "Environment State → LLM Reasoning → Action → Reward → GRPO Update\n",
15
- "```"
16
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  },
18
- {
19
- "cell_type": "code",
20
- "execution_count": null,
21
- "metadata": {},
22
- "outputs": [],
23
- "source": [
24
- "# ── 1. Install ────────────────────────────────────────────────────────────\n",
25
- "import sys\n",
26
- "!{sys.executable} -m pip install uv -q\n",
27
- "# vLLM + Unsloth (use venv if this errors — see hackathon notes)\n",
28
- "!uv pip install unsloth vllm --torch-backend=auto -q\n",
29
- "!uv pip install --upgrade --no-cache-dir --no-deps unsloth unsloth_zoo -q\n",
30
- "!uv pip install openenv-core==0.2.1 trl>=0.15.0 wandb gymnasium numpy pydantic -q\n",
31
- "print('✅ Install complete')"
32
- ]
33
- },
34
- {
35
- "cell_type": "code",
36
- "execution_count": null,
37
- "metadata": {},
38
- "outputs": [],
39
- "source": [
40
- "# ── 2. Clone project from HF Space ───────────────────────────────────────\n",
41
- "HF_SPACE = 'YOUR_HF_USERNAME/autonomous-driving-env' # <-- update\n",
42
- "HF_MODEL_REPO = 'YOUR_HF_USERNAME/autonomous-driving-grpo' # <-- update\n",
43
- "\n",
44
- "!git clone https://huggingface.co/spaces/{HF_SPACE} project\n",
45
- "import os; os.chdir('project')\n",
46
- "print('✅ Project cloned')"
47
- ]
48
- },
49
- {
50
- "cell_type": "code",
51
- "execution_count": null,
52
- "metadata": {},
53
- "outputs": [],
54
- "source": [
55
- "# ── 3. W&B + Config ──────────────────────────────────────────────────────\n",
56
- "import wandb\n",
57
- "wandb.login()\n",
58
- "\n",
59
- "CONFIG = {\n",
60
- " 'model_name': 'unsloth/Qwen3-4B-unsloth-bnb-4bit',\n",
61
- " # H100 BF16 (faster): 'unsloth/gpt-oss-20b-bf16'\n",
62
- " 'max_steps': 300, # reduce to 50 for quick test\n",
63
- " 'num_generations': 4,\n",
64
- " 'max_new_tokens': 128,\n",
65
- " 'learning_rate': 5e-6,\n",
66
- " 'batch_size': 2,\n",
67
- " 'grad_accum': 4,\n",
68
- " 'lora_r': 16,\n",
69
- " 'lora_alpha': 32,\n",
70
- " 'fast_inference': True, # uses vLLM\n",
71
- " 'games_per_step': 4,\n",
72
- " 'max_episode_steps': 30,\n",
73
- "}\n",
74
- "\n",
75
- "run = wandb.init(\n",
76
- " project='openenv-autonomous-driving',\n",
77
- " config=CONFIG,\n",
78
- " tags=['grpo', 'openenv', 'autonomous-driving', 'negotiation', 'multi-agent']\n",
79
- ")\n",
80
- "print('W&B run:', run.url)"
81
- ]
82
- },
83
- {
84
- "cell_type": "code",
85
- "execution_count": null,
86
- "metadata": {},
87
- "outputs": [],
88
- "source": [
89
- "# ── 4. Load Model (Unsloth) ──────────────────────────────────────────────\n",
90
- "from unsloth import FastLanguageModel\n",
91
- "import torch\n",
92
- "\n",
93
- "model, tokenizer = FastLanguageModel.from_pretrained(\n",
94
- " model_name=CONFIG['model_name'],\n",
95
- " max_seq_length=2048,\n",
96
- " load_in_4bit=True,\n",
97
- " fast_inference=CONFIG['fast_inference'],\n",
98
- " gpu_memory_utilization=0.7,\n",
99
- ")\n",
100
- "\n",
101
- "model = FastLanguageModel.get_peft_model(\n",
102
- " model,\n",
103
- " r=CONFIG['lora_r'],\n",
104
- " lora_alpha=CONFIG['lora_alpha'],\n",
105
- " target_modules=['q_proj','k_proj','v_proj','o_proj',\n",
106
- " 'gate_proj','up_proj','down_proj'],\n",
107
- " lora_dropout=0,\n",
108
- " bias='none',\n",
109
- " use_gradient_checkpointing='unsloth',\n",
110
- " random_state=42,\n",
111
- ")\n",
112
- "print(f'✅ Model loaded: {CONFIG[\"model_name\"]}')\n",
113
- "print(f'Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}')"
114
- ]
115
- },
116
- {
117
- "cell_type": "code",
118
- "execution_count": null,
119
- "metadata": {},
120
- "outputs": [],
121
- "source": [
122
- "# ── 5. Load Environment + Agent ──────────────────────────────────────────\n",
123
- "import sys; sys.path.insert(0, '.')\n",
124
- "from env.negotiation_env import NegotiationDrivingEnv\n",
125
- "from agents.negotiation_agent import NegotiationAgent\n",
126
- "\n",
127
- "# Sanity check\n",
128
- "env = NegotiationDrivingEnv()\n",
129
- "obs, _ = env.reset()\n",
130
- "print('Obs:', obs)\n",
131
- "print('Lidar:', env.lidar_scan())\n",
132
- "print('Collision prediction:', env.predict_collision())\n",
133
- "print()\n",
134
- "print(env.render())\n",
135
- "\n",
136
- "# Test negotiation\n",
137
- "resp = env.negotiate('blocker', 'Please yield, I need to pass safely')\n",
138
- "print('Blocker responds:', resp)"
139
- ]
140
- },
141
- {
142
- "cell_type": "code",
143
- "execution_count": null,
144
- "metadata": {},
145
- "outputs": [],
146
- "source": [
147
- "# ── 6. Reward Functions for GRPO ──────────────────────────────────────���──\n",
148
- "import json\n",
149
- "\n",
150
- "def format_reward(completions, **kwargs):\n",
151
- " rewards = []\n",
152
- " for c in completions:\n",
153
- " try:\n",
154
- " data = json.loads(c.strip())\n",
155
- " score = 0.0\n",
156
- " if len(data.get('thinking', '')) > 15: score += 0.2\n",
157
- " if isinstance(data.get('action'), int): score += 0.2\n",
158
- " if 'negotiate' in data: score += 0.1\n",
159
- " rewards.append(score)\n",
160
- " except Exception:\n",
161
- " rewards.append(-0.05)\n",
162
- " return rewards\n",
163
- "\n",
164
- "def action_validity_reward(completions, **kwargs):\n",
165
- " rewards = []\n",
166
- " for c in completions:\n",
167
- " try:\n",
168
- " action = json.loads(c.strip()).get('action', -1)\n",
169
- " rewards.append(0.1 if action in [0, 1, 2, 3] else -0.1)\n",
170
- " except Exception:\n",
171
- " rewards.append(-0.1)\n",
172
- " return rewards\n",
173
- "\n",
174
- "print('✅ Reward functions ready')"
175
- ]
176
- },
177
- {
178
- "cell_type": "code",
179
- "execution_count": null,
180
- "metadata": {},
181
- "outputs": [],
182
- "source": [
183
- "# ── 7. Rollout Collection ────────────────────────────────────────────────\n",
184
- "from datasets import Dataset\n",
185
- "from agents.negotiation_agent import SYSTEM_PROMPT\n",
186
- "\n",
187
- "def collect_episodes(model, tokenizer, n_games=4):\n",
188
- " agent = NegotiationAgent(model=model, tokenizer=tokenizer)\n",
189
- " experiences = []\n",
190
- " wins = 0\n",
191
- "\n",
192
- " for game_i in range(n_games):\n",
193
- " env = NegotiationDrivingEnv()\n",
194
- " obs, _ = env.reset()\n",
195
- " agent.reset_episode()\n",
196
- " done = False\n",
197
- " episode_start = len(experiences)\n",
198
- "\n",
199
- " for step_i in range(CONFIG['max_episode_steps']):\n",
200
- " if done:\n",
201
- " break\n",
202
- " action, response, fmt_reward = agent.act(env)\n",
203
- " obs, env_reward, done, _, info = env.step(action)\n",
204
- " total_reward = env_reward + fmt_reward\n",
205
- " agent.record(obs.tolist(), action, total_reward)\n",
206
- "\n",
207
- " prompt_str = tokenizer.apply_chat_template(\n",
208
- " [{'role': 'system', 'content': SYSTEM_PROMPT},\n",
209
- " {'role': 'user', 'content': agent.build_prompt(env)}],\n",
210
- " tokenize=False\n",
211
- " )\n",
212
- " experiences.append({\n",
213
- " 'prompt': prompt_str,\n",
214
- " 'response': response,\n",
215
- " 'reward': total_reward,\n",
216
- " })\n",
217
- "\n",
218
- " outcome = info.get('outcome', '')\n",
219
- " bonus = 1.0 if outcome == 'goal_reached' else (-1.0 if outcome == 'collision' else 0.0)\n",
220
- " for exp in experiences[episode_start:]:\n",
221
- " exp['reward'] += bonus\n",
222
- " if outcome == 'goal_reached':\n",
223
- " wins += 1\n",
224
- " print(f' Game {game_i+1}/{n_games} | {outcome} | steps={step_i}')\n",
225
- "\n",
226
- " return experiences, wins / n_games\n",
227
- "\n",
228
- "print('Collecting initial rollouts...')\n",
229
- "exps, win_rate = collect_episodes(model, tokenizer, CONFIG['games_per_step'])\n",
230
- "print(f'Win rate: {win_rate:.1%} | Samples: {len(exps)}')\n",
231
- "wandb.log({'initial_win_rate': win_rate})"
232
- ]
233
- },
234
- {
235
- "cell_type": "code",
236
- "execution_count": null,
237
- "metadata": {},
238
- "outputs": [],
239
- "source": [
240
- "# ── 8. GRPO Training ─────────────────────────────────────────────────────\n",
241
- "from trl import GRPOTrainer, GRPOConfig\n",
242
- "\n",
243
- "dataset = Dataset.from_list([\n",
244
- " {'prompt': e['prompt'], 'reward': e['reward']} for e in exps\n",
245
- "])\n",
246
- "\n",
247
- "grpo_config = GRPOConfig(\n",
248
- " output_dir='./autodrive-grpo-checkpoints',\n",
249
- " max_steps=CONFIG['max_steps'],\n",
250
- " per_device_train_batch_size=CONFIG['batch_size'],\n",
251
- " gradient_accumulation_steps=CONFIG['grad_accum'],\n",
252
- " learning_rate=CONFIG['learning_rate'],\n",
253
- " num_generations=CONFIG['num_generations'],\n",
254
- " max_new_tokens=CONFIG['max_new_tokens'],\n",
255
- " max_prompt_length=1024,\n",
256
- " bf16=True,\n",
257
- " logging_steps=10,\n",
258
- " save_steps=100,\n",
259
- " report_to='wandb',\n",
260
- " use_vllm=CONFIG['fast_inference'],\n",
261
- " vllm_gpu_memory_utilization=0.3,\n",
262
- " temperature=0.7,\n",
263
- " kl_coef=0.01,\n",
264
- ")\n",
265
- "\n",
266
- "trainer = GRPOTrainer(\n",
267
- " model=model,\n",
268
- " processing_class=tokenizer,\n",
269
- " reward_funcs=[format_reward, action_validity_reward],\n",
270
- " args=grpo_config,\n",
271
- " train_dataset=dataset,\n",
272
- ")\n",
273
- "\n",
274
- "print('🚀 Starting GRPO training...')\n",
275
- "trainer.train()\n",
276
- "print('✅ Training complete!')"
277
- ]
278
- },
279
- {
280
- "cell_type": "code",
281
- "execution_count": null,
282
- "metadata": {},
283
- "outputs": [],
284
- "source": [
285
- "# ── 9. Online RL — Closed-Loop Self-Improvement ──────────────────────────\n",
286
- "ONLINE_ITERS = 5\n",
287
- "win_rates = [win_rate]\n",
288
- "\n",
289
- "for i in range(ONLINE_ITERS):\n",
290
- " print(f'=== Online RL Iteration {i+1}/{ONLINE_ITERS} ===')\n",
291
- " fresh_exps, wr = collect_episodes(model, tokenizer, CONFIG['games_per_step'])\n",
292
- " win_rates.append(wr)\n",
293
- " wandb.log({'win_rate': wr, 'iteration': i + 1})\n",
294
- " print(f'Win rate: {wr:.1%}')\n",
295
- "\n",
296
- " trainer.train_dataset = Dataset.from_list(\n",
297
- " [{'prompt': e['prompt'], 'reward': e['reward']} for e in fresh_exps]\n",
298
- " )\n",
299
- " trainer.args.max_steps = 50\n",
300
- " trainer.train()\n",
301
- "\n",
302
- "print(f'Win rate progression: {[f\"{r:.1%}\" for r in win_rates]}')"
303
- ]
304
- },
305
- {
306
- "cell_type": "code",
307
- "execution_count": null,
308
- "metadata": {},
309
- "outputs": [],
310
- "source": [
311
- "# ── 10. Save & Push to HF Hub ────────────────────────────────────────────\n",
312
- "model.save_pretrained('autodrive-adapter')\n",
313
- "tokenizer.save_pretrained('autodrive-adapter')\n",
314
- "model.push_to_hub(HF_MODEL_REPO)\n",
315
- "tokenizer.push_to_hub(HF_MODEL_REPO)\n",
316
- "print(f'✅ Model: https://huggingface.co/{HF_MODEL_REPO}')\n",
317
- "wandb.finish()"
318
- ]
319
- }
320
- ],
321
- "metadata": {
322
- "kernelspec": {
323
- "display_name": "Python 3",
324
- "language": "python",
325
- "name": "python3"
326
- },
327
- "language_info": {
328
- "name": "python",
329
- "version": "3.11.0"
330
- },
331
- "accelerator": "GPU",
332
- "colab": {
333
- "gpuType": "H100",
334
- "provenance": []
335
- }
336
- },
337
- "nbformat": 4,
338
- "nbformat_minor": 4
339
  }
 
1
  {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "f0rjS2t06uOu"
7
+ },
8
+ "source": [
9
+ "# 🚗 Autonomous Driving Multi-Agent RL — GRPO Training\n",
10
+ "**OpenEnv v0.2.1 · Unsloth · vLLM · W&B**\n",
11
+ "\n",
12
+ "Runtime: **H100 GPU (BF16)** | Reduce `max_steps` to `50` for a quick test run.\n",
13
+ "\n",
14
+ "Pipeline:\n",
15
+ "```\n",
16
+ "Environment State → LLM Reasoning → Action → Reward → GRPO Update\n",
17
+ "```"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "source": [
23
+ "!unzip final_project.zip\n",
24
+ "%cd final_project\n",
25
+ "!pip install -r requirements.txt"
26
+ ],
27
+ "metadata": {
28
+ "colab": {
29
+ "base_uri": "https://localhost:8080/"
30
+ },
31
+ "id": "v1hP_bjO7Qz_",
32
+ "outputId": "b8abfec3-bd18-40a9-957e-f1b66933e97c"
33
+ },
34
+ "execution_count": 2,
35
+ "outputs": [
36
+ {
37
+ "output_type": "stream",
38
+ "name": "stdout",
39
+ "text": [
40
+ "unzip: cannot find or open final_project.zip, final_project.zip.zip or final_project.zip.ZIP.\n",
41
+ "[Errno 2] No such file or directory: 'final_project'\n",
42
+ "/content\n",
43
+ "\u001b[31mERROR: Could not open requirements file: [Errno 2] No such file or directory: 'requirements.txt'\u001b[0m\u001b[31m\n",
44
+ "\u001b[0m"
45
+ ]
46
+ }
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 3,
52
+ "metadata": {
53
+ "colab": {
54
+ "base_uri": "https://localhost:8080/"
55
+ },
56
+ "id": "hr_s_1mZ6uOw",
57
+ "outputId": "6f7b3119-ddd6-4737-ad2b-18ce8cb368e0"
58
+ },
59
+ "outputs": [
60
+ {
61
+ "output_type": "stream",
62
+ "name": "stdout",
63
+ "text": [
64
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m23.4/23.4 MB\u001b[0m \u001b[31m47.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
65
+ "\u001b[?25h✅ Install complete\n"
66
+ ]
67
+ }
68
+ ],
69
+ "source": [
70
+ "# ── 1. Install ────────────────────────────────────────────────────────────\n",
71
+ "import sys\n",
72
+ "!{sys.executable} -m pip install uv -q\n",
73
+ "# vLLM + Unsloth (use venv if this errors — see hackathon notes)\n",
74
+ "!uv pip install unsloth vllm --torch-backend=auto -q\n",
75
+ "!uv pip install --upgrade --no-cache-dir --no-deps unsloth unsloth_zoo -q\n",
76
+ "!uv pip install openenv-core==0.2.1 trl>=0.15.0 wandb gymnasium numpy pydantic -q\n",
77
+ "print('✅ Install complete')"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": 4,
83
+ "metadata": {
84
+ "colab": {
85
+ "base_uri": "https://localhost:8080/"
86
+ },
87
+ "id": "2i4rnzlG6uOx",
88
+ "outputId": "2272397d-8b25-4a16-b692-e3b732e1a459"
89
+ },
90
+ "outputs": [
91
+ {
92
+ "output_type": "stream",
93
+ "name": "stdout",
94
+ "text": [
95
+ "Cloning into 'AD'...\n",
96
+ "remote: Enumerating objects: 587, done.\u001b[K\n",
97
+ "remote: Total 587 (delta 0), reused 0 (delta 0), pack-reused 587 (from 1)\u001b[K\n",
98
+ "Receiving objects: 100% (587/587), 183.43 KiB | 2.78 MiB/s, done.\n",
99
+ "Resolving deltas: 100% (368/368), done.\n",
100
+ "✅ Project cloned\n"
101
+ ]
102
+ }
103
+ ],
104
+ "source": [
105
+ "# ── 2. Clone project from HF Space ───────────────────────────────────────\n",
106
+ "HF_SPACE = 'helshahaby/AD' # <-- update\n",
107
+ "HF_MODEL_REPO = 'helshahaby/AD-grpo' # <-- update\n",
108
+ "\n",
109
+ "!git clone https://huggingface.co/spaces/{HF_SPACE}\n",
110
+ "#import os; os.chdir('project')\n",
111
+ "print('✅ Project cloned')"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "metadata": {
118
+ "colab": {
119
+ "base_uri": "https://localhost:8080/"
120
+ },
121
+ "id": "dOUoZU9k6uOy",
122
+ "outputId": "88149668-c236-4195-92b3-8aa7340bc067"
123
+ },
124
+ "outputs": [
125
+ {
126
+ "output_type": "stream",
127
+ "name": "stderr",
128
+ "text": [
129
+ "/usr/local/lib/python3.12/dist-packages/notebook/notebookapp.py:191: SyntaxWarning: invalid escape sequence '\\/'\n",
130
+ " | |_| | '_ \\/ _` / _` | _/ -_)\n",
131
+ "\u001b[34m\u001b[1mwandb\u001b[0m: (1) Create a W&B account\n",
132
+ "\u001b[34m\u001b[1mwandb\u001b[0m: (2) Use an existing W&B account\n",
133
+ "\u001b[34m\u001b[1mwandb\u001b[0m: (3) Don't visualize my results\n",
134
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Enter your choice:"
135
+ ]
136
+ }
137
+ ],
138
+ "source": [
139
+ "# ── 3. W&B + Config ──────────────────────────────────────────────────────\n",
140
+ "import wandb\n",
141
+ "wandb.login()\n",
142
+ "\n",
143
+ "CONFIG = {\n",
144
+ " 'model_name': 'unsloth/Qwen3-4B-unsloth-bnb-4bit',\n",
145
+ " # H100 BF16 (faster): 'unsloth/gpt-oss-20b-bf16'\n",
146
+ " 'max_steps': 300, # reduce to 50 for quick test\n",
147
+ " 'num_generations': 4,\n",
148
+ " 'max_new_tokens': 128,\n",
149
+ " 'learning_rate': 5e-6,\n",
150
+ " 'batch_size': 2,\n",
151
+ " 'grad_accum': 4,\n",
152
+ " 'lora_r': 16,\n",
153
+ " 'lora_alpha': 32,\n",
154
+ " 'fast_inference': False, # uses vLLM\n",
155
+ " 'games_per_step': 4,\n",
156
+ " 'max_episode_steps': 30,\n",
157
+ "}\n",
158
+ "\n",
159
+ "run = wandb.init(\n",
160
+ " project='openenv-autonomous-driving',\n",
161
+ " config=CONFIG,\n",
162
+ " tags=['grpo', 'openenv', 'autonomous-driving', 'negotiation', 'multi-agent']\n",
163
+ ")\n",
164
+ "print('W&B run:', run.url)"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": null,
170
+ "metadata": {
171
+ "id": "aBRub7Tw6uOy"
172
+ },
173
+ "outputs": [],
174
+ "source": [
175
+ "# ── 4. Load Model (Unsloth) ──────────────────────────────────────────────\n",
176
+ "from unsloth import FastLanguageModel\n",
177
+ "import torch\n",
178
+ "\n",
179
+ "# T4 GPU detected (sm_75) — must disable vLLM (requires sm_80+)\n",
180
+ "# vLLM works on A100/H100 only. Use fast_inference=False for T4.\n",
181
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
182
+ " model_name=\"unsloth/Qwen3-4B-unsloth-bnb-4bit\",\n",
183
+ " max_seq_length=2048,\n",
184
+ " load_in_4bit=True,\n",
185
+ " fast_inference=False, # ← CHANGED: vLLM broken on T4 sm_75\n",
186
+ ")\n",
187
+ "\n",
188
+ "model = FastLanguageModel.get_peft_model(\n",
189
+ " model,\n",
190
+ " r=16,\n",
191
+ " lora_alpha=32,\n",
192
+ " target_modules=[\"q_proj\",\"k_proj\",\"v_proj\",\"o_proj\",\n",
193
+ " \"gate_proj\",\"up_proj\",\"down_proj\"],\n",
194
+ " lora_dropout=0,\n",
195
+ " bias=\"none\",\n",
196
+ " use_gradient_checkpointing=\"unsloth\",\n",
197
+ " random_state=42,\n",
198
+ ")\n",
199
+ "print(f\"✅ Model loaded on T4 (no vLLM)\")"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": null,
205
+ "metadata": {
206
+ "id": "pYeHREDK6uOy"
207
+ },
208
+ "outputs": [],
209
+ "source": [
210
+ "# ── 5. Load Environment + Agent ──────────────────────────────────────────\n",
211
+ "import sys; sys.path.insert(0, '.')\n",
212
+ "from env.negotiation_env import NegotiationDrivingEnv\n",
213
+ "from agents.negotiation_agent import NegotiationAgent\n",
214
+ "\n",
215
+ "# Sanity check\n",
216
+ "env = NegotiationDrivingEnv()\n",
217
+ "obs, _ = env.reset()\n",
218
+ "print('Obs:', obs)\n",
219
+ "print('Lidar:', env.lidar_scan())\n",
220
+ "print('Collision prediction:', env.predict_collision())\n",
221
+ "print()\n",
222
+ "print(env.render())\n",
223
+ "\n",
224
+ "# Test negotiation\n",
225
+ "resp = env.negotiate('blocker', 'Please yield, I need to pass safely')\n",
226
+ "print('Blocker responds:', resp)"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": null,
232
+ "metadata": {
233
+ "id": "FSsnaoxZ6uOz"
234
+ },
235
+ "outputs": [],
236
+ "source": [
237
+ "# ── 6. Reward Functions for GRPO ─────────────────────────────────────────\n",
238
+ "import json\n",
239
+ "\n",
240
+ "def format_reward(completions, **kwargs):\n",
241
+ " rewards = []\n",
242
+ " for c in completions:\n",
243
+ " try:\n",
244
+ " data = json.loads(c.strip())\n",
245
+ " score = 0.0\n",
246
+ " if len(data.get('thinking', '')) > 15: score += 0.2\n",
247
+ " if isinstance(data.get('action'), int): score += 0.2\n",
248
+ " if 'negotiate' in data: score += 0.1\n",
249
+ " rewards.append(score)\n",
250
+ " except Exception:\n",
251
+ " rewards.append(-0.05)\n",
252
+ " return rewards\n",
253
+ "\n",
254
+ "def action_validity_reward(completions, **kwargs):\n",
255
+ " rewards = []\n",
256
+ " for c in completions:\n",
257
+ " try:\n",
258
+ " action = json.loads(c.strip()).get('action', -1)\n",
259
+ " rewards.append(0.1 if action in [0, 1, 2, 3] else -0.1)\n",
260
+ " except Exception:\n",
261
+ " rewards.append(-0.1)\n",
262
+ " return rewards\n",
263
+ "\n",
264
+ "print('✅ Reward functions ready')"
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "source": [
270
+ "# ── Paste this cell BEFORE the rollout collection cell ──────────────────\n",
271
+ "import json, numpy as np\n",
272
+ "from json import JSONEncoder\n",
273
+ "\n",
274
+ "_orig = JSONEncoder.default\n",
275
+ "def _patched(self, o):\n",
276
+ " if isinstance(o, (np.integer,)): return int(o)\n",
277
+ " if isinstance(o, (np.floating,)): return float(o)\n",
278
+ " if isinstance(o, (np.bool_,)): return bool(o)\n",
279
+ " if isinstance(o, np.ndarray): return o.tolist()\n",
280
+ " return _orig(self, o)\n",
281
+ "JSONEncoder.default = _patched\n",
282
+ "print(\"✅ numpy JSON serialization patched globally\")"
283
+ ],
284
+ "metadata": {
285
+ "id": "jtt3fBe2EQ5_"
286
+ },
287
+ "execution_count": null,
288
+ "outputs": []
289
+ },
290
+ {
291
+ "cell_type": "code",
292
+ "execution_count": null,
293
+ "metadata": {
294
+ "id": "lHGC94fN6uOz"
295
+ },
296
+ "outputs": [],
297
+ "source": [
298
+ "# ── 7. Rollout Collection ────────────────────────────────────────────────\n",
299
+ "from datasets import Dataset\n",
300
+ "from agents.negotiation_agent import SYSTEM_PROMPT\n",
301
+ "\n",
302
+ "# ── FIX: JSON serializer that handles numpy/bool types ──────────────────\n",
303
+ "import json\n",
304
+ "import numpy as np\n",
305
+ "\n",
306
+ "class SafeEncoder(json.JSONEncoder):\n",
307
+ " def default(self, obj):\n",
308
+ " if isinstance(obj, np.integer): return int(obj)\n",
309
+ " if isinstance(obj, np.floating): return float(obj)\n",
310
+ " if isinstance(obj, np.bool_): return bool(obj)\n",
311
+ " if isinstance(obj, np.ndarray): return obj.tolist()\n",
312
+ " if isinstance(obj, bool): return bool(obj)\n",
313
+ " return super().default(obj)\n",
314
+ "\n",
315
+ "# Monkey-patch json.dumps to always use SafeEncoder\n",
316
+ "_original_dumps = json.dumps\n",
317
+ "def safe_dumps(obj, **kwargs):\n",
318
+ " kwargs.setdefault('cls', SafeEncoder)\n",
319
+ " return _original_dumps(obj, **kwargs)\n",
320
+ "json.dumps = safe_dumps\n",
321
+ "\n",
322
+ "print(\"✅ JSON serializer patched for numpy/bool types\")\n",
323
+ "\n",
324
+ "def collect_episodes(model, tokenizer, n_games=4):\n",
325
+ " agent = NegotiationAgent(model=model, tokenizer=tokenizer)\n",
326
+ " experiences = []\n",
327
+ " wins = 0\n",
328
+ "\n",
329
+ " for game_i in range(n_games):\n",
330
+ " env = NegotiationDrivingEnv()\n",
331
+ " obs, _ = env.reset()\n",
332
+ " agent.reset_episode()\n",
333
+ " done = False\n",
334
+ " episode_start = len(experiences)\n",
335
+ "\n",
336
+ " for step_i in range(CONFIG['max_episode_steps']):\n",
337
+ " if done:\n",
338
+ " break\n",
339
+ " action, response, fmt_reward = agent.act(env)\n",
340
+ " obs, env_reward, done, _, info = env.step(action)\n",
341
+ " total_reward = env_reward + fmt_reward\n",
342
+ " #convert numpy scalars explicitly:\n",
343
+ " agent.record(\n",
344
+ " [int(x) for x in obs], # numpy int32 → plain int\n",
345
+ " int(action),\n",
346
+ " float(total_reward)\n",
347
+ " )\n",
348
+ "\n",
349
+ " prompt_str = tokenizer.apply_chat_template(\n",
350
+ " [{'role': 'system', 'content': SYSTEM_PROMPT},\n",
351
+ " {'role': 'user', 'content': agent.build_prompt(env)}],\n",
352
+ " tokenize=False\n",
353
+ " )\n",
354
+ " experiences.append({\n",
355
+ " 'prompt': prompt_str,\n",
356
+ " 'response': response,\n",
357
+ " 'reward': total_reward,\n",
358
+ " })\n",
359
+ "\n",
360
+ " outcome = info.get('outcome', '')\n",
361
+ " bonus = 1.0 if outcome == 'goal_reached' else (-1.0 if outcome == 'collision' else 0.0)\n",
362
+ " for exp in experiences[episode_start:]:\n",
363
+ " exp['reward'] += bonus\n",
364
+ " if outcome == 'goal_reached':\n",
365
+ " wins += 1\n",
366
+ " print(f' Game {game_i+1}/{n_games} | {outcome} | steps={step_i}')\n",
367
+ "\n",
368
+ " return experiences, wins / n_games\n",
369
+ "\n",
370
+ "print('Collecting initial rollouts...')\n",
371
+ "exps, win_rate = collect_episodes(model, tokenizer, CONFIG['games_per_step'])\n",
372
+ "print(f'Win rate: {win_rate:.1%} | Samples: {len(exps)}')\n",
373
+ "wandb.log({'initial_win_rate': win_rate})"
374
+ ]
375
+ },
376
+ {
377
+ "cell_type": "code",
378
+ "source": [
379
+ "CONFIG['fast_inference'] = False\n",
380
+ "print(\"fast_inference:\", CONFIG['fast_inference']) # must print False"
381
+ ],
382
+ "metadata": {
383
+ "id": "PXTY00MTK9XW"
384
+ },
385
+ "execution_count": null,
386
+ "outputs": []
387
+ },
388
+ {
389
+ "cell_type": "code",
390
+ "execution_count": null,
391
+ "metadata": {
392
+ "id": "H5YbGlpM6uOz"
393
+ },
394
+ "outputs": [],
395
+ "source": [
396
+ "# ── 8. GRPO Training ─────────────────────────────────────────────────────\n",
397
+ "from trl import GRPOTrainer, GRPOConfig\n",
398
+ "from datasets import Dataset\n",
399
+ "\n",
400
+ "dataset = Dataset.from_list([\n",
401
+ " {'prompt': e['prompt'], 'reward': e['reward']} for e in exps\n",
402
+ "])\n",
403
+ "\n",
404
+ "grpo_config = GRPOConfig(\n",
405
+ " output_dir='./autodrive-grpo-checkpoints',\n",
406
+ " max_steps=50,\n",
407
+ " per_device_train_batch_size=1,\n",
408
+ " gradient_accumulation_steps=8,\n",
409
+ " learning_rate=5e-6,\n",
410
+ " num_generations=2,\n",
411
+ " max_completion_length=64,\n",
412
+ " max_prompt_length=512,\n",
413
+ " bf16=False,\n",
414
+ " fp16=True,\n",
415
+ " logging_steps=5,\n",
416
+ " save_steps=50,\n",
417
+ " report_to='wandb',\n",
418
+ " use_vllm=False, # ← hardcoded False, no CONFIG reference\n",
419
+ " temperature=0.7,\n",
420
+ " beta=0.01,\n",
421
+ ")\n",
422
+ "\n",
423
+ "trainer = GRPOTrainer(\n",
424
+ " model=model,\n",
425
+ " processing_class=tokenizer,\n",
426
+ " reward_funcs=[format_reward, action_validity_reward],\n",
427
+ " args=grpo_config,\n",
428
+ " train_dataset=dataset,\n",
429
+ ")\n",
430
+ "\n",
431
+ "print('🚀 Starting GRPO training...')\n",
432
+ "trainer.train()\n",
433
+ "print('✅ Training complete!')"
434
+ ]
435
+ },
436
+ {
437
+ "cell_type": "code",
438
+ "execution_count": null,
439
+ "metadata": {
440
+ "id": "FImz1duK6uOz"
441
+ },
442
+ "outputs": [],
443
+ "source": [
444
+ "# ── 9. Online RL — Closed-Loop Self-Improvement ──────────────────────────\n",
445
+ "ONLINE_ITERS = 5\n",
446
+ "win_rates = [win_rate]\n",
447
+ "\n",
448
+ "for i in range(ONLINE_ITERS):\n",
449
+ " print(f'=== Online RL Iteration {i+1}/{ONLINE_ITERS} ===')\n",
450
+ " fresh_exps, wr = collect_episodes(model, tokenizer, CONFIG['games_per_step'])\n",
451
+ " win_rates.append(wr)\n",
452
+ " wandb.log({'win_rate': wr, 'iteration': i + 1})\n",
453
+ " print(f'Win rate: {wr:.1%}')\n",
454
+ "\n",
455
+ " trainer.train_dataset = Dataset.from_list(\n",
456
+ " [{'prompt': e['prompt'], 'reward': e['reward']} for e in fresh_exps]\n",
457
+ " )\n",
458
+ " trainer.args.max_steps = 50\n",
459
+ " trainer.train()\n",
460
+ "\n",
461
+ "print(f'Win rate progression: {[f\"{r:.1%}\" for r in win_rates]}')"
462
+ ]
463
+ },
464
+ {
465
+ "cell_type": "code",
466
+ "execution_count": null,
467
+ "metadata": {
468
+ "id": "vbyEXn_p6uOz"
469
+ },
470
+ "outputs": [],
471
+ "source": [
472
+ "# ── 10. Save & Push to HF Hub ────────────────────────────────────────────\n",
473
+ "model.save_pretrained('autodrive-adapter')\n",
474
+ "tokenizer.save_pretrained('autodrive-adapter')\n",
475
+ "model.push_to_hub(HF_MODEL_REPO)\n",
476
+ "tokenizer.push_to_hub(HF_MODEL_REPO)\n",
477
+ "print(f'✅ Model: https://huggingface.co/{HF_MODEL_REPO}')\n",
478
+ "wandb.finish()"
479
+ ]
480
+ }
481
+ ],
482
+ "metadata": {
483
+ "kernelspec": {
484
+ "display_name": "Python 3",
485
+ "language": "python",
486
+ "name": "python3"
487
+ },
488
+ "language_info": {
489
+ "name": "python",
490
+ "version": "3.11.0"
491
+ },
492
+ "accelerator": "GPU",
493
+ "colab": {
494
+ "gpuType": "T4",
495
+ "provenance": []
496
+ }
497
  },
498
+ "nbformat": 4,
499
+ "nbformat_minor": 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
  }