HuggingFaceTB/Countdown-Task-GOLD
Viewer β’ Updated β’ 149k β’ 572 β’ 1
How to use Ilyayaya/gemma_kontur with PEFT:
from peft import PeftModel
from transformers import AutoModelForCausalLM
base_model = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-it")
model = PeftModel.from_pretrained(base_model, "Ilyayaya/gemma_kontur")This model is a fine-tuned version of google/gemma-3-1b-it. It contains LoRA adapters trained to solve the Countdown math puzzle as part of a Knowledge Distillation competition.
google/gemma-3-1b-itQwen/Qwen2.5-Math-7B-InstructInstead of learning from raw equations, this model was trained to replicate the reasoning steps of the Teacher model using a strict Telegraphic Style. This reduces linguistic noise, preventing the 1B student model from wasting attention capacity on conversational fillers. DPO was applied to penalize arithmetic hallucinations and logic shortcuts identified in the student's own generations.
Since these are LoRA weights, you need to load the base Gemma model first and then apply the peft adapters.
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
# 1. Load Base Model
base_model_name = "google/gemma-3-1b-it"
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# 2. Load LoRA Adapters
lora_repo = "Ilyayaya/gemma_kontur"
model = PeftModel.from_pretrained(model, lora_repo)
# 3. Inference Example
nums = [20, 50, 2, 4, 10, 4]
target = 1007
instruction = f"Using the numbers {nums}, create an equation that equals {target}. Show work in <think> and result in <answer>."
prompt = f"<start_of_turn>user\n{instruction}<end_of_turn>\n<start_of_turn>model\n<think>\nTarget: {target}.\nStep 1:"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
# Note: The model performs best with Best-of-N sampling
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=400,
temperature=0.8,
do_sample=True
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))