Llama-3.2-1B-GRPO

A Llama-3.2-1B model fine-tuned with Chain-of-Thought reasoning and GRPO (Group Relative Policy Optimization) for mathematical problem solving.

Model Description

This model is a two-stage fine-tuned version of meta-llama/Llama-3.2-1B:

  1. Stage 1: Chain-of-Thought (CoT) Supervised Fine-Tuning

  2. Stage 2: GRPO Reinforcement Learning

    • Starting from the CoT model, applied GRPO training on openai/gsm8k
    • Used verifiable rewards based on correct mathematical answers
    • Optimized to maintain Chain-of-Thought reasoning while improving accuracy

Training Details

Stage 1: Chain-of-Thought Fine-Tuning

Base Model: meta-llama/Llama-3.2-1B

Dataset: PursuitOfDataScience/MiniMax-M2.1-Mixture-of-Thoughts

Training Configuration:

  • Maximum context length: 4096 tokens
  • Batch size: 16 (with dynamic adjustment to 4 minimum)
  • Gradient accumulation steps: 8
  • Learning rate: 2e-5
  • Epochs: 1
  • Warmup steps: 100
  • Optimizer: AdamW (fused)
  • Precision: bfloat16
  • Hardware: H100 GPU with SDPA attention

Format: The model was trained with a simple chat format:

user: {question}
assistant: <think>{reasoning}</think>{answer}

Stage 2: GRPO Training

Starting Model: CoT fine-tuned model from Stage 1 (/project/rcc/youzhi/Llama-3.2-1B-reasoning/final_model)

Dataset: openai/gsm8k (train split)

GRPO Configuration:

  • Number of generations per prompt: 16
  • Max completion length: 1024 tokens
  • Max prompt length: 512 tokens
  • Batch size: 4
  • Gradient accumulation steps: 8 (effective batch: 32 prompts)
  • Learning rate: 5e-7
  • KL coefficient (beta): 0.05
  • Epochs: 1
  • Warmup steps: 50
  • Temperature: 0.8
  • Top-p: 0.95
  • Top-k: 50
  • Precision: bfloat16
  • Gradient checkpointing: Enabled
  • Hardware: H100 GPU with SDPA attention

Reward Function:

  • 1.0 for correct answer appearing in last 300 characters
  • 0.5 for correct answer appearing elsewhere
  • 0.0 for incorrect or missing answer

The reward function checks if the ground truth numerical answer appears in the generated text using word boundary matching to avoid partial number matches.

Usage

The model uses a simple prompt format:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_name = "PursuitOfDataScience/Llama-3.2-1B-GRPO"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Format the prompt
question = "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"

prompt = f"""user: Solve this math problem step by step. Show your reasoning inside <think></think> tags, then give the final answer after ####.

Question: {question}
assistant:"""

# Generate
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
    **inputs,
    max_new_tokens=512,
    temperature=0.7,
    top_p=0.9,
    do_sample=True,
    pad_token_id=tokenizer.pad_token_id
)

response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)

Expected output format:

<think>
[Step-by-step reasoning here]
</think>
#### [final answer]

Model Architecture

  • Base Architecture: Llama-3.2-1B (1.23B parameters)
  • Attention: SDPA (Scaled Dot-Product Attention)
  • Precision: bfloat16
  • Vocabulary Size: Same as base Llama-3.2-1B

Evaluation

The model was trained and evaluated on GSM8K mathematical reasoning tasks. The GRPO training used a verifiable reward function that checks for correct numerical answers, with higher rewards for answers appearing in the expected location (end of response).

Limitations

  • Trained specifically for mathematical reasoning tasks similar to GSM8K
  • May not generalize well to other domains without additional fine-tuning
  • Reasoning is optimized for problems that can be broken down into step-by-step solutions
  • Limited to English language tasks

Training Environment

  • Hardware: NVIDIA H100 GPU
  • Framework: PyTorch with Hugging Face Transformers and TRL
  • Optimization: Flash attention disabled, SDPA enabled, TF32 enabled
  • Special Handling: Flash attention and torchvision modules mocked to avoid import issues

Citation

If you use this model, please cite:

@misc{llama32-1b-grpo,
  author = {PursuitOfDataScience},
  title = {Llama-3.2-1B-GRPO: Chain-of-Thought Reasoning with GRPO},
  year = {2026},
  publisher = {HuggingFace},
  url = {https://huggingface.co/PursuitOfDataScience/Llama-3.2-1B-GRPO}
}

Acknowledgments

  • Base model: Meta's Llama-3.2-1B
  • CoT training data: MiniMax-M2.1-Mixture-of-Thoughts
  • GRPO training data: OpenAI's GSM8K dataset
  • Training framework: Hugging Face TRL (Transformer Reinforcement Learning)
Downloads last month
21
Safetensors
Model size
1B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for PursuitOfDataScience/Llama-3.2-1B-GRPO

Finetuned
(879)
this model
Merges
1 model

Datasets used to train PursuitOfDataScience/Llama-3.2-1B-GRPO