Spaces:
Sleeping
Sleeping
Update train_model.py
Browse files- train_model.py +31 -10
train_model.py
CHANGED
|
@@ -67,8 +67,20 @@ def load_and_prepare_dataset(task, dataset_name, tokenizer, sequence_length):
|
|
| 67 |
logging.info("Dataset loaded successfully.")
|
| 68 |
|
| 69 |
def tokenize_function(examples):
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
# Tokenize the dataset using the modified tokenize_function
|
| 74 |
tokenized_datasets = dataset.shuffle(seed=42).select(range(500)).map(tokenize_function, batched=True)
|
|
@@ -182,7 +194,6 @@ def main():
|
|
| 182 |
logging.error(f"Error initializing tokenizer or model: {str(e)}")
|
| 183 |
raise e
|
| 184 |
|
| 185 |
-
# Load and prepare dataset
|
| 186 |
# Load and prepare dataset
|
| 187 |
try:
|
| 188 |
tokenized_datasets = load_and_prepare_dataset(
|
|
@@ -194,26 +205,38 @@ def main():
|
|
| 194 |
except Exception as e:
|
| 195 |
logging.error("Failed to load and prepare dataset.")
|
| 196 |
raise e
|
| 197 |
-
|
| 198 |
# Define data collator
|
| 199 |
if args.task == "generation":
|
| 200 |
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
| 201 |
elif args.task == "classification":
|
| 202 |
-
data_collator = DataCollatorWithPadding(tokenizer=tokenizer
|
| 203 |
else:
|
| 204 |
logging.error("Unsupported task type for data collator.")
|
| 205 |
raise ValueError("Unsupported task type for data collator.")
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
# Initialize Trainer with the data collator
|
| 208 |
trainer = Trainer(
|
| 209 |
model=model,
|
| 210 |
args=training_args,
|
| 211 |
train_dataset=tokenized_datasets,
|
| 212 |
data_collator=data_collator,
|
| 213 |
-
optimizers=(get_optimizer(model, training_args.learning_rate), None)
|
| 214 |
)
|
| 215 |
|
| 216 |
-
|
| 217 |
# Start training
|
| 218 |
logging.info("Starting training...")
|
| 219 |
try:
|
|
@@ -253,5 +276,3 @@ def main():
|
|
| 253 |
|
| 254 |
if __name__ == "__main__":
|
| 255 |
main()
|
| 256 |
-
|
| 257 |
-
|
|
|
|
| 67 |
logging.info("Dataset loaded successfully.")
|
| 68 |
|
| 69 |
def tokenize_function(examples):
|
| 70 |
+
try:
|
| 71 |
+
# Tokenize with truncation, defer padding to DataCollator
|
| 72 |
+
tokens = tokenizer(
|
| 73 |
+
examples['text'],
|
| 74 |
+
truncation=True,
|
| 75 |
+
max_length=sequence_length, # Set maximum length
|
| 76 |
+
padding=False, # Padding will be handled by the DataCollatorWithPadding
|
| 77 |
+
return_tensors=None # Let the DataCollator handle tensor creation
|
| 78 |
+
)
|
| 79 |
+
return tokens
|
| 80 |
+
except Exception as e:
|
| 81 |
+
logging.error(f"Error during tokenization: {e}")
|
| 82 |
+
logging.error(f"Example data: {examples}")
|
| 83 |
+
raise e
|
| 84 |
|
| 85 |
# Tokenize the dataset using the modified tokenize_function
|
| 86 |
tokenized_datasets = dataset.shuffle(seed=42).select(range(500)).map(tokenize_function, batched=True)
|
|
|
|
| 194 |
logging.error(f"Error initializing tokenizer or model: {str(e)}")
|
| 195 |
raise e
|
| 196 |
|
|
|
|
| 197 |
# Load and prepare dataset
|
| 198 |
try:
|
| 199 |
tokenized_datasets = load_and_prepare_dataset(
|
|
|
|
| 205 |
except Exception as e:
|
| 206 |
logging.error("Failed to load and prepare dataset.")
|
| 207 |
raise e
|
| 208 |
+
|
| 209 |
# Define data collator
|
| 210 |
if args.task == "generation":
|
| 211 |
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
| 212 |
elif args.task == "classification":
|
| 213 |
+
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) # Dynamic padding during batch creation
|
| 214 |
else:
|
| 215 |
logging.error("Unsupported task type for data collator.")
|
| 216 |
raise ValueError("Unsupported task type for data collator.")
|
| 217 |
+
|
| 218 |
+
# Define training arguments
|
| 219 |
+
training_args = TrainingArguments(
|
| 220 |
+
output_dir=f"./models/{args.model_name}",
|
| 221 |
+
num_train_epochs=3,
|
| 222 |
+
per_device_train_batch_size=8 if args.task == "generation" else 16,
|
| 223 |
+
save_steps=5000,
|
| 224 |
+
save_total_limit=2,
|
| 225 |
+
logging_steps=500,
|
| 226 |
+
learning_rate=5e-4 if args.task == "generation" else 5e-5,
|
| 227 |
+
remove_unused_columns=False,
|
| 228 |
+
push_to_hub=False
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
# Initialize Trainer with the data collator
|
| 232 |
trainer = Trainer(
|
| 233 |
model=model,
|
| 234 |
args=training_args,
|
| 235 |
train_dataset=tokenized_datasets,
|
| 236 |
data_collator=data_collator,
|
| 237 |
+
optimizers=(get_optimizer(model, training_args.learning_rate), None)
|
| 238 |
)
|
| 239 |
|
|
|
|
| 240 |
# Start training
|
| 241 |
logging.info("Starting training...")
|
| 242 |
try:
|
|
|
|
| 276 |
|
| 277 |
if __name__ == "__main__":
|
| 278 |
main()
|
|
|
|
|
|