epinfomax commited on
Commit
b41204b
·
verified ·
1 Parent(s): 3468316

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +52 -0
train.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "datasets", "transformers", "accelerate"]
3
+ # ///
4
+
5
+ from datasets import load_dataset
6
+ from peft import LoraConfig
7
+ from trl import SFTTrainer, SFTConfig
8
+ import trackio
9
+ import os
10
+
11
+ print("🚀 Starting FunctionGemma 2B Fine-tuning")
12
+
13
+ # Load dataset
14
+ dataset = load_dataset("epinfomax/vn-function-calling-dataset", split="train")
15
+
16
+ # Training configuration
17
+ config = SFTConfig(
18
+ output_dir="vn-function-gemma-finetuned",
19
+ push_to_hub=True,
20
+ hub_model_id="epinfomax/vn-function-gemma-finetuned",
21
+ hub_strategy="every_save",
22
+ num_train_epochs=3,
23
+ per_device_train_batch_size=4,
24
+ gradient_accumulation_steps=4,
25
+ learning_rate=2e-5,
26
+ logging_steps=10,
27
+ save_strategy="steps",
28
+ save_steps=50,
29
+ report_to="trackio",
30
+ project="vn-function-calling",
31
+ run_name="function-gemma-2b-baseline"
32
+ )
33
+
34
+ # LoRA configuration
35
+ peft_config = LoraConfig(
36
+ r=16,
37
+ lora_alpha=32,
38
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
39
+ task_type="CAUSAL_LM",
40
+ )
41
+
42
+ # Initialize and train
43
+ trainer = SFTTrainer(
44
+ model="google/function-gemma-2b",
45
+ train_dataset=dataset,
46
+ peft_config=peft_config,
47
+ args=config,
48
+ )
49
+
50
+ trainer.train()
51
+ trainer.push_to_hub()
52
+ print("✅ Training complete and pushed to Hub!")