Buckets:
| # /// script | |
| # dependencies = [ | |
| # "trl", | |
| # "peft", | |
| # "datasets", | |
| # "transformers", | |
| # "accelerate" | |
| # ] | |
| # /// | |
| from datasets import load_dataset | |
| from trl.experimental.distillation import DistillationConfig, DistillationTrainer | |
| # 1. Load dataset and format as prompt-only chat messages | |
| dataset = load_dataset("openai/gsm8k", "main", split="train") | |
| dataset = dataset.map( | |
| lambda x: {"messages": [{"role": "user", "content": x["question"]}]}, | |
| remove_columns=dataset.column_names, | |
| ) | |
| # 2. Configure distillation | |
| config = DistillationConfig( | |
| output_dir="results/distill-qwen-gsm8k", | |
| num_train_epochs=1, | |
| bf16=True, | |
| save_strategy="no", | |
| # Distillation | |
| lmbda=1.0, # fully on-policy (student generates) | |
| beta=1.0, # reverse KL | |
| # Teacher | |
| teacher_model_init_kwargs={"torch_dtype": "bfloat16"}, | |
| push_to_hub=True, | |
| hub_model_id="Qwen2.5-1.5B-Instruct-gsm8k", | |
| ) | |
| # 3. Train | |
| trainer = DistillationTrainer( | |
| model="Qwen/Qwen2.5-1.5B-Instruct", | |
| teacher_model="Qwen/Qwen2.5-7B-Instruct", | |
| args=config, | |
| train_dataset=dataset, | |
| ) | |
| trainer.train() | |
| trainer.push_to_hub() | |
| trainer.save_model() |
Xet Storage Details
- Size:
- 1.19 kB
- Xet hash:
- 1e300d890c2236e4d628cb4cd39693cac2ecd33ed737b45ea611a65583db6b14
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.