| --- |
| language: |
| - en |
| license: llama3.2 |
| base_model: meta-llama/Llama-3.2-3B |
| tags: |
| - llama |
| - chain-of-thought |
| - reasoning |
| - sft |
| - math |
| datasets: |
| - PursuitOfDataScience/0.5M-thinking |
| --- |
| |
| # Llama-3.2-3B-Thinking |
|
|
| A Llama 3.2 3B model fine-tuned with **Chain-of-Thought (CoT) Supervised Fine-Tuning (SFT)** |
| to produce explicit step-by-step reasoning enclosed in `<think>` … `</think>` tags before |
| giving a final answer. |
|
|
| --- |
|
|
| ## Base Model |
|
|
| [meta-llama/Llama-3.2-3B](https://huggingface.co/meta-llama/Llama-3.2-3B) |
|
|
| --- |
|
|
| ## Training Data |
|
|
| [PursuitOfDataScience/0.5M-thinking](https://huggingface.co/datasets/PursuitOfDataScience/0.5M-thinking) |
|
|
| ~500 K multi-turn examples where the assistant turn contains structured chain-of-thought |
| reasoning wrapped in `<think>` / `</think>` tags followed by the final answer. |
|
|
| --- |
|
|
| ## SFT Process |
|
|
| ### Overview |
|
|
| The model was trained for **1 epoch** on the 0.5M-thinking dataset on a single **NVIDIA H100** |
| GPU using HuggingFace `Trainer`. |
|
|
| ### Key Training Choices |
|
|
| | Hyperparameter | Value | |
| |---|---| |
| | Base model | meta-llama/Llama-3.2-3B | |
| | Precision | bfloat16 | |
| | Max context length | 4096 tokens | |
| | Per-device batch size | 8 | |
| | Gradient accumulation | 8 steps (effective batch ≈ 64) | |
| | Optimizer | AdamW (fused) | |
| | Learning rate | 2e-5 | |
| | LR scheduler | Cosine | |
| | Warmup steps | 100 | |
| | Epochs | 1 | |
| | Attention | SDPA (no Flash-Attention) | |
| | `torch.compile` | ✅ (Hopper / H100) | |
| | Gradient checkpointing | ✅ | |
|
|
| ### Label Masking |
|
|
| Only the **assistant turn** (everything after the hardcoded `<think>` prefix) contributes |
| to the cross-entropy loss. The user / prompt tokens are masked with `-100` so the model |
| learns to reason, not to parrot the prompt. |
|
|
| ### Filtering |
|
|
| Examples whose full tokenized length (prompt + complete assistant response + special tokens) |
| exceeds **4096 tokens** are **filtered out** rather than truncated, ensuring the model never |
| trains on a response with a cut-off `</think>` or missing final answer. |
|
|
| --- |
|
|
| ## Prompt Format |
|
|
| The model was trained using a **plain-text role-prefixed** format with `<think>` hardcoded |
| into the prompt so the model always begins its response with chain-of-thought reasoning. |
|
|
| ### Training format |
|
|
| ``` |
| user: {question} |
| assistant: <think> |
| {chain-of-thought reasoning} |
| </think> |
| {final answer} |
| ``` |
|
|
| ### Inference format (recommended) |
|
|
| ``` |
| user: {question} |
| Think briefly, then give the final numerical answer after ####. |
| assistant: <think> |
| ``` |
|
|
| The model will complete the `<think>` block and then produce the final answer after `</think>`. |
|
|
| --- |
|
|
| ## GSM8K Benchmark Results (Pass@1) |
|
|
| Evaluated on the full GSM8K test set (1 319 examples) using **vLLM** with: |
| - Temperature: 0.6 |
| - Top-p: 0.9 |
| - Max new tokens: 4096 |
| - **3 independent runs** per checkpoint; results reported as mean ± std. |
| - **Strict numerical comparison**: gold `####` answer vs. extracted prediction. |
|
|
| | Model | Total | Mean Acc | Std | |
| |---|---|---|---| |
| | final_model | 1319 | 39.22% | 0.98% | |
| | checkpoint-1000 | 1319 | 31.13% | 0.72% | |
| | checkpoint-2000 | 1319 | 33.46% | 0.36% | |
| | checkpoint-3000 | 1319 | 35.30% | 0.13% | |
| | checkpoint-4000 | 1319 | 39.93% | 0.95% | |
| | checkpoint-5000 | 1319 | 38.49% | 0.93% | |
| | checkpoint-6000 | 1319 | 39.15% | 1.16% | |
| | checkpoint-7000 | 1319 | 38.77% | 0.48% | |
| | base_model (no SFT) | 1319 | 7.23% | 0.70% | |
|
|
| > **Best checkpoint**: checkpoint-4000 at **39.93%** mean accuracy. |
| > **Final merged model**: **39.22%** — within 1 pp of the best checkpoint. |
| > SFT improved GSM8K accuracy by **~32 percentage points** over the base model. |
|
|
| --- |
|
|
| ## Usage |
|
|
| ```python |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import torch |
| |
| model_id = "PursuitOfDataScience/llama3.2-3b-thinking" |
| |
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| ) |
| |
| question = "Janet's ducks lay 16 eggs per day. She eats 3 for breakfast and bakes muffins with 4. She sells the remainder at $2 per egg. How much does she make per day?" |
| |
| prompt = ( |
| f"user: {question}\n" |
| f"Think briefly, then give the final numerical answer after ####.\n" |
| f"assistant: <think>\n" |
| ) |
| |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=1024, |
| temperature=0.6, |
| top_p=0.9, |
| do_sample=True, |
| ) |
| |
| response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) |
| print(response) |
| ``` |
|
|
| --- |
|
|
| ## Citation |
|
|
| If you use this model, please cite the base model and training dataset: |
|
|
| - Base model: [meta-llama/Llama-3.2-3B](https://huggingface.co/meta-llama/Llama-3.2-3B) |
| - Training data: [PursuitOfDataScience/0.5M-thinking](https://huggingface.co/datasets/PursuitOfDataScience/0.5M-thinking) |
|
|