moos124's picture
download
raw
1.19 kB
# /// script
# dependencies = [
# "trl",
# "peft",
# "datasets",
# "transformers",
# "accelerate"
# ]
# ///
from datasets import load_dataset
from trl.experimental.distillation import DistillationConfig, DistillationTrainer
# 1. Load dataset and format as prompt-only chat messages
dataset = load_dataset("openai/gsm8k", "main", split="train")
dataset = dataset.map(
lambda x: {"messages": [{"role": "user", "content": x["question"]}]},
remove_columns=dataset.column_names,
)
# 2. Configure distillation
config = DistillationConfig(
output_dir="results/distill-qwen-gsm8k",
num_train_epochs=1,
bf16=True,
save_strategy="no",
# Distillation
lmbda=1.0, # fully on-policy (student generates)
beta=1.0, # reverse KL
# Teacher
teacher_model_init_kwargs={"torch_dtype": "bfloat16"},
push_to_hub=True,
hub_model_id="Qwen2.5-1.5B-Instruct-gsm8k",
)
# 3. Train
trainer = DistillationTrainer(
model="Qwen/Qwen2.5-1.5B-Instruct",
teacher_model="Qwen/Qwen2.5-7B-Instruct",
args=config,
train_dataset=dataset,
)
trainer.train()
trainer.push_to_hub()
trainer.save_model()

Xet Storage Details

Size:
1.19 kB
·
Xet hash:
1e300d890c2236e4d628cb4cd39693cac2ecd33ed737b45ea611a65583db6b14

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.