Dat1710's picture
Upload folder using huggingface_hub
00db46c verified

GRPO Countdown Problem

A project for training language models to solve arithmetic countdown problems using Supervised Fine-Tuning (SFT) followed by Group Relative Policy Optimization (GRPO).

Overview

This project implements a two-stage training pipeline:

  1. SFT (Supervised Fine-Tuning): Train the model on arithmetic problems with correct solutions
  2. GRPO (Group Relative Policy Optimization): Further optimize the model using reward-based learning

The goal is to train a language model to solve arithmetic countdown problems where you must use exactly four given numbers with basic arithmetic operations (+, -, *, /) to reach a target value.

Project Structure

grpo-countdown-problem/
β”œβ”€β”€ data/                           # Training and test datasets
β”œβ”€β”€ models/                         # Saved model checkpoints
β”‚   β”œβ”€β”€ sft/                       # SFT model outputs
β”‚   └── grpo/                      # GRPO model outputs
β”œβ”€β”€ src/
β”‚   β”œβ”€β”€ config/                    # Configuration files
β”‚   β”‚   β”œβ”€β”€ grpo/                  # GRPO training configs
β”‚   β”‚   └── sft/                   # SFT training configs
β”‚   β”œβ”€β”€ dataset/                   # Dataset loading and processing
β”‚   β”œβ”€β”€ examples/                  # Example scripts for inference
β”‚   β”œβ”€β”€ scripts/                   # Data generation and processing
β”‚   β”œβ”€β”€ training/                  # Training scripts
β”‚   β”‚   β”œβ”€β”€ grpo/                  # GRPO training
β”‚   β”‚   └── sft/                   # SFT training
β”‚   └── utils/                     # Utility functions
β”œβ”€β”€ main.py                        # Main entry point
β”œβ”€β”€ pyproject.toml                 # Project dependencies
└── README.md                      # This file

Requirements

  • Python 3.12+
  • CUDA-capable GPU (recommended)
  • At least 8GB GPU memory for Qwen2.5-Math-1.5B model

Installation

  1. Clone the repository:

    git clone <repository-url>
    cd grpo-countdown-problem
    
  2. Install dependencies using uv (recommended):

    # Install uv if you haven't already
    curl -LsSf https://astral.sh/uv/install.sh | sh
    
    # Install project dependencies
    uv sync
    

    Or using pip:

    pip install -e .
    
  3. Set up environment variables (if using OpenAI for data generation):

    cp .env.example .env
    # Edit .env and add your OpenAI API key
    

Data Preparation

Generate Training Data

  1. Generate SFT training data:

    python src/scripts/generate_training_dataset_sft.py \
      --output_path data/sft/train.csv \
      --num_problems 10000 \
      --num_workers 4
    
  2. Generate GRPO training data:

    python src/scripts/generate_training_dataset_grpo.py \
      --output_path data/grpo/train.csv \
      --num_problems 10000 \
      --num_workers 4
    
  3. Generate test data:

    python src/scripts/generate_training_dataset_grpo.py \
      --output_path data/grpo/test.csv \
      --num_problems 1000 \
      --num_workers 4
    

Data Format

The CSV files contain the following columns:

  • id: Unique problem identifier
  • problem_description: Natural language description of the problem
  • correct_answer: The target arithmetic expression
  • num1, num2, num3, num4: The four numbers to use
  • reasoning (SFT only): Step-by-step solution explanation

Training

Stage 1: Supervised Fine-Tuning (SFT)

Train the base model on arithmetic problems with supervised learning:

python src/training/sft/train_sft_hydra.py

Configuration: The training uses Hydra configuration files in src/config/sft/:

  • config.yaml: Main configuration
  • dataset/default.yaml: Dataset settings
  • model/qwen2.5-3b.yaml: Model and LoRA settings
  • training/default.yaml: Training hyperparameters

Key parameters:

  • Base model: Qwen/Qwen2.5-Math-1.5B
  • LoRA rank: 64
  • Learning rate: 2e-5
  • Batch size: 4 (per device)
  • Epochs: 2

Output: Trained SFT model saved to models/sft/

Stage 2: Group Relative Policy Optimization (GRPO)

Further optimize the SFT model using reward-based learning:

python src/training/grpo/train_grpo_hydra.py

Configuration: Uses Hydra configuration files in src/config/grpo/:

  • config.yaml: Main configuration (includes SFT model path)
  • dataset/default.yaml: Dataset settings
  • model/qwen2.5-3b.yaml: Model and LoRA settings
  • training/default.yaml: Training hyperparameters

Key parameters:

  • Builds on SFT model from models/sft/
  • Learning rate: 1e-5
  • Batch size: 2 (per device)
  • Epochs: 1
  • Generations per prompt: 8
  • Reward function: Mathematical correctness

Output: Trained GRPO model saved to models/grpo/

Custom Configuration

You can override configuration parameters:

# Override dataset size
python src/training/sft/train_sft_hydra.py dataset.max_rows=5000

# Override learning rate and batch size
python src/training/grpo/train_grpo_hydra.py \
  training.learning_rate=5e-6 \
  training.per_device_train_batch_size=1

# Use different output directory
python src/training/sft/train_sft_hydra.py output_dir=models/sft_experiment

Inference

Interactive Problem Solving

Use the trained model to solve individual problems:

python src/examples/run_model.py

This will load both SFT and GRPO models and solve a sample problem.

Batch Evaluation

Evaluate model accuracy on a test dataset:

python src/examples/calculate_accuracy.py \
  --csv_path data/grpo/test.csv \
  --sft_model_path models/sft/ \
  --grpo_model_path models/grpo/ \
  --max_samples 100 \
  --output_path results/evaluation_results.csv

Parameters:

  • --csv_path: Path to test CSV file
  • --sft_model_path: Path to SFT model directory
  • --grpo_model_path: Path to GRPO model directory
  • --max_samples: Limit number of test samples (optional)
  • --output_path: Save detailed results to CSV (optional)
  • --temperature: Sampling temperature (default: 1.0)
  • --max_new_tokens: Maximum tokens to generate (default: 4096)

Evaluation Metrics:

  • Accuracy: Percentage of problems solved correctly
  • Valid Format Rate: Percentage of responses in valid arithmetic format
  • Uses All Numbers Rate: Percentage of responses using all four numbers

Model-only Evaluation

Evaluate specific model stages:

# Evaluate only SFT model (no GRPO)
python src/examples/calculate_accuracy.py \
  --csv_path data/grpo/test.csv \
  --sft_model_path models/sft/ \
  --no_grpo

# Evaluate only base model (no SFT or GRPO)
python src/examples/calculate_accuracy.py \
  --csv_path data/grpo/test.csv \
  --no_sft --no_grpo

Configuration Details

Model Configuration

The project uses Qwen2.5-Math-1.5B as the base model with LoRA (Low-Rank Adaptation) for efficient fine-tuning:

  • LoRA rank: 64
  • LoRA alpha: 128
  • Target modules: All attention and MLP layers
  • LoRA dropout: 0.1

Training Configuration

SFT Training:

  • Optimizer: AdamW 8-bit
  • Learning rate: 2e-5 with linear scheduler
  • Warmup ratio: 0.1
  • Weight decay: 0.01
  • Max sequence length: 4096

GRPO Training:

  • Optimizer: AdamW 8-bit
  • Learning rate: 1e-5 with cosine scheduler
  • Warmup ratio: 0.1
  • Weight decay: 0.0
  • Temperature: 1.0
  • Generations per prompt: 8

Monitoring Training

Both training scripts log to TensorBoard:

# View training logs
tensorboard --logdir models/sft/runs    # For SFT training
tensorboard --logdir models/grpo/runs   # For GRPO training

Example Problem

Input: "Use 53, 3, 47, and 36 exactly once each with only +, -, *, and / operators to create an expression equal to 133."

Expected Output: A valid arithmetic expression like 53 + 47 + 36 - 3