epinfomax commited on
Commit
ef5974f
·
verified ·
1 Parent(s): 4c821d6

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +27 -7
train.py CHANGED
@@ -1,34 +1,54 @@
1
  # /// script
2
- # dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "datasets", "transformers", "accelerate"]
3
  # ///
4
 
5
  from datasets import load_dataset
6
  from peft import LoraConfig
7
  from trl import SFTTrainer, SFTConfig
 
8
  import trackio
9
  import os
10
 
11
- print("🚀 Starting FunctionGemma 270M Fine-tuning")
 
 
 
12
 
13
  # Load dataset
14
  dataset = load_dataset("epinfomax/vn-function-calling-dataset", split="train")
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  # Training configuration
17
  config = SFTConfig(
 
 
18
  output_dir="vn-function-gemma-270m-finetuned",
19
  push_to_hub=True,
20
  hub_model_id="epinfomax/vn-function-gemma-270m-finetuned",
21
  hub_strategy="every_save",
22
- num_train_epochs=5, # Increased epochs for the smaller model
23
- per_device_train_batch_size=8, # Increased batch size for the smaller model
24
  gradient_accumulation_steps=2,
25
- learning_rate=5e-5, # Slightly higher LR for smaller model
26
  logging_steps=5,
27
  save_strategy="steps",
28
  save_steps=50,
29
  report_to="trackio",
30
  project="vn-function-calling",
31
- run_name="function-gemma-270m-v1"
32
  )
33
 
34
  # LoRA configuration
@@ -41,7 +61,7 @@ peft_config = LoraConfig(
41
 
42
  # Initialize and train
43
  trainer = SFTTrainer(
44
- model="google/functiongemma-270m-it",
45
  train_dataset=dataset,
46
  peft_config=peft_config,
47
  args=config,
 
1
  # /// script
2
+ # dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "datasets", "transformers", "accelerate", "jinja2"]
3
  # ///
4
 
5
  from datasets import load_dataset
6
  from peft import LoraConfig
7
  from trl import SFTTrainer, SFTConfig
8
+ from transformers import AutoTokenizer
9
  import trackio
10
  import os
11
 
12
+ print("🚀 Starting FunctionGemma 270M Fine-tuning (V2 with Template Fix)")
13
+
14
+ model_id = "google/functiongemma-270m-it"
15
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
16
 
17
  # Load dataset
18
  dataset = load_dataset("epinfomax/vn-function-calling-dataset", split="train")
19
 
20
+ def format_conversation(example):
21
+ # Modern transformers template supports 'tools' argument
22
+ # We render the template to a string so SFTTrainer doesn't have to guess
23
+ text = tokenizer.apply_chat_template(
24
+ example["messages"],
25
+ tools=example["tools"],
26
+ tokenize=False,
27
+ add_generation_prompt=False
28
+ )
29
+ return {"text": text}
30
+
31
+ print("🔄 Pre-processing dataset with chat template...")
32
+ dataset = dataset.map(format_conversation, remove_columns=dataset.column_names)
33
+
34
  # Training configuration
35
  config = SFTConfig(
36
+ dataset_text_field="text", # Use the pre-rendered text
37
+ max_seq_length=1024,
38
  output_dir="vn-function-gemma-270m-finetuned",
39
  push_to_hub=True,
40
  hub_model_id="epinfomax/vn-function-gemma-270m-finetuned",
41
  hub_strategy="every_save",
42
+ num_train_epochs=5,
43
+ per_device_train_batch_size=8,
44
  gradient_accumulation_steps=2,
45
+ learning_rate=5e-5,
46
  logging_steps=5,
47
  save_strategy="steps",
48
  save_steps=50,
49
  report_to="trackio",
50
  project="vn-function-calling",
51
+ run_name="function-gemma-270m-v2-fixed"
52
  )
53
 
54
  # LoRA configuration
 
61
 
62
  # Initialize and train
63
  trainer = SFTTrainer(
64
+ model=model_id,
65
  train_dataset=dataset,
66
  peft_config=peft_config,
67
  args=config,