helshahaby commited on
Commit
185e2d2
·
verified ·
1 Parent(s): b21d658

Upload 6 files

Browse files
Files changed (6) hide show
  1. Dockerfile +24 -0
  2. HACKATHON_GUIDE.md +215 -0
  3. client.py +25 -0
  4. connect4_environment.py +225 -0
  5. connect4_grpo_training.ipynb +654 -0
  6. models.py +45 -0
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ build-essential \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Install Python dependencies
11
+ COPY server/requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ # Copy environment source
15
+ COPY . .
16
+ RUN pip install -e . --no-cache-dir
17
+
18
+ # HF Spaces runs on port 7860
19
+ EXPOSE 7860
20
+
21
+ # Enable web interface for HF Spaces demo
22
+ ENV ENABLE_WEB_INTERFACE=true
23
+
24
+ CMD ["python", "-m", "uvicorn", "connect4_env.server.app:app", "--host", "0.0.0.0", "--port", "7860"]
HACKATHON_GUIDE.md ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚗 Meta OpenEnv Hackathon — Connect4 Multi-Agent Autonomous Driving
2
+
3
+ ## Complete Delivery Guide
4
+
5
+ ---
6
+
7
+ ## 🏗️ Architecture Overview
8
+
9
+ ```
10
+ ┌─────────────────────────────────────────────────────────────────┐
11
+ │ TRAINING LOOP (Colab H100) │
12
+ │ │
13
+ │ ┌──────────────┐ prompts ┌─────────────────────────────┐ │
14
+ │ │ Unsloth │◄────────────►│ LLM (Qwen3-4B / gpt-oss) │ │
15
+ │ │ GRPO/TRL │ completions │ + LoRA Adapter │ │
16
+ │ └──────┬───────┘ └─────────────────────────────┘ │
17
+ │ │ rewards │
18
+ │ ┌──────▼───────┐ W&B │
19
+ │ │ Reward Fns │───────────► Experiment Tracking │
20
+ │ └──────┬───────┘ │
21
+ └─────────┼───────────────────────────────────────────────────────┘
22
+ │ step() / reset()
23
+ │ WebSocket
24
+ ┌─────────▼───────────────────────────────────────────────────────┐
25
+ │ HF SPACES (OpenEnv Environment Server) │
26
+ │ │
27
+ │ ┌─────────────────────────────────────────────────────────┐ │
28
+ │ │ Connect4Environment (FastAPI + OpenEnv v0.2.1) │ │
29
+ │ │ • 6×7 board = intersection grid │ │
30
+ │ │ • Player 1 (X) = Ego Vehicle (LLM) │ │
31
+ │ │ • Player 2 (O) = Rule-based opponent │ │
32
+ │ │ • Shaped rewards: win/loss/block/3-in-row/format │ │
33
+ │ └─────────────────────────────────────────────────────────┘ │
34
+ └─────────────────────────────────────────────────────────────────┘
35
+ ```
36
+
37
+ ---
38
+
39
+ ## 📁 File Structure
40
+
41
+ ```
42
+ connect4_env/ ← HF Spaces repo (deploy this)
43
+ ├── __init__.py
44
+ ├── models.py ← Pydantic Action/Observation/State
45
+ ├── client.py ← Connect4Env(EnvClient)
46
+ ├── openenv.yaml ← Manifest
47
+ ├── pyproject.toml
48
+ ├── Dockerfile ← HF Spaces Docker SDK
49
+ ├── README.md ← HF Space card
50
+ └── server/
51
+ ├── app.py ← FastAPI entry point
52
+ ├── connect4_environment.py ← Game logic + reward shaping
53
+ └── requirements.txt
54
+
55
+ connect4_grpo_training.ipynb ← Colab training notebook (H100)
56
+ ```
57
+
58
+ ---
59
+
60
+ ## 🚀 Step-by-Step Deployment
61
+
62
+ ### Step 1 — Deploy Environment to HF Spaces
63
+
64
+ ```bash
65
+ # Install OpenEnv CLI
66
+ pip install openenv-core==0.2.1
67
+
68
+ # Login to HF
69
+ huggingface-cli login
70
+
71
+ # From inside connect4_env/ directory:
72
+ cd connect4_env
73
+ openenv push --repo-id YOUR_HF_USERNAME/connect4-env
74
+
75
+ # OR manually:
76
+ # 1. Create new Space at https://huggingface.co/new-space
77
+ # 2. Set SDK = Docker, hardware = CPU Basic
78
+ # 3. Push this folder as the repo
79
+ ```
80
+
81
+ After deployment, your env is live at:
82
+ `https://YOUR_HF_USERNAME-connect4-env.hf.space`
83
+
84
+ Test it:
85
+ ```python
86
+ pip install openenv-core==0.2.1
87
+ from openenv.core.env_client import EnvClient
88
+ # ... or pip install from your HF Space
89
+ ```
90
+
91
+ ---
92
+
93
+ ### Step 2 — Run Training on Northflank / Colab
94
+
95
+ **Option A: Google Colab (recommended for hackathon)**
96
+ 1. Open `connect4_grpo_training.ipynb` in Colab
97
+ 2. Set Runtime → H100 GPU
98
+ 3. Update `HF_SPACE_URL` and `HF_MODEL_REPO` variables
99
+ 4. Run all cells
100
+
101
+ **Option B: Northflank Jupyter PyTorch**
102
+ 1. Go to https://app.northflank.com/t/openenv-hack-112/project/hackathon/services/jupyter-pytorch
103
+ 2. Upload the notebook
104
+ 3. The environment has PyTorch + CUDA pre-installed
105
+ 4. Install Unsloth: `uv pip install unsloth vllm --torch-backend=auto`
106
+
107
+ ---
108
+
109
+ ### Step 3 — vLLM GRPO Fix (if issues)
110
+
111
+ Per hackathon notes, if GRPO vLLM runs fail:
112
+ ```bash
113
+ python -m venv unsloth_env
114
+ source unsloth_env/bin/activate
115
+ pip install --upgrade pip && pip install uv
116
+ uv pip install unsloth vllm --torch-backend=auto
117
+ # Always update Unsloth:
118
+ pip install --upgrade --no-cache-dir --no-deps unsloth unsloth_zoo
119
+ ```
120
+
121
+ ---
122
+
123
+ ## 🔬 Training Pipeline Detail
124
+
125
+ ### Pre-training → SFT → RLHF → RL+Envs
126
+
127
+ ```
128
+ 1. BASE MODEL (Qwen3-4B or gpt-oss-20B)
129
+ Pre-trained on large text corpus
130
+
131
+ 2. SFT IMPLICIT
132
+ Prompt engineering guides format:
133
+ {"thinking": "...", "column": N}
134
+
135
+ 3. GRPO (RL without explicit reward model)
136
+ - num_generations=4 rollouts per prompt
137
+ - KL divergence penalty vs reference policy
138
+ - Format reward (JSON structure)
139
+ - Environment reward (win/loss/block)
140
+
141
+ 4. CLOSED-LOOP ONLINE RL
142
+ - Play N games with current policy
143
+ - Collect (prompt, response, reward) tuples
144
+ - Update policy with GRPO
145
+ - Repeat → self-improvement
146
+ ```
147
+
148
+ ### Reward Design
149
+
150
+ The reward function has 3 components:
151
+
152
+ | Component | Source | Value |
153
+ |-----------|--------|-------|
154
+ | **Outcome** | Environment (terminal) | ±10.0 |
155
+ | **Shaping** | Environment (per-step) | ±0.5, +0.2, -0.1 |
156
+ | **Format** | Local function | +0.3 |
157
+
158
+ Outcome is propagated back to all moves of a game (+1.0 win, -1.0 loss, +0.1 draw).
159
+
160
+ ---
161
+
162
+ ## 📊 W&B Metrics to Track
163
+
164
+ | Metric | What it shows |
165
+ |--------|---------------|
166
+ | `win_rate` | % games LLM wins vs rule-based |
167
+ | `reward/mean` | Average per-step reward |
168
+ | `kl_divergence` | Policy drift from base model |
169
+ | `format_reward` | % responses with valid JSON |
170
+ | `policy/entropy` | Exploration vs exploitation |
171
+
172
+ ---
173
+
174
+ ## 🔧 Environment Customization
175
+
176
+ The Connect4 environment can be extended for more realistic autonomous driving:
177
+
178
+ ```python
179
+ # Add to Connect4Action:
180
+ speed: float = Field(1.0, ge=0.0, le=3.0) # vehicle speed
181
+ lane_change: int = Field(0, ge=-1, le=1) # lane change direction
182
+
183
+ # Add to reward shaping:
184
+ def _safety_reward(self) -> float:
185
+ # Penalize high-speed moves near opponent
186
+ ...
187
+
188
+ # Add multi-agent (>2 vehicles):
189
+ AGENT3 = 3 # second LLM agent
190
+ ```
191
+
192
+ ---
193
+
194
+ ## 📎 Key Links
195
+
196
+ - **OpenEnv repo**: https://github.com/meta-pytorch/OpenEnv
197
+ - **Unsloth GRPO notebook**: https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/OpenEnv_gpt_oss_(20B)_Reinforcement_Learning_2048_Game_BF16.ipynb
198
+ - **Qwen3 GRPO (faster)**: https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_(4B)-GRPO.ipynb
199
+ - **TRL OpenEnv docs**: https://huggingface.co/docs/trl/openenv
200
+ - **Northflank Jupyter**: https://northflank.notion.site/Jupyter-Notebook-with-PyTorch-2036d14c7851802abb7ccb4a7c5c96be
201
+
202
+ ---
203
+
204
+ ## ✅ Hackathon Checklist
205
+
206
+ - [x] OpenEnv v0.2.1 environment built
207
+ - [x] Connect4 game logic with shaped rewards
208
+ - [x] Multi-agent (LLM + rule-based opponent)
209
+ - [x] Deploy to HF Spaces via `openenv push`
210
+ - [x] Unsloth GRPO training notebook (H100 BF16)
211
+ - [x] W&B experiment tracking
212
+ - [x] Closed-loop online RL loop
213
+ - [x] Format reward for JSON CoT reasoning
214
+ - [x] Evaluation tournament
215
+ - [ ] Push trained model to HF Hub ← fill in after training
client.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Connect4 Multi-Agent Environment — Client
3
+ OpenEnv v0.2.1 — connects to HF Space endpoint
4
+ """
5
+
6
+ from openenv.core.env_client import EnvClient
7
+ from .models import Connect4Action, Connect4Observation
8
+
9
+
10
+ class Connect4Env(EnvClient):
11
+ """
12
+ Client for the Connect4 multi-agent driving coordination environment.
13
+
14
+ Usage (async):
15
+ async with Connect4Env(base_url="https://YOUR-HF-SPACE.hf.space") as env:
16
+ obs = await env.reset()
17
+ result = await env.step(Connect4Action(column=3))
18
+
19
+ Usage (sync, for TRL/Unsloth training loops):
20
+ with Connect4Env(base_url="...").sync() as env:
21
+ obs = env.reset()
22
+ result = env.step(Connect4Action(column=3))
23
+ """
24
+ action_type = Connect4Action
25
+ observation_type = Connect4Observation
connect4_environment.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Connect4 Multi-Agent Environment — Server Side
3
+ Adapted for autonomous driving scenario:
4
+ - Agent 1 = "Ego vehicle" (LLM being trained)
5
+ - Agent 2 = "Opponent vehicle" (rule-based or another LLM)
6
+
7
+ The board represents a grid intersection control problem:
8
+ - Winning = successfully navigating without collision
9
+ - Rewards shaped for RL post-training
10
+ """
11
+
12
+ import numpy as np
13
+ from typing import Optional
14
+ from openenv.core.environment import Environment
15
+ from ..models import (
16
+ Connect4Action, Connect4Observation, Connect4State
17
+ )
18
+
19
+
20
+ ROWS = 6
21
+ COLS = 7
22
+ EMPTY = 0
23
+ AGENT1 = 1 # Ego vehicle / LLM under training
24
+ AGENT2 = 2 # Opponent / rule-based agent
25
+
26
+
27
+ class Connect4Environment(Environment):
28
+ """
29
+ Connect4 as a multi-agent driving coordination environment.
30
+
31
+ Observation:
32
+ - Board state (6x7 grid)
33
+ - Current player turn
34
+ - Legal moves
35
+ - Last move played
36
+ - Game status
37
+
38
+ Reward shaping (for RL):
39
+ +10.0 → Win (ego agent connects 4)
40
+ -10.0 → Loss (opponent connects 4)
41
+ +0.5 → Blocking opponent's winning move
42
+ +0.2 → Creating a 3-in-a-row
43
+ -0.1 → Invalid move attempt
44
+ 0.0 → Draw
45
+ """
46
+
47
+ def __init__(self):
48
+ super().__init__()
49
+ self.board: np.ndarray = np.zeros((ROWS, COLS), dtype=int)
50
+ self.current_player: int = AGENT1
51
+ self.done: bool = False
52
+ self.winner: Optional[int] = None
53
+ self.last_move: Optional[int] = None
54
+ self.move_history: list = []
55
+
56
+ # ------------------------------------------------------------------ #
57
+ # OpenEnv API #
58
+ # ------------------------------------------------------------------ #
59
+
60
+ def reset(self) -> Connect4Observation:
61
+ self.board = np.zeros((ROWS, COLS), dtype=int)
62
+ self.current_player = AGENT1
63
+ self.done = False
64
+ self.winner = None
65
+ self.last_move = None
66
+ self.move_history = []
67
+ return self._make_observation("Game reset. Your turn — you are Player 1 (Ego Vehicle).")
68
+
69
+ def step(self, action: Connect4Action) -> tuple[Connect4Observation, float, bool]:
70
+ if self.done:
71
+ obs = self._make_observation("Game already finished. Call reset() to start a new game.")
72
+ return obs, 0.0, True
73
+
74
+ col = action.column
75
+ reward = 0.0
76
+
77
+ # ---- validate move ----
78
+ if col < 0 or col >= COLS or not self._is_valid(col):
79
+ obs = self._make_observation(f"Invalid move: column {col} is full or out of range.")
80
+ return obs, -0.1, False
81
+
82
+ # ---- check for blocking bonus before placing ----
83
+ reward += self._blocking_bonus(col)
84
+
85
+ # ---- place piece ----
86
+ row = self._drop_piece(col, self.current_player)
87
+ self.last_move = col
88
+ self.move_history.append((self.current_player, col))
89
+
90
+ # ---- 3-in-a-row bonus ----
91
+ if self._count_streak(row, col, self.current_player) >= 3:
92
+ reward += 0.2
93
+
94
+ # ---- check win ----
95
+ if self._check_win(self.current_player):
96
+ self.done = True
97
+ self.winner = self.current_player
98
+ reward += 10.0 if self.current_player == AGENT1 else -10.0
99
+ msg = ("🏆 Ego vehicle wins! Successful navigation."
100
+ if self.current_player == AGENT1
101
+ else "💥 Opponent wins. Collision occurred.")
102
+ obs = self._make_observation(msg)
103
+ return obs, reward, True
104
+
105
+ # ---- check draw ----
106
+ if self._board_full():
107
+ self.done = True
108
+ obs = self._make_observation("🤝 Draw. Stalemate — no collision, no winner.")
109
+ return obs, 0.0, True
110
+
111
+ # ---- switch player ----
112
+ self.current_player = AGENT2 if self.current_player == AGENT1 else AGENT1
113
+ msg = f"Move accepted (col {col}). Now Player {self.current_player}'s turn."
114
+ obs = self._make_observation(msg)
115
+ return obs, reward, False
116
+
117
+ def state(self) -> Connect4State:
118
+ return Connect4State(
119
+ episode_id=self._episode_id,
120
+ step_count=self._step_count,
121
+ current_player=self.current_player,
122
+ done=self.done,
123
+ winner=self.winner,
124
+ move_history=self.move_history,
125
+ )
126
+
127
+ # ------------------------------------------------------------------ #
128
+ # Internal helpers #
129
+ # ------------------------------------------------------------------ #
130
+
131
+ def _make_observation(self, message: str) -> Connect4Observation:
132
+ return Connect4Observation(
133
+ board=self.board.tolist(),
134
+ board_str=self._render_board(),
135
+ current_player=self.current_player,
136
+ legal_moves=self._legal_moves(),
137
+ last_move=self.last_move,
138
+ done=self.done,
139
+ winner=self.winner,
140
+ message=message,
141
+ )
142
+
143
+ def _render_board(self) -> str:
144
+ symbols = {EMPTY: ".", AGENT1: "X", AGENT2: "O"}
145
+ rows = []
146
+ for r in range(ROWS):
147
+ rows.append(" ".join(symbols[self.board[r][c]] for c in range(COLS)))
148
+ rows.append("-" * (COLS * 2 - 1))
149
+ rows.append(" ".join(str(c) for c in range(COLS)))
150
+ return "\n".join(rows)
151
+
152
+ def _is_valid(self, col: int) -> bool:
153
+ return self.board[0][col] == EMPTY
154
+
155
+ def _legal_moves(self) -> list[int]:
156
+ return [c for c in range(COLS) if self._is_valid(c)]
157
+
158
+ def _drop_piece(self, col: int, player: int) -> int:
159
+ for row in range(ROWS - 1, -1, -1):
160
+ if self.board[row][col] == EMPTY:
161
+ self.board[row][col] = player
162
+ return row
163
+ return -1
164
+
165
+ def _check_win(self, player: int) -> bool:
166
+ b = self.board
167
+ # Horizontal
168
+ for r in range(ROWS):
169
+ for c in range(COLS - 3):
170
+ if all(b[r][c+i] == player for i in range(4)):
171
+ return True
172
+ # Vertical
173
+ for r in range(ROWS - 3):
174
+ for c in range(COLS):
175
+ if all(b[r+i][c] == player for i in range(4)):
176
+ return True
177
+ # Diagonal /
178
+ for r in range(3, ROWS):
179
+ for c in range(COLS - 3):
180
+ if all(b[r-i][c+i] == player for i in range(4)):
181
+ return True
182
+ # Diagonal \
183
+ for r in range(ROWS - 3):
184
+ for c in range(COLS - 3):
185
+ if all(b[r+i][c+i] == player for i in range(4)):
186
+ return True
187
+ return False
188
+
189
+ def _board_full(self) -> bool:
190
+ return all(self.board[0][c] != EMPTY for c in range(COLS))
191
+
192
+ def _count_streak(self, row: int, col: int, player: int) -> int:
193
+ """Count max consecutive pieces for player around (row, col)."""
194
+ directions = [(0,1),(1,0),(1,1),(1,-1)]
195
+ best = 1
196
+ for dr, dc in directions:
197
+ count = 1
198
+ for sign in [1, -1]:
199
+ r, c = row + sign*dr, col + sign*dc
200
+ while 0 <= r < ROWS and 0 <= c < COLS and self.board[r][c] == player:
201
+ count += 1
202
+ r += sign*dr
203
+ c += sign*dc
204
+ best = max(best, count)
205
+ return best
206
+
207
+ def _blocking_bonus(self, col: int) -> float:
208
+ """+0.5 if placing here blocks opponent's 4-in-a-row."""
209
+ opponent = AGENT2 if self.current_player == AGENT1 else AGENT1
210
+ test_board = self.board.copy()
211
+ for row in range(ROWS - 1, -1, -1):
212
+ if test_board[row][col] == EMPTY:
213
+ test_board[row][col] = opponent
214
+ break
215
+ # Check if opponent would have won
216
+ b = test_board
217
+ for r in range(ROWS):
218
+ for c in range(COLS - 3):
219
+ if all(b[r][c+i] == opponent for i in range(4)):
220
+ return 0.5
221
+ for r in range(ROWS - 3):
222
+ for c in range(COLS):
223
+ if all(b[r+i][c] == opponent for i in range(4)):
224
+ return 0.5
225
+ return 0.0
connect4_grpo_training.ipynb ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 🚗 Multi-Agent Autonomous Driving RL — Connect4 + OpenEnv v0.2.1\n",
8
+ "\n",
9
+ "**Hackathon Track:** Infra & Control, Tool & API Integration, Safety, Memory, Observability\n",
10
+ "\n",
11
+ "**Stack:**\n",
12
+ "- 🏗️ [OpenEnv v0.2.1](https://github.com/meta-pytorch/OpenEnv) — RL environment framework\n",
13
+ "- 🦥 [Unsloth](https://unsloth.ai) — fast GRPO fine-tuning (BF16, H100 optimized)\n",
14
+ "- 🤗 [TRL GRPO](https://huggingface.co/docs/trl) — policy optimization\n",
15
+ "- 📊 [W&B](https://wandb.ai) — experiment tracking\n",
16
+ "- ☁️ [HF Spaces](https://huggingface.co/spaces) — environment server\n",
17
+ "\n",
18
+ "**Environment:** Connect4 framed as multi-agent intersection coordination\n",
19
+ "- Player 1 (X) = Ego vehicle LLM (being trained)\n",
20
+ "- Player 2 (O) = Rule-based opponent vehicle\n",
21
+ "- Reward shaping encourages strategic, safe navigation decisions\n",
22
+ "\n",
23
+ "**Colab Runtime:** H100 GPU (BF16) — reduce `max_steps` for faster iteration"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "markdown",
28
+ "metadata": {},
29
+ "source": ["## 1️⃣ Install Dependencies"]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": null,
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "# Install Unsloth (latest, with vLLM for fast inference)\n",
38
+ "import sys\n",
39
+ "!{sys.executable} -m pip install --upgrade pip\n",
40
+ "!{sys.executable} -m pip install uv\n",
41
+ "\n",
42
+ "# Use venv for stability (recommended by hackathon notes)\n",
43
+ "# If running issues, uncomment and run in terminal:\n",
44
+ "# python -m venv unsloth_env && source unsloth_env/bin/activate\n",
45
+ "# uv pip install unsloth vllm --torch-backend=auto\n",
46
+ "\n",
47
+ "!uv pip install unsloth vllm --torch-backend=auto\n",
48
+ "!uv pip install --upgrade --no-cache-dir --no-deps unsloth unsloth_zoo\n",
49
+ "!uv pip install openenv-core==0.2.1 wandb trl>=0.15.0 pydantic numpy"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "# Install our Connect4 environment from HF Spaces\n",
59
+ "# Replace YOUR_HF_USERNAME with your actual HF username after deploying\n",
60
+ "HF_SPACE_REPO = \"YOUR_HF_USERNAME/connect4-env\" # <-- update this\n",
61
+ "HF_SPACE_URL = f\"https://{HF_SPACE_REPO.replace('/', '-')}.hf.space\"\n",
62
+ "\n",
63
+ "!pip install git+https://huggingface.co/spaces/{HF_SPACE_REPO}\n",
64
+ "\n",
65
+ "print(f\"Environment endpoint: {HF_SPACE_URL}\")"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "markdown",
70
+ "metadata": {},
71
+ "source": ["## 2️⃣ W&B Setup + Config"]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "metadata": {},
77
+ "outputs": [],
78
+ "source": [
79
+ "import wandb\n",
80
+ "wandb.login() # will prompt for API key\n",
81
+ "\n",
82
+ "# ─── Hyperparameters ───────────────────────────────────────────────────\n",
83
+ "CONFIG = {\n",
84
+ " # Model\n",
85
+ " \"model_name\": \"unsloth/Qwen3-4B-unsloth-bnb-4bit\", # fast 4-bit for Colab\n",
86
+ " # \"model_name\": \"unsloth/gpt-oss-20b-bf16\", # H100 BF16 (hackathon default)\n",
87
+ "\n",
88
+ " # Training\n",
89
+ " \"max_steps\": 300, # reduce to 50 for quick test\n",
90
+ " \"num_generations\": 4, # rollouts per prompt\n",
91
+ " \"max_new_tokens\": 64, # per move response\n",
92
+ " \"learning_rate\": 5e-6,\n",
93
+ " \"batch_size\": 2,\n",
94
+ " \"gradient_accumulation_steps\": 4,\n",
95
+ "\n",
96
+ " # LoRA\n",
97
+ " \"lora_r\": 16,\n",
98
+ " \"lora_alpha\": 32,\n",
99
+ " \"fast_inference\": True, # uses vLLM for speed\n",
100
+ "\n",
101
+ " # Environment\n",
102
+ " \"env_url\": HF_SPACE_URL,\n",
103
+ " \"games_per_step\": 4,\n",
104
+ " \"max_moves\": 42, # max moves in Connect4\n",
105
+ "\n",
106
+ " # Reward weights\n",
107
+ " \"reward_win\": 10.0,\n",
108
+ " \"reward_lose\": -10.0,\n",
109
+ " \"reward_block\": 0.5,\n",
110
+ " \"reward_three\": 0.2,\n",
111
+ " \"reward_invalid\": -0.1,\n",
112
+ " \"reward_format\": 0.3, # bonus for correct JSON format\n",
113
+ "}\n",
114
+ "\n",
115
+ "run = wandb.init(\n",
116
+ " project=\"openenv-connect4-autodrive\",\n",
117
+ " config=CONFIG,\n",
118
+ " tags=[\"connect4\", \"grpo\", \"openenv\", \"autonomous-driving\", \"multi-agent\"]\n",
119
+ ")\n",
120
+ "print(\"W&B run:\", run.url)"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "markdown",
125
+ "metadata": {},
126
+ "source": ["## 3️⃣ Load Model with Unsloth"]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "metadata": {},
132
+ "outputs": [],
133
+ "source": [
134
+ "from unsloth import FastLanguageModel\n",
135
+ "import torch\n",
136
+ "\n",
137
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
138
+ " model_name = CONFIG[\"model_name\"],\n",
139
+ " max_seq_length = 2048,\n",
140
+ " load_in_4bit = True, # set False for BF16 on H100\n",
141
+ " fast_inference = CONFIG[\"fast_inference\"],\n",
142
+ " gpu_memory_utilization = 0.7,\n",
143
+ ")\n",
144
+ "\n",
145
+ "# Add LoRA adapters\n",
146
+ "model = FastLanguageModel.get_peft_model(\n",
147
+ " model,\n",
148
+ " r = CONFIG[\"lora_r\"],\n",
149
+ " lora_alpha = CONFIG[\"lora_alpha\"],\n",
150
+ " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
151
+ " \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
152
+ " lora_dropout = 0,\n",
153
+ " bias = \"none\",\n",
154
+ " use_gradient_checkpointing = \"unsloth\",\n",
155
+ " random_state = 42,\n",
156
+ ")\n",
157
+ "\n",
158
+ "print(f\"✅ Model loaded: {CONFIG['model_name']}\")\n",
159
+ "print(f\" Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}\")"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "markdown",
164
+ "metadata": {},
165
+ "source": ["## 4️⃣ Connect4 Prompt Engineering + Reward Functions"]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": null,
170
+ "metadata": {},
171
+ "outputs": [],
172
+ "source": [
173
+ "import json, re\n",
174
+ "from typing import Optional\n",
175
+ "\n",
176
+ "# ─── System Prompt ──────────────────────────────────────────────────────\n",
177
+ "SYSTEM_PROMPT = \"\"\"You are an autonomous vehicle navigation AI (Player 1, symbol: X).\n",
178
+ "You are navigating a 6x7 grid intersection. Your goal is to coordinate your path\n",
179
+ "to create a connected route of 4 cells (Connect4) before the opponent vehicle (O).\n",
180
+ "\n",
181
+ "The board represents intersection occupancy. Each column is a lane (0-6).\n",
182
+ "Pieces fall to the lowest available row in each column.\n",
183
+ "\n",
184
+ "Think step by step about:\n",
185
+ "1. Your current formation and best extension\n",
186
+ "2. Opponent threats to block\n",
187
+ "3. The optimal column to select\n",
188
+ "\n",
189
+ "Respond ONLY with valid JSON:\n",
190
+ "{\"thinking\": \"<your reasoning>\", \"column\": <0-6>}\"\"\"\n",
191
+ "\n",
192
+ "\n",
193
+ "def format_prompt(obs_message: str, board_str: str, legal_moves: list) -> str:\n",
194
+ " return f\"\"\"Current board state:\n",
195
+ "```\n",
196
+ "{board_str}\n",
197
+ "```\n",
198
+ "Legal moves (columns): {legal_moves}\n",
199
+ "Status: {obs_message}\n",
200
+ "\n",
201
+ "Select your move:\"\"\"\n",
202
+ "\n",
203
+ "\n",
204
+ "# ─── Reward Functions ───────────────────────────────────────────────────\n",
205
+ "def parse_llm_move(response: str) -> Optional[int]:\n",
206
+ " \"\"\"Extract column from LLM JSON response.\"\"\"\n",
207
+ " try:\n",
208
+ " # Try direct JSON parse\n",
209
+ " data = json.loads(response.strip())\n",
210
+ " return int(data.get(\"column\", -1))\n",
211
+ " except Exception:\n",
212
+ " pass\n",
213
+ " # Fallback: regex\n",
214
+ " m = re.search(r'\"column\"\\s*:\\s*(\\d+)', response)\n",
215
+ " if m:\n",
216
+ " return int(m.group(1))\n",
217
+ " # Last resort: find any digit\n",
218
+ " digits = re.findall(r'\\b([0-6])\\b', response)\n",
219
+ " return int(digits[-1]) if digits else None\n",
220
+ "\n",
221
+ "\n",
222
+ "def format_reward(response: str) -> float:\n",
223
+ " \"\"\"Reward correct JSON format with thinking field.\"\"\"\n",
224
+ " try:\n",
225
+ " data = json.loads(response.strip())\n",
226
+ " has_thinking = isinstance(data.get(\"thinking\"), str) and len(data[\"thinking\"]) > 10\n",
227
+ " has_column = isinstance(data.get(\"column\"), int)\n",
228
+ " return CONFIG[\"reward_format\"] if (has_thinking and has_column) else 0.0\n",
229
+ " except Exception:\n",
230
+ " return -0.05 # small penalty for unparseable output\n",
231
+ "\n",
232
+ "\n",
233
+ "print(\"✅ Prompt and reward functions defined\")"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "markdown",
238
+ "metadata": {},
239
+ "source": ["## 5️⃣ Rule-Based Opponent (Player 2)"]
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": null,
244
+ "metadata": {},
245
+ "outputs": [],
246
+ "source": [
247
+ "import random\n",
248
+ "\n",
249
+ "def opponent_move(board: list, legal_moves: list) -> int:\n",
250
+ " \"\"\"\n",
251
+ " Rule-based opponent (Player 2 / O):\n",
252
+ " 1. Win if possible\n",
253
+ " 2. Block Player 1 winning move\n",
254
+ " 3. Prefer center\n",
255
+ " 4. Random\n",
256
+ " \"\"\"\n",
257
+ " ROWS, COLS = 6, 7\n",
258
+ " P2, P1 = 2, 1\n",
259
+ "\n",
260
+ " def can_win_at(b, col, player):\n",
261
+ " import copy\n",
262
+ " b2 = copy.deepcopy(b)\n",
263
+ " for row in range(ROWS-1, -1, -1):\n",
264
+ " if b2[row][col] == 0:\n",
265
+ " b2[row][col] = player\n",
266
+ " break\n",
267
+ " # Check win\n",
268
+ " for r in range(ROWS):\n",
269
+ " for c in range(COLS-3):\n",
270
+ " if all(b2[r][c+i] == player for i in range(4)): return True\n",
271
+ " for r in range(ROWS-3):\n",
272
+ " for c in range(COLS):\n",
273
+ " if all(b2[r+i][c] == player for i in range(4)): return True\n",
274
+ " for r in range(3, ROWS):\n",
275
+ " for c in range(COLS-3):\n",
276
+ " if all(b2[r-i][c+i] == player for i in range(4)): return True\n",
277
+ " for r in range(ROWS-3):\n",
278
+ " for c in range(COLS-3):\n",
279
+ " if all(b2[r+i][c+i] == player for i in range(4)): return True\n",
280
+ " return False\n",
281
+ "\n",
282
+ " # 1. Win\n",
283
+ " for col in legal_moves:\n",
284
+ " if can_win_at(board, col, P2):\n",
285
+ " return col\n",
286
+ " # 2. Block\n",
287
+ " for col in legal_moves:\n",
288
+ " if can_win_at(board, col, P1):\n",
289
+ " return col\n",
290
+ " # 3. Center preference\n",
291
+ " center_order = sorted(legal_moves, key=lambda c: abs(c - 3))\n",
292
+ " return center_order[0]\n",
293
+ "\n",
294
+ "print(\"✅ Rule-based opponent defined\")"
295
+ ]
296
+ },
297
+ {
298
+ "cell_type": "markdown",
299
+ "metadata": {},
300
+ "source": ["## 6️⃣ OpenEnv Game Loop (Environment Interaction)"]
301
+ },
302
+ {
303
+ "cell_type": "code",
304
+ "execution_count": null,
305
+ "metadata": {},
306
+ "outputs": [],
307
+ "source": [
308
+ "import asyncio\n",
309
+ "from connect4_env import Connect4Env, Connect4Action\n",
310
+ "\n",
311
+ "async def play_game(model, tokenizer, env_url: str, verbose: bool = False):\n",
312
+ " \"\"\"\n",
313
+ " Run one complete Connect4 game.\n",
314
+ " Returns list of (prompt, response, reward) tuples for GRPO training.\n",
315
+ " \"\"\"\n",
316
+ " experiences = []\n",
317
+ "\n",
318
+ " async with Connect4Env(base_url=env_url) as env:\n",
319
+ " obs = await env.reset()\n",
320
+ "\n",
321
+ " for move_num in range(CONFIG[\"max_moves\"]):\n",
322
+ " if obs.done:\n",
323
+ " break\n",
324
+ "\n",
325
+ " # ── Player 1: LLM turn ──────────────────────────────────────\n",
326
+ " if obs.current_player == 1:\n",
327
+ " prompt = format_prompt(obs.message, obs.board_str, obs.legal_moves)\n",
328
+ " messages = [\n",
329
+ " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
330
+ " {\"role\": \"user\", \"content\": prompt},\n",
331
+ " ]\n",
332
+ " input_ids = tokenizer.apply_chat_template(\n",
333
+ " messages, return_tensors=\"pt\", tokenize=True\n",
334
+ " ).to(model.device)\n",
335
+ "\n",
336
+ " with torch.no_grad():\n",
337
+ " output = model.generate(\n",
338
+ " input_ids,\n",
339
+ " max_new_tokens=CONFIG[\"max_new_tokens\"],\n",
340
+ " temperature=0.7,\n",
341
+ " do_sample=True,\n",
342
+ " pad_token_id=tokenizer.eos_token_id,\n",
343
+ " )\n",
344
+ " response = tokenizer.decode(\n",
345
+ " output[0][input_ids.shape[1]:], skip_special_tokens=True\n",
346
+ " )\n",
347
+ "\n",
348
+ " col = parse_llm_move(response)\n",
349
+ " if col is None or col not in obs.legal_moves:\n",
350
+ " col = random.choice(obs.legal_moves) # fallback\n",
351
+ " env_reward = CONFIG[\"reward_invalid\"]\n",
352
+ " else:\n",
353
+ " env_reward = 0.0 # will be updated after step\n",
354
+ "\n",
355
+ " result = await env.step(Connect4Action(\n",
356
+ " column=col,\n",
357
+ " reasoning=response[:200]\n",
358
+ " ))\n",
359
+ "\n",
360
+ " # Accumulate rewards\n",
361
+ " total_reward = (\n",
362
+ " result.reward\n",
363
+ " + format_reward(response)\n",
364
+ " )\n",
365
+ "\n",
366
+ " experiences.append({\n",
367
+ " \"prompt\": tokenizer.apply_chat_template(messages, tokenize=False),\n",
368
+ " \"response\": response,\n",
369
+ " \"reward\": total_reward,\n",
370
+ " \"move\": col,\n",
371
+ " \"move_num\": move_num,\n",
372
+ " })\n",
373
+ "\n",
374
+ " obs = result.observation\n",
375
+ " if verbose:\n",
376
+ " print(f\"P1 move {col} | reward {total_reward:.2f}\")\n",
377
+ " print(obs.board_str)\n",
378
+ "\n",
379
+ " # ── Player 2: Rule-based opponent ───────────────────────────\n",
380
+ " else:\n",
381
+ " col = opponent_move(obs.board, obs.legal_moves)\n",
382
+ " result = await env.step(Connect4Action(column=col))\n",
383
+ " obs = result.observation\n",
384
+ " if verbose:\n",
385
+ " print(f\"P2 move {col}\")\n",
386
+ "\n",
387
+ " # Terminal reward propagation — assign game outcome to all moves\n",
388
+ " if obs.winner == 1:\n",
389
+ " outcome_bonus = 1.0\n",
390
+ " elif obs.winner == 2:\n",
391
+ " outcome_bonus = -1.0\n",
392
+ " else:\n",
393
+ " outcome_bonus = 0.1 # draw is slightly positive\n",
394
+ "\n",
395
+ " for exp in experiences:\n",
396
+ " exp[\"reward\"] += outcome_bonus\n",
397
+ "\n",
398
+ " return experiences, obs.winner\n",
399
+ "\n",
400
+ "\n",
401
+ "# Quick sanity test (1 game, no training)\n",
402
+ "print(\"Running test game...\")\n",
403
+ "test_exps, winner = asyncio.run(\n",
404
+ " play_game(model, tokenizer, CONFIG[\"env_url\"], verbose=True)\n",
405
+ ")\n",
406
+ "print(f\"\\nTest game winner: Player {winner} | Experiences collected: {len(test_exps)}\")"
407
+ ]
408
+ },
409
+ {
410
+ "cell_type": "markdown",
411
+ "metadata": {},
412
+ "source": ["## 7️⃣ GRPO Training Loop (Unsloth + TRL)"]
413
+ },
414
+ {
415
+ "cell_type": "code",
416
+ "execution_count": null,
417
+ "metadata": {},
418
+ "outputs": [],
419
+ "source": [
420
+ "from trl import GRPOTrainer, GRPOConfig\n",
421
+ "from datasets import Dataset\n",
422
+ "\n",
423
+ "# ─── Build initial dataset from self-play ───────────────────────────────\n",
424
+ "print(\"Collecting initial self-play data...\")\n",
425
+ "all_experiences = []\n",
426
+ "wins = 0\n",
427
+ "\n",
428
+ "for game_i in range(CONFIG[\"games_per_step\"]):\n",
429
+ " exps, winner = asyncio.run(\n",
430
+ " play_game(model, tokenizer, CONFIG[\"env_url\"])\n",
431
+ " )\n",
432
+ " all_experiences.extend(exps)\n",
433
+ " if winner == 1:\n",
434
+ " wins += 1\n",
435
+ " print(f\" Game {game_i+1}/{CONFIG['games_per_step']} | winner={winner}\")\n",
436
+ "\n",
437
+ "print(f\"\\nInitial win rate: {wins}/{CONFIG['games_per_step']} = {wins/CONFIG['games_per_step']:.1%}\")\n",
438
+ "wandb.log({\"initial_win_rate\": wins / CONFIG[\"games_per_step\"]})\n",
439
+ "\n",
440
+ "# Convert to HF Dataset\n",
441
+ "dataset = Dataset.from_list([\n",
442
+ " {\"prompt\": e[\"prompt\"], \"reward\": e[\"reward\"]}\n",
443
+ " for e in all_experiences\n",
444
+ "])\n",
445
+ "print(f\"Dataset size: {len(dataset)} samples\")"
446
+ ]
447
+ },
448
+ {
449
+ "cell_type": "code",
450
+ "execution_count": null,
451
+ "metadata": {},
452
+ "outputs": [],
453
+ "source": [
454
+ "# ─── Reward function for GRPO Trainer ───────────────────────────────────\n",
455
+ "# GRPO expects: reward_funcs that take (prompts, completions) -> list[float]\n",
456
+ "\n",
457
+ "# We pre-computed rewards via env interaction, stored in dataset\n",
458
+ "# This function provides FORMAT reward during GRPO rollouts\n",
459
+ "def grpo_reward_format(completions, **kwargs) -> list[float]:\n",
460
+ " return [format_reward(c) for c in completions]\n",
461
+ "\n",
462
+ "\n",
463
+ "# ─── GRPO Config ────────────────────────────────────────────────────────\n",
464
+ "grpo_config = GRPOConfig(\n",
465
+ " output_dir=\"./connect4-grpo-checkpoints\",\n",
466
+ " num_train_epochs=1,\n",
467
+ " max_steps=CONFIG[\"max_steps\"],\n",
468
+ " per_device_train_batch_size=CONFIG[\"batch_size\"],\n",
469
+ " gradient_accumulation_steps=CONFIG[\"gradient_accumulation_steps\"],\n",
470
+ " learning_rate=CONFIG[\"learning_rate\"],\n",
471
+ " num_generations=CONFIG[\"num_generations\"],\n",
472
+ " max_new_tokens=CONFIG[\"max_new_tokens\"],\n",
473
+ " max_prompt_length=1024,\n",
474
+ " bf16=True,\n",
475
+ " logging_steps=10,\n",
476
+ " save_steps=100,\n",
477
+ " report_to=\"wandb\",\n",
478
+ " run_name=f\"connect4-grpo-{CONFIG['model_name'].split('/')[-1]}\",\n",
479
+ " # GRPO-specific\n",
480
+ " use_vllm=CONFIG[\"fast_inference\"],\n",
481
+ " vllm_gpu_memory_utilization=0.3,\n",
482
+ " temperature=0.7,\n",
483
+ " kl_coef=0.01,\n",
484
+ ")\n",
485
+ "\n",
486
+ "# ─── Trainer ────────────────────────────────────────────────────────────\n",
487
+ "trainer = GRPOTrainer(\n",
488
+ " model=model,\n",
489
+ " processing_class=tokenizer,\n",
490
+ " reward_funcs=[grpo_reward_format],\n",
491
+ " args=grpo_config,\n",
492
+ " train_dataset=dataset,\n",
493
+ ")\n",
494
+ "\n",
495
+ "print(\"✅ GRPO Trainer initialized\")\n",
496
+ "print(f\" max_steps: {CONFIG['max_steps']}\")\n",
497
+ "print(f\" fast_inference (vLLM): {CONFIG['fast_inference']}\")"
498
+ ]
499
+ },
500
+ {
501
+ "cell_type": "code",
502
+ "execution_count": null,
503
+ "metadata": {},
504
+ "outputs": [],
505
+ "source": [
506
+ "# ─── Run Training ────────────────────────────────────────────────────────\n",
507
+ "print(\"🚀 Starting GRPO training...\")\n",
508
+ "trainer.train()\n",
509
+ "print(\"✅ Training complete!\")"
510
+ ]
511
+ },
512
+ {
513
+ "cell_type": "markdown",
514
+ "metadata": {},
515
+ "source": ["## 8️⃣ Online RL Loop — Closed-Loop Self-Play Training"]
516
+ },
517
+ {
518
+ "cell_type": "code",
519
+ "execution_count": null,
520
+ "metadata": {},
521
+ "outputs": [],
522
+ "source": [
523
+ "\"\"\"[OPTIONAL - Advanced]\n",
524
+ "Online RL: alternate between:\n",
525
+ " (a) collecting fresh game data with current model\n",
526
+ " (b) GRPO update on fresh data\n",
527
+ "\n",
528
+ "This implements closed-loop learning — the key advantage of RL + Envs.\n",
529
+ "\"\"\"\n",
530
+ "\n",
531
+ "ONLINE_ITERATIONS = 5 # number of collect → train cycles\n",
532
+ "\n",
533
+ "win_rates = []\n",
534
+ "\n",
535
+ "for iteration in range(ONLINE_ITERATIONS):\n",
536
+ " print(f\"\\n{'='*50}\")\n",
537
+ " print(f\"Online RL Iteration {iteration+1}/{ONLINE_ITERATIONS}\")\n",
538
+ " print('='*50)\n",
539
+ "\n",
540
+ " # ── Collect fresh experience ──────────────────────────────────────\n",
541
+ " fresh_exps = []\n",
542
+ " wins = 0\n",
543
+ " for _ in range(CONFIG[\"games_per_step\"]):\n",
544
+ " exps, winner = asyncio.run(\n",
545
+ " play_game(model, tokenizer, CONFIG[\"env_url\"])\n",
546
+ " )\n",
547
+ " fresh_exps.extend(exps)\n",
548
+ " if winner == 1: wins += 1\n",
549
+ "\n",
550
+ " win_rate = wins / CONFIG[\"games_per_step\"]\n",
551
+ " win_rates.append(win_rate)\n",
552
+ " print(f\"Win rate: {win_rate:.1%}\")\n",
553
+ " wandb.log({\"win_rate\": win_rate, \"iteration\": iteration})\n",
554
+ "\n",
555
+ " # ── Update dataset ────────────────────────────────────────────────\n",
556
+ " fresh_dataset = Dataset.from_list([\n",
557
+ " {\"prompt\": e[\"prompt\"], \"reward\": e[\"reward\"]}\n",
558
+ " for e in fresh_exps\n",
559
+ " ])\n",
560
+ "\n",
561
+ " # ── Short GRPO update on fresh data ──────────────────────────────\n",
562
+ " trainer.train_dataset = fresh_dataset\n",
563
+ " trainer.args.max_steps = 50 # short update per iteration\n",
564
+ " trainer.train()\n",
565
+ "\n",
566
+ "print(f\"\\nFinal win rates across iterations: {win_rates}\")\n",
567
+ "print(f\"Improvement: {win_rates[0]:.1%} → {win_rates[-1]:.1%}\")"
568
+ ]
569
+ },
570
+ {
571
+ "cell_type": "markdown",
572
+ "metadata": {},
573
+ "source": ["## 9️⃣ Save & Push to HF Hub"]
574
+ },
575
+ {
576
+ "cell_type": "code",
577
+ "execution_count": null,
578
+ "metadata": {},
579
+ "outputs": [],
580
+ "source": [
581
+ "# Save LoRA adapter\n",
582
+ "model.save_pretrained(\"connect4-grpo-adapter\")\n",
583
+ "tokenizer.save_pretrained(\"connect4-grpo-adapter\")\n",
584
+ "\n",
585
+ "# Push to HF Hub\n",
586
+ "HF_MODEL_REPO = \"YOUR_HF_USERNAME/connect4-autonomous-driving-grpo\" # <-- update\n",
587
+ "model.push_to_hub(HF_MODEL_REPO)\n",
588
+ "tokenizer.push_to_hub(HF_MODEL_REPO)\n",
589
+ "\n",
590
+ "# Save merged model (optional, for inference)\n",
591
+ "# model.save_pretrained_merged(\"connect4-merged\", tokenizer)\n",
592
+ "\n",
593
+ "print(f\"✅ Model pushed to: https://huggingface.co/{HF_MODEL_REPO}\")\n",
594
+ "wandb.finish()"
595
+ ]
596
+ },
597
+ {
598
+ "cell_type": "markdown",
599
+ "metadata": {},
600
+ "source": [
601
+ "## 📊 Evaluation\n",
602
+ "Test the trained model against the rule-based opponent."
603
+ ]
604
+ },
605
+ {
606
+ "cell_type": "code",
607
+ "execution_count": null,
608
+ "metadata": {},
609
+ "outputs": [],
610
+ "source": [
611
+ "# Evaluation: 20-game tournament\n",
612
+ "FastLanguageModel.for_inference(model) # switch to inference mode\n",
613
+ "\n",
614
+ "EVAL_GAMES = 20\n",
615
+ "results = {1: 0, 2: 0, None: 0}\n",
616
+ "\n",
617
+ "for i in range(EVAL_GAMES):\n",
618
+ " _, winner = asyncio.run(play_game(model, tokenizer, CONFIG[\"env_url\"]))\n",
619
+ " results[winner] = results.get(winner, 0) + 1\n",
620
+ " print(f\"Game {i+1:2d}: winner = Player {winner}\")\n",
621
+ "\n",
622
+ "print(f\"\\n{'='*40}\")\n",
623
+ "print(f\"EVALUATION RESULTS ({EVAL_GAMES} games)\")\n",
624
+ "print(f\" LLM wins (P1): {results[1]:2d} ({results[1]/EVAL_GAMES:.1%})\")\n",
625
+ "print(f\" Rule wins (P2): {results[2]:2d} ({results[2]/EVAL_GAMES:.1%})\")\n",
626
+ "print(f\" Draws : {results[None]:2d} ({results.get(None,0)/EVAL_GAMES:.1%})\")\n",
627
+ "\n",
628
+ "wandb.log({\n",
629
+ " \"eval_win_rate\": results[1] / EVAL_GAMES,\n",
630
+ " \"eval_loss_rate\": results[2] / EVAL_GAMES,\n",
631
+ " \"eval_draw_rate\": results.get(None, 0) / EVAL_GAMES,\n",
632
+ "})"
633
+ ]
634
+ }
635
+ ],
636
+ "metadata": {
637
+ "kernelspec": {
638
+ "display_name": "Python 3",
639
+ "language": "python",
640
+ "name": "python3"
641
+ },
642
+ "language_info": {
643
+ "name": "python",
644
+ "version": "3.11.0"
645
+ },
646
+ "accelerator": "GPU",
647
+ "colab": {
648
+ "gpuType": "H100",
649
+ "provenance": []
650
+ }
651
+ },
652
+ "nbformat": 4,
653
+ "nbformat_minor": 4
654
+ }
models.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Connect4 Multi-Agent Environment — Models
3
+ OpenEnv v0.2.1
4
+ """
5
+
6
+ from typing import Optional
7
+ from pydantic import Field
8
+ from openenv.core.models import Action, Observation, State
9
+
10
+
11
+ class Connect4Action(Action):
12
+ """Action: choose which column to drop a piece into (0–6)."""
13
+ column: int = Field(
14
+ ...,
15
+ ge=0,
16
+ le=6,
17
+ description="Column index (0-6) to drop the piece into",
18
+ )
19
+ reasoning: Optional[str] = Field(
20
+ None,
21
+ description="LLM chain-of-thought reasoning for this move (used for reward shaping)",
22
+ )
23
+
24
+
25
+ class Connect4Observation(Observation):
26
+ """Full observation returned after each step."""
27
+ board: list[list[int]] = Field(..., description="6x7 board as nested list (0=empty, 1=P1, 2=P2)")
28
+ board_str: str = Field(..., description="Human-readable ASCII board")
29
+ current_player: int = Field(..., description="Which player moves next (1 or 2)")
30
+ legal_moves: list[int] = Field(..., description="List of valid column indices")
31
+ last_move: Optional[int] = Field(None, description="Column of the last move played")
32
+ done: bool = Field(False, description="Whether the game has ended")
33
+ winner: Optional[int] = Field(None, description="Winner (1 or 2) or None if ongoing/draw")
34
+ message: str = Field("", description="Human-readable status message")
35
+
36
+
37
+ class Connect4State(State):
38
+ """Episode-level state metadata."""
39
+ current_player: int = Field(1)
40
+ done: bool = Field(False)
41
+ winner: Optional[int] = Field(None)
42
+ move_history: list[tuple[int, int]] = Field(
43
+ default_factory=list,
44
+ description="List of (player, column) tuples"
45
+ )