Spaces:
Running
Running
Enhance finetune_from_chat_history function to improve chat history loading, QA pair extraction, and add temporary file cleanup for training data
Browse files- src/training/fine_tuner.py +26 -4
src/training/fine_tuner.py
CHANGED
|
@@ -402,20 +402,42 @@ def finetune_from_chat_history(epochs: int = 3,
|
|
| 402 |
"""
|
| 403 |
# Analyze chats and prepare data
|
| 404 |
analyzer = ChatAnalyzer()
|
| 405 |
-
report = analyzer.
|
| 406 |
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
|
| 411 |
# Create and start fine-tuning process
|
| 412 |
tuner = FineTuner()
|
| 413 |
success, message = tuner.prepare_and_train(
|
|
|
|
| 414 |
num_train_epochs=epochs,
|
| 415 |
per_device_train_batch_size=batch_size,
|
| 416 |
learning_rate=learning_rate
|
| 417 |
)
|
| 418 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
return success, message
|
| 420 |
|
| 421 |
if __name__ == "__main__":
|
|
|
|
| 402 |
"""
|
| 403 |
# Analyze chats and prepare data
|
| 404 |
analyzer = ChatAnalyzer()
|
| 405 |
+
report = analyzer.analyze_chats()
|
| 406 |
|
| 407 |
+
if not report or "Failed to load chat history" in report:
|
| 408 |
+
return False, "Failed to load chat history for training"
|
| 409 |
+
|
| 410 |
+
# Extract QA pairs for training
|
| 411 |
+
qa_pairs = analyzer.extract_question_answer_pairs()
|
| 412 |
+
|
| 413 |
+
if len(qa_pairs) < 10:
|
| 414 |
+
return False, f"Insufficient data for fine-tuning. Only {len(qa_pairs)} QA pairs found."
|
| 415 |
+
|
| 416 |
+
# Create temporary file for training data
|
| 417 |
+
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.jsonl') as f:
|
| 418 |
+
for pair in qa_pairs:
|
| 419 |
+
json.dump({
|
| 420 |
+
"messages": [
|
| 421 |
+
{"role": "user", "content": pair["question"]},
|
| 422 |
+
{"role": "assistant", "content": pair["answer"]}
|
| 423 |
+
]
|
| 424 |
+
}, f, ensure_ascii=False)
|
| 425 |
+
f.write('\n')
|
| 426 |
+
training_data_path = f.name
|
| 427 |
|
| 428 |
# Create and start fine-tuning process
|
| 429 |
tuner = FineTuner()
|
| 430 |
success, message = tuner.prepare_and_train(
|
| 431 |
+
training_data_path=training_data_path,
|
| 432 |
num_train_epochs=epochs,
|
| 433 |
per_device_train_batch_size=batch_size,
|
| 434 |
learning_rate=learning_rate
|
| 435 |
)
|
| 436 |
|
| 437 |
+
# Cleanup
|
| 438 |
+
if os.path.exists(training_data_path):
|
| 439 |
+
os.remove(training_data_path)
|
| 440 |
+
|
| 441 |
return success, message
|
| 442 |
|
| 443 |
if __name__ == "__main__":
|