davidsmts commited on
Commit
8676835
·
verified ·
1 Parent(s): 9d10b29

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +31 -0
train.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # /// script
3
+ # dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio"]
4
+ # ///
5
+
6
+ from datasets import load_dataset
7
+ from peft import LoraConfig
8
+ from trl import SFTTrainer, SFTConfig
9
+ import trackio
10
+
11
+ dataset = load_dataset("trl-lib/Capybara", split="train")
12
+
13
+ trainer = SFTTrainer(
14
+ model="Qwen/Qwen2.5-0.5B",
15
+ train_dataset=dataset,
16
+ peft_config=LoraConfig(r=16, lora_alpha=32),
17
+ args=SFTConfig(
18
+ output_dir="my-model",
19
+ push_to_hub=True,
20
+ hub_model_id="gemini-user/qwen-sft-test",
21
+ num_train_epochs=1,
22
+ per_device_train_batch_size=1,
23
+ gradient_accumulation_steps=4,
24
+ report_to="trackio",
25
+ run_name="qwen-sft-test-run",
26
+ project="qwen-sft-test-project"
27
+ )
28
+ )
29
+
30
+ trainer.train()
31
+ trainer.push_to_hub()