Llama-3.2-1B-GRPO
A Llama-3.2-1B model fine-tuned with Chain-of-Thought reasoning and GRPO (Group Relative Policy Optimization) for mathematical problem solving.
Model Description
This model is a two-stage fine-tuned version of meta-llama/Llama-3.2-1B:
Stage 1: Chain-of-Thought (CoT) Supervised Fine-Tuning
- Fine-tuned on PursuitOfDataScience/MiniMax-M2.1-Mixture-of-Thoughts
- Trained the model to use
<think>reasoning tags for step-by-step problem solving - Output model saved to
/project/rcc/youzhi/Llama-3.2-1B-reasoning/final_model
Stage 2: GRPO Reinforcement Learning
- Starting from the CoT model, applied GRPO training on openai/gsm8k
- Used verifiable rewards based on correct mathematical answers
- Optimized to maintain Chain-of-Thought reasoning while improving accuracy
Training Details
Stage 1: Chain-of-Thought Fine-Tuning
Base Model: meta-llama/Llama-3.2-1B
Dataset: PursuitOfDataScience/MiniMax-M2.1-Mixture-of-Thoughts
Training Configuration:
- Maximum context length: 4096 tokens
- Batch size: 16 (with dynamic adjustment to 4 minimum)
- Gradient accumulation steps: 8
- Learning rate: 2e-5
- Epochs: 1
- Warmup steps: 100
- Optimizer: AdamW (fused)
- Precision: bfloat16
- Hardware: H100 GPU with SDPA attention
Format: The model was trained with a simple chat format:
user: {question}
assistant: <think>{reasoning}</think>{answer}
Stage 2: GRPO Training
Starting Model: CoT fine-tuned model from Stage 1 (/project/rcc/youzhi/Llama-3.2-1B-reasoning/final_model)
Dataset: openai/gsm8k (train split)
GRPO Configuration:
- Number of generations per prompt: 16
- Max completion length: 1024 tokens
- Max prompt length: 512 tokens
- Batch size: 4
- Gradient accumulation steps: 8 (effective batch: 32 prompts)
- Learning rate: 5e-7
- KL coefficient (beta): 0.05
- Epochs: 1
- Warmup steps: 50
- Temperature: 0.8
- Top-p: 0.95
- Top-k: 50
- Precision: bfloat16
- Gradient checkpointing: Enabled
- Hardware: H100 GPU with SDPA attention
Reward Function:
- 1.0 for correct answer appearing in last 300 characters
- 0.5 for correct answer appearing elsewhere
- 0.0 for incorrect or missing answer
The reward function checks if the ground truth numerical answer appears in the generated text using word boundary matching to avoid partial number matches.
Usage
The model uses a simple prompt format:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "PursuitOfDataScience/Llama-3.2-1B-GRPO"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Format the prompt
question = "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"
prompt = f"""user: Solve this math problem step by step. Show your reasoning inside <think></think> tags, then give the final answer after ####.
Question: {question}
assistant:"""
# Generate
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.pad_token_id
)
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
Expected output format:
<think>
[Step-by-step reasoning here]
</think>
#### [final answer]
Model Architecture
- Base Architecture: Llama-3.2-1B (1.23B parameters)
- Attention: SDPA (Scaled Dot-Product Attention)
- Precision: bfloat16
- Vocabulary Size: Same as base Llama-3.2-1B
Evaluation
The model was trained and evaluated on GSM8K mathematical reasoning tasks. The GRPO training used a verifiable reward function that checks for correct numerical answers, with higher rewards for answers appearing in the expected location (end of response).
Limitations
- Trained specifically for mathematical reasoning tasks similar to GSM8K
- May not generalize well to other domains without additional fine-tuning
- Reasoning is optimized for problems that can be broken down into step-by-step solutions
- Limited to English language tasks
Training Environment
- Hardware: NVIDIA H100 GPU
- Framework: PyTorch with Hugging Face Transformers and TRL
- Optimization: Flash attention disabled, SDPA enabled, TF32 enabled
- Special Handling: Flash attention and torchvision modules mocked to avoid import issues
Citation
If you use this model, please cite:
@misc{llama32-1b-grpo,
author = {PursuitOfDataScience},
title = {Llama-3.2-1B-GRPO: Chain-of-Thought Reasoning with GRPO},
year = {2026},
publisher = {HuggingFace},
url = {https://huggingface.co/PursuitOfDataScience/Llama-3.2-1B-GRPO}
}
Acknowledgments
- Base model: Meta's Llama-3.2-1B
- CoT training data: MiniMax-M2.1-Mixture-of-Thoughts
- GRPO training data: OpenAI's GSM8K dataset
- Training framework: Hugging Face TRL (Transformer Reinforcement Learning)
- Downloads last month
- 21