File size: 771 Bytes
8676835 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
# /// script
# dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio"]
# ///
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
import trackio
dataset = load_dataset("trl-lib/Capybara", split="train")
trainer = SFTTrainer(
model="Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
peft_config=LoraConfig(r=16, lora_alpha=32),
args=SFTConfig(
output_dir="my-model",
push_to_hub=True,
hub_model_id="gemini-user/qwen-sft-test",
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
report_to="trackio",
run_name="qwen-sft-test-run",
project="qwen-sft-test-project"
)
)
trainer.train()
trainer.push_to_hub()
|