ceperaltab commited on
Commit
3bbfe8c
·
verified ·
1 Parent(s): 1b5e5e7

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +4 -3
train.py CHANGED
@@ -68,14 +68,15 @@ def main():
68
  # 7. Training Arguments (TRL v0.8.6 uses TrainingArguments from transformers)
69
  training_args = TrainingArguments(
70
  output_dir=OUTPUT_DIR,
71
- per_device_train_batch_size=2,
72
- gradient_accumulation_steps=4,
73
  learning_rate=2e-4,
74
  logging_steps=10,
75
  num_train_epochs=1,
76
  optim="paged_adamw_32bit",
77
  fp16=True,
78
  group_by_length=True,
 
79
  save_strategy="epoch",
80
  report_to="none",
81
  push_to_hub=True,
@@ -88,7 +89,7 @@ def main():
88
  train_dataset=dataset,
89
  peft_config=peft_config,
90
  formatting_func=formatting_prompts_func,
91
- max_seq_length=2048,
92
  tokenizer=tokenizer,
93
  args=training_args,
94
  )
 
68
  # 7. Training Arguments (TRL v0.8.6 uses TrainingArguments from transformers)
69
  training_args = TrainingArguments(
70
  output_dir=OUTPUT_DIR,
71
+ per_device_train_batch_size=1,
72
+ gradient_accumulation_steps=8, # Compensate for smaller batch
73
  learning_rate=2e-4,
74
  logging_steps=10,
75
  num_train_epochs=1,
76
  optim="paged_adamw_32bit",
77
  fp16=True,
78
  group_by_length=True,
79
+ gradient_checkpointing=True, # Save memory
80
  save_strategy="epoch",
81
  report_to="none",
82
  push_to_hub=True,
 
89
  train_dataset=dataset,
90
  peft_config=peft_config,
91
  formatting_func=formatting_prompts_func,
92
+ max_seq_length=1024, # Reduced for T4 memory
93
  tokenizer=tokenizer,
94
  args=training_args,
95
  )