sunkencity commited on
Commit
29ee62e
·
verified ·
1 Parent(s): 3468e66

Upload train_survival.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_survival.py +5 -3
train_survival.py CHANGED
@@ -44,6 +44,7 @@ peft_config = LoraConfig(
44
  )
45
 
46
  # Training Arguments
 
47
  training_args = SFTConfig(
48
  output_dir="./results",
49
  num_train_epochs=3,
@@ -54,8 +55,7 @@ training_args = SFTConfig(
54
  push_to_hub=True,
55
  hub_model_id=OUTPUT_MODEL_ID,
56
  fp16=True,
57
- max_seq_length=1024,
58
- dataset_text_field="text", # We need to format the data first if it's not in 'text'
59
  packing=False
60
  )
61
 
@@ -72,6 +72,7 @@ def formatting_prompts_func(example):
72
  return output_texts
73
 
74
  # Trainer
 
75
  trainer = SFTTrainer(
76
  model=model,
77
  train_dataset=dataset,
@@ -79,6 +80,7 @@ trainer = SFTTrainer(
79
  formatting_func=formatting_prompts_func,
80
  args=training_args,
81
  tokenizer=tokenizer,
 
82
  )
83
 
84
  print("Starting training...")
@@ -86,4 +88,4 @@ trainer.train()
86
 
87
  print("Pushing to hub...")
88
  trainer.push_to_hub()
89
- print("Done!")
 
44
  )
45
 
46
  # Training Arguments
47
+ # Removed max_seq_length from SFTConfig as it caused a TypeError
48
  training_args = SFTConfig(
49
  output_dir="./results",
50
  num_train_epochs=3,
 
55
  push_to_hub=True,
56
  hub_model_id=OUTPUT_MODEL_ID,
57
  fp16=True,
58
+ dataset_text_field="text",
 
59
  packing=False
60
  )
61
 
 
72
  return output_texts
73
 
74
  # Trainer
75
+ # max_seq_length is passed to SFTTrainer directly
76
  trainer = SFTTrainer(
77
  model=model,
78
  train_dataset=dataset,
 
80
  formatting_func=formatting_prompts_func,
81
  args=training_args,
82
  tokenizer=tokenizer,
83
+ max_seq_length=1024
84
  )
85
 
86
  print("Starting training...")
 
88
 
89
  print("Pushing to hub...")
90
  trainer.push_to_hub()
91
+ print("Done!")