File size: 8,226 Bytes
00db46c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
# GRPO Countdown Problem

A project for training language models to solve arithmetic countdown problems using Supervised Fine-Tuning (SFT) followed by Group Relative Policy Optimization (GRPO).

## Overview

This project implements a two-stage training pipeline:

1. **SFT (Supervised Fine-Tuning)**: Train the model on arithmetic problems with correct solutions
2. **GRPO (Group Relative Policy Optimization)**: Further optimize the model using reward-based learning

The goal is to train a language model to solve arithmetic countdown problems where you must use exactly four given numbers with basic arithmetic operations (+, -, *, /) to reach a target value.



## Project Structure



```

grpo-countdown-problem/

β”œβ”€β”€ data/                           # Training and test datasets

β”œβ”€β”€ models/                         # Saved model checkpoints

β”‚   β”œβ”€β”€ sft/                       # SFT model outputs

β”‚   └── grpo/                      # GRPO model outputs

β”œβ”€β”€ src/

β”‚   β”œβ”€β”€ config/                    # Configuration files

β”‚   β”‚   β”œβ”€β”€ grpo/                  # GRPO training configs

β”‚   β”‚   └── sft/                   # SFT training configs

β”‚   β”œβ”€β”€ dataset/                   # Dataset loading and processing

β”‚   β”œβ”€β”€ examples/                  # Example scripts for inference

β”‚   β”œβ”€β”€ scripts/                   # Data generation and processing

β”‚   β”œβ”€β”€ training/                  # Training scripts

β”‚   β”‚   β”œβ”€β”€ grpo/                  # GRPO training

β”‚   β”‚   └── sft/                   # SFT training

β”‚   └── utils/                     # Utility functions

β”œβ”€β”€ main.py                        # Main entry point

β”œβ”€β”€ pyproject.toml                 # Project dependencies

└── README.md                      # This file

```



## Requirements



- Python 3.12+

- CUDA-capable GPU (recommended)

- At least 8GB GPU memory for Qwen2.5-Math-1.5B model



## Installation



1. **Clone the repository:**

   ```bash

   git clone <repository-url>

   cd grpo-countdown-problem

   ```



2. **Install dependencies using uv (recommended):**

   ```bash

   # Install uv if you haven't already

   curl -LsSf https://astral.sh/uv/install.sh | sh

   

   # Install project dependencies

   uv sync

   ```



   **Or using pip:**

   ```bash

   pip install -e .

   ```



3. **Set up environment variables (if using OpenAI for data generation):**

   ```bash

   cp .env.example .env

   # Edit .env and add your OpenAI API key

   ```



## Data Preparation



### Generate Training Data



1. **Generate SFT training data:**

   ```bash

   python src/scripts/generate_training_dataset_sft.py \

     --output_path data/sft/train.csv \

     --num_problems 10000 \

     --num_workers 4

   ```



2. **Generate GRPO training data:**

   ```bash

   python src/scripts/generate_training_dataset_grpo.py \

     --output_path data/grpo/train.csv \

     --num_problems 10000 \

     --num_workers 4

   ```



3. **Generate test data:**

   ```bash

   python src/scripts/generate_training_dataset_grpo.py \

     --output_path data/grpo/test.csv \

     --num_problems 1000 \

     --num_workers 4

   ```



### Data Format



The CSV files contain the following columns:

- `id`: Unique problem identifier

- `problem_description`: Natural language description of the problem

- `correct_answer`: The target arithmetic expression

- `num1`, `num2`, `num3`, `num4`: The four numbers to use

- `reasoning` (SFT only): Step-by-step solution explanation



## Training



### Stage 1: Supervised Fine-Tuning (SFT)



Train the base model on arithmetic problems with supervised learning:



```bash

python src/training/sft/train_sft_hydra.py

```



**Configuration:** The training uses Hydra configuration files in `src/config/sft/`:

- `config.yaml`: Main configuration

- `dataset/default.yaml`: Dataset settings

- `model/qwen2.5-3b.yaml`: Model and LoRA settings

- `training/default.yaml`: Training hyperparameters



**Key parameters:**

- Base model: `Qwen/Qwen2.5-Math-1.5B`

- LoRA rank: 64

- Learning rate: 2e-5

- Batch size: 4 (per device)

- Epochs: 2



**Output:** Trained SFT model saved to `models/sft/`



### Stage 2: Group Relative Policy Optimization (GRPO)



Further optimize the SFT model using reward-based learning:



```bash

python src/training/grpo/train_grpo_hydra.py

```



**Configuration:** Uses Hydra configuration files in `src/config/grpo/`:

- `config.yaml`: Main configuration (includes SFT model path)

- `dataset/default.yaml`: Dataset settings

- `model/qwen2.5-3b.yaml`: Model and LoRA settings  

- `training/default.yaml`: Training hyperparameters



**Key parameters:**

- Builds on SFT model from `models/sft/`

- Learning rate: 1e-5

- Batch size: 2 (per device)

- Epochs: 1

- Generations per prompt: 8

- Reward function: Mathematical correctness



**Output:** Trained GRPO model saved to `models/grpo/`



### Custom Configuration



You can override configuration parameters:



```bash

# Override dataset size

python src/training/sft/train_sft_hydra.py dataset.max_rows=5000



# Override learning rate and batch size

python src/training/grpo/train_grpo_hydra.py \

  training.learning_rate=5e-6 \

  training.per_device_train_batch_size=1



# Use different output directory

python src/training/sft/train_sft_hydra.py output_dir=models/sft_experiment

```



## Inference



### Interactive Problem Solving



Use the trained model to solve individual problems:



```bash

python src/examples/run_model.py

```



This will load both SFT and GRPO models and solve a sample problem.



### Batch Evaluation



Evaluate model accuracy on a test dataset:



```bash

python src/examples/calculate_accuracy.py \

  --csv_path data/grpo/test.csv \

  --sft_model_path models/sft/ \

  --grpo_model_path models/grpo/ \

  --max_samples 100 \

  --output_path results/evaluation_results.csv

```



**Parameters:**

- `--csv_path`: Path to test CSV file

- `--sft_model_path`: Path to SFT model directory

- `--grpo_model_path`: Path to GRPO model directory

- `--max_samples`: Limit number of test samples (optional)

- `--output_path`: Save detailed results to CSV (optional)

- `--temperature`: Sampling temperature (default: 1.0)

- `--max_new_tokens`: Maximum tokens to generate (default: 4096)



**Evaluation Metrics:**

- **Accuracy**: Percentage of problems solved correctly

- **Valid Format Rate**: Percentage of responses in valid arithmetic format

- **Uses All Numbers Rate**: Percentage of responses using all four numbers



### Model-only Evaluation



Evaluate specific model stages:



```bash

# Evaluate only SFT model (no GRPO)

python src/examples/calculate_accuracy.py \

  --csv_path data/grpo/test.csv \

  --sft_model_path models/sft/ \

  --no_grpo



# Evaluate only base model (no SFT or GRPO)

python src/examples/calculate_accuracy.py \

  --csv_path data/grpo/test.csv \

  --no_sft --no_grpo

```



## Configuration Details



### Model Configuration



The project uses **Qwen2.5-Math-1.5B** as the base model with LoRA (Low-Rank Adaptation) for efficient fine-tuning:



- **LoRA rank**: 64

- **LoRA alpha**: 128

- **Target modules**: All attention and MLP layers

- **LoRA dropout**: 0.1



### Training Configuration



**SFT Training:**

- **Optimizer**: AdamW 8-bit

- **Learning rate**: 2e-5 with linear scheduler

- **Warmup ratio**: 0.1

- **Weight decay**: 0.01

- **Max sequence length**: 4096



**GRPO Training:**

- **Optimizer**: AdamW 8-bit

- **Learning rate**: 1e-5 with cosine scheduler

- **Warmup ratio**: 0.1

- **Weight decay**: 0.0

- **Temperature**: 1.0

- **Generations per prompt**: 8



## Monitoring Training



Both training scripts log to TensorBoard:



```bash

# View training logs

tensorboard --logdir models/sft/runs    # For SFT training

tensorboard --logdir models/grpo/runs   # For GRPO training

```



## Example Problem



**Input:** "Use 53, 3, 47, and 36 exactly once each with only +, -, *, and / operators to create an expression equal to 133."

**Expected Output:** A valid arithmetic expression like `53 + 47 + 36 - 3`