davidsmts's picture
Upload train.py with huggingface_hub
8676835 verified
# /// script
# dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio"]
# ///
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
import trackio
dataset = load_dataset("trl-lib/Capybara", split="train")
trainer = SFTTrainer(
model="Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
peft_config=LoraConfig(r=16, lora_alpha=32),
args=SFTConfig(
output_dir="my-model",
push_to_hub=True,
hub_model_id="gemini-user/qwen-sft-test",
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
report_to="trackio",
run_name="qwen-sft-test-run",
project="qwen-sft-test-project"
)
)
trainer.train()
trainer.push_to_hub()