--- language: - en license: llama3.2 base_model: meta-llama/Llama-3.2-3B tags: - llama - chain-of-thought - reasoning - sft - math datasets: - PursuitOfDataScience/0.5M-thinking --- # Llama-3.2-3B-Thinking A Llama 3.2 3B model fine-tuned with **Chain-of-Thought (CoT) Supervised Fine-Tuning (SFT)** to produce explicit step-by-step reasoning enclosed in `` … `` tags before giving a final answer. --- ## Base Model [meta-llama/Llama-3.2-3B](https://huggingface.co/meta-llama/Llama-3.2-3B) --- ## Training Data [PursuitOfDataScience/0.5M-thinking](https://huggingface.co/datasets/PursuitOfDataScience/0.5M-thinking) ~500 K multi-turn examples where the assistant turn contains structured chain-of-thought reasoning wrapped in `` / `` tags followed by the final answer. --- ## SFT Process ### Overview The model was trained for **1 epoch** on the 0.5M-thinking dataset on a single **NVIDIA H100** GPU using HuggingFace `Trainer`. ### Key Training Choices | Hyperparameter | Value | |---|---| | Base model | meta-llama/Llama-3.2-3B | | Precision | bfloat16 | | Max context length | 4096 tokens | | Per-device batch size | 8 | | Gradient accumulation | 8 steps (effective batch ≈ 64) | | Optimizer | AdamW (fused) | | Learning rate | 2e-5 | | LR scheduler | Cosine | | Warmup steps | 100 | | Epochs | 1 | | Attention | SDPA (no Flash-Attention) | | `torch.compile` | ✅ (Hopper / H100) | | Gradient checkpointing | ✅ | ### Label Masking Only the **assistant turn** (everything after the hardcoded `` prefix) contributes to the cross-entropy loss. The user / prompt tokens are masked with `-100` so the model learns to reason, not to parrot the prompt. ### Filtering Examples whose full tokenized length (prompt + complete assistant response + special tokens) exceeds **4096 tokens** are **filtered out** rather than truncated, ensuring the model never trains on a response with a cut-off `` or missing final answer. --- ## Prompt Format The model was trained using a **plain-text role-prefixed** format with `` hardcoded into the prompt so the model always begins its response with chain-of-thought reasoning. ### Training format ``` user: {question} assistant: {chain-of-thought reasoning} {final answer} ``` ### Inference format (recommended) ``` user: {question} Think briefly, then give the final numerical answer after ####. assistant: ``` The model will complete the `` block and then produce the final answer after ``. --- ## GSM8K Benchmark Results (Pass@1) Evaluated on the full GSM8K test set (1 319 examples) using **vLLM** with: - Temperature: 0.6 - Top-p: 0.9 - Max new tokens: 4096 - **3 independent runs** per checkpoint; results reported as mean ± std. - **Strict numerical comparison**: gold `####` answer vs. extracted prediction. | Model | Total | Mean Acc | Std | |---|---|---|---| | final_model | 1319 | 39.22% | 0.98% | | checkpoint-1000 | 1319 | 31.13% | 0.72% | | checkpoint-2000 | 1319 | 33.46% | 0.36% | | checkpoint-3000 | 1319 | 35.30% | 0.13% | | checkpoint-4000 | 1319 | 39.93% | 0.95% | | checkpoint-5000 | 1319 | 38.49% | 0.93% | | checkpoint-6000 | 1319 | 39.15% | 1.16% | | checkpoint-7000 | 1319 | 38.77% | 0.48% | | base_model (no SFT) | 1319 | 7.23% | 0.70% | > **Best checkpoint**: checkpoint-4000 at **39.93%** mean accuracy. > **Final merged model**: **39.22%** — within 1 pp of the best checkpoint. > SFT improved GSM8K accuracy by **~32 percentage points** over the base model. --- ## Usage ```python from transformers import AutoTokenizer, AutoModelForCausalLM import torch model_id = "PursuitOfDataScience/llama3.2-3b-thinking" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto", ) question = "Janet's ducks lay 16 eggs per day. She eats 3 for breakfast and bakes muffins with 4. She sells the remainder at $2 per egg. How much does she make per day?" prompt = ( f"user: {question}\n" f"Think briefly, then give the final numerical answer after ####.\n" f"assistant: \n" ) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=1024, temperature=0.6, top_p=0.9, do_sample=True, ) response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) print(response) ``` --- ## Citation If you use this model, please cite the base model and training dataset: - Base model: [meta-llama/Llama-3.2-3B](https://huggingface.co/meta-llama/Llama-3.2-3B) - Training data: [PursuitOfDataScience/0.5M-thinking](https://huggingface.co/datasets/PursuitOfDataScience/0.5M-thinking)