Gemma3Think - Teaching a Small Model to Think using a TPU
Training a small model to reason properly is a challenging task, particularly when working on unfamiliar hardware (TPU) using a new library (Tunix). This section details our systematic approach to this challenge.
Research Foundation
Our first step is to research current techniques for creating SOTA reasoning models. A key insight we gain from the DeepSeek R1 is the two-phase training recipe: an initial SFT phase to teach the model the correct format and really basic reasoning formats, then followed by RL to enhance the reasoning quality and structure.
More importantly, the Small Models Struggle to Learn from Strong Reasoners paper demonstrated that modes less than 3B will perform better with short and simple reasoning traces instead of long, verbose ones coming from larger models like DeepSeek R1. This led to our crucial decision to synthetically generate most of our reasoning traces, as most of the existing reasoning datasets are long CoT.
Dataset Composition and Generation
We created a custom multi-domain dataset covering 9 categories with about 32k rows.
| Category | Strategy | Description | Rows |
|---|---|---|---|
| creative_writing | semantic | Open-ended creative tasks | 4709 |
| brainstorming | semantic | Ideation and suggestions | 3766 |
| summarization | semantic | Text condensation | 2188 |
| math | exact | Numerical problems | 7473 |
| classification | exact | Label prediction | 2136 |
| code | keyword | Python programming | 427 |
| information_extraction | keyword | Structured data extraction | 1506 |
| general_qa | hybrid | Factual questions | 7706 |
| science_qa | hybrid | Scientific concepts | 3000 |
| (Strategy will be used in the reward function) |
We used Qwen3-14b to synthetically generate the reasoning traces. It is the best model that generates high-quality short reasoning traces while being extremely cost-effective (~70 tps on a 5070-TI).
This is an example of the prompt fed to Qwen:
You will be given a question, and a passage of relevant information. Generate the reasoning traces that get to the final answer.
Your reasoning should logically lead to your final answer, think about how you will structure the reasoning traces.
Your final answer should follow this format:
<reasoning>reasoning goes here</reasoning><answer>answer goes in here</answer>
Besides our 32k version, we also have a 20k and a 10k version of the dataset. The seed dataset can be accessed via the dataset card.
Each dataset also has a sft and rl subset, which are 20% and 80% respectively of the full main dataset.
Two-Phase Training Pipeline
As mentioned before, we have a two phase training pipeline. SFT first, then RL.
SFT Phase
Goal: familiarize the model with the <reasoning>...</reasoning><answer>...</answer> format and the type of responses we expect to see from it.
Training Config:
- Base model: Gemma3-1b-it
- Lora:
- Rank: 32
- Alpha: 64
- Learning Rate: 2e-4 with cosine decay and 10% warm-up
- Batch size: 2, Max length: 512 tokens
- Gradient Clipping: 1.0
- 1 epoch over the
maindataset.
Final checkpoint(safetensors): chimbiwide/gemma-3-1b-it-thinking-32k-sft-base
GRPO Phase
Goal: Refine reasoning quality and format with verifiable rewards.
Our SFT model would become our base model that we finetune with GRPO.
Training Config:
- Base model: chimbiwide/gemma-3-1b-it-thinking-32k-sft-base
- Lora:
- Rank: 32
- Alpha: 64
- Learning Rate: 5e-6 with cosine decay and 15% warm-up
- KL penalty (β): 0.2 (prevents policy drift)
- Clip epsilon (ε): 0.1
- Generations per prompt: **4 **
- Max gradient norm: 0.5
- Temperature: 1.0
- 1 epoch over the
rldataset.
Final model: chimbiwide/gemma-3-1b-it-thinking-32k-grpo-merged
Reward Function Design
The most important part is our category-aware reward system, which determines the evaluation method of each prompt based on its category.
This is also our only reward function. What we found is that the model learns so well after the SFT phase that it gets the format correct 99% of the time. Rendering special format rewards useless.
For which strategy each category uses, check the table in the datasets section.
Category-Specific Answer Rewards
We have 6 individual functions. Each returning a value of [0,1]
- semantic_score: uses
all-mpnet-base-v2sentence embeddings to compare the generated answer to reference answer using cosine similarity. Our replacement for LLM-As-Judge. More predictable, quantifiable and faster. - exact_score: exact match for math problems
- keyword_score: removes some common words such as "a", "the"; then checking how many of the "keywords" in the generated answer matches with the "keywords" from the reference. Used for tasks like coding, where we don't want to actually execute the code as the model is small and cant really generate code that can run/compile.
- length_score: check to see if the answer is way too long or way too short compared to the reference. Too long or too short often indicates poor answer in creative writing tasks.
- reasoning_score: check for appropriate reasoning length. Too long or too short indicates poor quality. There are also additional bonus points if the model uses structural words such as "first", "then", "therefore".
- format_bonus: Still a little bonus for following the format
Semantic Strategy
- 40% Semantic + 15% length + 25% reasoning + 20% format
Exact Strategy
- 50% exact + 30% reasoning + 20% format
Hybrid Strategy
- 35% Semantic + 20% keyword + % 25% reasoning + 20% format
Keyword Strategy
- 40% keyword + 20% semantic + 20% reasoning + 20% format
Compute Allocation
Due to the difficulty of getting a Kaggle TPU runtime. Most of our testing was done in Colab using v6e-1 with High-Ram.
The SFT phase approximately takes about 30 minutes on the 32k dataset. GRPO takes about 7 hours on the 28k RL subset. Costing about 30 compute units for each training run.
Design Decisions and Trade-offs
Our biggest problem is how to accurately grade open-ended tasks. We had 3 different ideas: LLM-as-Judge, Bert-Score, and Embedding Models.
LLM-as-Judge
Theoretically LLM-as-Judge is ideal for evaluating creative responses. But is impractical for RL training in Kaggle:
- Tunix supports a limited number of models and loading a large model for accuracy will greatly slow down the actual training.
- Unreliable across different runs. Very unreliable.
- Way too slow for limited TPU runtime
Bert-Score Bert-Score is another alternative to LLM-as-Judge. That should be faster while returning numerical scores making rewards easier to quantify. However, Bert-Score fails to capture when the generated answer is completely different than the reference. Additionally, it suffers from slow processing time and in many cases takes over 1 second to respond.
Embedding Models: With our prior experience in using embedding models with vector databases, we decided to use all-mpnet-base-v2 to compute embedding vectors and compare the generated answer to the correct one. This offers an excellent solution:
- Fast Inference (~20ms per comparison)
- Does not interfere with the TPU (runs on CPU)
- Reliably captures the semantic meaning
- Quantifiable returns, easy to integrate into a reward function.
Challenges and Solutions
As with training LLMs, wrong parameters can lead to devastating results. With SFT, the variables are only constrained to gradient clipping and learning rate most of the times. RL/GRPO introduces a lot more variables that can lead to a corrupted model. We have experience a few issues that we will list here:
- Instead of the learning rate and gradient clipping being to high eventually causing an NaN error. Our original learning rate (1e-6) and gradient clipping(0.02) was way too low, causing the model to barely learn.
- We originally had three reward functions (1 category, 1 complete format reward and 1 partial format reward). However, the loss function was around 0 and perplexity stayed at 1. Indicating the model is almost perfectly predicting the token. We later found that the issue is our base-SFT model is following the format exactly most of the time. Essentially gaining most of the reward at the start, cause it to not explore and learn. That's why we later removed these two functions and just add a 20% format bonus.
- Due to the final model submission being a Orbax checkpoint instead of a full model. Our strategy of SFT --> save as safetensor --> reload then GRPO. Would not work as the GRPO checkpoint would be applied to the base Gemma model instead of our SFT model. So we tried to use one single Lora adapter for the entire training run (both SFT and GRPO), but this caused catastrophic forgetting and the final cant even generate the tags consistently anymore.
Limitations and Future Work
Limitations:
- We there is room for improvement with the reward function for math&coding, resulting in mediocre performance in these areas.
- Using semantic similarity is fast, but it may miss subtle creative nuances. Future Improvements:
- Using a better model for generating the reasoning-traces can lead to an improve in performance.
- Experimenting with LLM-as-Judge in a less compute-restrained environment. Maybe fine-tune a small model to do judging.
- Expand the diversity of the dataset with more high-quality data.