Countdown Distillation on Gemma 3 1B
Overview
google/gemma-3-1b-it is a compact student model and a good fit for distillation. We trained it to solve Countdown-style arithmetic tasks: given a set of numbers and basic operators (+, -, *, /), the model must create an equation that reaches a target value. Example:
- Numbers:
[75, 80, 90, 24] - Target:
61 - Solution:
90 - 80 + 75 - 24 = 61
The student is supervised with reasoning traces, generated by Qwen2.5-7B-Instruct, from the Countdown dataset and learns to produce the final equation in <think> and <answer> format.
Dataset
The training data contains verified Countdown solutions with the following fields: target, nums, and messages. The final maximum sequence length is 1024 and the split is 95/5:
- Train:
27,809samples - Validation:
1,464samples
The token-length distribution:
Training
Distillation was performed with the following setup:
- GPU: NVIDIA GeForce RTX 5090
- VRAM: 31.35 GB
- CPU: Ryzen 9 9950X
- RAM: 64 GB
Training settings:
- max sequence length:
1024 - batch size:
4 - gradient accumulation:
8 - epochs:
1 - learning rate:
2e-4 - warmup ratio:
0.1 - scheduler: cosine
- optimiser:
adamw_torch - LoRA rank:
16 - LoRA alpha:
32 - LoRA dropout:
0.05
The best checkpoint is selected by validation loss.
Loss and accuracy curves
The training and validation losses show a steady downward trend and then settle near a stable plateau.
Also available as a logarithmic plot:
Validation accuracy gradually grows with small oscillations:
Evaluation
Validation was run on the first 1,000 examples of the validation split with batch size 200. The validation accuracy is:
- Original model:
0.1310(131/1000) - Reasoning fine-tuning:
0.82(820/1000)
Inference
Use these two cells for inference.
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
base_model_id = "google/gemma-3-1b-it"
adapter_id = "pymlex/gemma3-1b-countdown"
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16,
trust_remote_code=True,
)
model = PeftModel.from_pretrained(base_model, adapter_id)
model.eval()
def generate_continuation(model, tokenizer, prompt, max_new_tokens=850):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
prompt_len = inputs.input_ids.shape[1]
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=0.7,
top_p=0.95,
do_sample=True,
repetition_penalty=1.05,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
decoded = tokenizer.decode(outputs[0][prompt_len:], skip_special_tokens=True)
return decoded.strip()
sample_prompt = (
"Using the numbers [78, 46, 93], create an equation that equals 61. "
"You can use basic arithmetic operations (+, -, *, /) and each number can only be used once."
)
output = generate_continuation(model, tokenizer, sample_prompt, max_new_tokens=850)
print("Prompt:")
print(sample_prompt)
print("\nGenerated continuation:")
print(output)



