Rulga commited on
Commit
2984e21
·
1 Parent(s): 015ac8c

Refactor training method in FineTuner: Update training_data_path to be required, enhance gradient accumulation steps, and improve documentation for training parameters.

Browse files
Files changed (1) hide show
  1. src/training/fine_tuner.py +19 -26
src/training/fine_tuner.py CHANGED
@@ -276,35 +276,35 @@ def finetune_from_annotations(epochs=3, batch_size=4, learning_rate=2e-4, min_ra
276
 
277
  def train(
278
  self,
279
- training_data_path: Optional[str] = None,
280
  num_train_epochs: int = 3,
281
  per_device_train_batch_size: int = 4,
282
- gradient_accumulation_steps: int = 4,
283
  learning_rate: float = 2e-4,
284
  logging_steps: int = 10,
285
  save_strategy: str = "epoch"
286
  ) -> Tuple[bool, str]:
287
  """
288
- Start model fine-tuning process
 
 
 
 
 
 
 
 
 
 
 
 
289
  """
290
  try:
291
- # Prepare training data if path not specified
292
- if training_data_path is None:
293
- training_data_path = self.prepare_training_data()
294
- temp_data = True
295
- else:
296
- temp_data = False
297
-
298
- # Load model and tokenizer if not loaded
299
- if self.model is None or self.tokenizer is None:
300
- self.load_model_and_tokenizer()
301
-
302
  # Prepare model for training
303
  self.prepare_model_for_training()
304
 
305
  # Load dataset
306
- dataset = load_dataset("json", data_files=training_data_path, split="train")
307
- logger.info(f"Loaded {len(dataset)} examples from {training_data_path}")
308
 
309
  # Tokenize dataset
310
  tokenized_dataset = self.tokenize_dataset(dataset)
@@ -343,22 +343,15 @@ def finetune_from_annotations(epochs=3, batch_size=4, learning_rate=2e-4, min_ra
343
  )
344
 
345
  # Start training
346
- logger.info("Starting model training...")
347
  trainer.train()
348
 
349
  # Save model
350
- logger.info(f"Saving trained model to {self.output_dir}")
351
- trainer.save_model(self.output_dir)
352
- self.tokenizer.save_pretrained(self.output_dir)
353
-
354
- # Remove temporary file if created
355
- if temp_data and os.path.exists(training_data_path):
356
- os.remove(training_data_path)
357
 
358
  return True, f"Model successfully trained and saved to {self.output_dir}"
 
359
  except Exception as e:
360
- logger.error(f"Error during training: {str(e)}")
361
- return False, f"Error during training: {str(e)}"
362
 
363
  def upload_model_to_hub(
364
  self,
 
276
 
277
  def train(
278
  self,
279
+ training_data_path: str,
280
  num_train_epochs: int = 3,
281
  per_device_train_batch_size: int = 4,
282
+ gradient_accumulation_steps: int = 8,
283
  learning_rate: float = 2e-4,
284
  logging_steps: int = 10,
285
  save_strategy: str = "epoch"
286
  ) -> Tuple[bool, str]:
287
  """
288
+ Train the model using provided data
289
+
290
+ Args:
291
+ training_data_path: Path to training data file
292
+ num_train_epochs: Number of training epochs
293
+ per_device_train_batch_size: Batch size per device
294
+ gradient_accumulation_steps: Number of steps to accumulate gradients
295
+ learning_rate: Learning rate
296
+ logging_steps: Number of steps between logging
297
+ save_strategy: When to save checkpoints
298
+
299
+ Returns:
300
+ (success, message)
301
  """
302
  try:
 
 
 
 
 
 
 
 
 
 
 
303
  # Prepare model for training
304
  self.prepare_model_for_training()
305
 
306
  # Load dataset
307
+ dataset = load_dataset('json', data_files=training_data_path)['train']
 
308
 
309
  # Tokenize dataset
310
  tokenized_dataset = self.tokenize_dataset(dataset)
 
343
  )
344
 
345
  # Start training
 
346
  trainer.train()
347
 
348
  # Save model
349
+ trainer.save_model()
 
 
 
 
 
 
350
 
351
  return True, f"Model successfully trained and saved to {self.output_dir}"
352
+
353
  except Exception as e:
354
+ return False, f"Training failed: {str(e)}"
 
355
 
356
  def upload_model_to_hub(
357
  self,