Text Generation
Transformers
Safetensors
English
conversational
File size: 4,282 Bytes
53e6305
 
552da21
 
 
 
 
 
 
 
 
 
53e6305
 
552da21
53e6305
552da21
53e6305
552da21
53e6305
552da21
 
 
53e6305
552da21
53e6305
552da21
53e6305
552da21
53e6305
552da21
 
53e6305
552da21
53e6305
552da21
53e6305
552da21
53e6305
552da21
53e6305
552da21
 
 
 
53e6305
552da21
53e6305
552da21
 
 
 
 
 
 
 
 
 
 
53e6305
552da21
53e6305
552da21
53e6305
552da21
53e6305
552da21
53e6305
552da21
53e6305
552da21
53e6305
552da21
53e6305
552da21
53e6305
 
 
3f2b440
53e6305
2a9586b
3f2b440
53e6305
552da21
53e6305
552da21
53e6305
552da21
 
 
 
53e6305
552da21
 
53e6305
552da21
 
 
 
53e6305
552da21
 
 
 
 
 
53e6305
552da21
 
 
53e6305
552da21
 
 
 
53e6305
552da21
 
 
 
 
 
 
 
 
 
53e6305
552da21
 
53e6305
 
552da21
 
 
 
53e6305
552da21
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
---
library_name: transformers
license: gpl-3.0
datasets:
- HuggingFaceTB/Countdown-Task-GOLD
language:
- en
metrics:
- accuracy
base_model:
- google/gemma-3-1b-it
pipeline_tag: text-generation
---

# Countdown Distillation on Gemma 3 1B

## Overview

`google/gemma-3-1b-it` is a compact student model and a good fit for distillation. We trained it to solve Countdown-style arithmetic tasks: given a set of numbers and basic operators `(+, -, *, /)`, the model must create an equation that reaches a target value. Example:

- Numbers: `[75, 80, 90, 24]`
- Target: `61`
- Solution: `90 - 80 + 75 - 24 = 61`

The student is supervised with reasoning traces, generated by `Qwen2.5-7B-Instruct`, from the Countdown [dataset](https://huggingface.co/datasets/HuggingFaceTB/Countdown-Task-GOLD) and learns to produce the final equation in `<think>` and `<answer>` format.

## Dataset

The training data contains verified Countdown solutions with the following fields: `target`, `nums`, and `messages`. The final maximum sequence length is `1024` and the split is `95/5`:

- Train: `27,809` samples
- Validation: `1,464` samples

The token-length distribution:

![output_8_0](https://cdn-uploads.huggingface.co/production/uploads/6957bafe54c6b170be4df9cb/lL9oZ0bfrX71-lBEC9-PC.png)

## Training

Distillation was performed with the following setup:

- GPU: NVIDIA GeForce RTX 5090
- VRAM: 31.35 GB
- CPU: Ryzen 9 9950X
- RAM: 64 GB

Training settings:

- max sequence length: `1024`
- batch size: `4`
- gradient accumulation: `8`
- epochs: `1`
- learning rate: `2e-4`
- warmup ratio: `0.1`
- scheduler: cosine
- optimiser: `adamw_torch`
- LoRA rank: `16`
- LoRA alpha: `32`
- LoRA dropout: `0.05`

The best checkpoint is selected by validation loss.

## Loss and accuracy curves

The training and validation losses show a steady downward trend and then settle near a stable plateau.

![output_20_0](https://cdn-uploads.huggingface.co/production/uploads/6957bafe54c6b170be4df9cb/_o5DA1SvydEIoxFMe4Puk.png)

Also available as a logarithmic plot:

![output_21_0](https://cdn-uploads.huggingface.co/production/uploads/6957bafe54c6b170be4df9cb/iF5cntbkGrLWdo4JafYhe.png)

Validation accuracy gradually grows with small oscillations:

![output_22_0](https://cdn-uploads.huggingface.co/production/uploads/6957bafe54c6b170be4df9cb/rR8jUCL7lBRZ9SrPb0_LJ.png)

## Evaluation

Validation was run on the first `1,000` examples of the validation split with batch size `200`. The validation accuracy is:

- Original model: `0.1310` (`131/1000`)
- Reasoning fine-tuning: `0.82` (`820/1000`)

## Inference

Use these two cells for inference.

```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

base_model_id = "google/gemma-3-1b-it"
adapter_id = "pymlex/gemma3-1b-countdown"

tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16,
    trust_remote_code=True,
)

model = PeftModel.from_pretrained(base_model, adapter_id)
model.eval()
````

```python
def generate_continuation(model, tokenizer, prompt, max_new_tokens=850):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    prompt_len = inputs.input_ids.shape[1]

    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=0.7,
        top_p=0.95,
        do_sample=True,
        repetition_penalty=1.05,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
    )

    decoded = tokenizer.decode(outputs[0][prompt_len:], skip_special_tokens=True)
    return decoded.strip()


sample_prompt = (
    "Using the numbers [78, 46, 93], create an equation that equals 61. "
    "You can use basic arithmetic operations (+, -, *, /) and each number can only be used once."
)

output = generate_continuation(model, tokenizer, sample_prompt, max_new_tokens=850)
print("Prompt:")
print(sample_prompt)
print("\nGenerated continuation:")
print(output)
```