Llama-3.2-3B-Thinking

A Llama 3.2 3B model fine-tuned with Chain-of-Thought (CoT) Supervised Fine-Tuning (SFT) to produce explicit step-by-step reasoning enclosed in <think> … </think> tags before giving a final answer.


Base Model

meta-llama/Llama-3.2-3B


Training Data

PursuitOfDataScience/0.5M-thinking

~500 K multi-turn examples where the assistant turn contains structured chain-of-thought reasoning wrapped in <think> / </think> tags followed by the final answer.


SFT Process

Overview

The model was trained for 1 epoch on the 0.5M-thinking dataset on a single NVIDIA H100 GPU using HuggingFace Trainer.

Key Training Choices

Hyperparameter Value
Base model meta-llama/Llama-3.2-3B
Precision bfloat16
Max context length 4096 tokens
Per-device batch size 8
Gradient accumulation 8 steps (effective batch ≈ 64)
Optimizer AdamW (fused)
Learning rate 2e-5
LR scheduler Cosine
Warmup steps 100
Epochs 1
Attention SDPA (no Flash-Attention)
torch.compile ✅ (Hopper / H100)
Gradient checkpointing ✅

Label Masking

Only the assistant turn (everything after the hardcoded <think> prefix) contributes to the cross-entropy loss. The user / prompt tokens are masked with -100 so the model learns to reason, not to parrot the prompt.

Filtering

Examples whose full tokenized length (prompt + complete assistant response + special tokens) exceeds 4096 tokens are filtered out rather than truncated, ensuring the model never trains on a response with a cut-off </think> or missing final answer.


Prompt Format

The model was trained using a plain-text role-prefixed format with <think> hardcoded into the prompt so the model always begins its response with chain-of-thought reasoning.

Training format

user: {question}
assistant: <think>
{chain-of-thought reasoning}
</think>
{final answer}

Inference format (recommended)

user: {question}
Think briefly, then give the final numerical answer after ####.
assistant: <think>

The model will complete the <think> block and then produce the final answer after </think>.


GSM8K Benchmark Results (Pass@1)

Evaluated on the full GSM8K test set (1 319 examples) using vLLM with:

  • Temperature: 0.6
  • Top-p: 0.9
  • Max new tokens: 4096
  • 3 independent runs per checkpoint; results reported as mean ± std.
  • Strict numerical comparison: gold #### answer vs. extracted prediction.
Model Total Mean Acc Std
final_model 1319 39.22% 0.98%
checkpoint-1000 1319 31.13% 0.72%
checkpoint-2000 1319 33.46% 0.36%
checkpoint-3000 1319 35.30% 0.13%
checkpoint-4000 1319 39.93% 0.95%
checkpoint-5000 1319 38.49% 0.93%
checkpoint-6000 1319 39.15% 1.16%
checkpoint-7000 1319 38.77% 0.48%
base_model (no SFT) 1319 7.23% 0.70%

Best checkpoint: checkpoint-4000 at 39.93% mean accuracy.
Final merged model: 39.22% — within 1 pp of the best checkpoint.
SFT improved GSM8K accuracy by ~32 percentage points over the base model.


Usage

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_id = "PursuitOfDataScience/llama3.2-3b-thinking"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

question = "Janet's ducks lay 16 eggs per day. She eats 3 for breakfast and bakes muffins with 4. She sells the remainder at $2 per egg. How much does she make per day?"

prompt = (
    f"user: {question}\n"
    f"Think briefly, then give the final numerical answer after ####.\n"
    f"assistant: <think>\n"
)

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=1024,
        temperature=0.6,
        top_p=0.9,
        do_sample=True,
    )

response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)

Citation

If you use this model, please cite the base model and training dataset:

Downloads last month
75
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for PursuitOfDataScience/llama3.2-3b-thinking

Finetuned
(408)
this model

Dataset used to train PursuitOfDataScience/llama3.2-3b-thinking