Rulga commited on
Commit
6e1f5f4
·
1 Parent(s): 72f65c8

Refactor fine-tuning process to include batch size and learning rate parameters in finetune_from_chat_history function

Browse files
src/training/fine_tuner.py CHANGED
@@ -386,12 +386,16 @@ def finetune_from_annotations(epochs=3, batch_size=4, learning_rate=2e-4, min_ra
386
  except Exception as e:
387
  return False, f"Error uploading model to Hub: {str(e)}"
388
 
389
- def finetune_from_chat_history(epochs: int = 3) -> Tuple[bool, str]:
 
 
390
  """
391
  Function to start fine-tuning process based on chat history
392
 
393
  Args:
394
  epochs: Number of training epochs
 
 
395
 
396
  Returns:
397
  (success, message)
@@ -406,7 +410,11 @@ def finetune_from_chat_history(epochs: int = 3) -> Tuple[bool, str]:
406
 
407
  # Create and start fine-tuning process
408
  tuner = FineTuner()
409
- success, message = tuner.train(num_train_epochs=epochs)
 
 
 
 
410
 
411
  return success, message
412
 
 
386
  except Exception as e:
387
  return False, f"Error uploading model to Hub: {str(e)}"
388
 
389
+ def finetune_from_chat_history(epochs: int = 3,
390
+ batch_size: int = 4,
391
+ learning_rate: float = 2e-4) -> Tuple[bool, str]:
392
  """
393
  Function to start fine-tuning process based on chat history
394
 
395
  Args:
396
  epochs: Number of training epochs
397
+ batch_size: Training batch size
398
+ learning_rate: Learning rate
399
 
400
  Returns:
401
  (success, message)
 
410
 
411
  # Create and start fine-tuning process
412
  tuner = FineTuner()
413
+ success, message = tuner.prepare_and_train(
414
+ num_train_epochs=epochs,
415
+ per_device_train_batch_size=batch_size,
416
+ learning_rate=learning_rate
417
+ )
418
 
419
  return success, message
420
 
web/training_interface.py CHANGED
@@ -169,13 +169,10 @@ def delete_model_action(model_row_index, models_df):
169
  def start_finetune_action(epochs, batch_size, learning_rate):
170
  """Start model fine-tuning"""
171
  try:
172
- from src.training.fine_tuner import FineTuner
173
 
174
- tuner = FineTuner()
175
- success, message = tuner.train(
176
- num_train_epochs=epochs,
177
- per_device_train_batch_size=batch_size,
178
- learning_rate=learning_rate
179
  )
180
 
181
  return f"Training {'completed' if success else 'failed'}: {message}"
 
169
  def start_finetune_action(epochs, batch_size, learning_rate):
170
  """Start model fine-tuning"""
171
  try:
172
+ from src.training.fine_tuner import finetune_from_chat_history
173
 
174
+ success, message = finetune_from_chat_history(
175
+ epochs=epochs
 
 
 
176
  )
177
 
178
  return f"Training {'completed' if success else 'failed'}: {message}"