---
language:
- en
license: llama3.2
base_model: meta-llama/Llama-3.2-3B
tags:
- llama
- chain-of-thought
- reasoning
- sft
- math
datasets:
- PursuitOfDataScience/0.5M-thinking
---
# Llama-3.2-3B-Thinking
A Llama 3.2 3B model fine-tuned with **Chain-of-Thought (CoT) Supervised Fine-Tuning (SFT)**
to produce explicit step-by-step reasoning enclosed in `` … `` tags before
giving a final answer.
---
## Base Model
[meta-llama/Llama-3.2-3B](https://huggingface.co/meta-llama/Llama-3.2-3B)
---
## Training Data
[PursuitOfDataScience/0.5M-thinking](https://huggingface.co/datasets/PursuitOfDataScience/0.5M-thinking)
~500 K multi-turn examples where the assistant turn contains structured chain-of-thought
reasoning wrapped in `` / `` tags followed by the final answer.
---
## SFT Process
### Overview
The model was trained for **1 epoch** on the 0.5M-thinking dataset on a single **NVIDIA H100**
GPU using HuggingFace `Trainer`.
### Key Training Choices
| Hyperparameter | Value |
|---|---|
| Base model | meta-llama/Llama-3.2-3B |
| Precision | bfloat16 |
| Max context length | 4096 tokens |
| Per-device batch size | 8 |
| Gradient accumulation | 8 steps (effective batch ≈ 64) |
| Optimizer | AdamW (fused) |
| Learning rate | 2e-5 |
| LR scheduler | Cosine |
| Warmup steps | 100 |
| Epochs | 1 |
| Attention | SDPA (no Flash-Attention) |
| `torch.compile` | ✅ (Hopper / H100) |
| Gradient checkpointing | ✅ |
### Label Masking
Only the **assistant turn** (everything after the hardcoded `` prefix) contributes
to the cross-entropy loss. The user / prompt tokens are masked with `-100` so the model
learns to reason, not to parrot the prompt.
### Filtering
Examples whose full tokenized length (prompt + complete assistant response + special tokens)
exceeds **4096 tokens** are **filtered out** rather than truncated, ensuring the model never
trains on a response with a cut-off `` or missing final answer.
---
## Prompt Format
The model was trained using a **plain-text role-prefixed** format with `` hardcoded
into the prompt so the model always begins its response with chain-of-thought reasoning.
### Training format
```
user: {question}
assistant:
{chain-of-thought reasoning}
{final answer}
```
### Inference format (recommended)
```
user: {question}
Think briefly, then give the final numerical answer after ####.
assistant:
```
The model will complete the `` block and then produce the final answer after ``.
---
## GSM8K Benchmark Results (Pass@1)
Evaluated on the full GSM8K test set (1 319 examples) using **vLLM** with:
- Temperature: 0.6
- Top-p: 0.9
- Max new tokens: 4096
- **3 independent runs** per checkpoint; results reported as mean ± std.
- **Strict numerical comparison**: gold `####` answer vs. extracted prediction.
| Model | Total | Mean Acc | Std |
|---|---|---|---|
| final_model | 1319 | 39.22% | 0.98% |
| checkpoint-1000 | 1319 | 31.13% | 0.72% |
| checkpoint-2000 | 1319 | 33.46% | 0.36% |
| checkpoint-3000 | 1319 | 35.30% | 0.13% |
| checkpoint-4000 | 1319 | 39.93% | 0.95% |
| checkpoint-5000 | 1319 | 38.49% | 0.93% |
| checkpoint-6000 | 1319 | 39.15% | 1.16% |
| checkpoint-7000 | 1319 | 38.77% | 0.48% |
| base_model (no SFT) | 1319 | 7.23% | 0.70% |
> **Best checkpoint**: checkpoint-4000 at **39.93%** mean accuracy.
> **Final merged model**: **39.22%** — within 1 pp of the best checkpoint.
> SFT improved GSM8K accuracy by **~32 percentage points** over the base model.
---
## Usage
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
model_id = "PursuitOfDataScience/llama3.2-3b-thinking"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
)
question = "Janet's ducks lay 16 eggs per day. She eats 3 for breakfast and bakes muffins with 4. She sells the remainder at $2 per egg. How much does she make per day?"
prompt = (
f"user: {question}\n"
f"Think briefly, then give the final numerical answer after ####.\n"
f"assistant: \n"
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=1024,
temperature=0.6,
top_p=0.9,
do_sample=True,
)
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
```
---
## Citation
If you use this model, please cite the base model and training dataset:
- Base model: [meta-llama/Llama-3.2-3B](https://huggingface.co/meta-llama/Llama-3.2-3B)
- Training data: [PursuitOfDataScience/0.5M-thinking](https://huggingface.co/datasets/PursuitOfDataScience/0.5M-thinking)