Spaces:
Running
Running
Refactor training method in FineTuner: Update training_data_path to be required, enhance gradient accumulation steps, and improve documentation for training parameters.
Browse files- src/training/fine_tuner.py +19 -26
src/training/fine_tuner.py
CHANGED
|
@@ -276,35 +276,35 @@ def finetune_from_annotations(epochs=3, batch_size=4, learning_rate=2e-4, min_ra
|
|
| 276 |
|
| 277 |
def train(
|
| 278 |
self,
|
| 279 |
-
training_data_path:
|
| 280 |
num_train_epochs: int = 3,
|
| 281 |
per_device_train_batch_size: int = 4,
|
| 282 |
-
gradient_accumulation_steps: int =
|
| 283 |
learning_rate: float = 2e-4,
|
| 284 |
logging_steps: int = 10,
|
| 285 |
save_strategy: str = "epoch"
|
| 286 |
) -> Tuple[bool, str]:
|
| 287 |
"""
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
"""
|
| 290 |
try:
|
| 291 |
-
# Prepare training data if path not specified
|
| 292 |
-
if training_data_path is None:
|
| 293 |
-
training_data_path = self.prepare_training_data()
|
| 294 |
-
temp_data = True
|
| 295 |
-
else:
|
| 296 |
-
temp_data = False
|
| 297 |
-
|
| 298 |
-
# Load model and tokenizer if not loaded
|
| 299 |
-
if self.model is None or self.tokenizer is None:
|
| 300 |
-
self.load_model_and_tokenizer()
|
| 301 |
-
|
| 302 |
# Prepare model for training
|
| 303 |
self.prepare_model_for_training()
|
| 304 |
|
| 305 |
# Load dataset
|
| 306 |
-
dataset = load_dataset(
|
| 307 |
-
logger.info(f"Loaded {len(dataset)} examples from {training_data_path}")
|
| 308 |
|
| 309 |
# Tokenize dataset
|
| 310 |
tokenized_dataset = self.tokenize_dataset(dataset)
|
|
@@ -343,22 +343,15 @@ def finetune_from_annotations(epochs=3, batch_size=4, learning_rate=2e-4, min_ra
|
|
| 343 |
)
|
| 344 |
|
| 345 |
# Start training
|
| 346 |
-
logger.info("Starting model training...")
|
| 347 |
trainer.train()
|
| 348 |
|
| 349 |
# Save model
|
| 350 |
-
|
| 351 |
-
trainer.save_model(self.output_dir)
|
| 352 |
-
self.tokenizer.save_pretrained(self.output_dir)
|
| 353 |
-
|
| 354 |
-
# Remove temporary file if created
|
| 355 |
-
if temp_data and os.path.exists(training_data_path):
|
| 356 |
-
os.remove(training_data_path)
|
| 357 |
|
| 358 |
return True, f"Model successfully trained and saved to {self.output_dir}"
|
|
|
|
| 359 |
except Exception as e:
|
| 360 |
-
|
| 361 |
-
return False, f"Error during training: {str(e)}"
|
| 362 |
|
| 363 |
def upload_model_to_hub(
|
| 364 |
self,
|
|
|
|
| 276 |
|
| 277 |
def train(
|
| 278 |
self,
|
| 279 |
+
training_data_path: str,
|
| 280 |
num_train_epochs: int = 3,
|
| 281 |
per_device_train_batch_size: int = 4,
|
| 282 |
+
gradient_accumulation_steps: int = 8,
|
| 283 |
learning_rate: float = 2e-4,
|
| 284 |
logging_steps: int = 10,
|
| 285 |
save_strategy: str = "epoch"
|
| 286 |
) -> Tuple[bool, str]:
|
| 287 |
"""
|
| 288 |
+
Train the model using provided data
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
training_data_path: Path to training data file
|
| 292 |
+
num_train_epochs: Number of training epochs
|
| 293 |
+
per_device_train_batch_size: Batch size per device
|
| 294 |
+
gradient_accumulation_steps: Number of steps to accumulate gradients
|
| 295 |
+
learning_rate: Learning rate
|
| 296 |
+
logging_steps: Number of steps between logging
|
| 297 |
+
save_strategy: When to save checkpoints
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
(success, message)
|
| 301 |
"""
|
| 302 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
# Prepare model for training
|
| 304 |
self.prepare_model_for_training()
|
| 305 |
|
| 306 |
# Load dataset
|
| 307 |
+
dataset = load_dataset('json', data_files=training_data_path)['train']
|
|
|
|
| 308 |
|
| 309 |
# Tokenize dataset
|
| 310 |
tokenized_dataset = self.tokenize_dataset(dataset)
|
|
|
|
| 343 |
)
|
| 344 |
|
| 345 |
# Start training
|
|
|
|
| 346 |
trainer.train()
|
| 347 |
|
| 348 |
# Save model
|
| 349 |
+
trainer.save_model()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
return True, f"Model successfully trained and saved to {self.output_dir}"
|
| 352 |
+
|
| 353 |
except Exception as e:
|
| 354 |
+
return False, f"Training failed: {str(e)}"
|
|
|
|
| 355 |
|
| 356 |
def upload_model_to_hub(
|
| 357 |
self,
|