Llama-3.2-3B-Thinking
A Llama 3.2 3B model fine-tuned with Chain-of-Thought (CoT) Supervised Fine-Tuning (SFT)
to produce explicit step-by-step reasoning enclosed in <think> … </think> tags before
giving a final answer.
Base Model
Training Data
PursuitOfDataScience/0.5M-thinking
~500 K multi-turn examples where the assistant turn contains structured chain-of-thought
reasoning wrapped in <think> / </think> tags followed by the final answer.
SFT Process
Overview
The model was trained for 1 epoch on the 0.5M-thinking dataset on a single NVIDIA H100
GPU using HuggingFace Trainer.
Key Training Choices
| Hyperparameter | Value |
|---|---|
| Base model | meta-llama/Llama-3.2-3B |
| Precision | bfloat16 |
| Max context length | 4096 tokens |
| Per-device batch size | 8 |
| Gradient accumulation | 8 steps (effective batch ≈ 64) |
| Optimizer | AdamW (fused) |
| Learning rate | 2e-5 |
| LR scheduler | Cosine |
| Warmup steps | 100 |
| Epochs | 1 |
| Attention | SDPA (no Flash-Attention) |
torch.compile |
✅ (Hopper / H100) |
| Gradient checkpointing | ✅ |
Label Masking
Only the assistant turn (everything after the hardcoded <think> prefix) contributes
to the cross-entropy loss. The user / prompt tokens are masked with -100 so the model
learns to reason, not to parrot the prompt.
Filtering
Examples whose full tokenized length (prompt + complete assistant response + special tokens)
exceeds 4096 tokens are filtered out rather than truncated, ensuring the model never
trains on a response with a cut-off </think> or missing final answer.
Prompt Format
The model was trained using a plain-text role-prefixed format with <think> hardcoded
into the prompt so the model always begins its response with chain-of-thought reasoning.
Training format
user: {question}
assistant: <think>
{chain-of-thought reasoning}
</think>
{final answer}
Inference format (recommended)
user: {question}
Think briefly, then give the final numerical answer after ####.
assistant: <think>
The model will complete the <think> block and then produce the final answer after </think>.
GSM8K Benchmark Results (Pass@1)
Evaluated on the full GSM8K test set (1 319 examples) using vLLM with:
- Temperature: 0.6
- Top-p: 0.9
- Max new tokens: 4096
- 3 independent runs per checkpoint; results reported as mean ± std.
- Strict numerical comparison: gold
####answer vs. extracted prediction.
| Model | Total | Mean Acc | Std |
|---|---|---|---|
| final_model | 1319 | 39.22% | 0.98% |
| checkpoint-1000 | 1319 | 31.13% | 0.72% |
| checkpoint-2000 | 1319 | 33.46% | 0.36% |
| checkpoint-3000 | 1319 | 35.30% | 0.13% |
| checkpoint-4000 | 1319 | 39.93% | 0.95% |
| checkpoint-5000 | 1319 | 38.49% | 0.93% |
| checkpoint-6000 | 1319 | 39.15% | 1.16% |
| checkpoint-7000 | 1319 | 38.77% | 0.48% |
| base_model (no SFT) | 1319 | 7.23% | 0.70% |
Best checkpoint: checkpoint-4000 at 39.93% mean accuracy.
Final merged model: 39.22% — within 1 pp of the best checkpoint.
SFT improved GSM8K accuracy by ~32 percentage points over the base model.
Usage
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
model_id = "PursuitOfDataScience/llama3.2-3b-thinking"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
)
question = "Janet's ducks lay 16 eggs per day. She eats 3 for breakfast and bakes muffins with 4. She sells the remainder at $2 per egg. How much does she make per day?"
prompt = (
f"user: {question}\n"
f"Think briefly, then give the final numerical answer after ####.\n"
f"assistant: <think>\n"
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=1024,
temperature=0.6,
top_p=0.9,
do_sample=True,
)
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)
Citation
If you use this model, please cite the base model and training dataset:
- Base model: meta-llama/Llama-3.2-3B
- Training data: PursuitOfDataScience/0.5M-thinking
- Downloads last month
- 75
Model tree for PursuitOfDataScience/llama3.2-3b-thinking
Base model
meta-llama/Llama-3.2-3B