lightita commited on
Commit
15236b2
·
verified ·
1 Parent(s): 3628aa5

Update train_seallm_khm_sum.py

Browse files
Files changed (1) hide show
  1. train_seallm_khm_sum.py +10 -10
train_seallm_khm_sum.py CHANGED
@@ -5,8 +5,9 @@ from transformers import (
5
  AutoTokenizer,
6
  AutoModelForCausalLM,
7
  BitsAndBytesConfig,
 
8
  )
9
- from trl import SFTTrainer, SFTConfig
10
  from peft import LoraConfig
11
 
12
  MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B"
@@ -101,8 +102,8 @@ def main():
101
  task_type="CAUSAL_LM",
102
  )
103
 
104
- # NOTE: no max_seq_length here
105
- sft_config = SFTConfig(
106
  output_dir="seallm-khm-sum-lora",
107
  num_train_epochs=2,
108
  per_device_train_batch_size=2,
@@ -110,16 +111,14 @@ def main():
110
  gradient_accumulation_steps=8,
111
  learning_rate=2e-4,
112
  logging_steps=10,
113
- evaluation_strategy="steps", # <- was eval_strategy
114
  eval_steps=200,
115
  save_steps=200,
116
  save_total_limit=2,
117
- packing=True,
118
  lr_scheduler_type="cosine",
119
  warmup_ratio=0.03,
120
- bf16=True,
121
- gradient_checkpointing=True,
122
- report_to="none", # or "wandb"
123
  )
124
 
125
  trainer = SFTTrainer(
@@ -128,9 +127,10 @@ def main():
128
  train_dataset=train_ds,
129
  eval_dataset=eval_ds,
130
  peft_config=lora_config,
131
- args=sft_config,
132
  dataset_text_field="text",
133
- max_seq_length=1024, # <- moved here
 
134
  )
135
 
136
  trainer.train()
 
5
  AutoTokenizer,
6
  AutoModelForCausalLM,
7
  BitsAndBytesConfig,
8
+ TrainingArguments,
9
  )
10
+ from trl import SFTTrainer
11
  from peft import LoraConfig
12
 
13
  MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B"
 
102
  task_type="CAUSAL_LM",
103
  )
104
 
105
+ # Use standard TrainingArguments instead of SFTConfig
106
+ training_args = TrainingArguments(
107
  output_dir="seallm-khm-sum-lora",
108
  num_train_epochs=2,
109
  per_device_train_batch_size=2,
 
111
  gradient_accumulation_steps=8,
112
  learning_rate=2e-4,
113
  logging_steps=10,
114
+ evaluation_strategy="steps", # eval every eval_steps
115
  eval_steps=200,
116
  save_steps=200,
117
  save_total_limit=2,
 
118
  lr_scheduler_type="cosine",
119
  warmup_ratio=0.03,
120
+ bf16=True, # ok on modern GPUs; set False if it crashes
121
+ report_to="none", # or "wandb"
 
122
  )
123
 
124
  trainer = SFTTrainer(
 
127
  train_dataset=train_ds,
128
  eval_dataset=eval_ds,
129
  peft_config=lora_config,
130
+ args=training_args,
131
  dataset_text_field="text",
132
+ max_seq_length=1024, # set here instead of in config
133
+ # packing=False, # keep off for compatibility
134
  )
135
 
136
  trainer.train()