--- library_name: transformers license: gpl-3.0 datasets: - HuggingFaceTB/Countdown-Task-GOLD language: - en metrics: - accuracy base_model: - google/gemma-3-1b-it pipeline_tag: text-generation --- # Countdown Distillation on Gemma 3 1B ## Overview `google/gemma-3-1b-it` is a compact student model and a good fit for distillation. We trained it to solve Countdown-style arithmetic tasks: given a set of numbers and basic operators `(+, -, *, /)`, the model must create an equation that reaches a target value. Example: - Numbers: `[75, 80, 90, 24]` - Target: `61` - Solution: `90 - 80 + 75 - 24 = 61` The student is supervised with reasoning traces, generated by `Qwen2.5-7B-Instruct`, from the Countdown [dataset](https://huggingface.co/datasets/HuggingFaceTB/Countdown-Task-GOLD) and learns to produce the final equation in `` and `` format. ## Dataset The training data contains verified Countdown solutions with the following fields: `target`, `nums`, and `messages`. The final maximum sequence length is `1024` and the split is `95/5`: - Train: `27,809` samples - Validation: `1,464` samples The token-length distribution: ![output_8_0](https://cdn-uploads.huggingface.co/production/uploads/6957bafe54c6b170be4df9cb/lL9oZ0bfrX71-lBEC9-PC.png) ## Training Distillation was performed with the following setup: - GPU: NVIDIA GeForce RTX 5090 - VRAM: 31.35 GB - CPU: Ryzen 9 9950X - RAM: 64 GB Training settings: - max sequence length: `1024` - batch size: `4` - gradient accumulation: `8` - epochs: `1` - learning rate: `2e-4` - warmup ratio: `0.1` - scheduler: cosine - optimiser: `adamw_torch` - LoRA rank: `16` - LoRA alpha: `32` - LoRA dropout: `0.05` The best checkpoint is selected by validation loss. ## Loss and accuracy curves The training and validation losses show a steady downward trend and then settle near a stable plateau. ![output_20_0](https://cdn-uploads.huggingface.co/production/uploads/6957bafe54c6b170be4df9cb/_o5DA1SvydEIoxFMe4Puk.png) Also available as a logarithmic plot: ![output_21_0](https://cdn-uploads.huggingface.co/production/uploads/6957bafe54c6b170be4df9cb/iF5cntbkGrLWdo4JafYhe.png) Validation accuracy gradually grows with small oscillations: ![output_22_0](https://cdn-uploads.huggingface.co/production/uploads/6957bafe54c6b170be4df9cb/rR8jUCL7lBRZ9SrPb0_LJ.png) ## Evaluation Validation was run on the first `1,000` examples of the validation split with batch size `200`. The validation accuracy is: - Original model: `0.1310` (`131/1000`) - Reasoning fine-tuning: `0.82` (`820/1000`) ## Inference Use these two cells for inference. ```python import torch from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel base_model_id = "google/gemma-3-1b-it" adapter_id = "pymlex/gemma3-1b-countdown" tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" base_model = AutoModelForCausalLM.from_pretrained( base_model_id, device_map="auto", torch_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16, trust_remote_code=True, ) model = PeftModel.from_pretrained(base_model, adapter_id) model.eval() ```` ```python def generate_continuation(model, tokenizer, prompt, max_new_tokens=850): inputs = tokenizer(prompt, return_tensors="pt").to(model.device) prompt_len = inputs.input_ids.shape[1] outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=0.7, top_p=0.95, do_sample=True, repetition_penalty=1.05, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, ) decoded = tokenizer.decode(outputs[0][prompt_len:], skip_special_tokens=True) return decoded.strip() sample_prompt = ( "Using the numbers [78, 46, 93], create an equation that equals 61. " "You can use basic arithmetic operations (+, -, *, /) and each number can only be used once." ) output = generate_continuation(model, tokenizer, sample_prompt, max_new_tokens=850) print("Prompt:") print(sample_prompt) print("\nGenerated continuation:") print(output) ```