Spaces:
Running
Running
Refactor fine-tuning process to include batch size and learning rate parameters in finetune_from_chat_history function
Browse files- src/training/fine_tuner.py +10 -2
- web/training_interface.py +3 -6
src/training/fine_tuner.py
CHANGED
|
@@ -386,12 +386,16 @@ def finetune_from_annotations(epochs=3, batch_size=4, learning_rate=2e-4, min_ra
|
|
| 386 |
except Exception as e:
|
| 387 |
return False, f"Error uploading model to Hub: {str(e)}"
|
| 388 |
|
| 389 |
-
def finetune_from_chat_history(epochs: int = 3
|
|
|
|
|
|
|
| 390 |
"""
|
| 391 |
Function to start fine-tuning process based on chat history
|
| 392 |
|
| 393 |
Args:
|
| 394 |
epochs: Number of training epochs
|
|
|
|
|
|
|
| 395 |
|
| 396 |
Returns:
|
| 397 |
(success, message)
|
|
@@ -406,7 +410,11 @@ def finetune_from_chat_history(epochs: int = 3) -> Tuple[bool, str]:
|
|
| 406 |
|
| 407 |
# Create and start fine-tuning process
|
| 408 |
tuner = FineTuner()
|
| 409 |
-
success, message = tuner.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
|
| 411 |
return success, message
|
| 412 |
|
|
|
|
| 386 |
except Exception as e:
|
| 387 |
return False, f"Error uploading model to Hub: {str(e)}"
|
| 388 |
|
| 389 |
+
def finetune_from_chat_history(epochs: int = 3,
|
| 390 |
+
batch_size: int = 4,
|
| 391 |
+
learning_rate: float = 2e-4) -> Tuple[bool, str]:
|
| 392 |
"""
|
| 393 |
Function to start fine-tuning process based on chat history
|
| 394 |
|
| 395 |
Args:
|
| 396 |
epochs: Number of training epochs
|
| 397 |
+
batch_size: Training batch size
|
| 398 |
+
learning_rate: Learning rate
|
| 399 |
|
| 400 |
Returns:
|
| 401 |
(success, message)
|
|
|
|
| 410 |
|
| 411 |
# Create and start fine-tuning process
|
| 412 |
tuner = FineTuner()
|
| 413 |
+
success, message = tuner.prepare_and_train(
|
| 414 |
+
num_train_epochs=epochs,
|
| 415 |
+
per_device_train_batch_size=batch_size,
|
| 416 |
+
learning_rate=learning_rate
|
| 417 |
+
)
|
| 418 |
|
| 419 |
return success, message
|
| 420 |
|
web/training_interface.py
CHANGED
|
@@ -169,13 +169,10 @@ def delete_model_action(model_row_index, models_df):
|
|
| 169 |
def start_finetune_action(epochs, batch_size, learning_rate):
|
| 170 |
"""Start model fine-tuning"""
|
| 171 |
try:
|
| 172 |
-
from src.training.fine_tuner import
|
| 173 |
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
num_train_epochs=epochs,
|
| 177 |
-
per_device_train_batch_size=batch_size,
|
| 178 |
-
learning_rate=learning_rate
|
| 179 |
)
|
| 180 |
|
| 181 |
return f"Training {'completed' if success else 'failed'}: {message}"
|
|
|
|
| 169 |
def start_finetune_action(epochs, batch_size, learning_rate):
|
| 170 |
"""Start model fine-tuning"""
|
| 171 |
try:
|
| 172 |
+
from src.training.fine_tuner import finetune_from_chat_history
|
| 173 |
|
| 174 |
+
success, message = finetune_from_chat_history(
|
| 175 |
+
epochs=epochs
|
|
|
|
|
|
|
|
|
|
| 176 |
)
|
| 177 |
|
| 178 |
return f"Training {'completed' if success else 'failed'}: {message}"
|