ligaments-dev commited on
Commit
cb6f87d
·
verified ·
1 Parent(s): 1642b2a

Fix max_seq_length placement and torch_dtype deprecation

Browse files
Files changed (1) hide show
  1. train.py +2 -3
train.py CHANGED
@@ -63,7 +63,7 @@ print(f"Total conversations: {len(train_dataset)}")
63
  # Tokenizer
64
  # ------------------------------------------------------------------
65
  print("Loading tokenizer...")
66
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
67
  if tokenizer.pad_token is None:
68
  tokenizer.pad_token = tokenizer.eos_token
69
  tokenizer.pad_token_id = tokenizer.eos_token_id
@@ -76,7 +76,6 @@ model = AutoModelForCausalLM.from_pretrained(
76
  MODEL_ID,
77
  torch_dtype=torch.bfloat16,
78
  device_map="auto",
79
- trust_remote_code=True,
80
  )
81
 
82
  model.gradient_checkpointing_enable()
@@ -92,7 +91,6 @@ args = SFTConfig(
92
  per_device_train_batch_size=1,
93
  gradient_accumulation_steps=4,
94
  learning_rate=2e-5,
95
- max_seq_length=MAX_SEQ_LENGTH,
96
  logging_strategy="steps",
97
  logging_steps=10,
98
  logging_first_step=True,
@@ -113,6 +111,7 @@ trainer = SFTTrainer(
113
  args=args,
114
  train_dataset=train_dataset,
115
  processing_class=tokenizer,
 
116
  )
117
 
118
  # ------------------------------------------------------------------
 
63
  # Tokenizer
64
  # ------------------------------------------------------------------
65
  print("Loading tokenizer...")
66
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
67
  if tokenizer.pad_token is None:
68
  tokenizer.pad_token = tokenizer.eos_token
69
  tokenizer.pad_token_id = tokenizer.eos_token_id
 
76
  MODEL_ID,
77
  torch_dtype=torch.bfloat16,
78
  device_map="auto",
 
79
  )
80
 
81
  model.gradient_checkpointing_enable()
 
91
  per_device_train_batch_size=1,
92
  gradient_accumulation_steps=4,
93
  learning_rate=2e-5,
 
94
  logging_strategy="steps",
95
  logging_steps=10,
96
  logging_first_step=True,
 
111
  args=args,
112
  train_dataset=train_dataset,
113
  processing_class=tokenizer,
114
+ max_seq_length=MAX_SEQ_LENGTH,
115
  )
116
 
117
  # ------------------------------------------------------------------