Fix max_seq_length placement and torch_dtype deprecation
Browse files
train.py
CHANGED
|
@@ -63,7 +63,7 @@ print(f"Total conversations: {len(train_dataset)}")
|
|
| 63 |
# Tokenizer
|
| 64 |
# ------------------------------------------------------------------
|
| 65 |
print("Loading tokenizer...")
|
| 66 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID
|
| 67 |
if tokenizer.pad_token is None:
|
| 68 |
tokenizer.pad_token = tokenizer.eos_token
|
| 69 |
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
@@ -76,7 +76,6 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 76 |
MODEL_ID,
|
| 77 |
torch_dtype=torch.bfloat16,
|
| 78 |
device_map="auto",
|
| 79 |
-
trust_remote_code=True,
|
| 80 |
)
|
| 81 |
|
| 82 |
model.gradient_checkpointing_enable()
|
|
@@ -92,7 +91,6 @@ args = SFTConfig(
|
|
| 92 |
per_device_train_batch_size=1,
|
| 93 |
gradient_accumulation_steps=4,
|
| 94 |
learning_rate=2e-5,
|
| 95 |
-
max_seq_length=MAX_SEQ_LENGTH,
|
| 96 |
logging_strategy="steps",
|
| 97 |
logging_steps=10,
|
| 98 |
logging_first_step=True,
|
|
@@ -113,6 +111,7 @@ trainer = SFTTrainer(
|
|
| 113 |
args=args,
|
| 114 |
train_dataset=train_dataset,
|
| 115 |
processing_class=tokenizer,
|
|
|
|
| 116 |
)
|
| 117 |
|
| 118 |
# ------------------------------------------------------------------
|
|
|
|
| 63 |
# Tokenizer
|
| 64 |
# ------------------------------------------------------------------
|
| 65 |
print("Loading tokenizer...")
|
| 66 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 67 |
if tokenizer.pad_token is None:
|
| 68 |
tokenizer.pad_token = tokenizer.eos_token
|
| 69 |
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
|
|
| 76 |
MODEL_ID,
|
| 77 |
torch_dtype=torch.bfloat16,
|
| 78 |
device_map="auto",
|
|
|
|
| 79 |
)
|
| 80 |
|
| 81 |
model.gradient_checkpointing_enable()
|
|
|
|
| 91 |
per_device_train_batch_size=1,
|
| 92 |
gradient_accumulation_steps=4,
|
| 93 |
learning_rate=2e-5,
|
|
|
|
| 94 |
logging_strategy="steps",
|
| 95 |
logging_steps=10,
|
| 96 |
logging_first_step=True,
|
|
|
|
| 111 |
args=args,
|
| 112 |
train_dataset=train_dataset,
|
| 113 |
processing_class=tokenizer,
|
| 114 |
+
max_seq_length=MAX_SEQ_LENGTH,
|
| 115 |
)
|
| 116 |
|
| 117 |
# ------------------------------------------------------------------
|