epinfomax commited on
Commit
4c821d6
·
verified ·
1 Parent(s): b41204b

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +10 -10
train.py CHANGED
@@ -8,27 +8,27 @@ from trl import SFTTrainer, SFTConfig
8
  import trackio
9
  import os
10
 
11
- print("🚀 Starting FunctionGemma 2B Fine-tuning")
12
 
13
  # Load dataset
14
  dataset = load_dataset("epinfomax/vn-function-calling-dataset", split="train")
15
 
16
  # Training configuration
17
  config = SFTConfig(
18
- output_dir="vn-function-gemma-finetuned",
19
  push_to_hub=True,
20
- hub_model_id="epinfomax/vn-function-gemma-finetuned",
21
  hub_strategy="every_save",
22
- num_train_epochs=3,
23
- per_device_train_batch_size=4,
24
- gradient_accumulation_steps=4,
25
- learning_rate=2e-5,
26
- logging_steps=10,
27
  save_strategy="steps",
28
  save_steps=50,
29
  report_to="trackio",
30
  project="vn-function-calling",
31
- run_name="function-gemma-2b-baseline"
32
  )
33
 
34
  # LoRA configuration
@@ -41,7 +41,7 @@ peft_config = LoraConfig(
41
 
42
  # Initialize and train
43
  trainer = SFTTrainer(
44
- model="google/function-gemma-2b",
45
  train_dataset=dataset,
46
  peft_config=peft_config,
47
  args=config,
 
8
  import trackio
9
  import os
10
 
11
+ print("🚀 Starting FunctionGemma 270M Fine-tuning")
12
 
13
  # Load dataset
14
  dataset = load_dataset("epinfomax/vn-function-calling-dataset", split="train")
15
 
16
  # Training configuration
17
  config = SFTConfig(
18
+ output_dir="vn-function-gemma-270m-finetuned",
19
  push_to_hub=True,
20
+ hub_model_id="epinfomax/vn-function-gemma-270m-finetuned",
21
  hub_strategy="every_save",
22
+ num_train_epochs=5, # Increased epochs for the smaller model
23
+ per_device_train_batch_size=8, # Increased batch size for the smaller model
24
+ gradient_accumulation_steps=2,
25
+ learning_rate=5e-5, # Slightly higher LR for smaller model
26
+ logging_steps=5,
27
  save_strategy="steps",
28
  save_steps=50,
29
  report_to="trackio",
30
  project="vn-function-calling",
31
+ run_name="function-gemma-270m-v1"
32
  )
33
 
34
  # LoRA configuration
 
41
 
42
  # Initialize and train
43
  trainer = SFTTrainer(
44
+ model="google/functiongemma-270m-it",
45
  train_dataset=dataset,
46
  peft_config=peft_config,
47
  args=config,