ceperaltab commited on
Commit
5ca3265
·
verified ·
1 Parent(s): fe938a1

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +17 -8
train.py CHANGED
@@ -5,9 +5,10 @@ from transformers import (
5
  AutoModelForCausalLM,
6
  AutoTokenizer,
7
  BitsAndBytesConfig,
 
8
  )
9
  from peft import LoraConfig
10
- from trl import SFTTrainer, SFTConfig
11
 
12
  # --- CONFIGURATION ---
13
  MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct"
@@ -54,12 +55,19 @@ def main():
54
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
55
  )
56
 
 
 
 
 
 
 
 
 
57
  print("Starting SFTTrainer setup...")
58
 
59
- # 6. SFTConfig (new TRL API - all config goes here)
60
- sft_config = SFTConfig(
61
  output_dir=OUTPUT_DIR,
62
- max_seq_length=2048,
63
  per_device_train_batch_size=2,
64
  gradient_accumulation_steps=4,
65
  learning_rate=2e-4,
@@ -74,14 +82,15 @@ def main():
74
  hub_model_id=f"ceperaltab/{OUTPUT_DIR}",
75
  )
76
 
77
- # 7. Trainer - new API: use processing_class instead of tokenizer
78
- # The trainer automatically handles conversational datasets with "messages" field
79
  trainer = SFTTrainer(
80
  model=model,
81
- args=sft_config,
82
  train_dataset=dataset,
83
- processing_class=tokenizer, # New API: processing_class replaces tokenizer
84
  peft_config=peft_config,
 
 
 
 
85
  )
86
 
87
  print("Starting training...")
 
5
  AutoModelForCausalLM,
6
  AutoTokenizer,
7
  BitsAndBytesConfig,
8
+ TrainingArguments,
9
  )
10
  from peft import LoraConfig
11
+ from trl import SFTTrainer
12
 
13
  # --- CONFIGURATION ---
14
  MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct"
 
55
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
56
  )
57
 
58
+ # 6. Formatting Function for Chat Dataset (TRL v0.8.6 API)
59
+ def formatting_prompts_func(examples):
60
+ output_texts = []
61
+ for messages in examples['messages']:
62
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
63
+ output_texts.append(text)
64
+ return output_texts
65
+
66
  print("Starting SFTTrainer setup...")
67
 
68
+ # 7. Training Arguments (TRL v0.8.6 uses TrainingArguments from transformers)
69
+ training_args = TrainingArguments(
70
  output_dir=OUTPUT_DIR,
 
71
  per_device_train_batch_size=2,
72
  gradient_accumulation_steps=4,
73
  learning_rate=2e-4,
 
82
  hub_model_id=f"ceperaltab/{OUTPUT_DIR}",
83
  )
84
 
85
+ # 8. SFTTrainer (TRL v0.8.6 API)
 
86
  trainer = SFTTrainer(
87
  model=model,
 
88
  train_dataset=dataset,
 
89
  peft_config=peft_config,
90
+ formatting_func=formatting_prompts_func,
91
+ max_seq_length=2048,
92
+ tokenizer=tokenizer,
93
+ args=training_args,
94
  )
95
 
96
  print("Starting training...")