Text Generation
Transformers
Safetensors
English
conversational
pymlex's picture
Update README.md
2a9586b verified
---
library_name: transformers
license: gpl-3.0
datasets:
- HuggingFaceTB/Countdown-Task-GOLD
language:
- en
metrics:
- accuracy
base_model:
- google/gemma-3-1b-it
pipeline_tag: text-generation
---
# Countdown Distillation on Gemma 3 1B
## Overview
`google/gemma-3-1b-it` is a compact student model and a good fit for distillation. We trained it to solve Countdown-style arithmetic tasks: given a set of numbers and basic operators `(+, -, *, /)`, the model must create an equation that reaches a target value. Example:
- Numbers: `[75, 80, 90, 24]`
- Target: `61`
- Solution: `90 - 80 + 75 - 24 = 61`
The student is supervised with reasoning traces, generated by `Qwen2.5-7B-Instruct`, from the Countdown [dataset](https://huggingface.co/datasets/HuggingFaceTB/Countdown-Task-GOLD) and learns to produce the final equation in `<think>` and `<answer>` format.
## Dataset
The training data contains verified Countdown solutions with the following fields: `target`, `nums`, and `messages`. The final maximum sequence length is `1024` and the split is `95/5`:
- Train: `27,809` samples
- Validation: `1,464` samples
The token-length distribution:
![output_8_0](https://cdn-uploads.huggingface.co/production/uploads/6957bafe54c6b170be4df9cb/lL9oZ0bfrX71-lBEC9-PC.png)
## Training
Distillation was performed with the following setup:
- GPU: NVIDIA GeForce RTX 5090
- VRAM: 31.35 GB
- CPU: Ryzen 9 9950X
- RAM: 64 GB
Training settings:
- max sequence length: `1024`
- batch size: `4`
- gradient accumulation: `8`
- epochs: `1`
- learning rate: `2e-4`
- warmup ratio: `0.1`
- scheduler: cosine
- optimiser: `adamw_torch`
- LoRA rank: `16`
- LoRA alpha: `32`
- LoRA dropout: `0.05`
The best checkpoint is selected by validation loss.
## Loss and accuracy curves
The training and validation losses show a steady downward trend and then settle near a stable plateau.
![output_20_0](https://cdn-uploads.huggingface.co/production/uploads/6957bafe54c6b170be4df9cb/_o5DA1SvydEIoxFMe4Puk.png)
Also available as a logarithmic plot:
![output_21_0](https://cdn-uploads.huggingface.co/production/uploads/6957bafe54c6b170be4df9cb/iF5cntbkGrLWdo4JafYhe.png)
Validation accuracy gradually grows with small oscillations:
![output_22_0](https://cdn-uploads.huggingface.co/production/uploads/6957bafe54c6b170be4df9cb/rR8jUCL7lBRZ9SrPb0_LJ.png)
## Evaluation
Validation was run on the first `1,000` examples of the validation split with batch size `200`. The validation accuracy is:
- Original model: `0.1310` (`131/1000`)
- Reasoning fine-tuning: `0.82` (`820/1000`)
## Inference
Use these two cells for inference.
```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
base_model_id = "google/gemma-3-1b-it"
adapter_id = "pymlex/gemma3-1b-countdown"
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16,
trust_remote_code=True,
)
model = PeftModel.from_pretrained(base_model, adapter_id)
model.eval()
````
```python
def generate_continuation(model, tokenizer, prompt, max_new_tokens=850):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
prompt_len = inputs.input_ids.shape[1]
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=0.7,
top_p=0.95,
do_sample=True,
repetition_penalty=1.05,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
decoded = tokenizer.decode(outputs[0][prompt_len:], skip_special_tokens=True)
return decoded.strip()
sample_prompt = (
"Using the numbers [78, 46, 93], create an equation that equals 61. "
"You can use basic arithmetic operations (+, -, *, /) and each number can only be used once."
)
output = generate_continuation(model, tokenizer, sample_prompt, max_new_tokens=850)
print("Prompt:")
print(sample_prompt)
print("\nGenerated continuation:")
print(output)
```