epinfomax commited on
Commit
e3473e4
Β·
verified Β·
1 Parent(s): 994600a

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +32 -58
train.py CHANGED
@@ -5,16 +5,11 @@
5
  from datasets import load_dataset
6
  from peft import LoraConfig
7
  from trl import SFTTrainer, SFTConfig
8
- from transformers import AutoTokenizer, TrainingArguments
9
- import trl
10
- import transformers
11
  import trackio
12
  import os
13
- import inspect
14
 
15
- print(f"πŸš€ Starting FunctionGemma 270M Fine-tuning (V4 - Diagnostic)")
16
- print(f"πŸ“¦ TRL Version: {trl.__version__}")
17
- print(f"πŸ“¦ Transformers Version: {transformers.__version__}")
18
 
19
  model_id = "google/functiongemma-270m-it"
20
  tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -23,6 +18,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
23
  dataset = load_dataset("epinfomax/vn-function-calling-dataset", split="train")
24
 
25
  def format_conversation(example):
 
26
  text = tokenizer.apply_chat_template(
27
  example["messages"],
28
  tools=example["tools"],
@@ -34,64 +30,42 @@ def format_conversation(example):
34
  print("πŸ”„ Pre-processing dataset with chat template...")
35
  dataset = dataset.map(format_conversation, remove_columns=dataset.column_names)
36
 
37
- # Training configuration
38
- # Trying max_seq_length again but checking if it exists in SFTConfig first
39
- sft_config_args = {
40
- "dataset_text_field": "text",
41
- "output_dir": "vn-function-gemma-270m-finetuned",
42
- "push_to_hub": True,
43
- "hub_model_id": "epinfomax/vn-function-gemma-270m-finetuned",
44
- "hub_strategy": "every_save",
45
- "num_train_epochs": 5,
46
- "per_device_train_batch_size": 4,
47
- "gradient_accumulation_steps": 4,
48
- "learning_rate": 5e-5,
49
- "logging_steps": 5,
50
- "save_strategy": "steps",
51
- "save_steps": 50,
52
- "report_to": "trackio",
53
- "project": "vn-function-calling",
54
- "run_name": "function-gemma-270m-v4-diag"
55
- }
56
-
57
- # Check which parameter to use
58
- sft_fields = SFTConfig.__dataclass_fields__
59
- if "max_seq_length" in sft_fields:
60
- print("βœ… Using max_seq_length in SFTConfig")
61
- sft_config_args["max_seq_length"] = 1024
62
- elif "max_length" in sft_fields:
63
- print("βœ… Using max_length in SFTConfig")
64
- sft_config_args["max_length"] = 1024
65
- else:
66
- print("⚠️ Neither max_seq_length nor max_length found in SFTConfig fields!")
67
- print("Fields:", list(sft_fields.keys()))
68
-
69
- config = SFTConfig(**sft_config_args)
70
-
71
- # Initialize and train
72
- print("🎯 Initializing SFTTrainer...")
73
- trainer_kwargs = {
74
- "model": model_id,
75
- "train_dataset": dataset,
76
- "peft_config": peft_config,
77
- "args": config,
78
- }
79
-
80
- # Check SFTTrainer init signature
81
- trainer_params = inspect.signature(SFTTrainer.__init__).parameters
82
- if "max_seq_length" in trainer_params and "max_seq_length" not in sft_config_args:
83
- print("βœ… Adding max_seq_length to SFTTrainer")
84
- trainer_kwargs["max_seq_length"] = 1024
85
-
86
  peft_config = LoraConfig(
87
  r=16,
88
  lora_alpha=32,
89
  target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
90
  task_type="CAUSAL_LM",
91
  )
92
- trainer_kwargs["peft_config"] = peft_config
93
 
94
- trainer = SFTTrainer(**trainer_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  trainer.train()
97
  trainer.push_to_hub()
 
5
  from datasets import load_dataset
6
  from peft import LoraConfig
7
  from trl import SFTTrainer, SFTConfig
8
+ from transformers import AutoTokenizer
 
 
9
  import trackio
10
  import os
 
11
 
12
+ print("πŸš€ Starting FunctionGemma 270M Fine-tuning (V5 - Final)")
 
 
13
 
14
  model_id = "google/functiongemma-270m-it"
15
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
18
  dataset = load_dataset("epinfomax/vn-function-calling-dataset", split="train")
19
 
20
  def format_conversation(example):
21
+ # Pre-render the conversation using the model's chat template
22
  text = tokenizer.apply_chat_template(
23
  example["messages"],
24
  tools=example["tools"],
 
30
  print("πŸ”„ Pre-processing dataset with chat template...")
31
  dataset = dataset.map(format_conversation, remove_columns=dataset.column_names)
32
 
33
+ # LoRA configuration - Define early to avoid NameError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  peft_config = LoraConfig(
35
  r=16,
36
  lora_alpha=32,
37
  target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
38
  task_type="CAUSAL_LM",
39
  )
 
40
 
41
+ # Training configuration (TRL 0.26.2 style)
42
+ config = SFTConfig(
43
+ dataset_text_field="text",
44
+ max_length=1024, # Confirmed correct for TRL 0.26.2
45
+ output_dir="vn-function-gemma-270m-finetuned",
46
+ push_to_hub=True,
47
+ hub_model_id="epinfomax/vn-function-gemma-270m-finetuned",
48
+ hub_strategy="every_save",
49
+ num_train_epochs=5,
50
+ per_device_train_batch_size=4,
51
+ gradient_accumulation_steps=4,
52
+ learning_rate=5e-5,
53
+ logging_steps=5,
54
+ save_strategy="steps",
55
+ save_steps=50,
56
+ report_to="trackio",
57
+ project="vn-function-calling",
58
+ run_name="function-gemma-270m-final"
59
+ )
60
+
61
+ # Initialize and train
62
+ print("🎯 Initializing SFTTrainer...")
63
+ trainer = SFTTrainer(
64
+ model=model_id,
65
+ train_dataset=dataset,
66
+ peft_config=peft_config,
67
+ args=config,
68
+ )
69
 
70
  trainer.train()
71
  trainer.push_to_hub()