Text Generation
Transformers
Safetensors
English
conversational

Countdown Distillation on Gemma 3 1B

Overview

google/gemma-3-1b-it is a compact student model and a good fit for distillation. We trained it to solve Countdown-style arithmetic tasks: given a set of numbers and basic operators (+, -, *, /), the model must create an equation that reaches a target value. Example:

  • Numbers: [75, 80, 90, 24]
  • Target: 61
  • Solution: 90 - 80 + 75 - 24 = 61

The student is supervised with answers, generated by Qwen3-4B-Instruct-2507, from the Countdown dataset and learns to produce the final equation in <answer> format.

Dataset

The training data contains verified Countdown solutions with the following fields: target, nums, and messages. The final maximum sequence length is 1024 and the split is 95/5:

  • Train: 20,216 samples
  • Validation: 1,064 samples

The token-length distribution:

output_9_0

Training

Distillation was performed with the following setup:

  • GPU: NVIDIA GeForce RTX 5090
  • VRAM: 31.35 GB
  • CPU: Ryzen 9 9950X
  • RAM: 64 GB

Training settings:

  • max sequence length: 1024
  • batch size: 4
  • gradient accumulation: 8
  • epochs: 1
  • learning rate: 2e-4
  • warmup ratio: 0.1
  • scheduler: cosine
  • optimiser: adamw_torch
  • LoRA rank: 16
  • LoRA alpha: 32
  • LoRA dropout: 0.05

The best checkpoint is selected by validation loss.

Loss curves

The training and validation losses show a steady downward trend and then settle near a stable plateau:

output_22_0

Also a log-scaled variant of the plot:

output_23_0

Evaluation

Validation was run on a part of the validation split of 1,064 examples. The validation accuracy is:

  • Original model: 0.1310 (131/1000)
  • Non-reasoning fine-tuning: 0.8412 (895/1064)

Inference

Use these two cells for inference.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

base_model_id = "google/gemma-3-1b-it"
adapter_id = "pymlex/gemma3-1b-countdown"

tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16,
    trust_remote_code=True,
)

model = PeftModel.from_pretrained(base_model, adapter_id)
model.eval()
def generate_continuation(model, tokenizer, prompt, max_new_tokens=850):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    prompt_len = inputs.input_ids.shape[1]

    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=0.7,
        top_p=0.95,
        do_sample=True,
        repetition_penalty=1.05,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
    )

    decoded = tokenizer.decode(outputs[0][prompt_len:], skip_special_tokens=True)
    return decoded.strip()


sample_prompt = (
    "Using the numbers [78, 46, 93], create an equation that equals 61. "
    "You can use basic arithmetic operations (+, -, *, /) and each number can only be used once."
)

output = generate_continuation(model, tokenizer, sample_prompt, max_new_tokens=850)
print("Prompt:")
print(sample_prompt)
print("\nGenerated continuation:")
print(output)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for pymlex/gemma3-1b-countdown

Finetuned
(538)
this model

Dataset used to train pymlex/gemma3-1b-countdown