Upload folder using huggingface_hub
Browse files- .gitattributes +3 -0
- benchmark.py +1201 -0
- benchmark_report.txt +33 -0
- benchmark_report_2b-Copy1.txt +23 -0
- benchmark_report_2b.txt +23 -0
- benchmark_report_v2.txt +23 -0
- benchmark_report_wo_merging.txt +33 -0
- benchmark_results.png +3 -0
- benchmark_v1.py +803 -0
- benchmarking_v2.py +782 -0
- chat.py +339 -0
- data_preprocessor.py +524 -0
- finalmerged_model.zip +3 -0
- finetune_lfm.py +1311 -0
- finetune_lfm_complete_history.py +801 -0
- finetune_trl_supervised.py +221 -0
- merge_model.py +74 -0
- preprocess_kokoro_method.py +651 -0
- score_analysis_threshold_60.png +3 -0
- score_distribution.png +3 -0
- training_config.json +23 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
benchmark_results.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
score_analysis_threshold_60.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
score_distribution.png filter=lfs diff=lfs merge=lfs -text
|
benchmark.py
ADDED
|
@@ -0,0 +1,1201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 3 |
+
from peft import PeftModel
|
| 4 |
+
import json
|
| 5 |
+
from typing import Dict, List, Tuple
|
| 6 |
+
import numpy as np
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from sklearn.metrics import accuracy_score, f1_score
|
| 9 |
+
import evaluate
|
| 10 |
+
from datasets import load_dataset
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
import seaborn as sns
|
| 14 |
+
|
| 15 |
+
class CounselorBenchmark:
|
| 16 |
+
def __init__(self, base_model_path: str, finetuned_model_path: str):
|
| 17 |
+
"""
|
| 18 |
+
Initialize benchmark suite for counselor models
|
| 19 |
+
"""
|
| 20 |
+
self.base_model_path = base_model_path
|
| 21 |
+
self.finetuned_model_path = finetuned_model_path
|
| 22 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
+
|
| 24 |
+
# Load evaluation metrics
|
| 25 |
+
self.bleu = evaluate.load("sacrebleu")
|
| 26 |
+
self.rouge = evaluate.load("rouge")
|
| 27 |
+
self.bertscore = evaluate.load("bertscore")
|
| 28 |
+
|
| 29 |
+
def load_models(self):
|
| 30 |
+
"""Load both base and fine-tuned models for comparison"""
|
| 31 |
+
|
| 32 |
+
# Load base model
|
| 33 |
+
print("Loading base model...")
|
| 34 |
+
self.base_tokenizer = AutoTokenizer.from_pretrained(self.base_model_path)
|
| 35 |
+
self.base_model = AutoModelForCausalLM.from_pretrained(
|
| 36 |
+
self.base_model_path,
|
| 37 |
+
torch_dtype=torch.bfloat16,
|
| 38 |
+
device_map="auto"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Load fine-tuned model
|
| 42 |
+
print("Loading fine-tuned model...")
|
| 43 |
+
self.ft_tokenizer = AutoTokenizer.from_pretrained(self.finetuned_model_path)
|
| 44 |
+
self.ft_model = AutoModelForCausalLM.from_pretrained(
|
| 45 |
+
self.finetuned_model_path,
|
| 46 |
+
torch_dtype=torch.bfloat16,
|
| 47 |
+
device_map="auto"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def generate_response(self, model, tokenizer, prompt: str, max_length: int = 256):
|
| 51 |
+
"""Generate response from model"""
|
| 52 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
| 53 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 54 |
+
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
outputs = model.generate(
|
| 57 |
+
**inputs,
|
| 58 |
+
max_new_tokens=max_length,
|
| 59 |
+
temperature=0.7,
|
| 60 |
+
do_sample=True,
|
| 61 |
+
top_p=0.9,
|
| 62 |
+
repetition_penalty=1.1
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 66 |
+
# Extract only the generated part
|
| 67 |
+
response = response[len(prompt):].strip()
|
| 68 |
+
return response
|
| 69 |
+
|
| 70 |
+
def evaluate_empathy_score(self, response: str) -> float:
|
| 71 |
+
"""
|
| 72 |
+
Evaluate empathy in counselor response
|
| 73 |
+
Custom metric based on Japanese counseling keywords
|
| 74 |
+
"""
|
| 75 |
+
empathy_keywords = [
|
| 76 |
+
'わかります', '理解', '共感', '気持ち', '感じ',
|
| 77 |
+
'つらい', '大変', 'お察し', '心配', '支援'
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
score = sum(1 for keyword in empathy_keywords if keyword in response)
|
| 81 |
+
return min(score / len(empathy_keywords), 1.0)
|
| 82 |
+
|
| 83 |
+
def evaluate_response_quality(self, response: str) -> Dict[str, float]:
|
| 84 |
+
"""
|
| 85 |
+
Comprehensive response quality evaluation
|
| 86 |
+
"""
|
| 87 |
+
metrics = {}
|
| 88 |
+
|
| 89 |
+
# Length appropriateness (not too short, not too long)
|
| 90 |
+
response_length = len(response)
|
| 91 |
+
if 50 <= response_length <= 300:
|
| 92 |
+
metrics['length_score'] = 1.0
|
| 93 |
+
elif response_length < 50:
|
| 94 |
+
metrics['length_score'] = response_length / 50
|
| 95 |
+
else:
|
| 96 |
+
metrics['length_score'] = max(0, 1 - (response_length - 300) / 500)
|
| 97 |
+
|
| 98 |
+
# Question engagement (does counselor ask clarifying questions?)
|
| 99 |
+
metrics['question_score'] = 1.0 if '?' in response or 'か?' in response else 0.0
|
| 100 |
+
|
| 101 |
+
# Supportive language
|
| 102 |
+
support_phrases = ['大丈夫', '一緒に', '支援', 'サポート', '助け']
|
| 103 |
+
metrics['support_score'] = sum(1 for phrase in support_phrases if phrase in response) / len(support_phrases)
|
| 104 |
+
|
| 105 |
+
# Empathy score
|
| 106 |
+
metrics['empathy_score'] = self.evaluate_empathy_score(response)
|
| 107 |
+
|
| 108 |
+
return metrics
|
| 109 |
+
|
| 110 |
+
def benchmark_on_test_set(self, test_data_path: str, num_samples: int = 100):
|
| 111 |
+
"""
|
| 112 |
+
Run comprehensive benchmark on test set
|
| 113 |
+
"""
|
| 114 |
+
# Load test data
|
| 115 |
+
test_dataset = load_dataset('json', data_files=test_data_path, split='train')
|
| 116 |
+
test_samples = test_dataset.select(range(min(num_samples, len(test_dataset))))
|
| 117 |
+
|
| 118 |
+
results = {
|
| 119 |
+
'base_model': {'responses': [], 'metrics': []},
|
| 120 |
+
'finetuned_model': {'responses': [], 'metrics': []}
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
print(f"Evaluating on {len(test_samples)} test samples...")
|
| 124 |
+
|
| 125 |
+
for sample in tqdm(test_samples):
|
| 126 |
+
prompt = sample['text'].split('### Response:')[0] + '### Response:'
|
| 127 |
+
reference = sample['text'].split('### Response:')[1].strip() if '### Response:' in sample['text'] else ""
|
| 128 |
+
|
| 129 |
+
# Generate responses
|
| 130 |
+
base_response = self.generate_response(self.base_model, self.base_tokenizer, prompt)
|
| 131 |
+
ft_response = self.generate_response(self.ft_model, self.ft_tokenizer, prompt)
|
| 132 |
+
|
| 133 |
+
# Store responses
|
| 134 |
+
results['base_model']['responses'].append(base_response)
|
| 135 |
+
results['finetuned_model']['responses'].append(ft_response)
|
| 136 |
+
|
| 137 |
+
# Evaluate quality
|
| 138 |
+
base_metrics = self.evaluate_response_quality(base_response)
|
| 139 |
+
ft_metrics = self.evaluate_response_quality(ft_response)
|
| 140 |
+
|
| 141 |
+
results['base_model']['metrics'].append(base_metrics)
|
| 142 |
+
results['finetuned_model']['metrics'].append(ft_metrics)
|
| 143 |
+
|
| 144 |
+
return results
|
| 145 |
+
|
| 146 |
+
def calculate_aggregate_metrics(self, results: Dict) -> Dict:
|
| 147 |
+
"""Calculate aggregate metrics for comparison"""
|
| 148 |
+
aggregate = {}
|
| 149 |
+
|
| 150 |
+
for model_name in ['base_model', 'finetuned_model']:
|
| 151 |
+
model_metrics = results[model_name]['metrics']
|
| 152 |
+
|
| 153 |
+
aggregate[model_name] = {}
|
| 154 |
+
|
| 155 |
+
# Calculate average for each metric
|
| 156 |
+
metric_names = model_metrics[0].keys() if model_metrics else []
|
| 157 |
+
|
| 158 |
+
for metric in metric_names:
|
| 159 |
+
values = [m[metric] for m in model_metrics]
|
| 160 |
+
aggregate[model_name][metric] = {
|
| 161 |
+
'mean': np.mean(values),
|
| 162 |
+
'std': np.std(values),
|
| 163 |
+
'min': np.min(values),
|
| 164 |
+
'max': np.max(values)
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
return aggregate
|
| 168 |
+
|
| 169 |
+
def generate_comparison_report(self, results: Dict, aggregate: Dict):
|
| 170 |
+
"""Generate detailed comparison report"""
|
| 171 |
+
|
| 172 |
+
report = []
|
| 173 |
+
report.append("=" * 80)
|
| 174 |
+
report.append("COUNSELOR MODEL BENCHMARK REPORT")
|
| 175 |
+
report.append("=" * 80)
|
| 176 |
+
report.append("")
|
| 177 |
+
|
| 178 |
+
# Overall performance comparison
|
| 179 |
+
report.append("PERFORMANCE COMPARISON:")
|
| 180 |
+
report.append("-" * 40)
|
| 181 |
+
|
| 182 |
+
for metric in aggregate['base_model'].keys():
|
| 183 |
+
base_score = aggregate['base_model'][metric]['mean']
|
| 184 |
+
ft_score = aggregate['finetuned_model'][metric]['mean']
|
| 185 |
+
improvement = ((ft_score - base_score) / base_score * 100) if base_score > 0 else 0
|
| 186 |
+
|
| 187 |
+
report.append(f"\n{metric.upper()}:")
|
| 188 |
+
report.append(f" Base Model: {base_score:.3f} (±{aggregate['base_model'][metric]['std']:.3f})")
|
| 189 |
+
report.append(f" Fine-tuned Model: {ft_score:.3f} (±{aggregate['finetuned_model'][metric]['std']:.3f})")
|
| 190 |
+
report.append(f" Improvement: {improvement:+.1f}%")
|
| 191 |
+
|
| 192 |
+
# Calculate overall score
|
| 193 |
+
base_overall = np.mean([aggregate['base_model'][m]['mean'] for m in aggregate['base_model']])
|
| 194 |
+
ft_overall = np.mean([aggregate['finetuned_model'][m]['mean'] for m in aggregate['finetuned_model']])
|
| 195 |
+
overall_improvement = ((ft_overall - base_overall) / base_overall * 100) if base_overall > 0 else 0
|
| 196 |
+
|
| 197 |
+
report.append("\n" + "=" * 40)
|
| 198 |
+
report.append("OVERALL PERFORMANCE:")
|
| 199 |
+
report.append(f" Base Model: {base_overall:.3f}")
|
| 200 |
+
report.append(f" Fine-tuned Model: {ft_overall:.3f}")
|
| 201 |
+
report.append(f" Overall Improvement: {overall_improvement:+.1f}%")
|
| 202 |
+
report.append("=" * 40)
|
| 203 |
+
|
| 204 |
+
return "\n".join(report)
|
| 205 |
+
|
| 206 |
+
def visualize_results(self, aggregate: Dict):
|
| 207 |
+
"""Create visualization of benchmark results"""
|
| 208 |
+
|
| 209 |
+
# Prepare data for plotting
|
| 210 |
+
metrics = list(aggregate['base_model'].keys())
|
| 211 |
+
base_scores = [aggregate['base_model'][m]['mean'] for m in metrics]
|
| 212 |
+
ft_scores = [aggregate['finetuned_model'][m]['mean'] for m in metrics]
|
| 213 |
+
|
| 214 |
+
# Create comparison plot
|
| 215 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
|
| 216 |
+
|
| 217 |
+
# Bar plot comparison
|
| 218 |
+
x = np.arange(len(metrics))
|
| 219 |
+
width = 0.35
|
| 220 |
+
|
| 221 |
+
ax1.bar(x - width/2, base_scores, width, label='Base Model', color='lightblue')
|
| 222 |
+
ax1.bar(x + width/2, ft_scores, width, label='Fine-tuned Model', color='darkblue')
|
| 223 |
+
ax1.set_xlabel('Metrics')
|
| 224 |
+
ax1.set_ylabel('Score')
|
| 225 |
+
ax1.set_title('Model Performance Comparison')
|
| 226 |
+
ax1.set_xticks(x)
|
| 227 |
+
ax1.set_xticklabels(metrics, rotation=45, ha='right')
|
| 228 |
+
ax1.legend()
|
| 229 |
+
ax1.grid(True, alpha=0.3)
|
| 230 |
+
|
| 231 |
+
# Improvement percentage plot
|
| 232 |
+
improvements = [((ft - base) / base * 100) if base > 0 else 0
|
| 233 |
+
for base, ft in zip(base_scores, ft_scores)]
|
| 234 |
+
|
| 235 |
+
colors = ['green' if imp > 0 else 'red' for imp in improvements]
|
| 236 |
+
ax2.bar(metrics, improvements, color=colors, alpha=0.7)
|
| 237 |
+
ax2.set_xlabel('Metrics')
|
| 238 |
+
ax2.set_ylabel('Improvement (%)')
|
| 239 |
+
ax2.set_title('Fine-tuning Improvement over Base Model')
|
| 240 |
+
ax2.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
|
| 241 |
+
ax2.set_xticklabels(metrics, rotation=45, ha='right')
|
| 242 |
+
ax2.grid(True, alpha=0.3)
|
| 243 |
+
|
| 244 |
+
plt.tight_layout()
|
| 245 |
+
plt.savefig('benchmark_results.png', dpi=300, bbox_inches='tight')
|
| 246 |
+
plt.show()
|
| 247 |
+
|
| 248 |
+
print("Visualization saved as 'benchmark_results.png'")
|
| 249 |
+
|
| 250 |
+
# Run benchmarking
|
| 251 |
+
if __name__ == "__main__":
|
| 252 |
+
# Initialize benchmark
|
| 253 |
+
benchmark = CounselorBenchmark(
|
| 254 |
+
base_model_path="./models/LFM2-2.6B",
|
| 255 |
+
finetuned_model_path="./merged_counselor_mode_2b"
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# Load models
|
| 259 |
+
benchmark.load_models()
|
| 260 |
+
|
| 261 |
+
# Run benchmark
|
| 262 |
+
print("Running benchmark evaluation...")
|
| 263 |
+
results = benchmark.benchmark_on_test_set("./processed_data_score80/test.jsonl", num_samples=100)
|
| 264 |
+
|
| 265 |
+
# Calculate aggregate metrics
|
| 266 |
+
aggregate = benchmark.calculate_aggregate_metrics(results)
|
| 267 |
+
|
| 268 |
+
# Generate report
|
| 269 |
+
report = benchmark.generate_comparison_report(results, aggregate)
|
| 270 |
+
print(report)
|
| 271 |
+
|
| 272 |
+
# Save report
|
| 273 |
+
with open("benchmark_report_2b.txt", "w") as f:
|
| 274 |
+
f.write(report)
|
| 275 |
+
|
| 276 |
+
# Visualize results
|
| 277 |
+
benchmark.visualize_results(aggregate)
|
| 278 |
+
|
| 279 |
+
print("\nBenchmarking completed! Check 'benchmark_report.txt' for detailed results.")
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
####################
|
| 283 |
+
|
| 284 |
+
# import torch
|
| 285 |
+
# from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 286 |
+
# from peft import PeftModel, PeftConfig
|
| 287 |
+
# import numpy as np
|
| 288 |
+
# from typing import List, Dict, Tuple, Optional
|
| 289 |
+
# import json
|
| 290 |
+
# from tqdm import tqdm
|
| 291 |
+
# import os
|
| 292 |
+
# import gc
|
| 293 |
+
# import warnings
|
| 294 |
+
# from datetime import datetime
|
| 295 |
+
# import pandas as pd
|
| 296 |
+
# import matplotlib.pyplot as plt
|
| 297 |
+
# import seaborn as sns
|
| 298 |
+
# from nltk.translate.bleu_score import sentence_bleu, corpus_bleu, SmoothingFunction
|
| 299 |
+
# from rouge_score import rouge_scorer
|
| 300 |
+
# import nltk
|
| 301 |
+
# from collections import defaultdict
|
| 302 |
+
|
| 303 |
+
# # Download required NLTK data
|
| 304 |
+
# try:
|
| 305 |
+
# nltk.download('punkt', quiet=True)
|
| 306 |
+
# except:
|
| 307 |
+
# pass
|
| 308 |
+
|
| 309 |
+
# warnings.filterwarnings('ignore')
|
| 310 |
+
|
| 311 |
+
# class AdvancedCounselorBenchmark:
|
| 312 |
+
# def __init__(self,
|
| 313 |
+
# base_model_name: str = "LiquidAI/LFM2-1.2B",
|
| 314 |
+
# finetuned_model_path: str = "./counselor_model/best_model",
|
| 315 |
+
# merged_model_path: str = "./merged_counselor_model",
|
| 316 |
+
# test_data_path: str = "./processed_data_score70/test.jsonl",
|
| 317 |
+
# device: str = None):
|
| 318 |
+
# """
|
| 319 |
+
# Initialize advanced benchmark suite with BLEU and ROUGE metrics
|
| 320 |
+
|
| 321 |
+
# Args:
|
| 322 |
+
# base_model_name: Name/path of base model
|
| 323 |
+
# finetuned_model_path: Path to fine-tuned LoRA adapter
|
| 324 |
+
# merged_model_path: Path to save/load merged model
|
| 325 |
+
# test_data_path: Path to test dataset with reference responses
|
| 326 |
+
# device: Device to run on (cuda/cpu)
|
| 327 |
+
# """
|
| 328 |
+
# self.base_model_name = base_model_name
|
| 329 |
+
# self.finetuned_model_path = finetuned_model_path
|
| 330 |
+
# self.merged_model_path = merged_model_path
|
| 331 |
+
# self.test_data_path = test_data_path
|
| 332 |
+
# self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 333 |
+
|
| 334 |
+
# print(f"🔧 Initializing Advanced Benchmark Suite")
|
| 335 |
+
# print(f" Device: {self.device}")
|
| 336 |
+
# if self.device == "cuda":
|
| 337 |
+
# print(f" GPU: {torch.cuda.get_device_name(0)}")
|
| 338 |
+
# print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
| 339 |
+
|
| 340 |
+
# # Initialize ROUGE scorer
|
| 341 |
+
# self.rouge_scorer = rouge_scorer.RougeScorer(
|
| 342 |
+
# ['rouge1', 'rouge2', 'rougeL'],
|
| 343 |
+
# use_stemmer=False, # Set to False for Japanese
|
| 344 |
+
# lang='japanese'
|
| 345 |
+
# )
|
| 346 |
+
|
| 347 |
+
# # Smoothing function for BLEU scores
|
| 348 |
+
# self.smoothing = SmoothingFunction().method1
|
| 349 |
+
|
| 350 |
+
# self.results = {}
|
| 351 |
+
|
| 352 |
+
# def load_test_data(self) -> List[Dict]:
|
| 353 |
+
# """Load test dataset with reference responses"""
|
| 354 |
+
# print(f"\n📚 Loading test data from {self.test_data_path}")
|
| 355 |
+
|
| 356 |
+
# test_data = []
|
| 357 |
+
# if os.path.exists(self.test_data_path):
|
| 358 |
+
# with open(self.test_data_path, 'r', encoding='utf-8') as f:
|
| 359 |
+
# for line in f:
|
| 360 |
+
# data = json.loads(line)
|
| 361 |
+
# test_data.append(data)
|
| 362 |
+
# print(f" Loaded {len(test_data)} test examples")
|
| 363 |
+
# else:
|
| 364 |
+
# print(f"⚠️ Test data not found. Creating synthetic test data...")
|
| 365 |
+
# test_data = self.create_synthetic_test_data()
|
| 366 |
+
|
| 367 |
+
# return test_data
|
| 368 |
+
|
| 369 |
+
# def create_synthetic_test_data(self) -> List[Dict]:
|
| 370 |
+
# """Create synthetic test data if real data is not available"""
|
| 371 |
+
# synthetic_data = [
|
| 372 |
+
# {
|
| 373 |
+
# "text": "### Input:\n最近ストレスを感じています。\n\n### Response:\nストレスを感じているのですね。それは大変つらいことだと思います。どのような状況でストレスを感じることが��いですか?お話を聞かせていただければ、一緒に対処法を考えることができます。",
|
| 374 |
+
# "input": "最近ストレスを感じています。",
|
| 375 |
+
# "reference": "ストレスを感じているのですね。それは大変つらいことだと思います。どのような状況でストレスを感じることが多いですか?お話を聞かせていただければ、一緒に対処法を考えることができます。"
|
| 376 |
+
# },
|
| 377 |
+
# {
|
| 378 |
+
# "text": "### Input:\n仕事がうまくいかなくて悩んでいます。\n\n### Response:\n仕事でお悩みなのですね。うまくいかないと感じると、本当に辛いですよね。具体的にどのような点で困難を感じていらっしゃいますか?一緒に整理してみましょう。",
|
| 379 |
+
# "input": "仕事がうまくいかなくて悩んでいます。",
|
| 380 |
+
# "reference": "仕事でお悩みなのですね。うまくいかないと感じると、本当に辛いですよね。具体的にどのような点で困難を感じていらっしゃいますか?一緒に整理してみましょう。"
|
| 381 |
+
# },
|
| 382 |
+
# {
|
| 383 |
+
# "text": "### Input:\n人間関係で困っています。\n\n### Response:\n人間関係の悩みは本当に心が疲れますよね。お気持ちお察しします。どのような関係性でお困りでしょうか?職場、家族、友人関係など、もう少し詳しくお聞かせいただけますか?",
|
| 384 |
+
# "input": "人間関係で困っています。",
|
| 385 |
+
# "reference": "人間関係の悩みは本当に心が疲れますよね。お気持ちお察しします。どのような関係性でお困りでしょうか?職場、家族、友人関係など、もう少し詳しくお聞かせいただけますか?"
|
| 386 |
+
# },
|
| 387 |
+
# {
|
| 388 |
+
# "text": "### Input:\n将来が不安です。\n\n### Response:\n将来への不安を抱えていらっしゃるのですね。先が見えない不安は、とても重く感じられることと思います。特にどのような点について不安を感じていらっしゃいますか?",
|
| 389 |
+
# "input": "将来が不安です。",
|
| 390 |
+
# "reference": "将来への不安を抱えていらっしゃるのですね。先が見えない不安は、とても重く感じられることと思います。特にどのような点について不安を感じていらっしゃいますか?"
|
| 391 |
+
# },
|
| 392 |
+
# {
|
| 393 |
+
# "text": "### Input:\n自信が持てません。\n\n### Response:\n自信が持てないというお気持ち、よくわかります。多くの方が同じような悩みを抱えています。どのような場面で特に自信が持てないと感じますか?あなたの強みも一緒に見つけていきましょう。",
|
| 394 |
+
# "input": "自信が持てません。",
|
| 395 |
+
# "reference": "自信が持てないというお気持ち、よくわかります。多くの方が同じような悩みを抱えています。どのような場面で特に自信が持てないと感じますか?あなたの強みも一緒に見つけていきましょう。"
|
| 396 |
+
# }
|
| 397 |
+
# ]
|
| 398 |
+
# return synthetic_data
|
| 399 |
+
|
| 400 |
+
# def merge_and_save_model(self, force_merge: bool = False):
|
| 401 |
+
# """Merge LoRA weights with base model and save"""
|
| 402 |
+
# if os.path.exists(self.merged_model_path) and not force_merge:
|
| 403 |
+
# print(f"✅ Merged model already exists at {self.merged_model_path}")
|
| 404 |
+
# return
|
| 405 |
+
|
| 406 |
+
# print("\n🔄 Merging LoRA adapter with base model...")
|
| 407 |
+
|
| 408 |
+
# try:
|
| 409 |
+
# # Load base model
|
| 410 |
+
# print(" Loading base model...")
|
| 411 |
+
# base_model = AutoModelForCausalLM.from_pretrained(
|
| 412 |
+
# self.base_model_name,
|
| 413 |
+
# torch_dtype=torch.float16,
|
| 414 |
+
# device_map="auto" if self.device == "cuda" else None,
|
| 415 |
+
# trust_remote_code=True,
|
| 416 |
+
# low_cpu_mem_usage=True
|
| 417 |
+
# )
|
| 418 |
+
|
| 419 |
+
# # Check if adapter exists
|
| 420 |
+
# adapter_config_path = os.path.join(self.finetuned_model_path, "adapter_config.json")
|
| 421 |
+
# if not os.path.exists(adapter_config_path):
|
| 422 |
+
# print(f"⚠️ No LoRA adapter found at {self.finetuned_model_path}")
|
| 423 |
+
# model = base_model
|
| 424 |
+
# else:
|
| 425 |
+
# # Load LoRA adapter
|
| 426 |
+
# print(" Loading LoRA adapter...")
|
| 427 |
+
# model = PeftModel.from_pretrained(
|
| 428 |
+
# base_model,
|
| 429 |
+
# self.finetuned_model_path,
|
| 430 |
+
# torch_dtype=torch.float16
|
| 431 |
+
# )
|
| 432 |
+
|
| 433 |
+
# # Merge weights
|
| 434 |
+
# print(" Merging weights...")
|
| 435 |
+
# model = model.merge_and_unload()
|
| 436 |
+
|
| 437 |
+
# # Save merged model
|
| 438 |
+
# print(f" Saving merged model to {self.merged_model_path}...")
|
| 439 |
+
# model.save_pretrained(self.merged_model_path)
|
| 440 |
+
|
| 441 |
+
# # Save tokenizer
|
| 442 |
+
# tokenizer = AutoTokenizer.from_pretrained(
|
| 443 |
+
# self.finetuned_model_path
|
| 444 |
+
# if os.path.exists(os.path.join(self.finetuned_model_path, "tokenizer_config.json"))
|
| 445 |
+
# else self.base_model_name
|
| 446 |
+
# )
|
| 447 |
+
# tokenizer.save_pretrained(self.merged_model_path)
|
| 448 |
+
|
| 449 |
+
# print("✅ Model merged and saved successfully!")
|
| 450 |
+
|
| 451 |
+
# # Clean up memory
|
| 452 |
+
# del base_model, model
|
| 453 |
+
# gc.collect()
|
| 454 |
+
# torch.cuda.empty_cache()
|
| 455 |
+
|
| 456 |
+
# except Exception as e:
|
| 457 |
+
# print(f"❌ Error during merging: {e}")
|
| 458 |
+
# raise
|
| 459 |
+
|
| 460 |
+
# def load_models(self):
|
| 461 |
+
# """Load base and fine-tuned models for comparison"""
|
| 462 |
+
# print("\n📚 Loading models for benchmarking...")
|
| 463 |
+
|
| 464 |
+
# # Load tokenizer
|
| 465 |
+
# self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name)
|
| 466 |
+
# if self.tokenizer.pad_token is None:
|
| 467 |
+
# self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 468 |
+
|
| 469 |
+
# # Load base model
|
| 470 |
+
# print(" Loading base model...")
|
| 471 |
+
# self.base_model = AutoModelForCausalLM.from_pretrained(
|
| 472 |
+
# self.base_model_name,
|
| 473 |
+
# torch_dtype=torch.float16,
|
| 474 |
+
# device_map="auto" if self.device == "cuda" else None,
|
| 475 |
+
# trust_remote_code=True,
|
| 476 |
+
# low_cpu_mem_usage=True
|
| 477 |
+
# )
|
| 478 |
+
# self.base_model.eval()
|
| 479 |
+
|
| 480 |
+
# # Load merged fine-tuned model
|
| 481 |
+
# if os.path.exists(self.merged_model_path):
|
| 482 |
+
# print(" Loading merged fine-tuned model...")
|
| 483 |
+
# self.finetuned_model = AutoModelForCausalLM.from_pretrained(
|
| 484 |
+
# self.merged_model_path,
|
| 485 |
+
# torch_dtype=torch.float16,
|
| 486 |
+
# device_map="auto" if self.device == "cuda" else None,
|
| 487 |
+
# trust_remote_code=True,
|
| 488 |
+
# low_cpu_mem_usage=True
|
| 489 |
+
# )
|
| 490 |
+
# else:
|
| 491 |
+
# print(" Loading fine-tuned model (attempting PEFT)...")
|
| 492 |
+
# try:
|
| 493 |
+
# base_for_peft = AutoModelForCausalLM.from_pretrained(
|
| 494 |
+
# self.base_model_name,
|
| 495 |
+
# torch_dtype=torch.float16,
|
| 496 |
+
# device_map="auto" if self.device == "cuda" else None,
|
| 497 |
+
# trust_remote_code=True,
|
| 498 |
+
# low_cpu_mem_usage=True
|
| 499 |
+
# )
|
| 500 |
+
# self.finetuned_model = PeftModel.from_pretrained(
|
| 501 |
+
# base_for_peft,
|
| 502 |
+
# self.finetuned_model_path,
|
| 503 |
+
# torch_dtype=torch.float16
|
| 504 |
+
# )
|
| 505 |
+
# except:
|
| 506 |
+
# self.finetuned_model = AutoModelForCausalLM.from_pretrained(
|
| 507 |
+
# self.finetuned_model_path,
|
| 508 |
+
# torch_dtype=torch.float16,
|
| 509 |
+
# device_map="auto" if self.device == "cuda" else None,
|
| 510 |
+
# trust_remote_code=True,
|
| 511 |
+
# low_cpu_mem_usage=True
|
| 512 |
+
# )
|
| 513 |
+
|
| 514 |
+
# self.finetuned_model.eval()
|
| 515 |
+
# print("✅ Models loaded successfully!")
|
| 516 |
+
|
| 517 |
+
# def generate_response(self, model, prompt: str, max_length: int = 150) -> str:
|
| 518 |
+
# """Generate response from model"""
|
| 519 |
+
# inputs = self.tokenizer(
|
| 520 |
+
# prompt,
|
| 521 |
+
# return_tensors="pt",
|
| 522 |
+
# truncation=True,
|
| 523 |
+
# max_length=512
|
| 524 |
+
# )
|
| 525 |
+
|
| 526 |
+
# if self.device == "cuda":
|
| 527 |
+
# inputs = {k: v.cuda() for k, v in inputs.items()}
|
| 528 |
+
|
| 529 |
+
# with torch.no_grad():
|
| 530 |
+
# outputs = model.generate(
|
| 531 |
+
# **inputs,
|
| 532 |
+
# max_new_tokens=max_length,
|
| 533 |
+
# temperature=0.7,
|
| 534 |
+
# do_sample=True,
|
| 535 |
+
# top_p=0.9,
|
| 536 |
+
# pad_token_id=self.tokenizer.pad_token_id,
|
| 537 |
+
# eos_token_id=self.tokenizer.eos_token_id
|
| 538 |
+
# )
|
| 539 |
+
|
| 540 |
+
# response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 541 |
+
# # Extract only the generated response
|
| 542 |
+
# if "### Response:" in response:
|
| 543 |
+
# response = response.split("### Response:")[-1].strip()
|
| 544 |
+
# elif "Response:" in response:
|
| 545 |
+
# response = response.split("Response:")[-1].strip()
|
| 546 |
+
# else:
|
| 547 |
+
# # Remove the input prompt from response
|
| 548 |
+
# response = response[len(prompt):].strip()
|
| 549 |
+
|
| 550 |
+
# return response
|
| 551 |
+
|
| 552 |
+
# def tokenize_japanese(self, text: str) -> List[str]:
|
| 553 |
+
# """Tokenize Japanese text for BLEU calculation"""
|
| 554 |
+
# # Simple character-based tokenization for Japanese
|
| 555 |
+
# # In production, use MeCab or similar for better tokenization
|
| 556 |
+
# import re
|
| 557 |
+
|
| 558 |
+
# # Remove special characters and split
|
| 559 |
+
# text = re.sub(r'[。、!?\n]', ' ', text)
|
| 560 |
+
# tokens = text.strip().split()
|
| 561 |
+
|
| 562 |
+
# # Character-level tokenization as fallback
|
| 563 |
+
# if not tokens:
|
| 564 |
+
# tokens = list(text.strip())
|
| 565 |
+
|
| 566 |
+
# return tokens
|
| 567 |
+
|
| 568 |
+
# def calculate_bleu_scores(self, reference: str, hypothesis: str) -> Dict[str, float]:
|
| 569 |
+
# """Calculate BLEU-1, BLEU-2, BLEU-3, BLEU-4 scores"""
|
| 570 |
+
# # Tokenize texts
|
| 571 |
+
# ref_tokens = self.tokenize_japanese(reference)
|
| 572 |
+
# hyp_tokens = self.tokenize_japanese(hypothesis)
|
| 573 |
+
|
| 574 |
+
# # Calculate BLEU scores with different n-grams
|
| 575 |
+
# scores = {}
|
| 576 |
+
|
| 577 |
+
# # BLEU-1 (unigram)
|
| 578 |
+
# scores['BLEU-1'] = sentence_bleu(
|
| 579 |
+
# [ref_tokens], hyp_tokens,
|
| 580 |
+
# weights=(1.0, 0, 0, 0),
|
| 581 |
+
# smoothing_function=self.smoothing
|
| 582 |
+
# )
|
| 583 |
+
|
| 584 |
+
# # BLEU-2 (bigram)
|
| 585 |
+
# scores['BLEU-2'] = sentence_bleu(
|
| 586 |
+
# [ref_tokens], hyp_tokens,
|
| 587 |
+
# weights=(0.5, 0.5, 0, 0),
|
| 588 |
+
# smoothing_function=self.smoothing
|
| 589 |
+
# )
|
| 590 |
+
|
| 591 |
+
# # BLEU-3 (trigram)
|
| 592 |
+
# scores['BLEU-3'] = sentence_bleu(
|
| 593 |
+
# [ref_tokens], hyp_tokens,
|
| 594 |
+
# weights=(0.33, 0.33, 0.34, 0),
|
| 595 |
+
# smoothing_function=self.smoothing
|
| 596 |
+
# )
|
| 597 |
+
|
| 598 |
+
# # BLEU-4 (4-gram)
|
| 599 |
+
# scores['BLEU-4'] = sentence_bleu(
|
| 600 |
+
# [ref_tokens], hyp_tokens,
|
| 601 |
+
# weights=(0.25, 0.25, 0.25, 0.25),
|
| 602 |
+
# smoothing_function=self.smoothing
|
| 603 |
+
# )
|
| 604 |
+
|
| 605 |
+
# return scores
|
| 606 |
+
|
| 607 |
+
# def calculate_rouge_scores(self, reference: str, hypothesis: str) -> Dict[str, float]:
|
| 608 |
+
# """Calculate ROUGE-1, ROUGE-2, ROUGE-L scores"""
|
| 609 |
+
# scores = self.rouge_scorer.score(reference, hypothesis)
|
| 610 |
+
|
| 611 |
+
# return {
|
| 612 |
+
# 'ROUGE-1': scores['rouge1'].fmeasure,
|
| 613 |
+
# 'ROUGE-2': scores['rouge2'].fmeasure,
|
| 614 |
+
# 'ROUGE-L': scores['rougeL'].fmeasure
|
| 615 |
+
# }
|
| 616 |
+
|
| 617 |
+
# def run_bleu_rouge_benchmark(self, num_samples: int = None):
|
| 618 |
+
# """Run comprehensive BLEU and ROUGE benchmark"""
|
| 619 |
+
# print("\n" + "="*70)
|
| 620 |
+
# print("🏃 RUNNING BLEU & ROUGE BENCHMARK")
|
| 621 |
+
# print("="*70)
|
| 622 |
+
|
| 623 |
+
# # Load test data
|
| 624 |
+
# test_data = self.load_test_data()
|
| 625 |
+
|
| 626 |
+
# if num_samples:
|
| 627 |
+
# test_data = test_data[:num_samples]
|
| 628 |
+
# print(f" Using {num_samples} samples for benchmarking")
|
| 629 |
+
|
| 630 |
+
# # Initialize score collectors
|
| 631 |
+
# base_scores = defaultdict(list)
|
| 632 |
+
# finetuned_scores = defaultdict(list)
|
| 633 |
+
|
| 634 |
+
# # Metrics to calculate
|
| 635 |
+
# metrics = ['BLEU-1', 'BLEU-2', 'BLEU-3', 'BLEU-4',
|
| 636 |
+
# 'ROUGE-1', 'ROUGE-2', 'ROUGE-L']
|
| 637 |
+
|
| 638 |
+
# print(f"\n📊 Evaluating {len(test_data)} test examples...")
|
| 639 |
+
# print("-" * 70)
|
| 640 |
+
|
| 641 |
+
# detailed_results = []
|
| 642 |
+
|
| 643 |
+
# for i, example in enumerate(tqdm(test_data, desc="Evaluating")):
|
| 644 |
+
# # Extract input and reference
|
| 645 |
+
# if 'input' in example:
|
| 646 |
+
# input_text = example['input']
|
| 647 |
+
# else:
|
| 648 |
+
# # Try to extract from text field
|
| 649 |
+
# if "### Input:" in example['text']:
|
| 650 |
+
# input_text = example['text'].split("### Input:")[1].split("### Response:")[0].strip()
|
| 651 |
+
# else:
|
| 652 |
+
# input_text = example['text'].split("\n")[0].strip()
|
| 653 |
+
|
| 654 |
+
# if 'reference' in example:
|
| 655 |
+
# reference = example['reference']
|
| 656 |
+
# else:
|
| 657 |
+
# # Try to extract from text field
|
| 658 |
+
# if "### Response:" in example['text']:
|
| 659 |
+
# reference = example['text'].split("### Response:")[1].strip()
|
| 660 |
+
# else:
|
| 661 |
+
# parts = example['text'].split("\n")
|
| 662 |
+
# reference = parts[1] if len(parts) > 1 else parts[0]
|
| 663 |
+
|
| 664 |
+
# # Format input for models
|
| 665 |
+
# formatted_input = f"### Instruction:\nあなたは思いやりのある心理カウンセラーです。\n\n### Input:\n{input_text}\n\n### Response:\n"
|
| 666 |
+
|
| 667 |
+
# # Generate responses
|
| 668 |
+
# base_response = self.generate_response(self.base_model, formatted_input)
|
| 669 |
+
# finetuned_response = self.generate_response(self.finetuned_model, formatted_input)
|
| 670 |
+
|
| 671 |
+
# # Calculate BLEU scores
|
| 672 |
+
# base_bleu = self.calculate_bleu_scores(reference, base_response)
|
| 673 |
+
# finetuned_bleu = self.calculate_bleu_scores(reference, finetuned_response)
|
| 674 |
+
|
| 675 |
+
# # Calculate ROUGE scores
|
| 676 |
+
# base_rouge = self.calculate_rouge_scores(reference, base_response)
|
| 677 |
+
# finetuned_rouge = self.calculate_rouge_scores(reference, finetuned_response)
|
| 678 |
+
|
| 679 |
+
# # Combine scores
|
| 680 |
+
# base_all_scores = {**base_bleu, **base_rouge}
|
| 681 |
+
# finetuned_all_scores = {**finetuned_bleu, **finetuned_rouge}
|
| 682 |
+
|
| 683 |
+
# # Collect scores
|
| 684 |
+
# for metric in metrics:
|
| 685 |
+
# base_scores[metric].append(base_all_scores[metric])
|
| 686 |
+
# finetuned_scores[metric].append(finetuned_all_scores[metric])
|
| 687 |
+
|
| 688 |
+
# # Store detailed results
|
| 689 |
+
# detailed_results.append({
|
| 690 |
+
# 'input': input_text,
|
| 691 |
+
# 'reference': reference,
|
| 692 |
+
# 'base_response': base_response,
|
| 693 |
+
# 'finetuned_response': finetuned_response,
|
| 694 |
+
# 'base_scores': base_all_scores,
|
| 695 |
+
# 'finetuned_scores': finetuned_all_scores
|
| 696 |
+
# })
|
| 697 |
+
|
| 698 |
+
# # Print sample results
|
| 699 |
+
# if i < 3: # Show first 3 examples
|
| 700 |
+
# print(f"\n📝 Example {i+1}:")
|
| 701 |
+
# print(f" Input: {input_text[:50]}...")
|
| 702 |
+
# print(f" Reference: {reference[:50]}...")
|
| 703 |
+
# print(f" Base response: {base_response[:50]}...")
|
| 704 |
+
# print(f" Fine-tuned response: {finetuned_response[:50]}...")
|
| 705 |
+
# print(f" Base BLEU-4: {base_bleu['BLEU-4']:.3f}")
|
| 706 |
+
# print(f" Fine-tuned BLEU-4: {finetuned_bleu['BLEU-4']:.3f}")
|
| 707 |
+
|
| 708 |
+
# # Calculate aggregate statistics
|
| 709 |
+
# print("\n" + "="*70)
|
| 710 |
+
# print("📈 BENCHMARK RESULTS")
|
| 711 |
+
# print("="*70)
|
| 712 |
+
|
| 713 |
+
# self.results = {
|
| 714 |
+
# 'detailed_results': detailed_results,
|
| 715 |
+
# 'aggregate_scores': {},
|
| 716 |
+
# 'improvements': {}
|
| 717 |
+
# }
|
| 718 |
+
|
| 719 |
+
# # Print and store results
|
| 720 |
+
# print("\n" + "-"*70)
|
| 721 |
+
# print(f"{'Metric':<12} {'Base Model':<20} {'Fine-tuned Model':<20} {'Improvement':<15}")
|
| 722 |
+
# print("-"*70)
|
| 723 |
+
|
| 724 |
+
# for metric in metrics:
|
| 725 |
+
# base_mean = np.mean(base_scores[metric])
|
| 726 |
+
# base_std = np.std(base_scores[metric])
|
| 727 |
+
# finetuned_mean = np.mean(finetuned_scores[metric])
|
| 728 |
+
# finetuned_std = np.std(finetuned_scores[metric])
|
| 729 |
+
|
| 730 |
+
# # Calculate improvement
|
| 731 |
+
# if base_mean > 0:
|
| 732 |
+
# improvement = ((finetuned_mean - base_mean) / base_mean) * 100
|
| 733 |
+
# else:
|
| 734 |
+
# improvement = 0
|
| 735 |
+
|
| 736 |
+
# # Store results
|
| 737 |
+
# self.results['aggregate_scores'][metric] = {
|
| 738 |
+
# 'base_mean': base_mean,
|
| 739 |
+
# 'base_std': base_std,
|
| 740 |
+
# 'finetuned_mean': finetuned_mean,
|
| 741 |
+
# 'finetuned_std': finetuned_std
|
| 742 |
+
# }
|
| 743 |
+
# self.results['improvements'][metric] = improvement
|
| 744 |
+
|
| 745 |
+
# # Print results
|
| 746 |
+
# base_str = f"{base_mean:.3f} (±{base_std:.3f})"
|
| 747 |
+
# finetuned_str = f"{finetuned_mean:.3f} (±{finetuned_std:.3f})"
|
| 748 |
+
# imp_str = f"{improvement:+.1f}%"
|
| 749 |
+
|
| 750 |
+
# # Color code improvement
|
| 751 |
+
# if improvement > 0:
|
| 752 |
+
# imp_str = f"✅ {imp_str}"
|
| 753 |
+
# elif improvement < 0:
|
| 754 |
+
# imp_str = f"⚠️ {imp_str}"
|
| 755 |
+
# else:
|
| 756 |
+
# imp_str = f"➖ {imp_str}"
|
| 757 |
+
|
| 758 |
+
# print(f"{metric:<12} {base_str:<20} {finetuned_str:<20} {imp_str:<15}")
|
| 759 |
+
|
| 760 |
+
# # Calculate overall scores
|
| 761 |
+
# print("\n" + "="*70)
|
| 762 |
+
# print("🎯 OVERALL PERFORMANCE")
|
| 763 |
+
# print("="*70)
|
| 764 |
+
|
| 765 |
+
# # Average BLEU score
|
| 766 |
+
# bleu_metrics = ['BLEU-1', 'BLEU-2', 'BLEU-3', 'BLEU-4']
|
| 767 |
+
# base_bleu_avg = np.mean([np.mean(base_scores[m]) for m in bleu_metrics])
|
| 768 |
+
# finetuned_bleu_avg = np.mean([np.mean(finetuned_scores[m]) for m in bleu_metrics])
|
| 769 |
+
# bleu_improvement = ((finetuned_bleu_avg - base_bleu_avg) / base_bleu_avg) * 100 if base_bleu_avg > 0 else 0
|
| 770 |
+
|
| 771 |
+
# # Average ROUGE score
|
| 772 |
+
# rouge_metrics = ['ROUGE-1', 'ROUGE-2', 'ROUGE-L']
|
| 773 |
+
# base_rouge_avg = np.mean([np.mean(base_scores[m]) for m in rouge_metrics])
|
| 774 |
+
# finetuned_rouge_avg = np.mean([np.mean(finetuned_scores[m]) for m in rouge_metrics])
|
| 775 |
+
# rouge_improvement = ((finetuned_rouge_avg - base_rouge_avg) / base_rouge_avg) * 100 if base_rouge_avg > 0 else 0
|
| 776 |
+
|
| 777 |
+
# # Overall average
|
| 778 |
+
# base_overall = np.mean([np.mean(base_scores[m]) for m in metrics])
|
| 779 |
+
# finetuned_overall = np.mean([np.mean(finetuned_scores[m]) for m in metrics])
|
| 780 |
+
# overall_improvement = ((finetuned_overall - base_overall) / base_overall) * 100 if base_overall > 0 else 0
|
| 781 |
+
|
| 782 |
+
# self.results['summary'] = {
|
| 783 |
+
# 'bleu_average': {
|
| 784 |
+
# 'base': base_bleu_avg,
|
| 785 |
+
# 'finetuned': finetuned_bleu_avg,
|
| 786 |
+
# 'improvement': bleu_improvement
|
| 787 |
+
# },
|
| 788 |
+
# 'rouge_average': {
|
| 789 |
+
# 'base': base_rouge_avg,
|
| 790 |
+
# 'finetuned': finetuned_rouge_avg,
|
| 791 |
+
# 'improvement': rouge_improvement
|
| 792 |
+
# },
|
| 793 |
+
# 'overall': {
|
| 794 |
+
# 'base': base_overall,
|
| 795 |
+
# 'finetuned': finetuned_overall,
|
| 796 |
+
# 'improvement': overall_improvement
|
| 797 |
+
# }
|
| 798 |
+
# }
|
| 799 |
+
|
| 800 |
+
# print(f"\n📊 Average BLEU Score:")
|
| 801 |
+
# print(f" Base Model: {base_bleu_avg:.3f}")
|
| 802 |
+
# print(f" Fine-tuned Model: {finetuned_bleu_avg:.3f}")
|
| 803 |
+
# print(f" Improvement: {bleu_improvement:+.1f}%")
|
| 804 |
+
|
| 805 |
+
# print(f"\n📊 Average ROUGE Score:")
|
| 806 |
+
# print(f" Base Model: {base_rouge_avg:.3f}")
|
| 807 |
+
# print(f" Fine-tuned Model: {finetuned_rouge_avg:.3f}")
|
| 808 |
+
# print(f" Improvement: {rouge_improvement:+.1f}%")
|
| 809 |
+
|
| 810 |
+
# print(f"\n🎯 Overall Average:")
|
| 811 |
+
# print(f" Base Model: {base_overall:.3f}")
|
| 812 |
+
# print(f" Fine-tuned Model: {finetuned_overall:.3f}")
|
| 813 |
+
# print(f" Improvement: {overall_improvement:+.1f}%")
|
| 814 |
+
|
| 815 |
+
# print("="*70)
|
| 816 |
+
|
| 817 |
+
# return self.results
|
| 818 |
+
|
| 819 |
+
# def visualize_results(self, save_path: str = "bleu_rouge_benchmark.png"):
|
| 820 |
+
# """Create comprehensive visualization of BLEU and ROUGE results"""
|
| 821 |
+
# if 'aggregate_scores' not in self.results:
|
| 822 |
+
# print("❌ No results to visualize. Run benchmark first.")
|
| 823 |
+
# return
|
| 824 |
+
|
| 825 |
+
# print("\n📊 Creating visualizations...")
|
| 826 |
+
|
| 827 |
+
# fig, axes = plt.subplots(2, 3, figsize=(18, 12))
|
| 828 |
+
|
| 829 |
+
# # Color scheme
|
| 830 |
+
# base_color = '#3498db'
|
| 831 |
+
# finetuned_color = '#e74c3c'
|
| 832 |
+
# improvement_positive = '#27ae60'
|
| 833 |
+
# improvement_negative = '#c0392b'
|
| 834 |
+
|
| 835 |
+
# # 1. BLEU Scores Comparison
|
| 836 |
+
# bleu_metrics = ['BLEU-1', 'BLEU-2', 'BLEU-3', 'BLEU-4']
|
| 837 |
+
# bleu_base = [self.results['aggregate_scores'][m]['base_mean'] for m in bleu_metrics]
|
| 838 |
+
# bleu_finetuned = [self.results['aggregate_scores'][m]['finetuned_mean'] for m in bleu_metrics]
|
| 839 |
+
|
| 840 |
+
# x = np.arange(len(bleu_metrics))
|
| 841 |
+
# width = 0.35
|
| 842 |
+
|
| 843 |
+
# axes[0, 0].bar(x - width/2, bleu_base, width, label='Base Model',
|
| 844 |
+
# color=base_color, alpha=0.8)
|
| 845 |
+
# axes[0, 0].bar(x + width/2, bleu_finetuned, width, label='Fine-tuned Model',
|
| 846 |
+
# color=finetuned_color, alpha=0.8)
|
| 847 |
+
# axes[0, 0].set_xlabel('BLEU Metrics')
|
| 848 |
+
# axes[0, 0].set_ylabel('Score')
|
| 849 |
+
# axes[0, 0].set_title('BLEU Score Comparison')
|
| 850 |
+
# axes[0, 0].set_xticks(x)
|
| 851 |
+
# axes[0, 0].set_xticklabels(bleu_metrics)
|
| 852 |
+
# axes[0, 0].legend()
|
| 853 |
+
# axes[0, 0].grid(True, alpha=0.3)
|
| 854 |
+
# axes[0, 0].set_ylim([0, max(max(bleu_base), max(bleu_finetuned)) * 1.2])
|
| 855 |
+
|
| 856 |
+
# # 2. ROUGE Scores Comparison
|
| 857 |
+
# rouge_metrics = ['ROUGE-1', 'ROUGE-2', 'ROUGE-L']
|
| 858 |
+
# rouge_base = [self.results['aggregate_scores'][m]['base_mean'] for m in rouge_metrics]
|
| 859 |
+
# rouge_finetuned = [self.results['aggregate_scores'][m]['finetuned_mean'] for m in rouge_metrics]
|
| 860 |
+
|
| 861 |
+
# x = np.arange(len(rouge_metrics))
|
| 862 |
+
|
| 863 |
+
# axes[0, 1].bar(x - width/2, rouge_base, width, label='Base Model',
|
| 864 |
+
# color=base_color, alpha=0.8)
|
| 865 |
+
# axes[0, 1].bar(x + width/2, rouge_finetuned, width, label='Fine-tuned Model',
|
| 866 |
+
# color=finetuned_color, alpha=0.8)
|
| 867 |
+
# axes[0, 1].set_xlabel('ROUGE Metrics')
|
| 868 |
+
# axes[0, 1].set_ylabel('Score')
|
| 869 |
+
# axes[0, 1].set_title('ROUGE Score Comparison')
|
| 870 |
+
# axes[0, 1].set_xticks(x)
|
| 871 |
+
# axes[0, 1].set_xticklabels(rouge_metrics)
|
| 872 |
+
# axes[0, 1].legend()
|
| 873 |
+
# axes[0, 1].grid(True, alpha=0.3)
|
| 874 |
+
# axes[0, 1].set_ylim([0, max(max(rouge_base), max(rouge_finetuned)) * 1.2])
|
| 875 |
+
|
| 876 |
+
# # 3. Improvement Percentages
|
| 877 |
+
# all_metrics = bleu_metrics + rouge_metrics
|
| 878 |
+
# improvements = [self.results['improvements'][m] for m in all_metrics]
|
| 879 |
+
# colors = [improvement_positive if imp > 0 else improvement_negative for imp in improvements]
|
| 880 |
+
|
| 881 |
+
# axes[0, 2].barh(range(len(all_metrics)), improvements, color=colors, alpha=0.7)
|
| 882 |
+
# axes[0, 2].set_yticks(range(len(all_metrics)))
|
| 883 |
+
# axes[0, 2].set_yticklabels(all_metrics)
|
| 884 |
+
# axes[0, 2].set_xlabel('Improvement (%)')
|
| 885 |
+
# axes[0, 2].set_title('Performance Improvement by Metric')
|
| 886 |
+
# axes[0, 2].axvline(x=0, color='black', linestyle='-', linewidth=0.5)
|
| 887 |
+
# axes[0, 2].grid(True, alpha=0.3, axis='x')
|
| 888 |
+
|
| 889 |
+
# # 4. Line plot showing progression
|
| 890 |
+
# axes[1, 0].plot(bleu_metrics, bleu_base, 'o-', label='Base Model',
|
| 891 |
+
# color=base_color, linewidth=2, markersize=8)
|
| 892 |
+
# axes[1, 0].plot(bleu_metrics, bleu_finetuned, 's-', label='Fine-tuned Model',
|
| 893 |
+
# color=finetuned_color, linewidth=2, markersize=8)
|
| 894 |
+
# axes[1, 0].set_xlabel('BLEU N-gram')
|
| 895 |
+
# axes[1, 0].set_ylabel('Score')
|
| 896 |
+
# axes[1, 0].set_title('BLEU Score Progression')
|
| 897 |
+
# axes[1, 0].legend()
|
| 898 |
+
# axes[1, 0].grid(True, alpha=0.3)
|
| 899 |
+
|
| 900 |
+
# # 5. Summary Statistics
|
| 901 |
+
# ax5 = axes[1, 1]
|
| 902 |
+
# ax5.axis('off')
|
| 903 |
+
|
| 904 |
+
# summary_text = f"""
|
| 905 |
+
# BENCHMARK SUMMARY
|
| 906 |
+
# {'='*30}
|
| 907 |
+
|
| 908 |
+
# BLEU Average:
|
| 909 |
+
# Base: {self.results['summary']['bleu_average']['base']:.3f}
|
| 910 |
+
# Fine-tuned: {self.results['summary']['bleu_average']['finetuned']:.3f}
|
| 911 |
+
# Improvement: {self.results['summary']['bleu_average']['improvement']:+.1f}%
|
| 912 |
+
|
| 913 |
+
# ROUGE Average:
|
| 914 |
+
# Base: {self.results['summary']['rouge_average']['base']:.3f}
|
| 915 |
+
# Fine-tuned: {self.results['summary']['rouge_average']['finetuned']:.3f}
|
| 916 |
+
# Improvement: {self.results['summary']['rouge_average']['improvement']:+.1f}%
|
| 917 |
+
|
| 918 |
+
# Overall Performance:
|
| 919 |
+
# Base: {self.results['summary']['overall']['base']:.3f}
|
| 920 |
+
# Fine-tuned: {self.results['summary']['overall']['finetuned']:.3f}
|
| 921 |
+
# Improvement: {self.results['summary']['overall']['improvement']:+.1f}%
|
| 922 |
+
|
| 923 |
+
# Best Improvements:
|
| 924 |
+
# """
|
| 925 |
+
|
| 926 |
+
# # Find best improvements
|
| 927 |
+
# sorted_metrics = sorted(all_metrics,
|
| 928 |
+
# key=lambda m: self.results['improvements'][m],
|
| 929 |
+
# reverse=True)
|
| 930 |
+
|
| 931 |
+
# for m in sorted_metrics[:2]:
|
| 932 |
+
# summary_text += f" • {m}: {self.results['improvements'][m]:+.1f}%\n"
|
| 933 |
+
|
| 934 |
+
# if any(self.results['improvements'][m] < 0 for m in all_metrics):
|
| 935 |
+
# summary_text += f"\nNeeds Attention:\n"
|
| 936 |
+
# for m in sorted_metrics[-2:]:
|
| 937 |
+
# if self.results['improvements'][m] < 0:
|
| 938 |
+
# summary_text += f" • {m}: {self.results['improvements'][m]:+.1f}%\n"
|
| 939 |
+
|
| 940 |
+
# ax5.text(0.1, 0.9, summary_text, transform=ax5.transAxes,
|
| 941 |
+
# fontsize=10, verticalalignment='top', fontfamily='monospace')
|
| 942 |
+
|
| 943 |
+
# # 6. Heatmap of all scores
|
| 944 |
+
# metrics_for_heatmap = all_metrics
|
| 945 |
+
# models = ['Base', 'Fine-tuned']
|
| 946 |
+
|
| 947 |
+
# heatmap_data = []
|
| 948 |
+
# for metric in metrics_for_heatmap:
|
| 949 |
+
# heatmap_data.append([
|
| 950 |
+
# self.results['aggregate_scores'][metric]['base_mean'],
|
| 951 |
+
# self.results['aggregate_scores'][metric]['finetuned_mean']
|
| 952 |
+
# ])
|
| 953 |
+
|
| 954 |
+
# im = axes[1, 2].imshow(heatmap_data, cmap='YlOrRd', aspect='auto')
|
| 955 |
+
# axes[1, 2].set_xticks(np.arange(len(models)))
|
| 956 |
+
# axes[1, 2].set_yticks(np.arange(len(metrics_for_heatmap)))
|
| 957 |
+
# axes[1, 2].set_xticklabels(models)
|
| 958 |
+
# axes[1, 2].set_yticklabels(metrics_for_heatmap)
|
| 959 |
+
# axes[1, 2].set_title('Score Heatmap')
|
| 960 |
+
|
| 961 |
+
# # Add text annotations
|
| 962 |
+
# for i in range(len(metrics_for_heatmap)):
|
| 963 |
+
# for j in range(len(models)):
|
| 964 |
+
# text = axes[1, 2].text(j, i, f'{heatmap_data[i][j]:.3f}',
|
| 965 |
+
# ha="center", va="center", color="black", fontsize=8)
|
| 966 |
+
|
| 967 |
+
# plt.colorbar(im, ax=axes[1, 2])
|
| 968 |
+
|
| 969 |
+
# plt.suptitle('BLEU & ROUGE Benchmark Results', fontsize=16, fontweight='bold')
|
| 970 |
+
# plt.tight_layout()
|
| 971 |
+
# plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 972 |
+
# print(f"✅ Visualization saved to {save_path}")
|
| 973 |
+
|
| 974 |
+
# plt.show()
|
| 975 |
+
|
| 976 |
+
# def save_results(self, output_path: str = "bleu_rouge_results.json"):
|
| 977 |
+
# """Save benchmark results to JSON"""
|
| 978 |
+
# # Convert numpy types to Python native types for JSON serialization
|
| 979 |
+
# def convert_to_native(obj):
|
| 980 |
+
# if isinstance(obj, np.floating):
|
| 981 |
+
# return float(obj)
|
| 982 |
+
# elif isinstance(obj, np.integer):
|
| 983 |
+
# return int(obj)
|
| 984 |
+
# elif isinstance(obj, np.ndarray):
|
| 985 |
+
# return obj.tolist()
|
| 986 |
+
# elif isinstance(obj, dict):
|
| 987 |
+
# return {k: convert_to_native(v) for k, v in obj.items()}
|
| 988 |
+
# elif isinstance(obj, list):
|
| 989 |
+
# return [convert_to_native(item) for item in obj]
|
| 990 |
+
# return obj
|
| 991 |
+
|
| 992 |
+
# results_native = convert_to_native(self.results)
|
| 993 |
+
|
| 994 |
+
# with open(output_path, 'w', encoding='utf-8') as f:
|
| 995 |
+
# json.dump(results_native, f, ensure_ascii=False, indent=2)
|
| 996 |
+
# print(f"✅ Results saved to {output_path}")
|
| 997 |
+
|
| 998 |
+
# def generate_detailed_report(self, output_path: str = "bleu_rouge_report.md"):
|
| 999 |
+
# """Generate detailed markdown report"""
|
| 1000 |
+
# if not self.results:
|
| 1001 |
+
# print("❌ No results to report. Run benchmark first.")
|
| 1002 |
+
# return
|
| 1003 |
+
|
| 1004 |
+
# report = f"""# BLEU & ROUGE Benchmark Report
|
| 1005 |
+
# Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
| 1006 |
+
|
| 1007 |
+
# ## Executive Summary
|
| 1008 |
+
|
| 1009 |
+
# Comprehensive evaluation of the fine-tuned counseling model using BLEU and ROUGE metrics.
|
| 1010 |
+
|
| 1011 |
+
# ### Overall Performance
|
| 1012 |
+
# - **Base Model Score**: {self.results['summary']['overall']['base']:.3f}
|
| 1013 |
+
# - **Fine-tuned Model Score**: {self.results['summary']['overall']['finetuned']:.3f}
|
| 1014 |
+
# - **Overall Improvement**: {self.results['summary']['overall']['improvement']:+.1f}%
|
| 1015 |
+
|
| 1016 |
+
# ## Detailed Metrics
|
| 1017 |
+
|
| 1018 |
+
# ### BLEU Scores
|
| 1019 |
+
# | Metric | Base Model | Fine-tuned Model | Improvement |
|
| 1020 |
+
# |--------|------------|------------------|-------------|
|
| 1021 |
+
# """
|
| 1022 |
+
|
| 1023 |
+
# for metric in ['BLEU-1', 'BLEU-2', 'BLEU-3', 'BLEU-4']:
|
| 1024 |
+
# scores = self.results['aggregate_scores'][metric]
|
| 1025 |
+
# report += f"| {metric} | {scores['base_mean']:.3f} (±{scores['base_std']:.3f}) | "
|
| 1026 |
+
# report += f"{scores['finetuned_mean']:.3f} (±{scores['finetuned_std']:.3f}) | "
|
| 1027 |
+
# report += f"{self.results['improvements'][metric]:+.1f}% |\n"
|
| 1028 |
+
|
| 1029 |
+
# report += f"""
|
| 1030 |
+
|
| 1031 |
+
# **BLEU Average**: {self.results['summary']['bleu_average']['improvement']:+.1f}% improvement
|
| 1032 |
+
|
| 1033 |
+
# ### ROUGE Scores
|
| 1034 |
+
# | Metric | Base Model | Fine-tuned Model | Improvement |
|
| 1035 |
+
# |--------|------------|------------------|-------------|
|
| 1036 |
+
# """
|
| 1037 |
+
|
| 1038 |
+
# for metric in ['ROUGE-1', 'ROUGE-2', 'ROUGE-L']:
|
| 1039 |
+
# scores = self.results['aggregate_scores'][metric]
|
| 1040 |
+
# report += f"| {metric} | {scores['base_mean']:.3f} (±{scores['base_std']:.3f}) | "
|
| 1041 |
+
# report += f"{scores['finetuned_mean']:.3f} (±{scores['finetuned_std']:.3f}) | "
|
| 1042 |
+
# report += f"{self.results['improvements'][metric]:+.1f}% |\n"
|
| 1043 |
+
|
| 1044 |
+
# report += f"""
|
| 1045 |
+
|
| 1046 |
+
# **ROUGE Average**: {self.results['summary']['rouge_average']['improvement']:+.1f}% improvement
|
| 1047 |
+
|
| 1048 |
+
# ## Sample Outputs
|
| 1049 |
+
|
| 1050 |
+
# """
|
| 1051 |
+
|
| 1052 |
+
# # Add sample outputs
|
| 1053 |
+
# for i, result in enumerate(self.results['detailed_results'][:3]):
|
| 1054 |
+
# report += f"""### Example {i+1}
|
| 1055 |
+
|
| 1056 |
+
# **Input**: {result['input']}
|
| 1057 |
+
|
| 1058 |
+
# **Reference**: {result['reference'][:200]}...
|
| 1059 |
+
|
| 1060 |
+
# **Base Model Response**: {result['base_response'][:200]}...
|
| 1061 |
+
|
| 1062 |
+
# **Fine-tuned Model Response**: {result['finetuned_response'][:200]}...
|
| 1063 |
+
|
| 1064 |
+
# **Scores**:
|
| 1065 |
+
# - Base BLEU-4: {result['base_scores']['BLEU-4']:.3f}, ROUGE-L: {result['base_scores']['ROUGE-L']:.3f}
|
| 1066 |
+
# - Fine-tuned BLEU-4: {result['finetuned_scores']['BLEU-4']:.3f}, ROUGE-L: {result['finetuned_scores']['ROUGE-L']:.3f}
|
| 1067 |
+
|
| 1068 |
+
# ---
|
| 1069 |
+
|
| 1070 |
+
# """
|
| 1071 |
+
|
| 1072 |
+
# report += """## Analysis & Recommendations
|
| 1073 |
+
|
| 1074 |
+
# """
|
| 1075 |
+
|
| 1076 |
+
# overall_imp = self.results['summary']['overall']['improvement']
|
| 1077 |
+
|
| 1078 |
+
# if overall_imp < -10:
|
| 1079 |
+
# report += """### ⚠️ Significant Performance Degradation
|
| 1080 |
+
|
| 1081 |
+
# The fine-tuned model shows significant degradation in BLEU/ROUGE scores. This indicates:
|
| 1082 |
+
|
| 1083 |
+
# 1. **Catastrophic Forgetting**: The model has lost its language generation capabilities
|
| 1084 |
+
# 2. **Overfitting**: The model memorized training data instead of learning patterns
|
| 1085 |
+
# 3. **Format Mismatch**: Training and inference formats may differ
|
| 1086 |
+
|
| 1087 |
+
# **Immediate Actions Required**:
|
| 1088 |
+
# - ✅ Ensure proper model merging (LoRA weights with base model)
|
| 1089 |
+
# - ✅ Reduce learning rate (try 1e-5 or 2e-5)
|
| 1090 |
+
# - ✅ Use smaller LoRA rank (r=4 or r=8)
|
| 1091 |
+
# - ✅ Mix general conversation data with counseling data (80/20 ratio)
|
| 1092 |
+
# - ✅ Implement regularization (weight decay=0.1, dropout=0.1)
|
| 1093 |
+
# - ✅ Use early stopping with patience=3
|
| 1094 |
+
# """
|
| 1095 |
+
# elif overall_imp < 0:
|
| 1096 |
+
# report += """### ⚠️ Minor Performance Degradation
|
| 1097 |
+
|
| 1098 |
+
# The model shows slight degradation. Common causes:
|
| 1099 |
+
|
| 1100 |
+
# 1. **Aggressive Fine-tuning**: Parameters changed too much
|
| 1101 |
+
# 2. **Limited Training Data**: Not enough diverse examples
|
| 1102 |
+
# 3. **Domain Shift**: Counseling domain too different from base training
|
| 1103 |
+
|
| 1104 |
+
# **Recommended Actions**:
|
| 1105 |
+
# - ✅ Fine-tune for fewer epochs (1-2 instead of 3)
|
| 1106 |
+
# - ✅ Use gradient accumulation for larger effective batch size
|
| 1107 |
+
# - ✅ Implement knowledge distillation from base model
|
| 1108 |
+
# - ✅ Add more diverse training examples
|
| 1109 |
+
# """
|
| 1110 |
+
# elif overall_imp < 10:
|
| 1111 |
+
# report += """### 📊 Modest Improvement
|
| 1112 |
+
|
| 1113 |
+
# The model shows small but positive improvements.
|
| 1114 |
+
|
| 1115 |
+
# **To Further Improve**:
|
| 1116 |
+
# - ✅ Increase training data quality and quantity
|
| 1117 |
+
# - ✅ Experiment with different generation parameters
|
| 1118 |
+
# - ✅ Fine-tune on domain-specific pre-training
|
| 1119 |
+
# - ✅ Use ensemble methods with base model
|
| 1120 |
+
# """
|
| 1121 |
+
# else:
|
| 1122 |
+
# report += """### ✅ Significant Improvement
|
| 1123 |
+
|
| 1124 |
+
# Excellent results! The fine-tuned model shows substantial improvements.
|
| 1125 |
+
|
| 1126 |
+
# **Next Steps**:
|
| 1127 |
+
# - ✅ Deploy for A/B testing with users
|
| 1128 |
+
# - ✅ Monitor performance on edge cases
|
| 1129 |
+
# - ✅ Consider model compression for deployment
|
| 1130 |
+
# - ✅ Collect user feedback for iterative improvement
|
| 1131 |
+
# """
|
| 1132 |
+
|
| 1133 |
+
# with open(output_path, 'w', encoding='utf-8') as f:
|
| 1134 |
+
# f.write(report)
|
| 1135 |
+
|
| 1136 |
+
# print(f"✅ Detailed report saved to {output_path}")
|
| 1137 |
+
|
| 1138 |
+
# # Main execution
|
| 1139 |
+
# if __name__ == "__main__":
|
| 1140 |
+
# import argparse
|
| 1141 |
+
|
| 1142 |
+
# parser = argparse.ArgumentParser(description='Advanced BLEU & ROUGE Benchmark')
|
| 1143 |
+
# parser.add_argument('--base_model', type=str, default='LiquidAI/LFM2-2.6B',
|
| 1144 |
+
# help='Base model name')
|
| 1145 |
+
# parser.add_argument('--finetuned_path', type=str, default='./counselor_model/best_model',
|
| 1146 |
+
# help='Path to fine-tuned model')
|
| 1147 |
+
# parser.add_argument('--merged_path', type=str, default='./merged_counselor_mode_2b',
|
| 1148 |
+
# help='Path to save/load merged model')
|
| 1149 |
+
# parser.add_argument('--test_data', type=str, default='./processed_data_score80/test.jsonl',
|
| 1150 |
+
# help='Path to test data')
|
| 1151 |
+
# parser.add_argument('--num_samples', type=int, default=None,
|
| 1152 |
+
# help='Number of samples to evaluate (None for all)')
|
| 1153 |
+
# parser.add_argument('--force_merge', action='store_true',
|
| 1154 |
+
# help='Force re-merge even if merged model exists')
|
| 1155 |
+
# parser.add_argument('--skip_merge', action='store_true',
|
| 1156 |
+
# help='Skip merging step')
|
| 1157 |
+
# parser.add_argument('--output_dir', type=str, default='./benchmark_results',
|
| 1158 |
+
# help='Directory to save results')
|
| 1159 |
+
|
| 1160 |
+
# args = parser.parse_args()
|
| 1161 |
+
|
| 1162 |
+
# # Create output directory
|
| 1163 |
+
# os.makedirs(args.output_dir, exist_ok=True)
|
| 1164 |
+
|
| 1165 |
+
# try:
|
| 1166 |
+
# # Initialize benchmark
|
| 1167 |
+
# print("🚀 Initializing Advanced BLEU & ROUGE Benchmark")
|
| 1168 |
+
# benchmark = AdvancedCounselorBenchmark(
|
| 1169 |
+
# base_model_name=args.base_model,
|
| 1170 |
+
# finetuned_model_path=args.finetuned_path,
|
| 1171 |
+
# merged_model_path=args.merged_path,
|
| 1172 |
+
# test_data_path=args.test_data
|
| 1173 |
+
# )
|
| 1174 |
+
|
| 1175 |
+
# # Merge models if needed
|
| 1176 |
+
# if not args.skip_merge:
|
| 1177 |
+
# benchmark.merge_and_save_model(force_merge=args.force_merge)
|
| 1178 |
+
|
| 1179 |
+
# # Load models
|
| 1180 |
+
# benchmark.load_models()
|
| 1181 |
+
|
| 1182 |
+
# # Run BLEU & ROUGE benchmark
|
| 1183 |
+
# results = benchmark.run_bleu_rouge_benchmark(num_samples=args.num_samples)
|
| 1184 |
+
|
| 1185 |
+
# # Save results
|
| 1186 |
+
# benchmark.save_results(os.path.join(args.output_dir, "bleu_rouge_results_2b.json"))
|
| 1187 |
+
|
| 1188 |
+
# # Generate visualizations
|
| 1189 |
+
# benchmark.visualize_results(os.path.join(args.output_dir, "bleu_rouge_visualization_2b.png"))
|
| 1190 |
+
|
| 1191 |
+
# # Generate detailed report
|
| 1192 |
+
# benchmark.generate_detailed_report(os.path.join(args.output_dir, "bleu_rouge_report_2b.md"))
|
| 1193 |
+
|
| 1194 |
+
# print("\n✅ BLEU & ROUGE Benchmarking completed successfully!")
|
| 1195 |
+
# print(f"📁 Results saved to {args.output_dir}/")
|
| 1196 |
+
|
| 1197 |
+
# except Exception as e:
|
| 1198 |
+
# print(f"\n❌ Error during benchmarking: {e}")
|
| 1199 |
+
# import traceback
|
| 1200 |
+
# traceback.print_exc()
|
| 1201 |
+
|
benchmark_report.txt
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
================================================================================
|
| 2 |
+
COUNSELOR MODEL BENCHMARK REPORT
|
| 3 |
+
================================================================================
|
| 4 |
+
|
| 5 |
+
PERFORMANCE COMPARISON:
|
| 6 |
+
----------------------------------------
|
| 7 |
+
|
| 8 |
+
LENGTH_SCORE:
|
| 9 |
+
Base Model: 0.809 (±0.153)
|
| 10 |
+
Fine-tuned Model: 0.840 (±0.174)
|
| 11 |
+
Improvement: +3.8%
|
| 12 |
+
|
| 13 |
+
QUESTION_SCORE:
|
| 14 |
+
Base Model: 0.660 (±0.474)
|
| 15 |
+
Fine-tuned Model: 0.850 (±0.357)
|
| 16 |
+
Improvement: +28.8%
|
| 17 |
+
|
| 18 |
+
SUPPORT_SCORE:
|
| 19 |
+
Base Model: 0.248 (±0.184)
|
| 20 |
+
Fine-tuned Model: 0.088 (±0.124)
|
| 21 |
+
Improvement: -64.5%
|
| 22 |
+
|
| 23 |
+
EMPATHY_SCORE:
|
| 24 |
+
Base Model: 0.262 (±0.086)
|
| 25 |
+
Fine-tuned Model: 0.152 (±0.114)
|
| 26 |
+
Improvement: -42.0%
|
| 27 |
+
|
| 28 |
+
========================================
|
| 29 |
+
OVERALL PERFORMANCE:
|
| 30 |
+
Base Model: 0.495
|
| 31 |
+
Fine-tuned Model: 0.483
|
| 32 |
+
Overall Improvement: -2.5%
|
| 33 |
+
========================================
|
benchmark_report_2b-Copy1.txt
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
================================================================================
|
| 2 |
+
COUNSELOR MODEL BENCHMARK REPORT
|
| 3 |
+
================================================================================
|
| 4 |
+
|
| 5 |
+
PERFORMANCE COMPARISON:
|
| 6 |
+
----------------------------------------
|
| 7 |
+
|
| 8 |
+
LENGTH_SCORE:
|
| 9 |
+
Base Model: 0.876 (±0.138)
|
| 10 |
+
Fine-tuned Model: 0.956 (±0.135)
|
| 11 |
+
Improvement: +9.2%
|
| 12 |
+
|
| 13 |
+
QUESTION_SCORE:
|
| 14 |
+
Base Model: 0.670 (±0.470)
|
| 15 |
+
Fine-tuned Model: 0.900 (±0.300)
|
| 16 |
+
Improvement: +34.3%
|
| 17 |
+
|
| 18 |
+
========================================
|
| 19 |
+
OVERALL PERFORMANCE:
|
| 20 |
+
Base Model: 0.773
|
| 21 |
+
Fine-tuned Model: 0.928
|
| 22 |
+
Overall Improvement: +20.1%
|
| 23 |
+
========================================
|
benchmark_report_2b.txt
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
================================================================================
|
| 2 |
+
COUNSELOR MODEL BENCHMARK REPORT
|
| 3 |
+
================================================================================
|
| 4 |
+
|
| 5 |
+
PERFORMANCE COMPARISON:
|
| 6 |
+
----------------------------------------
|
| 7 |
+
|
| 8 |
+
LENGTH_SCORE:
|
| 9 |
+
Base Model: 0.876 (±0.138)
|
| 10 |
+
Fine-tuned Model: 0.956 (±0.135)
|
| 11 |
+
Improvement: +9.2%
|
| 12 |
+
|
| 13 |
+
QUESTION_SCORE:
|
| 14 |
+
Base Model: 0.670 (±0.470)
|
| 15 |
+
Fine-tuned Model: 0.900 (±0.300)
|
| 16 |
+
Improvement: +34.3%
|
| 17 |
+
|
| 18 |
+
========================================
|
| 19 |
+
OVERALL PERFORMANCE:
|
| 20 |
+
Base Model: 0.773
|
| 21 |
+
Fine-tuned Model: 0.928
|
| 22 |
+
Overall Improvement: +20.1%
|
| 23 |
+
========================================
|
benchmark_report_v2.txt
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
================================================================================
|
| 2 |
+
COUNSELOR MODEL BENCHMARK REPORT
|
| 3 |
+
================================================================================
|
| 4 |
+
|
| 5 |
+
PERFORMANCE COMPARISON:
|
| 6 |
+
----------------------------------------
|
| 7 |
+
|
| 8 |
+
LENGTH_SCORE:
|
| 9 |
+
Base Model: 0.785 (±0.146)
|
| 10 |
+
Fine-tuned Model: 0.822 (±0.189)
|
| 11 |
+
Improvement: +4.8%
|
| 12 |
+
|
| 13 |
+
QUESTION_SCORE:
|
| 14 |
+
Base Model: 0.680 (±0.466)
|
| 15 |
+
Fine-tuned Model: 0.870 (±0.336)
|
| 16 |
+
Improvement: +27.9%
|
| 17 |
+
|
| 18 |
+
========================================
|
| 19 |
+
OVERALL PERFORMANCE:
|
| 20 |
+
Base Model: 0.732
|
| 21 |
+
Fine-tuned Model: 0.846
|
| 22 |
+
Overall Improvement: +15.5%
|
| 23 |
+
========================================
|
benchmark_report_wo_merging.txt
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
================================================================================
|
| 2 |
+
COUNSELOR MODEL BENCHMARK REPORT
|
| 3 |
+
================================================================================
|
| 4 |
+
|
| 5 |
+
PERFORMANCE COMPARISON:
|
| 6 |
+
----------------------------------------
|
| 7 |
+
|
| 8 |
+
LENGTH_SCORE:
|
| 9 |
+
Base Model: 0.807 (±0.154)
|
| 10 |
+
Fine-tuned Model: 0.808 (±0.202)
|
| 11 |
+
Improvement: +0.1%
|
| 12 |
+
|
| 13 |
+
QUESTION_SCORE:
|
| 14 |
+
Base Model: 0.670 (±0.470)
|
| 15 |
+
Fine-tuned Model: 0.910 (±0.286)
|
| 16 |
+
Improvement: +35.8%
|
| 17 |
+
|
| 18 |
+
SUPPORT_SCORE:
|
| 19 |
+
Base Model: 0.236 (±0.186)
|
| 20 |
+
Fine-tuned Model: 0.082 (±0.120)
|
| 21 |
+
Improvement: -65.3%
|
| 22 |
+
|
| 23 |
+
EMPATHY_SCORE:
|
| 24 |
+
Base Model: 0.267 (±0.099)
|
| 25 |
+
Fine-tuned Model: 0.141 (±0.100)
|
| 26 |
+
Improvement: -47.2%
|
| 27 |
+
|
| 28 |
+
========================================
|
| 29 |
+
OVERALL PERFORMANCE:
|
| 30 |
+
Base Model: 0.495
|
| 31 |
+
Fine-tuned Model: 0.485
|
| 32 |
+
Overall Improvement: -2.0%
|
| 33 |
+
========================================
|
benchmark_results.png
ADDED
|
Git LFS Details
|
benchmark_v1.py
ADDED
|
@@ -0,0 +1,803 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Comprehensive Japanese Counseling Model Benchmark Script
|
| 3 |
+
Based on KokoroChat paper evaluation methodology
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 8 |
+
import numpy as np
|
| 9 |
+
from typing import List, Dict, Tuple, Optional, Any
|
| 10 |
+
import json
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import os
|
| 13 |
+
import gc
|
| 14 |
+
import warnings
|
| 15 |
+
from datetime import datetime
|
| 16 |
+
import pandas as pd
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
import seaborn as sns
|
| 19 |
+
from collections import defaultdict
|
| 20 |
+
import MeCab
|
| 21 |
+
from rouge_score import rouge_scorer
|
| 22 |
+
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu, SmoothingFunction
|
| 23 |
+
import sacrebleu
|
| 24 |
+
from bert_score import score as bert_score
|
| 25 |
+
import re
|
| 26 |
+
import statistics
|
| 27 |
+
|
| 28 |
+
warnings.filterwarnings('ignore')
|
| 29 |
+
|
| 30 |
+
# Set style for better visualizations
|
| 31 |
+
plt.style.use('seaborn-v0_8-darkgrid')
|
| 32 |
+
sns.set_palette("husl")
|
| 33 |
+
|
| 34 |
+
class JapaneseCounselingBenchmark:
|
| 35 |
+
"""
|
| 36 |
+
Comprehensive benchmark suite for Japanese counseling models
|
| 37 |
+
Following KokoroChat paper evaluation methodology
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self,
|
| 41 |
+
base_model_name: str = "LiquidAI/LFM2-1.2B",
|
| 42 |
+
finetuned_model_path: str = "./merged_counselor_model",
|
| 43 |
+
test_data_path: str = "./processed_data_score70/test.jsonl",
|
| 44 |
+
device: str = None):
|
| 45 |
+
"""
|
| 46 |
+
Initialize Japanese counseling benchmark
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
base_model_name: Name/path of base model
|
| 50 |
+
finetuned_model_path: Path to fine-tuned merged model
|
| 51 |
+
test_data_path: Path to test dataset
|
| 52 |
+
device: Device to run on (cuda/cpu)
|
| 53 |
+
"""
|
| 54 |
+
self.base_model_name = base_model_name
|
| 55 |
+
self.finetuned_model_path = finetuned_model_path
|
| 56 |
+
self.test_data_path = test_data_path
|
| 57 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 58 |
+
|
| 59 |
+
print("="*80)
|
| 60 |
+
print("🎌 Japanese Counseling Model Benchmark Suite")
|
| 61 |
+
print("="*80)
|
| 62 |
+
print(f"📍 Device: {self.device}")
|
| 63 |
+
if self.device == "cuda":
|
| 64 |
+
print(f" GPU: {torch.cuda.get_device_name(0)}")
|
| 65 |
+
print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
| 66 |
+
|
| 67 |
+
# Initialize MeCab for Japanese tokenization
|
| 68 |
+
try:
|
| 69 |
+
self.mecab = MeCab.Tagger("-Owakati") # Wakati-gaki mode for word segmentation
|
| 70 |
+
print("✅ MeCab initialized for Japanese tokenization")
|
| 71 |
+
except:
|
| 72 |
+
print("⚠️ MeCab not available. Install with: apt-get install mecab libmecab-dev mecab-ipadic-utf8")
|
| 73 |
+
print(" and: pip install mecab-python3")
|
| 74 |
+
print(" Using fallback character-level tokenization")
|
| 75 |
+
self.mecab = None
|
| 76 |
+
|
| 77 |
+
# Initialize ROUGE scorer (without lang parameter)
|
| 78 |
+
self.rouge_scorer = rouge_scorer.RougeScorer(
|
| 79 |
+
['rouge1', 'rouge2', 'rougeL'],
|
| 80 |
+
use_stemmer=False # Don't use stemming for Japanese
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Smoothing function for BLEU
|
| 84 |
+
self.smoothing = SmoothingFunction().method1
|
| 85 |
+
|
| 86 |
+
# Results storage
|
| 87 |
+
self.results = {}
|
| 88 |
+
self.detailed_results = []
|
| 89 |
+
|
| 90 |
+
def tokenize_japanese(self, text: str) -> List[str]:
|
| 91 |
+
"""
|
| 92 |
+
Tokenize Japanese text using MeCab or fallback method
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
text: Japanese text to tokenize
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
List of tokens
|
| 99 |
+
"""
|
| 100 |
+
if self.mecab:
|
| 101 |
+
try:
|
| 102 |
+
# Use MeCab for proper Japanese tokenization
|
| 103 |
+
tokens = self.mecab.parse(text).strip().split()
|
| 104 |
+
return tokens if tokens else list(text)
|
| 105 |
+
except:
|
| 106 |
+
# Fallback if MeCab fails
|
| 107 |
+
pass
|
| 108 |
+
|
| 109 |
+
# Fallback to character-level tokenization
|
| 110 |
+
# Remove punctuation and split
|
| 111 |
+
text = re.sub(r'[。、!?\n\s]', ' ', text)
|
| 112 |
+
# Split by spaces and then into characters
|
| 113 |
+
words = text.split()
|
| 114 |
+
if words:
|
| 115 |
+
# Try to keep some word boundaries
|
| 116 |
+
tokens = []
|
| 117 |
+
for word in words:
|
| 118 |
+
if len(word) <= 4: # Keep short words together
|
| 119 |
+
tokens.append(word)
|
| 120 |
+
else: # Split longer words into characters
|
| 121 |
+
tokens.extend(list(word))
|
| 122 |
+
return tokens
|
| 123 |
+
else:
|
| 124 |
+
# Pure character-level tokenization
|
| 125 |
+
return list(text.replace(' ', ''))
|
| 126 |
+
|
| 127 |
+
def load_test_data(self, max_samples: Optional[int] = None) -> List[Dict]:
|
| 128 |
+
"""
|
| 129 |
+
Load test dataset
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
max_samples: Maximum number of samples to load
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
List of test examples
|
| 136 |
+
"""
|
| 137 |
+
print(f"\n📚 Loading test data from {self.test_data_path}")
|
| 138 |
+
|
| 139 |
+
test_data = []
|
| 140 |
+
|
| 141 |
+
if not os.path.exists(self.test_data_path):
|
| 142 |
+
print(f"❌ Test data not found at {self.test_data_path}")
|
| 143 |
+
print(" Creating synthetic test data for demonstration...")
|
| 144 |
+
return self.create_synthetic_test_data()
|
| 145 |
+
|
| 146 |
+
with open(self.test_data_path, 'r', encoding='utf-8') as f:
|
| 147 |
+
for i, line in enumerate(f):
|
| 148 |
+
if max_samples and i >= max_samples:
|
| 149 |
+
break
|
| 150 |
+
try:
|
| 151 |
+
data = json.loads(line)
|
| 152 |
+
|
| 153 |
+
# Parse the text field to extract input and response
|
| 154 |
+
text = data.get('text', '')
|
| 155 |
+
|
| 156 |
+
# Extract input and reference response
|
| 157 |
+
if "### Input:" in text and "### Response:" in text:
|
| 158 |
+
parts = text.split("### Input:")
|
| 159 |
+
if len(parts) > 1:
|
| 160 |
+
input_part = parts[1].split("### Response:")[0].strip()
|
| 161 |
+
response_part = text.split("### Response:")[1].strip()
|
| 162 |
+
|
| 163 |
+
test_data.append({
|
| 164 |
+
'input': input_part,
|
| 165 |
+
'reference': response_part,
|
| 166 |
+
'score': data.get('score', 0),
|
| 167 |
+
'topic': data.get('topic', 'Unknown')
|
| 168 |
+
})
|
| 169 |
+
except Exception as e:
|
| 170 |
+
print(f"⚠️ Error parsing line {i}: {e}")
|
| 171 |
+
continue
|
| 172 |
+
|
| 173 |
+
if not test_data:
|
| 174 |
+
print("⚠️ No valid test data found. Creating synthetic data...")
|
| 175 |
+
return self.create_synthetic_test_data()
|
| 176 |
+
|
| 177 |
+
print(f"✅ Loaded {len(test_data)} test examples")
|
| 178 |
+
return test_data
|
| 179 |
+
|
| 180 |
+
def create_synthetic_test_data(self) -> List[Dict]:
|
| 181 |
+
"""Create synthetic test data for demonstration"""
|
| 182 |
+
synthetic_data = [
|
| 183 |
+
{
|
| 184 |
+
'input': '最近ストレスを感じています。',
|
| 185 |
+
'reference': 'ストレスを感じているのですね。それは大変つらいことだと思います。どのような状況でストレスを感じることが多いですか?',
|
| 186 |
+
'score': 75,
|
| 187 |
+
'topic': 'ストレス'
|
| 188 |
+
},
|
| 189 |
+
{
|
| 190 |
+
'input': '仕事がうまくいかなくて悩んでいます。',
|
| 191 |
+
'reference': '仕事でお悩みなのですね。うまくいかないと感じると、本当に辛いですよね。具体的にどのような点で困難を感じていらっしゃいますか?',
|
| 192 |
+
'score': 78,
|
| 193 |
+
'topic': '仕事'
|
| 194 |
+
},
|
| 195 |
+
{
|
| 196 |
+
'input': '人間関係で困っています。',
|
| 197 |
+
'reference': '人間関係の悩みは本当に心が疲れますよね。お気持ちお察しします。どのような関係性でお困りでしょうか?',
|
| 198 |
+
'score': 80,
|
| 199 |
+
'topic': '人間関係'
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
'input': '将来が不安です。',
|
| 203 |
+
'reference': '将来への不安を抱えていらっしゃるのですね。先が見えない不安は、とても重く感じられることと思います。',
|
| 204 |
+
'score': 72,
|
| 205 |
+
'topic': '不安'
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
'input': '自信が持てません。',
|
| 209 |
+
'reference': '自信が持てないというお気持ち、よくわかります。多くの方が同じような悩みを抱えています。',
|
| 210 |
+
'score': 76,
|
| 211 |
+
'topic': '自信'
|
| 212 |
+
}
|
| 213 |
+
]
|
| 214 |
+
return synthetic_data
|
| 215 |
+
|
| 216 |
+
def load_models(self):
|
| 217 |
+
"""Load base and fine-tuned models"""
|
| 218 |
+
print("\n🤖 Loading models for benchmarking...")
|
| 219 |
+
|
| 220 |
+
# Load tokenizer
|
| 221 |
+
print(" Loading tokenizer...")
|
| 222 |
+
try:
|
| 223 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name)
|
| 224 |
+
except:
|
| 225 |
+
print(" Using GPT2 tokenizer as fallback...")
|
| 226 |
+
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 227 |
+
|
| 228 |
+
if self.tokenizer.pad_token is None:
|
| 229 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 230 |
+
|
| 231 |
+
# Load base model
|
| 232 |
+
print(" Loading base model...")
|
| 233 |
+
try:
|
| 234 |
+
self.base_model = AutoModelForCausalLM.from_pretrained(
|
| 235 |
+
self.base_model_name,
|
| 236 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
| 237 |
+
device_map="auto" if self.device == "cuda" else None,
|
| 238 |
+
trust_remote_code=True,
|
| 239 |
+
low_cpu_mem_usage=True
|
| 240 |
+
)
|
| 241 |
+
except Exception as e:
|
| 242 |
+
print(f" ⚠️ Could not load base model {self.base_model_name}: {e}")
|
| 243 |
+
print(" Using GPT2 as fallback base model...")
|
| 244 |
+
self.base_model = AutoModelForCausalLM.from_pretrained(
|
| 245 |
+
"gpt2",
|
| 246 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
| 247 |
+
device_map="auto" if self.device == "cuda" else None
|
| 248 |
+
)
|
| 249 |
+
self.base_model.eval()
|
| 250 |
+
|
| 251 |
+
# Load fine-tuned model
|
| 252 |
+
print(f" Loading fine-tuned model from {self.finetuned_model_path}...")
|
| 253 |
+
|
| 254 |
+
# Check if model exists
|
| 255 |
+
if not os.path.exists(self.finetuned_model_path):
|
| 256 |
+
print(f" ⚠️ Fine-tuned model not found at {self.finetuned_model_path}")
|
| 257 |
+
print(" Using base model for both comparisons (for demonstration)")
|
| 258 |
+
self.finetuned_model = self.base_model
|
| 259 |
+
else:
|
| 260 |
+
try:
|
| 261 |
+
self.finetuned_model = AutoModelForCausalLM.from_pretrained(
|
| 262 |
+
self.finetuned_model_path,
|
| 263 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
| 264 |
+
device_map="auto" if self.device == "cuda" else None,
|
| 265 |
+
trust_remote_code=True,
|
| 266 |
+
low_cpu_mem_usage=True,
|
| 267 |
+
local_files_only=True
|
| 268 |
+
)
|
| 269 |
+
self.finetuned_model.eval()
|
| 270 |
+
except Exception as e:
|
| 271 |
+
print(f" ⚠️ Error loading fine-tuned model: {e}")
|
| 272 |
+
print(" Using base model for comparison")
|
| 273 |
+
self.finetuned_model = self.base_model
|
| 274 |
+
|
| 275 |
+
print("✅ Models loaded successfully!")
|
| 276 |
+
|
| 277 |
+
def generate_response(self, model, prompt: str, max_length: int = 150) -> str:
|
| 278 |
+
"""
|
| 279 |
+
Generate response from model
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
model: Model to use for generation
|
| 283 |
+
prompt: Input prompt
|
| 284 |
+
max_length: Maximum length of generated response
|
| 285 |
+
|
| 286 |
+
Returns:
|
| 287 |
+
Generated response text
|
| 288 |
+
"""
|
| 289 |
+
# Format prompt for counseling
|
| 290 |
+
formatted_prompt = f"""### Instruction:
|
| 291 |
+
あなたは思いやりのある心理カウンセラーです。
|
| 292 |
+
クライアントの感情を理解し、共感的で支援的な応答を提供してください。
|
| 293 |
+
|
| 294 |
+
### Input:
|
| 295 |
+
{prompt}
|
| 296 |
+
|
| 297 |
+
### Response:
|
| 298 |
+
"""
|
| 299 |
+
|
| 300 |
+
# Tokenize input
|
| 301 |
+
inputs = self.tokenizer(
|
| 302 |
+
formatted_prompt,
|
| 303 |
+
return_tensors="pt",
|
| 304 |
+
truncation=True,
|
| 305 |
+
max_length=512
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
if self.device == "cuda":
|
| 309 |
+
inputs = {k: v.cuda() for k, v in inputs.items()}
|
| 310 |
+
|
| 311 |
+
# Generate response
|
| 312 |
+
try:
|
| 313 |
+
with torch.no_grad():
|
| 314 |
+
outputs = model.generate(
|
| 315 |
+
**inputs,
|
| 316 |
+
max_new_tokens=max_length,
|
| 317 |
+
temperature=0.7,
|
| 318 |
+
do_sample=True,
|
| 319 |
+
top_p=0.9,
|
| 320 |
+
repetition_penalty=1.1,
|
| 321 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 322 |
+
eos_token_id=self.tokenizer.eos_token_id
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# Decode response
|
| 326 |
+
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 327 |
+
|
| 328 |
+
# Extract only the generated response
|
| 329 |
+
if "### Response:" in full_response:
|
| 330 |
+
response = full_response.split("### Response:")[-1].strip()
|
| 331 |
+
else:
|
| 332 |
+
response = full_response[len(formatted_prompt):].strip()
|
| 333 |
+
except Exception as e:
|
| 334 |
+
print(f" ⚠️ Generation error: {e}")
|
| 335 |
+
response = "申し訳ございません。応答を生成できませんでした。"
|
| 336 |
+
|
| 337 |
+
return response
|
| 338 |
+
|
| 339 |
+
def calculate_bleu_scores(self, reference: str, hypothesis: str) -> Dict[str, float]:
|
| 340 |
+
"""
|
| 341 |
+
Calculate BLEU scores using Japanese tokenization
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
reference: Reference text
|
| 345 |
+
hypothesis: Generated text
|
| 346 |
+
|
| 347 |
+
Returns:
|
| 348 |
+
Dictionary of BLEU scores
|
| 349 |
+
"""
|
| 350 |
+
# Tokenize using MeCab or fallback
|
| 351 |
+
ref_tokens = self.tokenize_japanese(reference)
|
| 352 |
+
hyp_tokens = self.tokenize_japanese(hypothesis)
|
| 353 |
+
|
| 354 |
+
# Ensure we have tokens
|
| 355 |
+
if not ref_tokens:
|
| 356 |
+
ref_tokens = ['empty']
|
| 357 |
+
if not hyp_tokens:
|
| 358 |
+
hyp_tokens = ['empty']
|
| 359 |
+
|
| 360 |
+
# Calculate BLEU scores
|
| 361 |
+
scores = {}
|
| 362 |
+
|
| 363 |
+
try:
|
| 364 |
+
# BLEU-1 through BLEU-4
|
| 365 |
+
for n in range(1, 5):
|
| 366 |
+
weights = tuple([1/n] * n + [0] * (4-n))
|
| 367 |
+
score = sentence_bleu(
|
| 368 |
+
[ref_tokens],
|
| 369 |
+
hyp_tokens,
|
| 370 |
+
weights=weights,
|
| 371 |
+
smoothing_function=self.smoothing
|
| 372 |
+
)
|
| 373 |
+
scores[f'BLEU-{n}'] = score
|
| 374 |
+
except Exception as e:
|
| 375 |
+
print(f" ⚠️ BLEU calculation error: {e}")
|
| 376 |
+
for n in range(1, 5):
|
| 377 |
+
scores[f'BLEU-{n}'] = 0.0
|
| 378 |
+
|
| 379 |
+
return scores
|
| 380 |
+
|
| 381 |
+
def calculate_rouge_scores(self, reference: str, hypothesis: str) -> Dict[str, float]:
|
| 382 |
+
"""
|
| 383 |
+
Calculate ROUGE scores for Japanese text
|
| 384 |
+
|
| 385 |
+
Args:
|
| 386 |
+
reference: Reference text
|
| 387 |
+
hypothesis: Generated text
|
| 388 |
+
|
| 389 |
+
Returns:
|
| 390 |
+
Dictionary of ROUGE scores
|
| 391 |
+
"""
|
| 392 |
+
try:
|
| 393 |
+
# For Japanese, we need to add spaces between tokens for ROUGE scorer
|
| 394 |
+
if self.mecab:
|
| 395 |
+
ref_tokenized = ' '.join(self.tokenize_japanese(reference))
|
| 396 |
+
hyp_tokenized = ' '.join(self.tokenize_japanese(hypothesis))
|
| 397 |
+
else:
|
| 398 |
+
# Character-level with spaces
|
| 399 |
+
ref_tokenized = ' '.join(list(reference))
|
| 400 |
+
hyp_tokenized = ' '.join(list(hypothesis))
|
| 401 |
+
|
| 402 |
+
# Calculate ROUGE scores
|
| 403 |
+
scores = self.rouge_scorer.score(ref_tokenized, hyp_tokenized)
|
| 404 |
+
|
| 405 |
+
return {
|
| 406 |
+
'ROUGE-1': scores['rouge1'].fmeasure,
|
| 407 |
+
'ROUGE-2': scores['rouge2'].fmeasure,
|
| 408 |
+
'ROUGE-L': scores['rougeL'].fmeasure
|
| 409 |
+
}
|
| 410 |
+
except Exception as e:
|
| 411 |
+
print(f" ⚠️ ROUGE calculation error: {e}")
|
| 412 |
+
return {
|
| 413 |
+
'ROUGE-1': 0.0,
|
| 414 |
+
'ROUGE-2': 0.0,
|
| 415 |
+
'ROUGE-L': 0.0
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
def calculate_bert_score(self, references: List[str], hypotheses: List[str]) -> Dict[str, float]:
|
| 419 |
+
"""
|
| 420 |
+
Calculate BERTScore for semantic similarity
|
| 421 |
+
|
| 422 |
+
Args:
|
| 423 |
+
references: List of reference texts
|
| 424 |
+
hypotheses: List of generated texts
|
| 425 |
+
|
| 426 |
+
Returns:
|
| 427 |
+
Dictionary with BERTScore metrics
|
| 428 |
+
"""
|
| 429 |
+
try:
|
| 430 |
+
# Calculate BERTScore
|
| 431 |
+
P, R, F1 = bert_score(
|
| 432 |
+
hypotheses,
|
| 433 |
+
references,
|
| 434 |
+
lang='ja',
|
| 435 |
+
verbose=False,
|
| 436 |
+
device=self.device
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
return {
|
| 440 |
+
'BERTScore_P': float(P.mean()),
|
| 441 |
+
'BERTScore_R': float(R.mean()),
|
| 442 |
+
'BERTScore_F1': float(F1.mean())
|
| 443 |
+
}
|
| 444 |
+
except Exception as e:
|
| 445 |
+
print(f" ⚠️ BERTScore calculation failed: {e}")
|
| 446 |
+
print(" Install with: pip install bert-score")
|
| 447 |
+
return {
|
| 448 |
+
'BERTScore_P': 0.0,
|
| 449 |
+
'BERTScore_R': 0.0,
|
| 450 |
+
'BERTScore_F1': 0.0
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
def evaluate_counseling_quality(self, response: str) -> Dict[str, float]:
|
| 454 |
+
"""
|
| 455 |
+
Evaluate counseling-specific qualities
|
| 456 |
+
Based on KokoroChat paper evaluation criteria
|
| 457 |
+
|
| 458 |
+
Args:
|
| 459 |
+
response: Generated counseling response
|
| 460 |
+
|
| 461 |
+
Returns:
|
| 462 |
+
Dictionary of counseling quality scores
|
| 463 |
+
"""
|
| 464 |
+
scores = {}
|
| 465 |
+
|
| 466 |
+
# 1. Empathy Score (共感度)
|
| 467 |
+
empathy_keywords = [
|
| 468 |
+
'わかります', '理解', '共感', 'お気持ち', 'つらい',
|
| 469 |
+
'大変', 'お察し', 'そうですね', 'なるほど', '感じ'
|
| 470 |
+
]
|
| 471 |
+
empathy_score = sum(1 for keyword in empathy_keywords if keyword in response)
|
| 472 |
+
scores['empathy'] = min(empathy_score / 5.0, 1.0) # Normalize to 0-1
|
| 473 |
+
|
| 474 |
+
# 2. Support Score (支援度)
|
| 475 |
+
support_keywords = [
|
| 476 |
+
'サポート', '支援', '助け', '一緒に', '協力',
|
| 477 |
+
'応援', 'お手伝い', '力になり', '相談', '話を聞'
|
| 478 |
+
]
|
| 479 |
+
support_score = sum(1 for keyword in support_keywords if keyword in response)
|
| 480 |
+
scores['support'] = min(support_score / 5.0, 1.0)
|
| 481 |
+
|
| 482 |
+
# 3. Active Listening (傾聴)
|
| 483 |
+
listening_indicators = ['?', 'でしょうか', 'ですか', 'いかがですか', 'どのような']
|
| 484 |
+
scores['active_listening'] = 1.0 if any(ind in response for ind in listening_indicators) else 0.3
|
| 485 |
+
|
| 486 |
+
# 4. Positivity (前向きさ)
|
| 487 |
+
positive_keywords = ['大丈夫', '良い', '素晴らしい', '頑張', '希望', '改善', '解決']
|
| 488 |
+
positive_score = sum(1 for keyword in positive_keywords if keyword in response)
|
| 489 |
+
scores['positivity'] = min(positive_score / 3.0, 1.0)
|
| 490 |
+
|
| 491 |
+
# 5. Response Appropriateness (応答の適切さ)
|
| 492 |
+
response_length = len(response)
|
| 493 |
+
if 30 <= response_length <= 200:
|
| 494 |
+
scores['appropriateness'] = 1.0
|
| 495 |
+
elif 20 <= response_length < 30 or 200 < response_length <= 300:
|
| 496 |
+
scores['appropriateness'] = 0.7
|
| 497 |
+
else:
|
| 498 |
+
scores['appropriateness'] = 0.4
|
| 499 |
+
|
| 500 |
+
return scores
|
| 501 |
+
|
| 502 |
+
def run_comprehensive_benchmark(self, num_samples: Optional[int] = None):
|
| 503 |
+
"""
|
| 504 |
+
Run comprehensive benchmark evaluation
|
| 505 |
+
|
| 506 |
+
Args:
|
| 507 |
+
num_samples: Number of samples to evaluate (None for all)
|
| 508 |
+
"""
|
| 509 |
+
print("\n" + "="*80)
|
| 510 |
+
print("🚀 Running Comprehensive Benchmark")
|
| 511 |
+
print("="*80)
|
| 512 |
+
|
| 513 |
+
# Load test data
|
| 514 |
+
test_data = self.load_test_data(max_samples=num_samples)
|
| 515 |
+
|
| 516 |
+
if not test_data:
|
| 517 |
+
raise ValueError("No test data available!")
|
| 518 |
+
|
| 519 |
+
# Initialize metric collectors
|
| 520 |
+
base_metrics = defaultdict(list)
|
| 521 |
+
finetuned_metrics = defaultdict(list)
|
| 522 |
+
|
| 523 |
+
# Collect all responses for BERTScore
|
| 524 |
+
all_references = []
|
| 525 |
+
all_base_responses = []
|
| 526 |
+
all_finetuned_responses = []
|
| 527 |
+
|
| 528 |
+
print(f"\n📊 Evaluating {len(test_data)} test examples...")
|
| 529 |
+
print("-"*80)
|
| 530 |
+
|
| 531 |
+
# Process each test example
|
| 532 |
+
for i, example in enumerate(tqdm(test_data, desc="Evaluating")):
|
| 533 |
+
input_text = example['input']
|
| 534 |
+
reference = example['reference']
|
| 535 |
+
|
| 536 |
+
# Generate responses
|
| 537 |
+
base_response = self.generate_response(self.base_model, input_text)
|
| 538 |
+
finetuned_response = self.generate_response(self.finetuned_model, input_text)
|
| 539 |
+
|
| 540 |
+
# Collect for BERTScore
|
| 541 |
+
all_references.append(reference)
|
| 542 |
+
all_base_responses.append(base_response)
|
| 543 |
+
all_finetuned_responses.append(finetuned_response)
|
| 544 |
+
|
| 545 |
+
# Calculate BLEU scores
|
| 546 |
+
base_bleu = self.calculate_bleu_scores(reference, base_response)
|
| 547 |
+
finetuned_bleu = self.calculate_bleu_scores(reference, finetuned_response)
|
| 548 |
+
|
| 549 |
+
for key, value in base_bleu.items():
|
| 550 |
+
base_metrics[key].append(value)
|
| 551 |
+
for key, value in finetuned_bleu.items():
|
| 552 |
+
finetuned_metrics[key].append(value)
|
| 553 |
+
|
| 554 |
+
# Calculate ROUGE scores
|
| 555 |
+
base_rouge = self.calculate_rouge_scores(reference, base_response)
|
| 556 |
+
finetuned_rouge = self.calculate_rouge_scores(reference, finetuned_response)
|
| 557 |
+
|
| 558 |
+
for key, value in base_rouge.items():
|
| 559 |
+
base_metrics[key].append(value)
|
| 560 |
+
for key, value in finetuned_rouge.items():
|
| 561 |
+
finetuned_metrics[key].append(value)
|
| 562 |
+
|
| 563 |
+
# Evaluate counseling quality
|
| 564 |
+
base_quality = self.evaluate_counseling_quality(base_response)
|
| 565 |
+
finetuned_quality = self.evaluate_counseling_quality(finetuned_response)
|
| 566 |
+
|
| 567 |
+
for key, value in base_quality.items():
|
| 568 |
+
base_metrics[f'quality_{key}'].append(value)
|
| 569 |
+
for key, value in finetuned_quality.items():
|
| 570 |
+
finetuned_metrics[f'quality_{key}'].append(value)
|
| 571 |
+
|
| 572 |
+
# Store detailed results
|
| 573 |
+
self.detailed_results.append({
|
| 574 |
+
'input': input_text,
|
| 575 |
+
'reference': reference,
|
| 576 |
+
'base_response': base_response,
|
| 577 |
+
'finetuned_response': finetuned_response,
|
| 578 |
+
'base_metrics': {**base_bleu, **base_rouge, **base_quality},
|
| 579 |
+
'finetuned_metrics': {**finetuned_bleu, **finetuned_rouge, **finetuned_quality}
|
| 580 |
+
})
|
| 581 |
+
|
| 582 |
+
# Show sample outputs
|
| 583 |
+
if i < 3:
|
| 584 |
+
print(f"\n📝 Example {i+1}:")
|
| 585 |
+
print(f"Input: {input_text[:100]}...")
|
| 586 |
+
print(f"Base BLEU-4: {base_bleu['BLEU-4']:.3f}, Fine-tuned BLEU-4: {finetuned_bleu['BLEU-4']:.3f}")
|
| 587 |
+
|
| 588 |
+
# Calculate BERTScore for all examples
|
| 589 |
+
if len(all_references) > 0:
|
| 590 |
+
print("\n🧮 Calculating BERTScore...")
|
| 591 |
+
base_bert = self.calculate_bert_score(all_references, all_base_responses)
|
| 592 |
+
finetuned_bert = self.calculate_bert_score(all_references, all_finetuned_responses)
|
| 593 |
+
|
| 594 |
+
for key, value in base_bert.items():
|
| 595 |
+
base_metrics[key] = [value] * len(test_data)
|
| 596 |
+
for key, value in finetuned_bert.items():
|
| 597 |
+
finetuned_metrics[key] = [value] * len(test_data)
|
| 598 |
+
|
| 599 |
+
# Calculate aggregate statistics
|
| 600 |
+
self.results = self.calculate_aggregate_statistics(base_metrics, finetuned_metrics)
|
| 601 |
+
|
| 602 |
+
# Print results
|
| 603 |
+
self.print_results()
|
| 604 |
+
|
| 605 |
+
return self.results
|
| 606 |
+
|
| 607 |
+
def calculate_aggregate_statistics(self, base_metrics: Dict, finetuned_metrics: Dict) -> Dict:
|
| 608 |
+
"""
|
| 609 |
+
Calculate aggregate statistics from collected metrics
|
| 610 |
+
|
| 611 |
+
Args:
|
| 612 |
+
base_metrics: Base model metrics
|
| 613 |
+
finetuned_metrics: Fine-tuned model metrics
|
| 614 |
+
|
| 615 |
+
Returns:
|
| 616 |
+
Dictionary of aggregate results
|
| 617 |
+
"""
|
| 618 |
+
results = {
|
| 619 |
+
'metrics': {},
|
| 620 |
+
'improvements': {},
|
| 621 |
+
'summary': {}
|
| 622 |
+
}
|
| 623 |
+
|
| 624 |
+
# Calculate statistics for each metric
|
| 625 |
+
all_metric_names = set(base_metrics.keys()) | set(finetuned_metrics.keys())
|
| 626 |
+
|
| 627 |
+
for metric in all_metric_names:
|
| 628 |
+
base_values = base_metrics.get(metric, [0])
|
| 629 |
+
finetuned_values = finetuned_metrics.get(metric, [0])
|
| 630 |
+
|
| 631 |
+
results['metrics'][metric] = {
|
| 632 |
+
'base': {
|
| 633 |
+
'mean': float(np.mean(base_values)),
|
| 634 |
+
'std': float(np.std(base_values)),
|
| 635 |
+
'min': float(np.min(base_values)),
|
| 636 |
+
'max': float(np.max(base_values))
|
| 637 |
+
},
|
| 638 |
+
'finetuned': {
|
| 639 |
+
'mean': float(np.mean(finetuned_values)),
|
| 640 |
+
'std': float(np.std(finetuned_values)),
|
| 641 |
+
'min': float(np.min(finetuned_values)),
|
| 642 |
+
'max': float(np.max(finetuned_values))
|
| 643 |
+
}
|
| 644 |
+
}
|
| 645 |
+
|
| 646 |
+
# Calculate improvement
|
| 647 |
+
base_mean = np.mean(base_values)
|
| 648 |
+
finetuned_mean = np.mean(finetuned_values)
|
| 649 |
+
if base_mean > 0:
|
| 650 |
+
improvement = ((finetuned_mean - base_mean) / base_mean) * 100
|
| 651 |
+
else:
|
| 652 |
+
improvement = 0
|
| 653 |
+
|
| 654 |
+
results['improvements'][metric] = improvement
|
| 655 |
+
|
| 656 |
+
# Calculate summary statistics
|
| 657 |
+
bleu_metrics = [m for m in results['metrics'] if 'BLEU' in m]
|
| 658 |
+
rouge_metrics = [m for m in results['metrics'] if 'ROUGE' in m]
|
| 659 |
+
quality_metrics = [m for m in results['metrics'] if 'quality' in m]
|
| 660 |
+
|
| 661 |
+
# Average improvements
|
| 662 |
+
results['summary'] = {
|
| 663 |
+
'bleu_avg_improvement': np.mean([results['improvements'][m] for m in bleu_metrics]) if bleu_metrics else 0,
|
| 664 |
+
'rouge_avg_improvement': np.mean([results['improvements'][m] for m in rouge_metrics]) if rouge_metrics else 0,
|
| 665 |
+
'quality_avg_improvement': np.mean([results['improvements'][m] for m in quality_metrics]) if quality_metrics else 0,
|
| 666 |
+
'overall_improvement': np.mean(list(results['improvements'].values())) if results['improvements'] else 0
|
| 667 |
+
}
|
| 668 |
+
|
| 669 |
+
return results
|
| 670 |
+
|
| 671 |
+
def print_results(self):
|
| 672 |
+
"""Print formatted benchmark results"""
|
| 673 |
+
print("\n" + "="*80)
|
| 674 |
+
print("📊 BENCHMARK RESULTS")
|
| 675 |
+
print("="*80)
|
| 676 |
+
|
| 677 |
+
# Group metrics by category
|
| 678 |
+
bleu_metrics = sorted([m for m in self.results['metrics'] if 'BLEU' in m])
|
| 679 |
+
rouge_metrics = sorted([m for m in self.results['metrics'] if 'ROUGE' in m])
|
| 680 |
+
bert_metrics = sorted([m for m in self.results['metrics'] if 'BERT' in m])
|
| 681 |
+
quality_metrics = sorted([m for m in self.results['metrics'] if 'quality' in m])
|
| 682 |
+
|
| 683 |
+
# Print BLEU scores
|
| 684 |
+
if bleu_metrics:
|
| 685 |
+
print("\n📘 BLEU Scores:")
|
| 686 |
+
print("-"*60)
|
| 687 |
+
print(f"{'Metric':<15} {'Base Model':<20} {'Fine-tuned':<20} {'Improvement':<15}")
|
| 688 |
+
print("-"*60)
|
| 689 |
+
for metric in bleu_metrics:
|
| 690 |
+
base = self.results['metrics'][metric]['base']['mean']
|
| 691 |
+
finetuned = self.results['metrics'][metric]['finetuned']['mean']
|
| 692 |
+
improvement = self.results['improvements'][metric]
|
| 693 |
+
print(f"{metric:<15} {base:.4f}±{self.results['metrics'][metric]['base']['std']:.3f} "
|
| 694 |
+
f"{finetuned:.4f}±{self.results['metrics'][metric]['finetuned']['std']:.3f} "
|
| 695 |
+
f"{improvement:+.1f}%")
|
| 696 |
+
|
| 697 |
+
# Print ROUGE scores
|
| 698 |
+
if rouge_metrics:
|
| 699 |
+
print("\n📕 ROUGE Scores:")
|
| 700 |
+
print("-"*60)
|
| 701 |
+
for metric in rouge_metrics:
|
| 702 |
+
base = self.results['metrics'][metric]['base']['mean']
|
| 703 |
+
finetuned = self.results['metrics'][metric]['finetuned']['mean']
|
| 704 |
+
improvement = self.results['improvements'][metric]
|
| 705 |
+
print(f"{metric:<15} {base:.4f}±{self.results['metrics'][metric]['base']['std']:.3f} "
|
| 706 |
+
f"{finetuned:.4f}±{self.results['metrics'][metric]['finetuned']['std']:.3f} "
|
| 707 |
+
f"{improvement:+.1f}%")
|
| 708 |
+
|
| 709 |
+
# Print BERTScore
|
| 710 |
+
if bert_metrics:
|
| 711 |
+
print("\n📗 BERTScore:")
|
| 712 |
+
print("-"*60)
|
| 713 |
+
for metric in bert_metrics:
|
| 714 |
+
base = self.results['metrics'][metric]['base']['mean']
|
| 715 |
+
finetuned = self.results['metrics'][metric]['finetuned']['mean']
|
| 716 |
+
improvement = self.results['improvements'][metric]
|
| 717 |
+
print(f"{metric:<15} {base:.4f} {finetuned:.4f} {improvement:+.1f}%")
|
| 718 |
+
|
| 719 |
+
# Print Counseling Quality scores
|
| 720 |
+
if quality_metrics:
|
| 721 |
+
print("\n💬 Counseling Quality Metrics:")
|
| 722 |
+
print("-"*60)
|
| 723 |
+
for metric in quality_metrics:
|
| 724 |
+
base = self.results['metrics'][metric]['base']['mean']
|
| 725 |
+
finetuned = self.results['metrics'][metric]['finetuned']['mean']
|
| 726 |
+
improvement = self.results['improvements'][metric]
|
| 727 |
+
metric_name = metric.replace('quality_', '').capitalize()
|
| 728 |
+
print(f"{metric_name:<15} {base:.4f}±{self.results['metrics'][metric]['base']['std']:.3f} "
|
| 729 |
+
f"{finetuned:.4f}±{self.results['metrics'][metric]['finetuned']['std']:.3f} "
|
| 730 |
+
f"{improvement:+.1f}%")
|
| 731 |
+
|
| 732 |
+
# Print summary
|
| 733 |
+
print("\n" + "="*80)
|
| 734 |
+
print("📈 SUMMARY")
|
| 735 |
+
print("="*80)
|
| 736 |
+
print(f"Average BLEU Improvement: {self.results['summary']['bleu_avg_improvement']:+.1f}%")
|
| 737 |
+
print(f"Average ROUGE Improvement: {self.results['summary']['rouge_avg_improvement']:+.1f}%")
|
| 738 |
+
print(f"Average Quality Improvement: {self.results['summary']['quality_avg_improvement']:+.1f}%")
|
| 739 |
+
print(f"Overall Improvement: {self.results['summary']['overall_improvement']:+.1f}%")
|
| 740 |
+
print("="*80)
|
| 741 |
+
|
| 742 |
+
def save_results(self, output_dir: str = "./benchmark_results"):
|
| 743 |
+
"""Save all benchmark results"""
|
| 744 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 745 |
+
|
| 746 |
+
# Save detailed results
|
| 747 |
+
with open(os.path.join(output_dir, "detailed_results.json"), 'w', encoding='utf-8') as f:
|
| 748 |
+
json.dump(self.detailed_results, f, ensure_ascii=False, indent=2, default=str)
|
| 749 |
+
|
| 750 |
+
# Save aggregate results
|
| 751 |
+
with open(os.path.join(output_dir, "aggregate_results.json"), 'w', encoding='utf-8') as f:
|
| 752 |
+
json.dump(self.results, f, ensure_ascii=False, indent=2, default=str)
|
| 753 |
+
|
| 754 |
+
print(f"✅ Results saved to {output_dir}/")
|
| 755 |
+
|
| 756 |
+
|
| 757 |
+
def main():
|
| 758 |
+
"""Main execution function"""
|
| 759 |
+
import argparse
|
| 760 |
+
|
| 761 |
+
parser = argparse.ArgumentParser(description='Japanese Counseling Model Benchmark')
|
| 762 |
+
parser.add_argument('--base_model', type=str, default='LiquidAI/LFM2-1.2B',
|
| 763 |
+
help='Base model name or path')
|
| 764 |
+
parser.add_argument('--finetuned_model', type=str, default='./merged_counselor_model',
|
| 765 |
+
help='Path to fine-tuned merged model')
|
| 766 |
+
parser.add_argument('--test_data', type=str, default='./processed_data_score70/test.jsonl',
|
| 767 |
+
help='Path to test data')
|
| 768 |
+
parser.add_argument('--num_samples', type=int, default=None,
|
| 769 |
+
help='Number of samples to evaluate (None for all)')
|
| 770 |
+
parser.add_argument('--output_dir', type=str, default='./benchmark_results',
|
| 771 |
+
help='Directory to save results')
|
| 772 |
+
|
| 773 |
+
args = parser.parse_args()
|
| 774 |
+
|
| 775 |
+
try:
|
| 776 |
+
# Initialize benchmark
|
| 777 |
+
print("🎌 Initializing Japanese Counseling Benchmark Suite")
|
| 778 |
+
benchmark = JapaneseCounselingBenchmark(
|
| 779 |
+
base_model_name=args.base_model,
|
| 780 |
+
finetuned_model_path=args.finetuned_model,
|
| 781 |
+
test_data_path=args.test_data
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
# Load models
|
| 785 |
+
benchmark.load_models()
|
| 786 |
+
|
| 787 |
+
# Run benchmark
|
| 788 |
+
results = benchmark.run_comprehensive_benchmark(num_samples=args.num_samples)
|
| 789 |
+
|
| 790 |
+
# Save results
|
| 791 |
+
benchmark.save_results(args.output_dir)
|
| 792 |
+
|
| 793 |
+
print("\n✅ Benchmark completed successfully!")
|
| 794 |
+
print(f"📁 Results saved to {args.output_dir}/")
|
| 795 |
+
|
| 796 |
+
except Exception as e:
|
| 797 |
+
print(f"\n❌ Error during benchmarking: {e}")
|
| 798 |
+
import traceback
|
| 799 |
+
traceback.print_exc()
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
if __name__ == "__main__":
|
| 803 |
+
main()
|
benchmarking_v2.py
ADDED
|
@@ -0,0 +1,782 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Fixed Optimized Japanese Counseling Model Benchmark with proper DataParallel handling
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.nn.parallel import DataParallel
|
| 8 |
+
from torch.utils.data import Dataset, DataLoader
|
| 9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 10 |
+
import numpy as np
|
| 11 |
+
from typing import List, Dict, Tuple, Optional, Any
|
| 12 |
+
import json
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
import os
|
| 15 |
+
import gc
|
| 16 |
+
import warnings
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
import pandas as pd
|
| 19 |
+
from collections import defaultdict
|
| 20 |
+
import MeCab
|
| 21 |
+
from rouge_score import rouge_scorer
|
| 22 |
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
| 23 |
+
import re
|
| 24 |
+
import wandb
|
| 25 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 26 |
+
import time
|
| 27 |
+
|
| 28 |
+
# Suppress warnings
|
| 29 |
+
warnings.filterwarnings('ignore')
|
| 30 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
| 31 |
+
|
| 32 |
+
# Suppress Pydantic warnings
|
| 33 |
+
import logging
|
| 34 |
+
logging.getLogger('pydantic').setLevel(logging.ERROR)
|
| 35 |
+
|
| 36 |
+
class TestDataset(Dataset):
|
| 37 |
+
"""Custom dataset for efficient batch processing"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, data: List[Dict]):
|
| 40 |
+
self.data = data
|
| 41 |
+
|
| 42 |
+
def __len__(self):
|
| 43 |
+
return len(self.data)
|
| 44 |
+
|
| 45 |
+
def __getitem__(self, idx):
|
| 46 |
+
return self.data[idx]
|
| 47 |
+
|
| 48 |
+
def custom_collate_fn(batch):
|
| 49 |
+
"""Custom collate function to handle dictionary data properly"""
|
| 50 |
+
return batch
|
| 51 |
+
|
| 52 |
+
class OptimizedJapaneseBenchmark:
|
| 53 |
+
"""
|
| 54 |
+
Highly optimized benchmark suite with multi-GPU support and WandB logging
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self,
|
| 58 |
+
base_model_name: str = "LiquidAI/LFM2-1.2B",
|
| 59 |
+
finetuned_model_path: str = "./merged_counselor_model",
|
| 60 |
+
test_data_path: str = "./processed_data_score80/test.jsonl",
|
| 61 |
+
batch_size: int = 16, # Reduced for stability
|
| 62 |
+
num_workers: int = 0,
|
| 63 |
+
use_wandb: bool = True):
|
| 64 |
+
"""
|
| 65 |
+
Initialize optimized benchmark with multi-GPU support
|
| 66 |
+
"""
|
| 67 |
+
self.base_model_name = base_model_name
|
| 68 |
+
self.finetuned_model_path = finetuned_model_path
|
| 69 |
+
self.test_data_path = test_data_path
|
| 70 |
+
self.batch_size = batch_size
|
| 71 |
+
self.num_workers = num_workers
|
| 72 |
+
|
| 73 |
+
# Setup devices
|
| 74 |
+
self.setup_devices()
|
| 75 |
+
|
| 76 |
+
# Initialize WandB
|
| 77 |
+
if use_wandb:
|
| 78 |
+
self.init_wandb()
|
| 79 |
+
else:
|
| 80 |
+
self.wandb_enabled = False
|
| 81 |
+
|
| 82 |
+
# Initialize tokenizers and scorers
|
| 83 |
+
self.setup_tokenizers_and_scorers()
|
| 84 |
+
|
| 85 |
+
# Results storage
|
| 86 |
+
self.results = {}
|
| 87 |
+
self.detailed_results = []
|
| 88 |
+
|
| 89 |
+
def setup_devices(self):
|
| 90 |
+
"""Setup multi-GPU configuration"""
|
| 91 |
+
if torch.cuda.is_available():
|
| 92 |
+
self.num_gpus = torch.cuda.device_count()
|
| 93 |
+
print(f"🚀 Found {self.num_gpus} GPUs")
|
| 94 |
+
|
| 95 |
+
self.device_ids = list(range(self.num_gpus))
|
| 96 |
+
self.device = torch.device("cuda:0")
|
| 97 |
+
|
| 98 |
+
for i in range(self.num_gpus):
|
| 99 |
+
print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
|
| 100 |
+
print(f" Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")
|
| 101 |
+
else:
|
| 102 |
+
self.num_gpus = 0
|
| 103 |
+
self.device = torch.device("cpu")
|
| 104 |
+
print("⚠️ No GPU found, using CPU")
|
| 105 |
+
|
| 106 |
+
def init_wandb(self):
|
| 107 |
+
"""Initialize WandB for experiment tracking"""
|
| 108 |
+
try:
|
| 109 |
+
run_name = f"benchmark-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
|
| 110 |
+
|
| 111 |
+
wandb.init(
|
| 112 |
+
project="japanese-counseling-benchmark",
|
| 113 |
+
name=run_name,
|
| 114 |
+
config={
|
| 115 |
+
"base_model": self.base_model_name,
|
| 116 |
+
"finetuned_model": self.finetuned_model_path,
|
| 117 |
+
"batch_size": self.batch_size,
|
| 118 |
+
"num_gpus": self.num_gpus,
|
| 119 |
+
"timestamp": datetime.now().isoformat()
|
| 120 |
+
},
|
| 121 |
+
tags=["benchmark", "japanese", "counseling", "multi-gpu"]
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
self.wandb_enabled = True
|
| 125 |
+
print(f"✅ WandB initialized: {wandb.run.name}")
|
| 126 |
+
print(f"📊 View at: {wandb.run.get_url()}")
|
| 127 |
+
except Exception as e:
|
| 128 |
+
print(f"⚠️ WandB initialization failed: {e}")
|
| 129 |
+
self.wandb_enabled = False
|
| 130 |
+
|
| 131 |
+
def setup_tokenizers_and_scorers(self):
|
| 132 |
+
"""Setup tokenizers and scoring functions"""
|
| 133 |
+
# Initialize MeCab for Japanese tokenization
|
| 134 |
+
try:
|
| 135 |
+
self.mecab = MeCab.Tagger("-Owakati")
|
| 136 |
+
print("✅ MeCab initialized")
|
| 137 |
+
except:
|
| 138 |
+
print("⚠️ MeCab not available, using character tokenization")
|
| 139 |
+
self.mecab = None
|
| 140 |
+
|
| 141 |
+
# Initialize ROUGE scorer
|
| 142 |
+
self.rouge_scorer = rouge_scorer.RougeScorer(
|
| 143 |
+
['rouge1', 'rouge2', 'rougeL'],
|
| 144 |
+
use_stemmer=False
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# BLEU smoothing
|
| 148 |
+
self.smoothing = SmoothingFunction().method1
|
| 149 |
+
|
| 150 |
+
def load_test_data_fast(self, max_samples: Optional[int] = None) -> List[Dict]:
|
| 151 |
+
"""Fast loading of test data"""
|
| 152 |
+
print(f"\n📚 Loading test data from {self.test_data_path}")
|
| 153 |
+
|
| 154 |
+
test_data = []
|
| 155 |
+
|
| 156 |
+
if not os.path.exists(self.test_data_path):
|
| 157 |
+
print("⚠️ Test data not found, using synthetic data")
|
| 158 |
+
return self.create_synthetic_test_data()
|
| 159 |
+
|
| 160 |
+
try:
|
| 161 |
+
with open(self.test_data_path, 'r', encoding='utf-8') as f:
|
| 162 |
+
lines = f.readlines()
|
| 163 |
+
|
| 164 |
+
if max_samples:
|
| 165 |
+
lines = lines[:max_samples]
|
| 166 |
+
|
| 167 |
+
for line in tqdm(lines, desc="Loading data"):
|
| 168 |
+
try:
|
| 169 |
+
data = json.loads(line)
|
| 170 |
+
text = data.get('text', '')
|
| 171 |
+
|
| 172 |
+
if "### Input:" in text and "### Response:" in text:
|
| 173 |
+
input_part = text.split("### Input:")[1].split("### Response:")[0].strip()
|
| 174 |
+
response_part = text.split("### Response:")[1].strip()
|
| 175 |
+
|
| 176 |
+
test_data.append({
|
| 177 |
+
'input': input_part,
|
| 178 |
+
'reference': response_part,
|
| 179 |
+
'score': data.get('score', 0),
|
| 180 |
+
'topic': data.get('topic', 'Unknown')
|
| 181 |
+
})
|
| 182 |
+
except:
|
| 183 |
+
continue
|
| 184 |
+
|
| 185 |
+
except Exception as e:
|
| 186 |
+
print(f"Error loading data: {e}")
|
| 187 |
+
return self.create_synthetic_test_data()
|
| 188 |
+
|
| 189 |
+
if not test_data:
|
| 190 |
+
print("⚠️ No valid data found, using synthetic data")
|
| 191 |
+
return self.create_synthetic_test_data()
|
| 192 |
+
|
| 193 |
+
print(f"✅ Loaded {len(test_data)} test examples")
|
| 194 |
+
|
| 195 |
+
if self.wandb_enabled:
|
| 196 |
+
wandb.log({"test_data_size": len(test_data)})
|
| 197 |
+
|
| 198 |
+
return test_data
|
| 199 |
+
|
| 200 |
+
def create_synthetic_test_data(self) -> List[Dict]:
|
| 201 |
+
"""Create synthetic test data"""
|
| 202 |
+
return [
|
| 203 |
+
{
|
| 204 |
+
'input': f'ストレスを感じています。',
|
| 205 |
+
'reference': f'お気持ちわかります。どのような状況でストレスを感じていますか?',
|
| 206 |
+
'score': 75,
|
| 207 |
+
'topic': 'stress'
|
| 208 |
+
}
|
| 209 |
+
for i in range(10)
|
| 210 |
+
]
|
| 211 |
+
|
| 212 |
+
def load_models_optimized(self):
|
| 213 |
+
"""Load models with optimization for multi-GPU"""
|
| 214 |
+
print("\n🤖 Loading models with optimization...")
|
| 215 |
+
|
| 216 |
+
# Load tokenizer
|
| 217 |
+
print(" Loading tokenizer...")
|
| 218 |
+
try:
|
| 219 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 220 |
+
self.base_model_name,
|
| 221 |
+
use_fast=True
|
| 222 |
+
)
|
| 223 |
+
except:
|
| 224 |
+
self.tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
|
| 225 |
+
|
| 226 |
+
if self.tokenizer.pad_token is None:
|
| 227 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 228 |
+
|
| 229 |
+
# Load base model
|
| 230 |
+
print(" Loading base model...")
|
| 231 |
+
try:
|
| 232 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 233 |
+
self.base_model_name,
|
| 234 |
+
torch_dtype=torch.float16,
|
| 235 |
+
trust_remote_code=True,
|
| 236 |
+
low_cpu_mem_usage=True
|
| 237 |
+
)
|
| 238 |
+
except Exception as e:
|
| 239 |
+
print(f" Error loading base model: {e}")
|
| 240 |
+
print(" Using GPT2 as fallback...")
|
| 241 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 242 |
+
"gpt2",
|
| 243 |
+
torch_dtype=torch.float16
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# Load fine-tuned model
|
| 247 |
+
print(" Loading fine-tuned model...")
|
| 248 |
+
if os.path.exists(self.finetuned_model_path):
|
| 249 |
+
try:
|
| 250 |
+
finetuned_model = AutoModelForCausalLM.from_pretrained(
|
| 251 |
+
self.finetuned_model_path,
|
| 252 |
+
torch_dtype=torch.float16,
|
| 253 |
+
trust_remote_code=True,
|
| 254 |
+
low_cpu_mem_usage=True,
|
| 255 |
+
local_files_only=True
|
| 256 |
+
)
|
| 257 |
+
except Exception as e:
|
| 258 |
+
print(f" Error loading fine-tuned model: {e}")
|
| 259 |
+
finetuned_model = base_model
|
| 260 |
+
else:
|
| 261 |
+
print(" Fine-tuned model not found, using base model")
|
| 262 |
+
finetuned_model = base_model
|
| 263 |
+
|
| 264 |
+
# Move models to GPU
|
| 265 |
+
base_model = base_model.to(self.device)
|
| 266 |
+
finetuned_model = finetuned_model.to(self.device)
|
| 267 |
+
|
| 268 |
+
# Setup for multi-GPU if available
|
| 269 |
+
if self.num_gpus > 1:
|
| 270 |
+
print(f" Setting up DataParallel for {self.num_gpus} GPUs...")
|
| 271 |
+
self.base_model = DataParallel(base_model, device_ids=self.device_ids)
|
| 272 |
+
self.finetuned_model = DataParallel(finetuned_model, device_ids=self.device_ids)
|
| 273 |
+
else:
|
| 274 |
+
self.base_model = base_model
|
| 275 |
+
self.finetuned_model = finetuned_model
|
| 276 |
+
|
| 277 |
+
self.base_model.eval()
|
| 278 |
+
self.finetuned_model.eval()
|
| 279 |
+
|
| 280 |
+
print("✅ Models loaded and optimized!")
|
| 281 |
+
|
| 282 |
+
if self.wandb_enabled:
|
| 283 |
+
wandb.log({
|
| 284 |
+
"model_loaded": True,
|
| 285 |
+
"num_gpus_used": self.num_gpus
|
| 286 |
+
})
|
| 287 |
+
|
| 288 |
+
def generate_batch_responses(self, model, prompts: List[str], max_length: int = 150) -> List[str]:
|
| 289 |
+
"""Generate responses in batch for efficiency"""
|
| 290 |
+
if len(prompts) == 0:
|
| 291 |
+
return []
|
| 292 |
+
|
| 293 |
+
formatted_prompts = [
|
| 294 |
+
f"""### Instruction:
|
| 295 |
+
あなたは思いやりのある心理カウンセラーです。
|
| 296 |
+
|
| 297 |
+
### Input:
|
| 298 |
+
{prompt}
|
| 299 |
+
|
| 300 |
+
### Response:
|
| 301 |
+
""" for prompt in prompts
|
| 302 |
+
]
|
| 303 |
+
|
| 304 |
+
try:
|
| 305 |
+
# Tokenize all prompts at once
|
| 306 |
+
inputs = self.tokenizer(
|
| 307 |
+
formatted_prompts,
|
| 308 |
+
return_tensors="pt",
|
| 309 |
+
truncation=True,
|
| 310 |
+
max_length=512,
|
| 311 |
+
padding=True,
|
| 312 |
+
padding_side= 'left'
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 316 |
+
|
| 317 |
+
# Get the actual model from DataParallel if needed
|
| 318 |
+
actual_model = model.module if isinstance(model, DataParallel) else model
|
| 319 |
+
|
| 320 |
+
# Generate in batch
|
| 321 |
+
with torch.no_grad():
|
| 322 |
+
with torch.cuda.amp.autocast():
|
| 323 |
+
outputs = actual_model.generate(
|
| 324 |
+
**inputs,
|
| 325 |
+
max_new_tokens=max_length,
|
| 326 |
+
temperature=0.7,
|
| 327 |
+
do_sample=True,
|
| 328 |
+
top_p=0.9,
|
| 329 |
+
num_beams=1,
|
| 330 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 331 |
+
eos_token_id=self.tokenizer.eos_token_id
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
# Decode all at once
|
| 335 |
+
responses = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
| 336 |
+
|
| 337 |
+
# Extract only generated parts
|
| 338 |
+
extracted_responses = []
|
| 339 |
+
for i, response in enumerate(responses):
|
| 340 |
+
if "### Response:" in response:
|
| 341 |
+
extracted = response.split("### Response:")[-1].strip()
|
| 342 |
+
else:
|
| 343 |
+
extracted = response[len(formatted_prompts[i]):].strip()
|
| 344 |
+
extracted_responses.append(extracted if extracted else "応答を生成できませんでした。")
|
| 345 |
+
|
| 346 |
+
return extracted_responses
|
| 347 |
+
|
| 348 |
+
except Exception as e:
|
| 349 |
+
print(f"Error in batch generation: {e}")
|
| 350 |
+
# Return default responses
|
| 351 |
+
return ["申し訳ございません。応答を生成できませんでした。"] * len(prompts)
|
| 352 |
+
|
| 353 |
+
def tokenize_japanese(self, text: str) -> List[str]:
|
| 354 |
+
"""Tokenize Japanese text"""
|
| 355 |
+
if not text:
|
| 356 |
+
return ['empty']
|
| 357 |
+
|
| 358 |
+
if self.mecab:
|
| 359 |
+
try:
|
| 360 |
+
tokens = self.mecab.parse(text).strip().split()
|
| 361 |
+
return tokens if tokens else list(text)
|
| 362 |
+
except:
|
| 363 |
+
pass
|
| 364 |
+
|
| 365 |
+
# Fallback to character tokenization
|
| 366 |
+
return list(text.replace(' ', ''))
|
| 367 |
+
|
| 368 |
+
def calculate_metrics_batch(self, references: List[str], hypotheses: List[str]) -> Dict:
|
| 369 |
+
"""Calculate all metrics in batch"""
|
| 370 |
+
metrics = defaultdict(list)
|
| 371 |
+
|
| 372 |
+
for ref, hyp in zip(references, hypotheses):
|
| 373 |
+
if not ref or not hyp:
|
| 374 |
+
# Add default scores for empty strings
|
| 375 |
+
for n in range(1, 5):
|
| 376 |
+
metrics[f'BLEU-{n}'].append(0.0)
|
| 377 |
+
metrics['ROUGE-1'].append(0.0)
|
| 378 |
+
metrics['ROUGE-2'].append(0.0)
|
| 379 |
+
metrics['ROUGE-L'].append(0.0)
|
| 380 |
+
continue
|
| 381 |
+
|
| 382 |
+
try:
|
| 383 |
+
# Tokenize
|
| 384 |
+
ref_tokens = self.tokenize_japanese(ref)
|
| 385 |
+
hyp_tokens = self.tokenize_japanese(hyp)
|
| 386 |
+
|
| 387 |
+
# BLEU scores
|
| 388 |
+
for n in range(1, 5):
|
| 389 |
+
weights = tuple([1/n] * n + [0] * (4-n))
|
| 390 |
+
try:
|
| 391 |
+
score = sentence_bleu(
|
| 392 |
+
[ref_tokens],
|
| 393 |
+
hyp_tokens,
|
| 394 |
+
weights=weights,
|
| 395 |
+
smoothing_function=self.smoothing
|
| 396 |
+
)
|
| 397 |
+
metrics[f'BLEU-{n}'].append(score)
|
| 398 |
+
except:
|
| 399 |
+
metrics[f'BLEU-{n}'].append(0.0)
|
| 400 |
+
|
| 401 |
+
# ROUGE scores
|
| 402 |
+
try:
|
| 403 |
+
ref_spaced = ' '.join(ref_tokens)
|
| 404 |
+
hyp_spaced = ' '.join(hyp_tokens)
|
| 405 |
+
rouge_scores = self.rouge_scorer.score(ref_spaced, hyp_spaced)
|
| 406 |
+
metrics['ROUGE-1'].append(rouge_scores['rouge1'].fmeasure)
|
| 407 |
+
metrics['ROUGE-2'].append(rouge_scores['rouge2'].fmeasure)
|
| 408 |
+
metrics['ROUGE-L'].append(rouge_scores['rougeL'].fmeasure)
|
| 409 |
+
except:
|
| 410 |
+
metrics['ROUGE-1'].append(0.0)
|
| 411 |
+
metrics['ROUGE-2'].append(0.0)
|
| 412 |
+
metrics['ROUGE-L'].append(0.0)
|
| 413 |
+
|
| 414 |
+
except Exception as e:
|
| 415 |
+
# Add zeros for failed calculations
|
| 416 |
+
for n in range(1, 5):
|
| 417 |
+
metrics[f'BLEU-{n}'].append(0.0)
|
| 418 |
+
metrics['ROUGE-1'].append(0.0)
|
| 419 |
+
metrics['ROUGE-2'].append(0.0)
|
| 420 |
+
metrics['ROUGE-L'].append(0.0)
|
| 421 |
+
|
| 422 |
+
return dict(metrics)
|
| 423 |
+
|
| 424 |
+
def run_fast_benchmark(self, num_samples: Optional[int] = None):
|
| 425 |
+
"""Run optimized benchmark with batch processing"""
|
| 426 |
+
print("\n" + "="*80)
|
| 427 |
+
print("🚀 Running Fast Multi-GPU Benchmark")
|
| 428 |
+
print("="*80)
|
| 429 |
+
|
| 430 |
+
start_time = time.time()
|
| 431 |
+
|
| 432 |
+
# Load test data
|
| 433 |
+
test_data = self.load_test_data_fast(max_samples=num_samples)
|
| 434 |
+
|
| 435 |
+
if not test_data:
|
| 436 |
+
raise ValueError("No test data available!")
|
| 437 |
+
|
| 438 |
+
# Create DataLoader
|
| 439 |
+
dataset = TestDataset(test_data)
|
| 440 |
+
dataloader = DataLoader(
|
| 441 |
+
dataset,
|
| 442 |
+
batch_size=self.batch_size,
|
| 443 |
+
shuffle=False,
|
| 444 |
+
num_workers=0,
|
| 445 |
+
collate_fn=custom_collate_fn,
|
| 446 |
+
pin_memory=True if self.device.type == 'cuda' else False
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
# Initialize metric collectors
|
| 450 |
+
all_base_metrics = defaultdict(list)
|
| 451 |
+
all_finetuned_metrics = defaultdict(list)
|
| 452 |
+
|
| 453 |
+
print(f"\n📊 Evaluating {len(test_data)} examples in {len(dataloader)} batches...")
|
| 454 |
+
print(f" Batch size: {self.batch_size}")
|
| 455 |
+
print(f" Using {self.num_gpus} GPU(s)")
|
| 456 |
+
|
| 457 |
+
# Process batches
|
| 458 |
+
successful_batches = 0
|
| 459 |
+
for batch_idx, batch in enumerate(tqdm(dataloader, desc="Processing batches")):
|
| 460 |
+
try:
|
| 461 |
+
# Extract batch data
|
| 462 |
+
inputs = [item['input'] for item in batch]
|
| 463 |
+
references = [item['reference'] for item in batch]
|
| 464 |
+
|
| 465 |
+
# Generate responses in batch
|
| 466 |
+
base_responses = self.generate_batch_responses(self.base_model, inputs)
|
| 467 |
+
finetuned_responses = self.generate_batch_responses(self.finetuned_model, inputs)
|
| 468 |
+
|
| 469 |
+
# Calculate metrics in batch
|
| 470 |
+
base_metrics = self.calculate_metrics_batch(references, base_responses)
|
| 471 |
+
finetuned_metrics = self.calculate_metrics_batch(references, finetuned_responses)
|
| 472 |
+
|
| 473 |
+
# Aggregate metrics
|
| 474 |
+
for key, values in base_metrics.items():
|
| 475 |
+
all_base_metrics[key].extend(values)
|
| 476 |
+
for key, values in finetuned_metrics.items():
|
| 477 |
+
all_finetuned_metrics[key].extend(values)
|
| 478 |
+
|
| 479 |
+
successful_batches += 1
|
| 480 |
+
|
| 481 |
+
# Log progress to WandB
|
| 482 |
+
if self.wandb_enabled and batch_idx % 5 == 0:
|
| 483 |
+
progress = (batch_idx + 1) / len(dataloader) * 100
|
| 484 |
+
|
| 485 |
+
# Calculate current averages
|
| 486 |
+
current_bleu4_base = np.mean(all_base_metrics.get('BLEU-4', [0]))
|
| 487 |
+
current_bleu4_finetuned = np.mean(all_finetuned_metrics.get('BLEU-4', [0]))
|
| 488 |
+
current_rouge_l_base = np.mean(all_base_metrics.get('ROUGE-L', [0]))
|
| 489 |
+
current_rouge_l_finetuned = np.mean(all_finetuned_metrics.get('ROUGE-L', [0]))
|
| 490 |
+
|
| 491 |
+
wandb.log({
|
| 492 |
+
"progress": progress,
|
| 493 |
+
"batches_processed": batch_idx + 1,
|
| 494 |
+
"samples_processed": min((batch_idx + 1) * self.batch_size, len(test_data)),
|
| 495 |
+
"current_bleu4_base": current_bleu4_base,
|
| 496 |
+
"current_bleu4_finetuned": current_bleu4_finetuned,
|
| 497 |
+
"current_rouge_l_base": current_rouge_l_base,
|
| 498 |
+
"current_rouge_l_finetuned": current_rouge_l_finetuned
|
| 499 |
+
})
|
| 500 |
+
|
| 501 |
+
# Store examples for analysis
|
| 502 |
+
if batch_idx == 0 and len(inputs) > 0:
|
| 503 |
+
for i in range(min(3, len(inputs))):
|
| 504 |
+
self.detailed_results.append({
|
| 505 |
+
'input': inputs[i],
|
| 506 |
+
'reference': references[i],
|
| 507 |
+
'base_response': base_responses[i] if i < len(base_responses) else "",
|
| 508 |
+
'finetuned_response': finetuned_responses[i] if i < len(finetuned_responses) else ""
|
| 509 |
+
})
|
| 510 |
+
|
| 511 |
+
# Print sample
|
| 512 |
+
print(f"\n📝 Sample Example:")
|
| 513 |
+
print(f"Input: {inputs[0][:100]}...")
|
| 514 |
+
print(f"Reference: {references[0][:100]}...")
|
| 515 |
+
print(f"Base response: {base_responses[0][:100]}...")
|
| 516 |
+
print(f"Fine-tuned response: {finetuned_responses[0][:100]}...")
|
| 517 |
+
|
| 518 |
+
except Exception as e:
|
| 519 |
+
print(f"Error processing batch {batch_idx}: {e}")
|
| 520 |
+
continue
|
| 521 |
+
|
| 522 |
+
print(f"\n✅ Successfully processed {successful_batches}/{len(dataloader)} batches")
|
| 523 |
+
|
| 524 |
+
# Calculate final statistics
|
| 525 |
+
self.results = self.calculate_final_statistics(all_base_metrics, all_finetuned_metrics)
|
| 526 |
+
|
| 527 |
+
# Calculate processing time
|
| 528 |
+
total_time = time.time() - start_time
|
| 529 |
+
samples_per_second = len(test_data) / total_time if total_time > 0 else 0
|
| 530 |
+
|
| 531 |
+
print(f"\n⏱️ Benchmark completed in {total_time:.2f} seconds")
|
| 532 |
+
print(f" Processing speed: {samples_per_second:.2f} samples/second")
|
| 533 |
+
|
| 534 |
+
# Log final results to WandB
|
| 535 |
+
if self.wandb_enabled:
|
| 536 |
+
wandb.log({
|
| 537 |
+
"total_time_seconds": total_time,
|
| 538 |
+
"samples_per_second": samples_per_second,
|
| 539 |
+
"total_samples": len(test_data),
|
| 540 |
+
"successful_batches": successful_batches,
|
| 541 |
+
**{f"final_{k}": v for k, v in self.results['summary'].items()}
|
| 542 |
+
})
|
| 543 |
+
|
| 544 |
+
# Log detailed metrics
|
| 545 |
+
for metric_name, improvements in self.results['improvements'].items():
|
| 546 |
+
wandb.log({f"improvement_{metric_name}": improvements})
|
| 547 |
+
|
| 548 |
+
# Create visualization
|
| 549 |
+
if self.results['metrics']:
|
| 550 |
+
self.create_wandb_visualizations()
|
| 551 |
+
|
| 552 |
+
# Print results
|
| 553 |
+
self.print_results()
|
| 554 |
+
|
| 555 |
+
return self.results
|
| 556 |
+
|
| 557 |
+
def create_wandb_visualizations(self):
|
| 558 |
+
"""Create WandB visualizations"""
|
| 559 |
+
if not self.wandb_enabled or not self.results.get('metrics'):
|
| 560 |
+
return
|
| 561 |
+
|
| 562 |
+
try:
|
| 563 |
+
# Create comparison table
|
| 564 |
+
data = []
|
| 565 |
+
for metric in self.results['metrics']:
|
| 566 |
+
data.append([
|
| 567 |
+
metric,
|
| 568 |
+
self.results['metrics'][metric]['base']['mean'],
|
| 569 |
+
self.results['metrics'][metric]['finetuned']['mean'],
|
| 570 |
+
self.results['improvements'][metric]
|
| 571 |
+
])
|
| 572 |
+
|
| 573 |
+
table = wandb.Table(
|
| 574 |
+
columns=["Metric", "Base", "Fine-tuned", "Improvement (%)"],
|
| 575 |
+
data=data
|
| 576 |
+
)
|
| 577 |
+
wandb.log({"results_comparison": table})
|
| 578 |
+
|
| 579 |
+
# Log bar chart of improvements
|
| 580 |
+
wandb.log({
|
| 581 |
+
"improvements_chart": wandb.plot.bar(
|
| 582 |
+
wandb.Table(
|
| 583 |
+
data=[[m, self.results['improvements'][m]] for m in self.results['improvements']],
|
| 584 |
+
columns=["Metric", "Improvement (%)"]
|
| 585 |
+
),
|
| 586 |
+
"Metric", "Improvement (%)",
|
| 587 |
+
title="Model Improvements"
|
| 588 |
+
)
|
| 589 |
+
})
|
| 590 |
+
except Exception as e:
|
| 591 |
+
print(f"Error creating visualizations: {e}")
|
| 592 |
+
|
| 593 |
+
def calculate_final_statistics(self, base_metrics: Dict, finetuned_metrics: Dict) -> Dict:
|
| 594 |
+
"""Calculate final aggregate statistics"""
|
| 595 |
+
results = {
|
| 596 |
+
'metrics': {},
|
| 597 |
+
'improvements': {},
|
| 598 |
+
'summary': {}
|
| 599 |
+
}
|
| 600 |
+
|
| 601 |
+
# Calculate statistics for each metric
|
| 602 |
+
all_metric_names = set(base_metrics.keys()) | set(finetuned_metrics.keys())
|
| 603 |
+
|
| 604 |
+
for metric in all_metric_names:
|
| 605 |
+
base_values = base_metrics.get(metric, [0])
|
| 606 |
+
finetuned_values = finetuned_metrics.get(metric, [0])
|
| 607 |
+
|
| 608 |
+
# Filter out any None values
|
| 609 |
+
base_values = [v for v in base_values if v is not None]
|
| 610 |
+
finetuned_values = [v for v in finetuned_values if v is not None]
|
| 611 |
+
|
| 612 |
+
if not base_values:
|
| 613 |
+
base_values = [0]
|
| 614 |
+
if not finetuned_values:
|
| 615 |
+
finetuned_values = [0]
|
| 616 |
+
|
| 617 |
+
results['metrics'][metric] = {
|
| 618 |
+
'base': {
|
| 619 |
+
'mean': float(np.mean(base_values)),
|
| 620 |
+
'std': float(np.std(base_values)),
|
| 621 |
+
'min': float(np.min(base_values)),
|
| 622 |
+
'max': float(np.max(base_values))
|
| 623 |
+
},
|
| 624 |
+
'finetuned': {
|
| 625 |
+
'mean': float(np.mean(finetuned_values)),
|
| 626 |
+
'std': float(np.std(finetuned_values)),
|
| 627 |
+
'min': float(np.min(finetuned_values)),
|
| 628 |
+
'max': float(np.max(finetuned_values))
|
| 629 |
+
}
|
| 630 |
+
}
|
| 631 |
+
|
| 632 |
+
# Calculate improvement
|
| 633 |
+
base_mean = np.mean(base_values)
|
| 634 |
+
finetuned_mean = np.mean(finetuned_values)
|
| 635 |
+
if base_mean > 0:
|
| 636 |
+
improvement = ((finetuned_mean - base_mean) / base_mean) * 100
|
| 637 |
+
else:
|
| 638 |
+
improvement = 0 if finetuned_mean == 0 else 100
|
| 639 |
+
|
| 640 |
+
results['improvements'][metric] = improvement
|
| 641 |
+
|
| 642 |
+
# Calculate summary statistics
|
| 643 |
+
bleu_metrics = [m for m in results['metrics'] if 'BLEU' in m]
|
| 644 |
+
rouge_metrics = [m for m in results['metrics'] if 'ROUGE' in m]
|
| 645 |
+
|
| 646 |
+
results['summary'] = {
|
| 647 |
+
'bleu_avg_improvement': np.mean([results['improvements'][m] for m in bleu_metrics]) if bleu_metrics else 0,
|
| 648 |
+
'rouge_avg_improvement': np.mean([results['improvements'][m] for m in rouge_metrics]) if rouge_metrics else 0,
|
| 649 |
+
'overall_improvement': np.mean(list(results['improvements'].values())) if results['improvements'] else 0
|
| 650 |
+
}
|
| 651 |
+
|
| 652 |
+
return results
|
| 653 |
+
|
| 654 |
+
def print_results(self):
|
| 655 |
+
"""Print formatted results"""
|
| 656 |
+
print("\n" + "="*80)
|
| 657 |
+
print("📊 BENCHMARK RESULTS")
|
| 658 |
+
print("="*80)
|
| 659 |
+
|
| 660 |
+
if not self.results or 'metrics' not in self.results:
|
| 661 |
+
print("No results to display")
|
| 662 |
+
return
|
| 663 |
+
|
| 664 |
+
# BLEU scores
|
| 665 |
+
print("\n📘 BLEU Scores:")
|
| 666 |
+
print("-"*60)
|
| 667 |
+
print(f"{'Metric':<15} {'Base':<15} {'Fine-tuned':<15} {'Improvement':<15}")
|
| 668 |
+
print("-"*60)
|
| 669 |
+
|
| 670 |
+
for metric in sorted([m for m in self.results['metrics'] if 'BLEU' in m]):
|
| 671 |
+
base = self.results['metrics'][metric]['base']['mean']
|
| 672 |
+
finetuned = self.results['metrics'][metric]['finetuned']['mean']
|
| 673 |
+
improvement = self.results['improvements'][metric]
|
| 674 |
+
print(f"{metric:<15} {base:.4f} {finetuned:.4f} {improvement:+.1f}%")
|
| 675 |
+
|
| 676 |
+
# ROUGE scores
|
| 677 |
+
print("\n📕 ROUGE Scores:")
|
| 678 |
+
print("-"*60)
|
| 679 |
+
|
| 680 |
+
for metric in sorted([m for m in self.results['metrics'] if 'ROUGE' in m]):
|
| 681 |
+
base = self.results['metrics'][metric]['base']['mean']
|
| 682 |
+
finetuned = self.results['metrics'][metric]['finetuned']['mean']
|
| 683 |
+
improvement = self.results['improvements'][metric]
|
| 684 |
+
print(f"{metric:<15} {base:.4f} {finetuned:.4f} {improvement:+.1f}%")
|
| 685 |
+
|
| 686 |
+
# Summary
|
| 687 |
+
print("\n" + "="*80)
|
| 688 |
+
print("📈 SUMMARY")
|
| 689 |
+
print("="*80)
|
| 690 |
+
print(f"BLEU Average Improvement: {self.results['summary']['bleu_avg_improvement']:+.1f}%")
|
| 691 |
+
print(f"ROUGE Average Improvement: {self.results['summary']['rouge_avg_improvement']:+.1f}%")
|
| 692 |
+
print(f"Overall Improvement: {self.results['summary']['overall_improvement']:+.1f}%")
|
| 693 |
+
print("="*80)
|
| 694 |
+
|
| 695 |
+
def save_results(self, output_dir: str = "./benchmark_results"):
|
| 696 |
+
"""Save results"""
|
| 697 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 698 |
+
|
| 699 |
+
# Save results
|
| 700 |
+
with open(os.path.join(output_dir, "results.json"), 'w', encoding='utf-8') as f:
|
| 701 |
+
json.dump(self.results, f, ensure_ascii=False, indent=2, default=str)
|
| 702 |
+
|
| 703 |
+
with open(os.path.join(output_dir, "examples.json"), 'w', encoding='utf-8') as f:
|
| 704 |
+
json.dump(self.detailed_results, f, ensure_ascii=False, indent=2)
|
| 705 |
+
|
| 706 |
+
# Save to WandB
|
| 707 |
+
if self.wandb_enabled:
|
| 708 |
+
try:
|
| 709 |
+
artifact = wandb.Artifact(
|
| 710 |
+
name=f"benchmark-results-{wandb.run.id}",
|
| 711 |
+
type="benchmark_results",
|
| 712 |
+
description="Japanese counseling model benchmark results"
|
| 713 |
+
)
|
| 714 |
+
artifact.add_dir(output_dir)
|
| 715 |
+
wandb.log_artifact(artifact)
|
| 716 |
+
except Exception as e:
|
| 717 |
+
print(f"Error saving to WandB: {e}")
|
| 718 |
+
|
| 719 |
+
print(f"✅ Results saved to {output_dir}/")
|
| 720 |
+
|
| 721 |
+
def cleanup(self):
|
| 722 |
+
"""Clean up resources"""
|
| 723 |
+
if self.wandb_enabled:
|
| 724 |
+
wandb.finish()
|
| 725 |
+
|
| 726 |
+
if torch.cuda.is_available():
|
| 727 |
+
torch.cuda.empty_cache()
|
| 728 |
+
|
| 729 |
+
gc.collect()
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
def main():
|
| 733 |
+
"""Main execution"""
|
| 734 |
+
import argparse
|
| 735 |
+
|
| 736 |
+
parser = argparse.ArgumentParser(description='Optimized Japanese Counseling Benchmark')
|
| 737 |
+
parser.add_argument('--base_model', type=str, default='LiquidAI/LFM2-1.2B')
|
| 738 |
+
parser.add_argument('--finetuned_model', type=str, default='./merged_counselor_model')
|
| 739 |
+
parser.add_argument('--test_data', type=str, default='./processed_data_score80/test.jsonl')
|
| 740 |
+
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for processing')
|
| 741 |
+
parser.add_argument('--num_samples', type=int, default=None, help='Number of samples to evaluate')
|
| 742 |
+
parser.add_argument('--output_dir', type=str, default='./benchmark_results_fast')
|
| 743 |
+
parser.add_argument('--no_wandb', action='store_true', help='Disable WandB logging')
|
| 744 |
+
|
| 745 |
+
args = parser.parse_args()
|
| 746 |
+
|
| 747 |
+
try:
|
| 748 |
+
# Initialize benchmark
|
| 749 |
+
print("🚀 Initializing Optimized Benchmark Suite")
|
| 750 |
+
benchmark = OptimizedJapaneseBenchmark(
|
| 751 |
+
base_model_name=args.base_model,
|
| 752 |
+
finetuned_model_path=args.finetuned_model,
|
| 753 |
+
test_data_path=args.test_data,
|
| 754 |
+
batch_size=args.batch_size,
|
| 755 |
+
use_wandb=not args.no_wandb
|
| 756 |
+
)
|
| 757 |
+
|
| 758 |
+
# Load models
|
| 759 |
+
benchmark.load_models_optimized()
|
| 760 |
+
|
| 761 |
+
# Run benchmark
|
| 762 |
+
results = benchmark.run_fast_benchmark(num_samples=args.num_samples)
|
| 763 |
+
|
| 764 |
+
# Save results
|
| 765 |
+
benchmark.save_results(args.output_dir)
|
| 766 |
+
|
| 767 |
+
# Cleanup
|
| 768 |
+
benchmark.cleanup()
|
| 769 |
+
|
| 770 |
+
print("\n✅ Benchmark completed successfully!")
|
| 771 |
+
|
| 772 |
+
except Exception as e:
|
| 773 |
+
print(f"\n❌ Error: {e}")
|
| 774 |
+
import traceback
|
| 775 |
+
traceback.print_exc()
|
| 776 |
+
|
| 777 |
+
if 'benchmark' in locals():
|
| 778 |
+
benchmark.cleanup()
|
| 779 |
+
|
| 780 |
+
|
| 781 |
+
if __name__ == "__main__":
|
| 782 |
+
main()
|
chat.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Interactive Chat Interface for Testing Fine-tuned Japanese Counseling Model
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 7 |
+
import os
|
| 8 |
+
import warnings
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
import json
|
| 11 |
+
|
| 12 |
+
warnings.filterwarnings('ignore')
|
| 13 |
+
|
| 14 |
+
class CounselorChatInterface:
|
| 15 |
+
def __init__(self, model_path: str = "./merged_counselor_model"):
|
| 16 |
+
"""
|
| 17 |
+
Initialize the chat interface with the fine-tuned model
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
model_path: Path to the fine-tuned model
|
| 21 |
+
"""
|
| 22 |
+
self.model_path = model_path
|
| 23 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 24 |
+
|
| 25 |
+
print("="*80)
|
| 26 |
+
print("🎌 Japanese Counseling Model Chat Interface")
|
| 27 |
+
print("="*80)
|
| 28 |
+
print(f"📍 Device: {self.device}")
|
| 29 |
+
|
| 30 |
+
if self.device.type == "cuda":
|
| 31 |
+
print(f" GPU: {torch.cuda.get_device_name(0)}")
|
| 32 |
+
print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
| 33 |
+
|
| 34 |
+
self.load_model()
|
| 35 |
+
self.conversation_history = []
|
| 36 |
+
|
| 37 |
+
def load_model(self):
|
| 38 |
+
"""Load the fine-tuned model and tokenizer"""
|
| 39 |
+
print(f"\n🤖 Loading model from {self.model_path}...")
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
# Load tokenizer
|
| 43 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 44 |
+
self.model_path,
|
| 45 |
+
local_files_only=True
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Set padding token if not set
|
| 49 |
+
if self.tokenizer.pad_token is None:
|
| 50 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 51 |
+
|
| 52 |
+
# Load model
|
| 53 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 54 |
+
self.model_path,
|
| 55 |
+
torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
|
| 56 |
+
device_map="auto" if self.device.type == "cuda" else None,
|
| 57 |
+
local_files_only=True,
|
| 58 |
+
trust_remote_code=True
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
self.model.eval()
|
| 62 |
+
print("✅ Model loaded successfully!")
|
| 63 |
+
|
| 64 |
+
except Exception as e:
|
| 65 |
+
print(f"❌ Error loading model: {e}")
|
| 66 |
+
print("Trying alternative loading method...")
|
| 67 |
+
|
| 68 |
+
# Try loading with base tokenizer
|
| 69 |
+
try:
|
| 70 |
+
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 71 |
+
if self.tokenizer.pad_token is None:
|
| 72 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 73 |
+
|
| 74 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 75 |
+
self.model_path,
|
| 76 |
+
torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
|
| 77 |
+
local_files_only=True
|
| 78 |
+
)
|
| 79 |
+
self.model = self.model.to(self.device)
|
| 80 |
+
self.model.eval()
|
| 81 |
+
print("✅ Model loaded with fallback tokenizer!")
|
| 82 |
+
except Exception as e2:
|
| 83 |
+
print(f"❌ Failed to load model: {e2}")
|
| 84 |
+
raise
|
| 85 |
+
|
| 86 |
+
def generate_response(self, user_input: str,
|
| 87 |
+
temperature: float = 0,
|
| 88 |
+
max_length: int = 200,
|
| 89 |
+
use_context: bool = True) -> str:
|
| 90 |
+
"""
|
| 91 |
+
Generate a counseling response
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
user_input: User's message
|
| 95 |
+
temperature: Generation temperature (0.1-1.0)
|
| 96 |
+
max_length: Maximum response length
|
| 97 |
+
use_context: Whether to use conversation history
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
Generated response
|
| 101 |
+
"""
|
| 102 |
+
# Format the prompt
|
| 103 |
+
if use_context and len(self.conversation_history) > 0:
|
| 104 |
+
# Include recent context
|
| 105 |
+
context = "\n".join(self.conversation_history[-4:]) # Last 2 exchanges
|
| 106 |
+
prompt = f"""### Instruction:
|
| 107 |
+
あなたは思いやりのある心理カウンセラーです。
|
| 108 |
+
クライアントの感情を理解し、共感的で支援的な応答を提供してください。
|
| 109 |
+
|
| 110 |
+
### Context:
|
| 111 |
+
{context}
|
| 112 |
+
|
| 113 |
+
### Input:
|
| 114 |
+
{user_input}
|
| 115 |
+
|
| 116 |
+
### Response:
|
| 117 |
+
"""
|
| 118 |
+
else:
|
| 119 |
+
prompt = f"""### Instruction:
|
| 120 |
+
あなたは思いやりのある心理カウンセラーです。
|
| 121 |
+
クライアントの感情を理解し、共感的で支援的な応答を提供してください。
|
| 122 |
+
|
| 123 |
+
### Input:
|
| 124 |
+
{user_input}
|
| 125 |
+
|
| 126 |
+
### Response:
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
# Tokenize
|
| 130 |
+
inputs = self.tokenizer(
|
| 131 |
+
prompt,
|
| 132 |
+
return_tensors="pt",
|
| 133 |
+
truncation=True,
|
| 134 |
+
max_length=512
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
if self.device.type == "cuda":
|
| 138 |
+
inputs = {k: v.cuda() for k, v in inputs.items()}
|
| 139 |
+
|
| 140 |
+
# Generate
|
| 141 |
+
try:
|
| 142 |
+
with torch.no_grad():
|
| 143 |
+
with torch.cuda.amp.autocast() if self.device.type == "cuda" else torch.autocast("cpu"):
|
| 144 |
+
outputs = self.model.generate(
|
| 145 |
+
**inputs,
|
| 146 |
+
max_new_tokens=max_length,
|
| 147 |
+
temperature=temperature,
|
| 148 |
+
do_sample=True,
|
| 149 |
+
top_p=0.9,
|
| 150 |
+
top_k=50,
|
| 151 |
+
repetition_penalty=1.1,
|
| 152 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 153 |
+
eos_token_id=self.tokenizer.eos_token_id
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# Decode
|
| 157 |
+
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 158 |
+
|
| 159 |
+
# Extract only the response part
|
| 160 |
+
if "### Response:" in full_response:
|
| 161 |
+
response = full_response.split("### Response:")[-1].strip()
|
| 162 |
+
else:
|
| 163 |
+
response = full_response[len(prompt):].strip()
|
| 164 |
+
|
| 165 |
+
return response
|
| 166 |
+
|
| 167 |
+
except Exception as e:
|
| 168 |
+
print(f"Error generating response: {e}")
|
| 169 |
+
return "申し訳ございません。応答の生成中にエラーが発生しました。"
|
| 170 |
+
|
| 171 |
+
def chat(self):
|
| 172 |
+
"""Start interactive chat session"""
|
| 173 |
+
print("\n" + "="*80)
|
| 174 |
+
print("💬 チャットを開始します (Chat session started)")
|
| 175 |
+
print("="*80)
|
| 176 |
+
print("Commands:")
|
| 177 |
+
print(" /quit or /exit - 終了 (Exit)")
|
| 178 |
+
print(" /clear - 会話履歴をクリア (Clear conversation history)")
|
| 179 |
+
print(" /save - 会話を保存 (Save conversation)")
|
| 180 |
+
print(" /temp <value> - 温度パラメータを設定 (Set temperature, e.g., /temp 0.8)")
|
| 181 |
+
print(" /context on/off - コンテキスト使用の切り替え (Toggle context usage)")
|
| 182 |
+
print("-"*80)
|
| 183 |
+
|
| 184 |
+
temperature = 0.1
|
| 185 |
+
use_context = True
|
| 186 |
+
|
| 187 |
+
while True:
|
| 188 |
+
try:
|
| 189 |
+
# Get user input
|
| 190 |
+
user_input = input("\n👤 You: ").strip()
|
| 191 |
+
|
| 192 |
+
# Check for commands
|
| 193 |
+
if user_input.lower() in ['/quit', '/exit', '/q']:
|
| 194 |
+
print("\n👋 さようなら!(Goodbye!)")
|
| 195 |
+
break
|
| 196 |
+
|
| 197 |
+
elif user_input.lower() == '/clear':
|
| 198 |
+
self.conversation_history = []
|
| 199 |
+
print("✅ 会話履歴をクリアしました (Conversation history cleared)")
|
| 200 |
+
continue
|
| 201 |
+
|
| 202 |
+
elif user_input.lower() == '/save':
|
| 203 |
+
self.save_conversation()
|
| 204 |
+
continue
|
| 205 |
+
|
| 206 |
+
elif user_input.lower().startswith('/temp'):
|
| 207 |
+
try:
|
| 208 |
+
temperature = float(user_input.split()[1])
|
| 209 |
+
temperature = 0.1 # max(0.1, min(, temperature))
|
| 210 |
+
print(f"✅ Temperature set to {temperature}")
|
| 211 |
+
except:
|
| 212 |
+
print("❌ Invalid temperature. Use: /temp 0.7")
|
| 213 |
+
continue
|
| 214 |
+
|
| 215 |
+
elif user_input.lower().startswith('/context'):
|
| 216 |
+
try:
|
| 217 |
+
setting = user_input.split()[1].lower()
|
| 218 |
+
use_context = setting == 'on'
|
| 219 |
+
print(f"✅ Context {'enabled' if use_context else 'disabled'}")
|
| 220 |
+
except:
|
| 221 |
+
print("❌ Use: /context on or /context off")
|
| 222 |
+
continue
|
| 223 |
+
|
| 224 |
+
elif user_input.startswith('/'):
|
| 225 |
+
print("❌ Unknown command")
|
| 226 |
+
continue
|
| 227 |
+
|
| 228 |
+
# Generate response
|
| 229 |
+
print("\n🤖 Counselor: ", end="", flush=True)
|
| 230 |
+
response = self.generate_response(
|
| 231 |
+
user_input,
|
| 232 |
+
temperature=temperature,
|
| 233 |
+
use_context=use_context
|
| 234 |
+
)
|
| 235 |
+
print(response)
|
| 236 |
+
|
| 237 |
+
# Add to history
|
| 238 |
+
self.conversation_history.append(f"Client: {user_input}")
|
| 239 |
+
self.conversation_history.append(f"Counselor: {response}")
|
| 240 |
+
|
| 241 |
+
except KeyboardInterrupt:
|
| 242 |
+
print("\n\n👋 さようなら!(Goodbye!)")
|
| 243 |
+
break
|
| 244 |
+
except Exception as e:
|
| 245 |
+
print(f"\n❌ Error: {e}")
|
| 246 |
+
continue
|
| 247 |
+
|
| 248 |
+
def save_conversation(self):
|
| 249 |
+
"""Save the conversation to a file"""
|
| 250 |
+
if not self.conversation_history:
|
| 251 |
+
print("❌ No conversation to save")
|
| 252 |
+
return
|
| 253 |
+
|
| 254 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 255 |
+
filename = f"conversation_{timestamp}.json"
|
| 256 |
+
|
| 257 |
+
conversation_data = {
|
| 258 |
+
"timestamp": timestamp,
|
| 259 |
+
"model_path": self.model_path,
|
| 260 |
+
"conversation": self.conversation_history
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
with open(filename, 'w', encoding='utf-8') as f:
|
| 264 |
+
json.dump(conversation_data, f, ensure_ascii=False, indent=2)
|
| 265 |
+
|
| 266 |
+
print(f"✅ Conversation saved to {filename}")
|
| 267 |
+
|
| 268 |
+
def test_responses(self):
|
| 269 |
+
"""Test the model with predefined inputs"""
|
| 270 |
+
print("\n" + "="*80)
|
| 271 |
+
print("🧪 Testing Model Responses")
|
| 272 |
+
print("="*80)
|
| 273 |
+
|
| 274 |
+
test_inputs = [
|
| 275 |
+
"こんにちは。最近ストレスを感じています。",
|
| 276 |
+
"仕事がうまくいかなくて悩んでいます。",
|
| 277 |
+
"人間関係で困っています。どうすればいいでしょうか。",
|
| 278 |
+
"将来が不安で眠れません。",
|
| 279 |
+
"自分に自信が持てません。",
|
| 280 |
+
"家族との関係で悩んでいます。",
|
| 281 |
+
"毎日が辛いです。",
|
| 282 |
+
"誰にも相談できません。"
|
| 283 |
+
]
|
| 284 |
+
|
| 285 |
+
print("\nTesting with different temperature settings:\n")
|
| 286 |
+
|
| 287 |
+
for temp in [0, 0.1]:
|
| 288 |
+
print(f"\n🌡️ Temperature: {temp}")
|
| 289 |
+
print("-"*60)
|
| 290 |
+
|
| 291 |
+
for i, test_input in enumerate(test_inputs[:3], 1):
|
| 292 |
+
print(f"\n{i}. Input: {test_input}")
|
| 293 |
+
response = self.generate_response(test_input, temperature=temp, use_context=False)
|
| 294 |
+
print(f" Response: {response[:200]}...")
|
| 295 |
+
print()
|
| 296 |
+
|
| 297 |
+
print("="*80)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def main():
|
| 301 |
+
"""Main function"""
|
| 302 |
+
import argparse
|
| 303 |
+
|
| 304 |
+
parser = argparse.ArgumentParser(description='Chat with fine-tuned counseling model')
|
| 305 |
+
parser.add_argument('--model_path', type=str, default='./merged_counselor_mode_2b',
|
| 306 |
+
help='Path to the fine-tuned model')
|
| 307 |
+
parser.add_argument('--test_only', action='store_true',
|
| 308 |
+
help='Only run test responses without chat')
|
| 309 |
+
|
| 310 |
+
args = parser.parse_args()
|
| 311 |
+
|
| 312 |
+
# Check if model exists
|
| 313 |
+
if not os.path.exists(args.model_path):
|
| 314 |
+
print(f"❌ Model not found at {args.model_path}")
|
| 315 |
+
print("\nAvailable models:")
|
| 316 |
+
for item in os.listdir('.'):
|
| 317 |
+
if 'model' in item.lower() and os.path.isdir(item):
|
| 318 |
+
print(f" - {item}")
|
| 319 |
+
return
|
| 320 |
+
|
| 321 |
+
try:
|
| 322 |
+
# Initialize chat interface
|
| 323 |
+
chat = CounselorChatInterface(model_path=args.model_path)
|
| 324 |
+
|
| 325 |
+
if args.test_only:
|
| 326 |
+
# Run tests only
|
| 327 |
+
chat.test_responses()
|
| 328 |
+
else:
|
| 329 |
+
# Start interactive chat
|
| 330 |
+
chat.chat()
|
| 331 |
+
|
| 332 |
+
except Exception as e:
|
| 333 |
+
print(f"❌ Error: {e}")
|
| 334 |
+
import traceback
|
| 335 |
+
traceback.print_exc()
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
if __name__ == "__main__":
|
| 339 |
+
main()
|
data_preprocessor.py
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from typing import List, Dict, Tuple, Optional
|
| 6 |
+
import random
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import re
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
class KokoroChatPreprocessor:
|
| 13 |
+
def __init__(self, data_path: str, max_length: int = 2048, min_score: int = 60):
|
| 14 |
+
"""
|
| 15 |
+
Initialize the preprocessor for KokoroChat dataset
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
data_path: Path to KokoroChat repository
|
| 19 |
+
max_length: Maximum sequence length for model input
|
| 20 |
+
min_score: Minimum score threshold for filtering conversations (default: 60)
|
| 21 |
+
"""
|
| 22 |
+
self.data_path = Path(data_path)
|
| 23 |
+
self.max_length = max_length
|
| 24 |
+
self.min_score = min_score
|
| 25 |
+
self.conversations = []
|
| 26 |
+
self.score_distribution = [] # Track score distribution
|
| 27 |
+
self.system_prompt = """あなたは思いやりのある心理カウンセラーです。
|
| 28 |
+
クライアントの感情を理解し、共感的で支援的な応答を提供してください。
|
| 29 |
+
プライバシーを尊重し、判断を下さず、希望と実用的な洞察を提供することに焦点を当ててください。"""
|
| 30 |
+
|
| 31 |
+
def load_json_files(self) -> List[Dict]:
|
| 32 |
+
"""Load all JSON files from the dataset"""
|
| 33 |
+
json_files = []
|
| 34 |
+
# Changed from "data" to "kokorochat_dialogues"
|
| 35 |
+
data_dir = self.data_path / "kokorochat_dialogues"
|
| 36 |
+
|
| 37 |
+
# Check if data directory exists, if not try root directory
|
| 38 |
+
if not data_dir.exists():
|
| 39 |
+
data_dir = self.data_path
|
| 40 |
+
print(f"Using root directory: {data_dir}")
|
| 41 |
+
else:
|
| 42 |
+
print(f"Using data directory: {data_dir}")
|
| 43 |
+
|
| 44 |
+
for root, dirs, files in os.walk(data_dir):
|
| 45 |
+
for file in tqdm(files, desc="Loading JSON files"):
|
| 46 |
+
if file.endswith('.json'):
|
| 47 |
+
file_path = os.path.join(root, file)
|
| 48 |
+
try:
|
| 49 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 50 |
+
data = json.load(f)
|
| 51 |
+
json_files.append(data)
|
| 52 |
+
except Exception as e:
|
| 53 |
+
print(f"Error loading {file_path}: {e}")
|
| 54 |
+
|
| 55 |
+
return json_files
|
| 56 |
+
|
| 57 |
+
def analyze_score_distribution(self, json_files: List[Dict]) -> Dict:
|
| 58 |
+
"""
|
| 59 |
+
Analyze the distribution of scores in the dataset
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
Dictionary with score statistics
|
| 63 |
+
"""
|
| 64 |
+
scores = []
|
| 65 |
+
for data in json_files:
|
| 66 |
+
if 'review_by_client_jp' in data:
|
| 67 |
+
score = data['review_by_client_jp'].get('点数', 0)
|
| 68 |
+
if score > 0: # Only count valid scores
|
| 69 |
+
scores.append(score)
|
| 70 |
+
self.score_distribution.append(score)
|
| 71 |
+
|
| 72 |
+
if scores:
|
| 73 |
+
stats = {
|
| 74 |
+
'total_conversations': len(json_files),
|
| 75 |
+
'conversations_with_scores': len(scores),
|
| 76 |
+
'mean_score': float(np.mean(scores)),
|
| 77 |
+
'median_score': float(np.median(scores)),
|
| 78 |
+
'std_score': float(np.std(scores)),
|
| 79 |
+
'min_score': float(np.min(scores)),
|
| 80 |
+
'max_score': float(np.max(scores)),
|
| 81 |
+
'percentiles': {
|
| 82 |
+
'25th': float(np.percentile(scores, 25)),
|
| 83 |
+
'50th': float(np.percentile(scores, 50)),
|
| 84 |
+
'75th': float(np.percentile(scores, 75)),
|
| 85 |
+
'90th': float(np.percentile(scores, 90))
|
| 86 |
+
},
|
| 87 |
+
'score_ranges': {
|
| 88 |
+
'0-30': int(sum(1 for s in scores if 0 <= s < 30)),
|
| 89 |
+
'30-50': int(sum(1 for s in scores if 30 <= s < 50)),
|
| 90 |
+
'50-60': int(sum(1 for s in scores if 50 <= s < 60)),
|
| 91 |
+
'60-70': int(sum(1 for s in scores if 60 <= s < 70)),
|
| 92 |
+
'70-80': int(sum(1 for s in scores if 70 <= s < 80)),
|
| 93 |
+
'80-90': int(sum(1 for s in scores if 80 <= s < 90)),
|
| 94 |
+
'90-100': int(sum(1 for s in scores if 90 <= s <= 100)),
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
# Calculate how many conversations would be kept at different thresholds
|
| 99 |
+
threshold_analysis = {}
|
| 100 |
+
for threshold in [30, 40, 50, 60, 65, 70, 75, 80]:
|
| 101 |
+
kept = sum(1 for s in scores if s >= threshold)
|
| 102 |
+
threshold_analysis[f'threshold_{threshold}'] = {
|
| 103 |
+
'conversations_kept': kept,
|
| 104 |
+
'percentage_kept': round((kept / len(scores)) * 100, 2)
|
| 105 |
+
}
|
| 106 |
+
stats['threshold_analysis'] = threshold_analysis
|
| 107 |
+
|
| 108 |
+
return stats
|
| 109 |
+
else:
|
| 110 |
+
return {'error': 'No valid scores found in dataset'}
|
| 111 |
+
|
| 112 |
+
def plot_score_distribution(self, save_path: str = "score_distribution.png"):
|
| 113 |
+
"""
|
| 114 |
+
Plot the distribution of scores
|
| 115 |
+
"""
|
| 116 |
+
if not self.score_distribution:
|
| 117 |
+
print("No scores to plot. Run analyze_score_distribution first.")
|
| 118 |
+
return
|
| 119 |
+
|
| 120 |
+
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
|
| 121 |
+
|
| 122 |
+
# Histogram
|
| 123 |
+
axes[0, 0].hist(self.score_distribution, bins=20, edgecolor='black', alpha=0.7)
|
| 124 |
+
axes[0, 0].axvline(self.min_score, color='red', linestyle='--',
|
| 125 |
+
label=f'Current threshold: {self.min_score}')
|
| 126 |
+
axes[0, 0].set_xlabel('Score')
|
| 127 |
+
axes[0, 0].set_ylabel('Frequency')
|
| 128 |
+
axes[0, 0].set_title('Score Distribution')
|
| 129 |
+
axes[0, 0].legend()
|
| 130 |
+
axes[0, 0].grid(True, alpha=0.3)
|
| 131 |
+
|
| 132 |
+
# Box plot
|
| 133 |
+
axes[0, 1].boxplot(self.score_distribution, vert=True)
|
| 134 |
+
axes[0, 1].set_ylabel('Score')
|
| 135 |
+
axes[0, 1].set_title('Score Box Plot')
|
| 136 |
+
axes[0, 1].grid(True, alpha=0.3)
|
| 137 |
+
|
| 138 |
+
# Cumulative distribution
|
| 139 |
+
sorted_scores = np.sort(self.score_distribution)
|
| 140 |
+
cumulative = np.arange(1, len(sorted_scores) + 1) / len(sorted_scores)
|
| 141 |
+
axes[1, 0].plot(sorted_scores, cumulative)
|
| 142 |
+
axes[1, 0].axvline(self.min_score, color='red', linestyle='--',
|
| 143 |
+
label=f'Current threshold: {self.min_score}')
|
| 144 |
+
axes[1, 0].set_xlabel('Score')
|
| 145 |
+
axes[1, 0].set_ylabel('Cumulative Probability')
|
| 146 |
+
axes[1, 0].set_title('Cumulative Distribution')
|
| 147 |
+
axes[1, 0].legend()
|
| 148 |
+
axes[1, 0].grid(True, alpha=0.3)
|
| 149 |
+
|
| 150 |
+
# Threshold impact analysis
|
| 151 |
+
thresholds = range(30, 90, 5)
|
| 152 |
+
kept_percentages = []
|
| 153 |
+
for t in thresholds:
|
| 154 |
+
kept = sum(1 for s in self.score_distribution if s >= t)
|
| 155 |
+
kept_percentages.append((kept / len(self.score_distribution)) * 100)
|
| 156 |
+
|
| 157 |
+
axes[1, 1].plot(thresholds, kept_percentages, marker='o')
|
| 158 |
+
axes[1, 1].axvline(self.min_score, color='red', linestyle='--',
|
| 159 |
+
label=f'Current threshold: {self.min_score}')
|
| 160 |
+
axes[1, 1].set_xlabel('Score Threshold')
|
| 161 |
+
axes[1, 1].set_ylabel('% of Conversations Kept')
|
| 162 |
+
axes[1, 1].set_title('Impact of Score Threshold')
|
| 163 |
+
axes[1, 1].legend()
|
| 164 |
+
axes[1, 1].grid(True, alpha=0.3)
|
| 165 |
+
|
| 166 |
+
plt.tight_layout()
|
| 167 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 168 |
+
plt.show()
|
| 169 |
+
print(f"Score distribution plot saved to {save_path}")
|
| 170 |
+
|
| 171 |
+
def extract_high_quality_conversations(self, data: Dict) -> List[Dict]:
|
| 172 |
+
"""
|
| 173 |
+
Extract conversations with high counselor ratings based on min_score
|
| 174 |
+
Focus on conversations where counselor performed well
|
| 175 |
+
"""
|
| 176 |
+
conversations = []
|
| 177 |
+
|
| 178 |
+
# Check if review exists and has good score
|
| 179 |
+
if 'review_by_client_jp' in data:
|
| 180 |
+
review = data['review_by_client_jp']
|
| 181 |
+
score = review.get('点数', 0)
|
| 182 |
+
|
| 183 |
+
# Use configurable min_score threshold
|
| 184 |
+
if score >= self.min_score:
|
| 185 |
+
dialogue = data.get('dialogue', [])
|
| 186 |
+
|
| 187 |
+
# Create conversation pairs
|
| 188 |
+
conversation_text = ""
|
| 189 |
+
for turn in dialogue:
|
| 190 |
+
role = turn['role']
|
| 191 |
+
utterance = turn['utterance']
|
| 192 |
+
|
| 193 |
+
if role == 'counselor':
|
| 194 |
+
conversation_text += f"カウンセラー: {utterance}\n"
|
| 195 |
+
else:
|
| 196 |
+
conversation_text += f"クライアント: {utterance}\n"
|
| 197 |
+
|
| 198 |
+
# Extract detailed metrics for potential weighted training
|
| 199 |
+
conversations.append({
|
| 200 |
+
'text': conversation_text,
|
| 201 |
+
'score': score, # Store the score here
|
| 202 |
+
'topic': data.get('topic', {}).get('main_jp', 'Unknown'),
|
| 203 |
+
'review_metrics': {
|
| 204 |
+
'empathy': review.get('聴いてもらえた、わかってもらえたと感じた', 0),
|
| 205 |
+
'respect': review.get('尊重されたと感じた', 0),
|
| 206 |
+
'insights': review.get('新しい気づきや体験があった', 0),
|
| 207 |
+
'hope': review.get('希望や期待を感じられた', 0),
|
| 208 |
+
'concerns_addressed': review.get('取り組みたかったことを扱えた', 0),
|
| 209 |
+
'collaboration': review.get('一緒に考えながら取り組めた', 0),
|
| 210 |
+
'rhythm': review.get('やりとりのリズムがあっていた', 0),
|
| 211 |
+
'comfort': review.get('居心地のよいやりとりだった', 0),
|
| 212 |
+
'overall_appropriate': review.get('全体として適切でよかった', 0),
|
| 213 |
+
'valuable': review.get('今回の相談は価値があった', 0),
|
| 214 |
+
'smooth_start': review.get('相談開始の円滑さ', 0),
|
| 215 |
+
'good_ending': review.get('相談終了のタイミング(不必要に聴きすぎていないか)、円滑さ', 0),
|
| 216 |
+
'acceptance_empathy': review.get('受容·共感', 0),
|
| 217 |
+
'affirmation': review.get('肯定·承認', 0),
|
| 218 |
+
'effective_questions': review.get('的確な質問による会話の促進', 0),
|
| 219 |
+
'summarization': review.get('要約', 0),
|
| 220 |
+
'problem_clarification': review.get('問題の明確化', 0),
|
| 221 |
+
'goal_clarification': review.get('この相談での目標の明確化', 0),
|
| 222 |
+
'actionable_suggestions': review.get('次の行動につながる提案', 0),
|
| 223 |
+
'encouragement': review.get('勇気づけ·希望の喚起', 0)
|
| 224 |
+
}
|
| 225 |
+
})
|
| 226 |
+
|
| 227 |
+
return conversations
|
| 228 |
+
|
| 229 |
+
def create_training_examples(self, conversations: List[Dict],
|
| 230 |
+
use_weighted_sampling: bool = False) -> List[Dict]:
|
| 231 |
+
"""
|
| 232 |
+
Create training examples in instruction-following format
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
conversations: List of conversation dictionaries
|
| 236 |
+
use_weighted_sampling: If True, create more examples from higher-scored conversations
|
| 237 |
+
"""
|
| 238 |
+
training_examples = []
|
| 239 |
+
|
| 240 |
+
for conv in tqdm(conversations, desc="Creating training examples"):
|
| 241 |
+
dialogue_lines = conv['text'].split('\n')
|
| 242 |
+
score = conv['score'] # Get score from the conversation dict
|
| 243 |
+
|
| 244 |
+
# Calculate sampling weight based on score if enabled
|
| 245 |
+
if use_weighted_sampling:
|
| 246 |
+
# Higher scores get more weight (normalized to 1-3 range)
|
| 247 |
+
weight = max(1, int((score - self.min_score) / 20) + 1)
|
| 248 |
+
else:
|
| 249 |
+
weight = 1
|
| 250 |
+
|
| 251 |
+
# Create multiple training examples from each conversation
|
| 252 |
+
for _ in range(weight): # Repeat based on weight
|
| 253 |
+
for i in range(0, len(dialogue_lines) - 1, 2):
|
| 254 |
+
if i + 1 < len(dialogue_lines):
|
| 255 |
+
client_line = dialogue_lines[i]
|
| 256 |
+
counselor_line = dialogue_lines[i + 1]
|
| 257 |
+
|
| 258 |
+
# Check if lines contain the expected prefixes
|
| 259 |
+
if 'クライアント:' in client_line and 'カウンセラー:' in counselor_line:
|
| 260 |
+
client_msg = client_line.replace('クライアント: ', '').replace('クライアント:', '').strip()
|
| 261 |
+
counselor_msg = counselor_line.replace('カウンセラー: ', '').replace('カウンセラー:', '').strip()
|
| 262 |
+
|
| 263 |
+
# Skip empty messages
|
| 264 |
+
if not client_msg or not counselor_msg:
|
| 265 |
+
continue
|
| 266 |
+
|
| 267 |
+
# Format for instruction tuning
|
| 268 |
+
example = {
|
| 269 |
+
'instruction': self.system_prompt,
|
| 270 |
+
'input': client_msg,
|
| 271 |
+
'output': counselor_msg,
|
| 272 |
+
'score': score, # Use the score from conversation
|
| 273 |
+
'topic': conv['topic'],
|
| 274 |
+
'metrics': conv['review_metrics'] # Include detailed metrics
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
training_examples.append(example)
|
| 278 |
+
|
| 279 |
+
return training_examples
|
| 280 |
+
|
| 281 |
+
def prepare_dataset(self, test_size: float = 0.1, val_size: float = 0.1,
|
| 282 |
+
use_weighted_sampling: bool = False,
|
| 283 |
+
analyze_scores: bool = True):
|
| 284 |
+
"""
|
| 285 |
+
Prepare train, validation, and test datasets
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
test_size: Proportion of data for testing
|
| 289 |
+
val_size: Proportion of data for validation
|
| 290 |
+
use_weighted_sampling: If True, oversample high-quality conversations
|
| 291 |
+
analyze_scores: If True, print score distribution analysis
|
| 292 |
+
"""
|
| 293 |
+
print("Loading KokoroChat dataset...")
|
| 294 |
+
json_files = self.load_json_files()
|
| 295 |
+
print(f"Loaded {len(json_files)} conversation files")
|
| 296 |
+
|
| 297 |
+
# Analyze score distribution if requested
|
| 298 |
+
if analyze_scores:
|
| 299 |
+
print("\n" + "="*60)
|
| 300 |
+
print("SCORE DISTRIBUTION ANALYSIS")
|
| 301 |
+
print("="*60)
|
| 302 |
+
stats = self.analyze_score_distribution(json_files)
|
| 303 |
+
|
| 304 |
+
if 'error' not in stats:
|
| 305 |
+
print(f"Total conversations: {stats['total_conversations']}")
|
| 306 |
+
print(f"Conversations with scores: {stats['conversations_with_scores']}")
|
| 307 |
+
print(f"\nScore Statistics:")
|
| 308 |
+
print(f" Mean: {stats['mean_score']:.2f}")
|
| 309 |
+
print(f" Median: {stats['median_score']:.2f}")
|
| 310 |
+
print(f" Std Dev: {stats['std_score']:.2f}")
|
| 311 |
+
print(f" Range: {stats['min_score']:.0f} - {stats['max_score']:.0f}")
|
| 312 |
+
|
| 313 |
+
print(f"\nScore Distribution:")
|
| 314 |
+
for range_name, count in stats['score_ranges'].items():
|
| 315 |
+
percentage = (count / stats['conversations_with_scores']) * 100
|
| 316 |
+
print(f" {range_name}: {count} ({percentage:.1f}%)")
|
| 317 |
+
|
| 318 |
+
print(f"\nThreshold Impact Analysis:")
|
| 319 |
+
for threshold_name, data in stats['threshold_analysis'].items():
|
| 320 |
+
threshold = threshold_name.split('_')[1]
|
| 321 |
+
print(f" Threshold >= {threshold}: {data['conversations_kept']} conversations ({data['percentage_kept']:.1f}%)")
|
| 322 |
+
|
| 323 |
+
print(f"\nCurrent threshold ({self.min_score}) will keep: ", end="")
|
| 324 |
+
kept = sum(1 for s in self.score_distribution if s >= self.min_score)
|
| 325 |
+
print(f"{kept} conversations ({(kept/len(self.score_distribution))*100:.1f}%)")
|
| 326 |
+
print("="*60 + "\n")
|
| 327 |
+
|
| 328 |
+
# Plot distribution
|
| 329 |
+
self.plot_score_distribution()
|
| 330 |
+
|
| 331 |
+
all_conversations = []
|
| 332 |
+
filtered_count = 0
|
| 333 |
+
total_count = 0
|
| 334 |
+
|
| 335 |
+
for data in json_files:
|
| 336 |
+
if 'review_by_client_jp' in data:
|
| 337 |
+
total_count += 1
|
| 338 |
+
score = data['review_by_client_jp'].get('点数', 0)
|
| 339 |
+
if score < self.min_score:
|
| 340 |
+
filtered_count += 1
|
| 341 |
+
|
| 342 |
+
conversations = self.extract_high_quality_conversations(data)
|
| 343 |
+
all_conversations.extend(conversations)
|
| 344 |
+
|
| 345 |
+
print(f"Filtered out {filtered_count} conversations with score < {self.min_score}")
|
| 346 |
+
print(f"Extracted {len(all_conversations)} high-quality conversations (score >= {self.min_score})")
|
| 347 |
+
|
| 348 |
+
# Create training examples
|
| 349 |
+
training_examples = self.create_training_examples(
|
| 350 |
+
all_conversations,
|
| 351 |
+
use_weighted_sampling=use_weighted_sampling
|
| 352 |
+
)
|
| 353 |
+
print(f"Created {len(training_examples)} training examples")
|
| 354 |
+
|
| 355 |
+
if use_weighted_sampling:
|
| 356 |
+
print("Note: Used weighted sampling - higher scored conversations appear more frequently")
|
| 357 |
+
|
| 358 |
+
# Shuffle and split
|
| 359 |
+
random.shuffle(training_examples)
|
| 360 |
+
|
| 361 |
+
total_size = len(training_examples)
|
| 362 |
+
test_split = int(total_size * test_size)
|
| 363 |
+
val_split = int(total_size * val_size)
|
| 364 |
+
|
| 365 |
+
test_data = training_examples[:test_split]
|
| 366 |
+
val_data = training_examples[test_split:test_split + val_split]
|
| 367 |
+
train_data = training_examples[test_split + val_split:]
|
| 368 |
+
|
| 369 |
+
print(f"\nDataset splits:")
|
| 370 |
+
print(f" Train: {len(train_data)} examples")
|
| 371 |
+
print(f" Validation: {len(val_data)} examples")
|
| 372 |
+
print(f" Test: {len(test_data)} examples")
|
| 373 |
+
|
| 374 |
+
return {
|
| 375 |
+
'train': train_data,
|
| 376 |
+
'validation': val_data,
|
| 377 |
+
'test': test_data
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
def format_for_lfm(self, example: Dict) -> str:
|
| 381 |
+
"""
|
| 382 |
+
Format example for LFM model training
|
| 383 |
+
"""
|
| 384 |
+
formatted = f"""### Instruction:
|
| 385 |
+
{example['instruction']}
|
| 386 |
+
|
| 387 |
+
### Input:
|
| 388 |
+
{example['input']}
|
| 389 |
+
|
| 390 |
+
### Response:
|
| 391 |
+
{example['output']}"""
|
| 392 |
+
return formatted
|
| 393 |
+
|
| 394 |
+
def save_datasets(self, datasets: Dict, output_dir: str):
|
| 395 |
+
"""Save processed datasets with proper type conversion for JSON serialization"""
|
| 396 |
+
output_path = Path(output_dir)
|
| 397 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 398 |
+
|
| 399 |
+
# Helper function to convert numpy types to Python native types
|
| 400 |
+
def convert_to_native(obj):
|
| 401 |
+
if isinstance(obj, np.integer):
|
| 402 |
+
return int(obj)
|
| 403 |
+
elif isinstance(obj, np.floating):
|
| 404 |
+
return float(obj)
|
| 405 |
+
elif isinstance(obj, np.ndarray):
|
| 406 |
+
return obj.tolist()
|
| 407 |
+
else:
|
| 408 |
+
return obj
|
| 409 |
+
|
| 410 |
+
# Save dataset statistics
|
| 411 |
+
stats = {
|
| 412 |
+
'min_score_threshold': int(self.min_score),
|
| 413 |
+
'dataset_sizes': {
|
| 414 |
+
'train': len(datasets['train']),
|
| 415 |
+
'validation': len(datasets['validation']),
|
| 416 |
+
'test': len(datasets['test'])
|
| 417 |
+
},
|
| 418 |
+
'score_distribution': {}
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
for split_name, data in datasets.items():
|
| 422 |
+
# Calculate score distribution for this split
|
| 423 |
+
scores = [ex['score'] for ex in data]
|
| 424 |
+
if scores:
|
| 425 |
+
stats['score_distribution'][split_name] = {
|
| 426 |
+
'mean': float(np.mean(scores)),
|
| 427 |
+
'median': float(np.median(scores)),
|
| 428 |
+
'min': float(np.min(scores)),
|
| 429 |
+
'max': float(np.max(scores)),
|
| 430 |
+
'std': float(np.std(scores))
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
# Save as JSONL for easier streaming
|
| 434 |
+
file_path = output_path / f"{split_name}.jsonl"
|
| 435 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
| 436 |
+
for example in data:
|
| 437 |
+
formatted_text = self.format_for_lfm(example)
|
| 438 |
+
# Convert all numpy types to native Python types
|
| 439 |
+
json_obj = {
|
| 440 |
+
'text': formatted_text,
|
| 441 |
+
'score': convert_to_native(example['score']),
|
| 442 |
+
'topic': example['topic']
|
| 443 |
+
}
|
| 444 |
+
json_line = json.dumps(json_obj, ensure_ascii=False)
|
| 445 |
+
f.write(json_line + '\n')
|
| 446 |
+
|
| 447 |
+
print(f"Saved {split_name} dataset with {len(data)} examples to {file_path}")
|
| 448 |
+
|
| 449 |
+
# Save statistics
|
| 450 |
+
stats_path = output_path / "dataset_stats.json"
|
| 451 |
+
with open(stats_path, 'w', encoding='utf-8') as f:
|
| 452 |
+
json.dump(stats, f, ensure_ascii=False, indent=2)
|
| 453 |
+
print(f"Saved dataset statistics to {stats_path}")
|
| 454 |
+
|
| 455 |
+
# Print summary statistics
|
| 456 |
+
print("\n" + "="*60)
|
| 457 |
+
print("DATASET SUMMARY")
|
| 458 |
+
print("="*60)
|
| 459 |
+
print(f"Minimum score threshold: {stats['min_score_threshold']}")
|
| 460 |
+
print("\nDataset sizes:")
|
| 461 |
+
for split, size in stats['dataset_sizes'].items():
|
| 462 |
+
print(f" {split}: {size} examples")
|
| 463 |
+
|
| 464 |
+
print("\nScore distributions by split:")
|
| 465 |
+
for split, dist in stats['score_distribution'].items():
|
| 466 |
+
print(f" {split}:")
|
| 467 |
+
print(f" Mean: {dist['mean']:.2f}")
|
| 468 |
+
print(f" Std: {dist['std']:.2f}")
|
| 469 |
+
print(f" Range: {dist['min']:.0f} - {dist['max']:.0f}")
|
| 470 |
+
print("="*60)
|
| 471 |
+
|
| 472 |
+
# Run preprocessing with different score thresholds
|
| 473 |
+
if __name__ == "__main__":
|
| 474 |
+
import argparse
|
| 475 |
+
|
| 476 |
+
parser = argparse.ArgumentParser(description='Preprocess KokoroChat dataset')
|
| 477 |
+
parser.add_argument('--data_path', type=str, default='./KokoroChat',
|
| 478 |
+
help='Path to KokoroChat repository')
|
| 479 |
+
parser.add_argument('--min_score', type=int, default=70,
|
| 480 |
+
help='Minimum score threshold for filtering (default: 70)')
|
| 481 |
+
parser.add_argument('--output_dir', type=str, default='./processed_data',
|
| 482 |
+
help='Output directory for processed data')
|
| 483 |
+
parser.add_argument('--weighted_sampling', action='store_true',
|
| 484 |
+
help='Use weighted sampling based on scores')
|
| 485 |
+
parser.add_argument('--test_size', type=float, default=0.1,
|
| 486 |
+
help='Test set size (default: 0.1)')
|
| 487 |
+
parser.add_argument('--val_size', type=float, default=0.1,
|
| 488 |
+
help='Validation set size (default: 0.1)')
|
| 489 |
+
parser.add_argument('--analyze_only', action='store_true',
|
| 490 |
+
help='Only analyze score distribution without processing')
|
| 491 |
+
|
| 492 |
+
args = parser.parse_args()
|
| 493 |
+
|
| 494 |
+
# Initialize preprocessor with configurable min_score
|
| 495 |
+
preprocessor = KokoroChatPreprocessor(
|
| 496 |
+
data_path=args.data_path,
|
| 497 |
+
min_score=args.min_score
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
if args.analyze_only:
|
| 501 |
+
# Just analyze the score distribution
|
| 502 |
+
print("Running score distribution analysis only...")
|
| 503 |
+
json_files = preprocessor.load_json_files()
|
| 504 |
+
stats = preprocessor.analyze_score_distribution(json_files)
|
| 505 |
+
preprocessor.plot_score_distribution(f"score_analysis_threshold_{args.min_score}.png")
|
| 506 |
+
else:
|
| 507 |
+
# Full preprocessing
|
| 508 |
+
print(f"Processing with minimum score threshold: {args.min_score}")
|
| 509 |
+
datasets = preprocessor.prepare_dataset(
|
| 510 |
+
test_size=args.test_size,
|
| 511 |
+
val_size=args.val_size,
|
| 512 |
+
use_weighted_sampling=args.weighted_sampling,
|
| 513 |
+
analyze_scores=True
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
# Save with threshold in directory name
|
| 517 |
+
output_dir = f"{args.output_dir}_score{args.min_score}"
|
| 518 |
+
preprocessor.save_datasets(datasets, output_dir)
|
| 519 |
+
|
| 520 |
+
print(f"\nProcessing complete! Data saved to {output_dir}")
|
| 521 |
+
print("\nNext steps:")
|
| 522 |
+
print("1. Run fine-tuning: python finetune_lfm.py")
|
| 523 |
+
print("2. Run benchmarking: python benchmark_model.py")
|
| 524 |
+
print("3. Optimize for mobile: python optimize_for_mobile.py")
|
finalmerged_model.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:73fdb12fa8819f3d5160ec5414e55e827d08d1d69874a4168035b7f0c9fb02a4
|
| 3 |
+
size 1806737356
|
finetune_lfm.py
ADDED
|
@@ -0,0 +1,1311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import torch
|
| 2 |
+
# from transformers import (
|
| 3 |
+
# AutoModelForCausalLM,
|
| 4 |
+
# AutoTokenizer,
|
| 5 |
+
# TrainingArguments,
|
| 6 |
+
# Trainer,
|
| 7 |
+
# DataCollatorForLanguageModeling,
|
| 8 |
+
# BitsAndBytesConfig
|
| 9 |
+
# )
|
| 10 |
+
# from peft import (
|
| 11 |
+
# LoraConfig,
|
| 12 |
+
# get_peft_model,
|
| 13 |
+
# prepare_model_for_kbit_training,
|
| 14 |
+
# TaskType
|
| 15 |
+
# )
|
| 16 |
+
# from datasets import load_dataset, Dataset
|
| 17 |
+
# import os
|
| 18 |
+
# from typing import Dict, List, Optional
|
| 19 |
+
# import numpy as np
|
| 20 |
+
# from tqdm import tqdm
|
| 21 |
+
# import json
|
| 22 |
+
# import gc
|
| 23 |
+
# import warnings
|
| 24 |
+
# warnings.filterwarnings('ignore')
|
| 25 |
+
|
| 26 |
+
# class LFMCounselorFineTuner:
|
| 27 |
+
# def __init__(self, model_name: str = "LiquidAI/LFM2-2.6B", use_4bit: bool = True):
|
| 28 |
+
# """
|
| 29 |
+
# Initialize the fine-tuner for LFM models
|
| 30 |
+
|
| 31 |
+
# Args:
|
| 32 |
+
# model_name: Name of the base model
|
| 33 |
+
# use_4bit: Whether to use 4-bit quantization for memory efficiency
|
| 34 |
+
# """
|
| 35 |
+
# self.model_name = model_name
|
| 36 |
+
# self.use_4bit = use_4bit
|
| 37 |
+
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 38 |
+
|
| 39 |
+
# print(f"Using device: {self.device}")
|
| 40 |
+
# if torch.cuda.is_available():
|
| 41 |
+
# print(f"GPU: {torch.cuda.get_device_name(0)}")
|
| 42 |
+
# print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
| 43 |
+
|
| 44 |
+
# # Disable wandb for simplicity
|
| 45 |
+
# os.environ["WANDB_DISABLED"] = "true"
|
| 46 |
+
|
| 47 |
+
# def setup_model_and_tokenizer(self):
|
| 48 |
+
# """Setup model with quantization and LoRA"""
|
| 49 |
+
|
| 50 |
+
# print("Loading tokenizer...")
|
| 51 |
+
# # Tokenizer setup
|
| 52 |
+
# try:
|
| 53 |
+
# self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
| 54 |
+
# except:
|
| 55 |
+
# # Fallback to a known working tokenizer if model-specific one fails
|
| 56 |
+
# print("Using fallback tokenizer...")
|
| 57 |
+
# self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 58 |
+
|
| 59 |
+
# # Add padding token if it doesn't exist
|
| 60 |
+
# if self.tokenizer.pad_token is None:
|
| 61 |
+
# self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 62 |
+
# if self.tokenizer.eos_token is None:
|
| 63 |
+
# self.tokenizer.eos_token = "</s>"
|
| 64 |
+
# self.tokenizer.pad_token = "</s>"
|
| 65 |
+
|
| 66 |
+
# self.tokenizer.padding_side = "right"
|
| 67 |
+
|
| 68 |
+
# # Quantization config for memory efficiency
|
| 69 |
+
# if self.use_4bit:
|
| 70 |
+
# print("Setting up 4-bit quantization...")
|
| 71 |
+
# bnb_config = BitsAndBytesConfig(
|
| 72 |
+
# load_in_4bit=True,
|
| 73 |
+
# bnb_4bit_quant_type="nf4",
|
| 74 |
+
# bnb_4bit_compute_dtype=torch.float16, # Use float16 for better compatibility
|
| 75 |
+
# bnb_4bit_use_double_quant=True
|
| 76 |
+
# )
|
| 77 |
+
# else:
|
| 78 |
+
# bnb_config = None
|
| 79 |
+
|
| 80 |
+
# # Load model
|
| 81 |
+
# print(f"Loading model: {self.model_name}...")
|
| 82 |
+
# try:
|
| 83 |
+
# self.model = AutoModelForCausalLM.from_pretrained(
|
| 84 |
+
# self.model_name,
|
| 85 |
+
# quantization_config=bnb_config,
|
| 86 |
+
# device_map="auto",
|
| 87 |
+
# trust_remote_code=True,
|
| 88 |
+
# torch_dtype=torch.float16
|
| 89 |
+
# )
|
| 90 |
+
# except Exception as e:
|
| 91 |
+
# print(f"Error loading model: {e}")
|
| 92 |
+
# print("Attempting to load without quantization...")
|
| 93 |
+
# self.model = AutoModelForCausalLM.from_pretrained(
|
| 94 |
+
# self.model_name,
|
| 95 |
+
# device_map="auto",
|
| 96 |
+
# trust_remote_code=True,
|
| 97 |
+
# torch_dtype=torch.float16,
|
| 98 |
+
# low_cpu_mem_usage=True
|
| 99 |
+
# )
|
| 100 |
+
|
| 101 |
+
# # Enable gradient checkpointing to save memory
|
| 102 |
+
# if hasattr(self.model, 'gradient_checkpointing_enable'):
|
| 103 |
+
# self.model.gradient_checkpointing_enable()
|
| 104 |
+
|
| 105 |
+
# # Prepare model for k-bit training
|
| 106 |
+
# if self.use_4bit:
|
| 107 |
+
# print("Preparing model for 4-bit training...")
|
| 108 |
+
# self.model = prepare_model_for_kbit_training(self.model)
|
| 109 |
+
|
| 110 |
+
# # LoRA configuration - optimized for counseling task
|
| 111 |
+
# print("Applying LoRA configuration...")
|
| 112 |
+
|
| 113 |
+
# # Find the target modules dynamically
|
| 114 |
+
# target_modules = self.find_target_modules()
|
| 115 |
+
|
| 116 |
+
# lora_config = LoraConfig(
|
| 117 |
+
# r=16, # Reduced rank for stability
|
| 118 |
+
# lora_alpha=32, # Alpha parameter for LoRA scaling
|
| 119 |
+
# target_modules=target_modules,
|
| 120 |
+
# lora_dropout=0.05,
|
| 121 |
+
# bias="none",
|
| 122 |
+
# task_type=TaskType.CAUSAL_LM,
|
| 123 |
+
# inference_mode=False
|
| 124 |
+
# )
|
| 125 |
+
|
| 126 |
+
# # Apply LoRA
|
| 127 |
+
# self.model = get_peft_model(self.model, lora_config)
|
| 128 |
+
|
| 129 |
+
# # Print trainable parameters
|
| 130 |
+
# self.model.print_trainable_parameters()
|
| 131 |
+
|
| 132 |
+
# def find_target_modules(self):
|
| 133 |
+
# """Find linear modules to apply LoRA to"""
|
| 134 |
+
# target_modules = []
|
| 135 |
+
# for name, module in self.model.named_modules():
|
| 136 |
+
# if isinstance(module, torch.nn.Linear):
|
| 137 |
+
# # Extract the module name
|
| 138 |
+
# names = name.split('.')
|
| 139 |
+
# if len(names) > 0:
|
| 140 |
+
# target_modules.append(names[-1])
|
| 141 |
+
|
| 142 |
+
# # Remove duplicates and filter common patterns
|
| 143 |
+
# target_modules = list(set(target_modules))
|
| 144 |
+
|
| 145 |
+
# # Common patterns for transformer models
|
| 146 |
+
# common_targets = ["q_proj", "v_proj", "k_proj", "o_proj",
|
| 147 |
+
# "gate_proj", "up_proj", "down_proj",
|
| 148 |
+
# "fc1", "fc2", "query", "key", "value", "dense"]
|
| 149 |
+
|
| 150 |
+
# # Filter to only include common targets if they exist
|
| 151 |
+
# final_targets = [t for t in target_modules if any(ct in t.lower() for ct in common_targets)]
|
| 152 |
+
|
| 153 |
+
# # If no common targets found, use all linear layers
|
| 154 |
+
# if not final_targets:
|
| 155 |
+
# final_targets = target_modules[:6] # Limit to prevent too many parameters
|
| 156 |
+
|
| 157 |
+
# print(f"LoRA target modules: {final_targets}")
|
| 158 |
+
# return final_targets if final_targets else ["q_proj", "v_proj"] # Fallback
|
| 159 |
+
|
| 160 |
+
# def load_and_process_datasets(self, data_path: str):
|
| 161 |
+
# """Load and process datasets without multiprocessing issues"""
|
| 162 |
+
|
| 163 |
+
# print(f"Loading datasets from {data_path}...")
|
| 164 |
+
|
| 165 |
+
# # Load train dataset
|
| 166 |
+
# train_texts = []
|
| 167 |
+
# with open(f'{data_path}/train.jsonl', 'r', encoding='utf-8') as f:
|
| 168 |
+
# for line in tqdm(f, desc="Loading training data"):
|
| 169 |
+
# data = json.loads(line)
|
| 170 |
+
# train_texts.append(data['text'])
|
| 171 |
+
|
| 172 |
+
# # Load validation dataset
|
| 173 |
+
# val_texts = []
|
| 174 |
+
# with open(f'{data_path}/validation.jsonl', 'r', encoding='utf-8') as f:
|
| 175 |
+
# for line in tqdm(f, desc="Loading validation data"):
|
| 176 |
+
# data = json.loads(line)
|
| 177 |
+
# val_texts.append(data['text'])
|
| 178 |
+
|
| 179 |
+
# print(f"Loaded {len(train_texts)} training examples")
|
| 180 |
+
# print(f"Loaded {len(val_texts)} validation examples")
|
| 181 |
+
|
| 182 |
+
# # Tokenize datasets in batches (avoiding multiprocessing)
|
| 183 |
+
# print("Tokenizing training dataset...")
|
| 184 |
+
# train_encodings = self.tokenize_texts(train_texts)
|
| 185 |
+
|
| 186 |
+
# print("Tokenizing validation dataset...")
|
| 187 |
+
# val_encodings = self.tokenize_texts(val_texts)
|
| 188 |
+
|
| 189 |
+
# # Create datasets
|
| 190 |
+
# self.train_dataset = Dataset.from_dict(train_encodings)
|
| 191 |
+
# self.val_dataset = Dataset.from_dict(val_encodings)
|
| 192 |
+
|
| 193 |
+
# # Set format for PyTorch
|
| 194 |
+
# self.train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
|
| 195 |
+
# self.val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
|
| 196 |
+
|
| 197 |
+
# # Clean up memory
|
| 198 |
+
# del train_texts, val_texts, train_encodings, val_encodings
|
| 199 |
+
# gc.collect()
|
| 200 |
+
|
| 201 |
+
# def tokenize_texts(self, texts: List[str], batch_size: int = 100):
|
| 202 |
+
# """Tokenize texts in batches to avoid memory issues"""
|
| 203 |
+
# all_input_ids = []
|
| 204 |
+
# all_attention_masks = []
|
| 205 |
+
|
| 206 |
+
# for i in tqdm(range(0, len(texts), batch_size), desc="Tokenizing"):
|
| 207 |
+
# batch_texts = texts[i:i + batch_size]
|
| 208 |
+
|
| 209 |
+
# # Tokenize batch
|
| 210 |
+
# encodings = self.tokenizer(
|
| 211 |
+
# batch_texts,
|
| 212 |
+
# truncation=True,
|
| 213 |
+
# padding='max_length',
|
| 214 |
+
# max_length=512,
|
| 215 |
+
# return_tensors='pt'
|
| 216 |
+
# )
|
| 217 |
+
|
| 218 |
+
# # Convert to lists
|
| 219 |
+
# all_input_ids.extend(encodings['input_ids'].tolist())
|
| 220 |
+
# all_attention_masks.extend(encodings['attention_mask'].tolist())
|
| 221 |
+
|
| 222 |
+
# # Create labels (same as input_ids for language modeling)
|
| 223 |
+
# labels = all_input_ids.copy()
|
| 224 |
+
|
| 225 |
+
# return {
|
| 226 |
+
# 'input_ids': all_input_ids,
|
| 227 |
+
# 'attention_mask': all_attention_masks,
|
| 228 |
+
# 'labels': labels
|
| 229 |
+
# }
|
| 230 |
+
|
| 231 |
+
# def setup_training_args(self, output_dir: str = "./counselor_model_2b"):
|
| 232 |
+
# """Setup training arguments optimized for counseling task"""
|
| 233 |
+
|
| 234 |
+
# print("Setting up training arguments...")
|
| 235 |
+
|
| 236 |
+
# # Calculate batch sizes based on available memory
|
| 237 |
+
# if torch.cuda.is_available():
|
| 238 |
+
# gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
|
| 239 |
+
# if gpu_memory < 16: # Less than 16GB
|
| 240 |
+
# batch_size = 1
|
| 241 |
+
# gradient_accumulation = 16
|
| 242 |
+
# elif gpu_memory < 24: # Less than 24GB
|
| 243 |
+
# batch_size = 2
|
| 244 |
+
# gradient_accumulation = 8
|
| 245 |
+
# else: # 24GB or more
|
| 246 |
+
# batch_size = 4
|
| 247 |
+
# gradient_accumulation = 4
|
| 248 |
+
# else:
|
| 249 |
+
# batch_size = 1
|
| 250 |
+
# gradient_accumulation = 16
|
| 251 |
+
|
| 252 |
+
# print(f"Using batch_size={batch_size}, gradient_accumulation={gradient_accumulation}")
|
| 253 |
+
|
| 254 |
+
# self.training_args = TrainingArguments(
|
| 255 |
+
# output_dir=output_dir,
|
| 256 |
+
# num_train_epochs=3,
|
| 257 |
+
# per_device_train_batch_size=batch_size,
|
| 258 |
+
# per_device_eval_batch_size=batch_size,
|
| 259 |
+
# gradient_accumulation_steps=gradient_accumulation,
|
| 260 |
+
# gradient_checkpointing=True,
|
| 261 |
+
# warmup_steps=100,
|
| 262 |
+
# learning_rate=5e-5, # Conservative learning rate
|
| 263 |
+
# fp16=True,
|
| 264 |
+
# logging_steps=50,
|
| 265 |
+
# eval_strategy="steps",
|
| 266 |
+
# eval_steps=200,
|
| 267 |
+
# save_strategy="steps",
|
| 268 |
+
# save_steps=400,
|
| 269 |
+
# save_total_limit=2,
|
| 270 |
+
# load_best_model_at_end=True,
|
| 271 |
+
# metric_for_best_model="eval_loss",
|
| 272 |
+
# greater_is_better=False,
|
| 273 |
+
# report_to="none", # Disable all reporting
|
| 274 |
+
# push_to_hub=False,
|
| 275 |
+
# optim="adamw_torch", # Use standard optimizer
|
| 276 |
+
# lr_scheduler_type="linear",
|
| 277 |
+
# weight_decay=0.01,
|
| 278 |
+
# max_grad_norm=1.0,
|
| 279 |
+
# remove_unused_columns=False,
|
| 280 |
+
# label_names=["labels"],
|
| 281 |
+
# dataloader_num_workers=0, # Disable multiprocessing in dataloader
|
| 282 |
+
# dataloader_pin_memory=False, # Disable pinned memory to avoid issues
|
| 283 |
+
# )
|
| 284 |
+
|
| 285 |
+
# def train(self):
|
| 286 |
+
# """Execute training"""
|
| 287 |
+
|
| 288 |
+
# print("Initializing trainer...")
|
| 289 |
+
|
| 290 |
+
# # Data collator for language modeling
|
| 291 |
+
# data_collator = DataCollatorForLanguageModeling(
|
| 292 |
+
# tokenizer=self.tokenizer,
|
| 293 |
+
# mlm=False,
|
| 294 |
+
# pad_to_multiple_of=8
|
| 295 |
+
# )
|
| 296 |
+
|
| 297 |
+
# # Custom training to handle potential issues
|
| 298 |
+
# try:
|
| 299 |
+
# # Initialize trainer
|
| 300 |
+
# trainer = Trainer(
|
| 301 |
+
# model=self.model,
|
| 302 |
+
# args=self.training_args,
|
| 303 |
+
# train_dataset=self.train_dataset,
|
| 304 |
+
# eval_dataset=self.val_dataset,
|
| 305 |
+
# data_collator=data_collator,
|
| 306 |
+
# tokenizer=self.tokenizer,
|
| 307 |
+
# )
|
| 308 |
+
|
| 309 |
+
# # Start training
|
| 310 |
+
# print("="*50)
|
| 311 |
+
# print("Starting fine-tuning...")
|
| 312 |
+
# print("="*50)
|
| 313 |
+
|
| 314 |
+
# # Train with error handling
|
| 315 |
+
# train_result = trainer.train()
|
| 316 |
+
|
| 317 |
+
# # Save the final model
|
| 318 |
+
# print("\nSaving fine-tuned model...")
|
| 319 |
+
# trainer.save_model(f"{self.training_args.output_dir}/final_model_2b")
|
| 320 |
+
# self.tokenizer.save_pretrained(f"{self.training_args.output_dir}/final_model_2b")
|
| 321 |
+
|
| 322 |
+
# # Save training metrics
|
| 323 |
+
# with open(f"{self.training_args.output_dir}/training_metrics.json", 'w') as f:
|
| 324 |
+
# json.dump(train_result.metrics, f, indent=2)
|
| 325 |
+
|
| 326 |
+
# print("\n" + "="*50)
|
| 327 |
+
# print("Training completed successfully!")
|
| 328 |
+
# print(f"Model saved to: {self.training_args.output_dir}/final_model_2b")
|
| 329 |
+
# print("="*50)
|
| 330 |
+
|
| 331 |
+
# return trainer
|
| 332 |
+
|
| 333 |
+
# except Exception as e:
|
| 334 |
+
# print(f"Error during training: {e}")
|
| 335 |
+
# print("Attempting to save checkpoint...")
|
| 336 |
+
|
| 337 |
+
# # Try to save whatever we have
|
| 338 |
+
# try:
|
| 339 |
+
# self.model.save_pretrained(f"{self.training_args.output_dir}/checkpoint_emergency")
|
| 340 |
+
# self.tokenizer.save_pretrained(f"{self.training_args.output_dir}/checkpoint_emergency")
|
| 341 |
+
# print(f"Emergency checkpoint saved to: {self.training_args.output_dir}/checkpoint_emergency")
|
| 342 |
+
# except:
|
| 343 |
+
# print("Could not save emergency checkpoint")
|
| 344 |
+
|
| 345 |
+
# raise e
|
| 346 |
+
|
| 347 |
+
# def test_model(model_path: str, tokenizer_path: str):
|
| 348 |
+
# """Test the fine-tuned model with a sample input"""
|
| 349 |
+
|
| 350 |
+
# print("\n" + "="*50)
|
| 351 |
+
# print("Testing fine-tuned model...")
|
| 352 |
+
# print("="*50)
|
| 353 |
+
|
| 354 |
+
# # Load model and tokenizer
|
| 355 |
+
# from peft import PeftModel, PeftConfig
|
| 356 |
+
|
| 357 |
+
# tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
| 358 |
+
|
| 359 |
+
# # Try to load as PEFT model
|
| 360 |
+
# try:
|
| 361 |
+
# config = PeftConfig.from_pretrained(model_path)
|
| 362 |
+
# model = AutoModelForCausalLM.from_pretrained(
|
| 363 |
+
# config.base_model_name_or_path,
|
| 364 |
+
# torch_dtype=torch.float16,
|
| 365 |
+
# device_map="auto"
|
| 366 |
+
# )
|
| 367 |
+
# model = PeftModel.from_pretrained(model, model_path)
|
| 368 |
+
# except:
|
| 369 |
+
# # Load as regular model
|
| 370 |
+
# model = AutoModelForCausalLM.from_pretrained(
|
| 371 |
+
# model_path,
|
| 372 |
+
# torch_dtype=torch.float16,
|
| 373 |
+
# device_map="auto"
|
| 374 |
+
# )
|
| 375 |
+
|
| 376 |
+
# model.eval()
|
| 377 |
+
|
| 378 |
+
# # Test input
|
| 379 |
+
# test_input = "こんにちは。最近ストレスを感じています。"
|
| 380 |
+
|
| 381 |
+
# # Generate response
|
| 382 |
+
# inputs = tokenizer(test_input, return_tensors="pt")
|
| 383 |
+
# inputs = {k: v.cuda() if torch.cuda.is_available() else v for k, v in inputs.items()}
|
| 384 |
+
|
| 385 |
+
# with torch.no_grad():
|
| 386 |
+
# outputs = model.generate(
|
| 387 |
+
# **inputs,
|
| 388 |
+
# max_new_tokens=100,
|
| 389 |
+
# temperature=0.1,
|
| 390 |
+
# do_sample=True,
|
| 391 |
+
# top_p=0.9
|
| 392 |
+
# )
|
| 393 |
+
|
| 394 |
+
# response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 395 |
+
# print(f"Input: {test_input}")
|
| 396 |
+
# print(f"Response: {response}")
|
| 397 |
+
# print("="*50)
|
| 398 |
+
|
| 399 |
+
# # Main training script
|
| 400 |
+
# if __name__ == "__main__":
|
| 401 |
+
# import argparse
|
| 402 |
+
|
| 403 |
+
# parser = argparse.ArgumentParser(description='Fine-tune LFM model for counseling')
|
| 404 |
+
# parser.add_argument('--model_name', type=str, default='gpt2', # Using GPT2 as fallback
|
| 405 |
+
# help='Base model name (use gpt2 if liquid model fails)')
|
| 406 |
+
# parser.add_argument('--data_path', type=str, default='./processed_data_score80',
|
| 407 |
+
# help='Path to processed data')
|
| 408 |
+
# parser.add_argument('--output_dir', type=str, default='./counselor_model_2b',
|
| 409 |
+
# help='Output directory for fine-tuned model')
|
| 410 |
+
# parser.add_argument('--use_4bit', action='store_true', default=False,
|
| 411 |
+
# help='Use 4-bit quantization (set to False for stability)')
|
| 412 |
+
# parser.add_argument('--test_only', action='store_true',
|
| 413 |
+
# help='Only test existing model')
|
| 414 |
+
|
| 415 |
+
# args = parser.parse_args()
|
| 416 |
+
|
| 417 |
+
# if args.test_only:
|
| 418 |
+
# # Test existing model
|
| 419 |
+
# test_model(
|
| 420 |
+
# f"{args.output_dir}/final_model_2b",
|
| 421 |
+
# f"{args.output_dir}/final_model_2b"
|
| 422 |
+
# )
|
| 423 |
+
# else:
|
| 424 |
+
# # Check if CUDA is available
|
| 425 |
+
# if not torch.cuda.is_available():
|
| 426 |
+
# print("Warning: CUDA is not available. Training will be very slow on CPU.")
|
| 427 |
+
# print("It's highly recommended to use a GPU for training.")
|
| 428 |
+
# response = input("Do you want to continue anyway? (y/n): ")
|
| 429 |
+
# if response.lower() != 'y':
|
| 430 |
+
# exit()
|
| 431 |
+
|
| 432 |
+
# try:
|
| 433 |
+
# # Clear GPU cache
|
| 434 |
+
# if torch.cuda.is_available():
|
| 435 |
+
# torch.cuda.empty_cache()
|
| 436 |
+
|
| 437 |
+
# # Initialize fine-tuner
|
| 438 |
+
# print(f"Initializing fine-tuner with model: {args.model_name}")
|
| 439 |
+
# finetuner = LFMCounselorFineTuner(
|
| 440 |
+
# model_name=args.model_name,
|
| 441 |
+
# use_4bit=args.use_4bit
|
| 442 |
+
# )
|
| 443 |
+
|
| 444 |
+
# # Setup model
|
| 445 |
+
# print("\nSetting up model and tokenizer...")
|
| 446 |
+
# finetuner.setup_model_and_tokenizer()
|
| 447 |
+
|
| 448 |
+
# # Load datasets (using new method without multiprocessing)
|
| 449 |
+
# print("\nLoading and processing datasets...")
|
| 450 |
+
# finetuner.load_and_process_datasets(args.data_path)
|
| 451 |
+
|
| 452 |
+
# # Setup training arguments
|
| 453 |
+
# print("\nSetting up training arguments...")
|
| 454 |
+
# finetuner.setup_training_args(args.output_dir)
|
| 455 |
+
|
| 456 |
+
# # Train
|
| 457 |
+
# trainer = finetuner.train()
|
| 458 |
+
|
| 459 |
+
# # Test the model
|
| 460 |
+
# print("\nTesting the fine-tuned model...")
|
| 461 |
+
# test_model(
|
| 462 |
+
# f"{args.output_dir}/final_model_2b",
|
| 463 |
+
# f"{args.output_dir}/final_model_2b"
|
| 464 |
+
# )
|
| 465 |
+
|
| 466 |
+
# print("\n✅ Fine-tuning completed successfully!")
|
| 467 |
+
# print(f"📁 Model saved to: {args.output_dir}/final_model_2b")
|
| 468 |
+
# print("\nNext steps:")
|
| 469 |
+
# print("1. Test more: python finetune_lfm.py --test_only")
|
| 470 |
+
# print("2. Run benchmarking: python benchmark_model.py")
|
| 471 |
+
# print("3. Optimize for mobile: python optimize_for_mobile.py")
|
| 472 |
+
|
| 473 |
+
# except KeyboardInterrupt:
|
| 474 |
+
# print("\n\nTraining interrupted by user.")
|
| 475 |
+
# print("Partial model may be saved in checkpoints.")
|
| 476 |
+
# except Exception as e:
|
| 477 |
+
# print(f"\n❌ Error during fine-tuning: {e}")
|
| 478 |
+
# import traceback
|
| 479 |
+
# traceback.print_exc()
|
| 480 |
+
# print("\nTroubleshooting tips:")
|
| 481 |
+
# print("1. Try reducing batch size")
|
| 482 |
+
# print("2. Try without 4-bit quantization: remove --use_4bit")
|
| 483 |
+
# print("3. Try with a smaller model like gpt2")
|
| 484 |
+
# print("4. Ensure you have enough GPU memory")
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
###### wandb login ######
|
| 489 |
+
|
| 490 |
+
import torch
|
| 491 |
+
from transformers import (
|
| 492 |
+
AutoModelForCausalLM,
|
| 493 |
+
AutoTokenizer,
|
| 494 |
+
TrainingArguments,
|
| 495 |
+
Trainer,
|
| 496 |
+
DataCollatorForLanguageModeling,
|
| 497 |
+
BitsAndBytesConfig,
|
| 498 |
+
TrainerCallback
|
| 499 |
+
)
|
| 500 |
+
from peft import (
|
| 501 |
+
LoraConfig,
|
| 502 |
+
get_peft_model,
|
| 503 |
+
prepare_model_for_kbit_training,
|
| 504 |
+
TaskType
|
| 505 |
+
)
|
| 506 |
+
from datasets import load_dataset, Dataset
|
| 507 |
+
import os
|
| 508 |
+
from typing import Dict, List, Optional
|
| 509 |
+
import numpy as np
|
| 510 |
+
from tqdm import tqdm
|
| 511 |
+
import json
|
| 512 |
+
import gc
|
| 513 |
+
import warnings
|
| 514 |
+
import wandb
|
| 515 |
+
from datetime import datetime
|
| 516 |
+
|
| 517 |
+
warnings.filterwarnings('ignore')
|
| 518 |
+
|
| 519 |
+
class LFMCounselorFineTuner:
|
| 520 |
+
def __init__(self, model_name: str = "LiquidAI/LFM2-2.6B", use_4bit: bool = True):
|
| 521 |
+
"""
|
| 522 |
+
Initialize the fine-tuner for LFM models
|
| 523 |
+
|
| 524 |
+
Args:
|
| 525 |
+
model_name: Name of the base model
|
| 526 |
+
use_4bit: Whether to use 4-bit quantization for memory efficiency
|
| 527 |
+
"""
|
| 528 |
+
self.model_name = model_name
|
| 529 |
+
self.use_4bit = use_4bit
|
| 530 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 531 |
+
|
| 532 |
+
print(f"Using device: {self.device}")
|
| 533 |
+
gpu_memory = 0
|
| 534 |
+
if torch.cuda.is_available():
|
| 535 |
+
gpu_name = torch.cuda.get_device_name(0)
|
| 536 |
+
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
|
| 537 |
+
print(f"GPU: {gpu_name}")
|
| 538 |
+
print(f"GPU Memory: {gpu_memory:.2f} GB")
|
| 539 |
+
|
| 540 |
+
# Initialize WandB (always enabled)
|
| 541 |
+
try:
|
| 542 |
+
# Create a unique run name with timestamp
|
| 543 |
+
run_name = f"lfm-counselor-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
|
| 544 |
+
|
| 545 |
+
# Initialize wandb with comprehensive config
|
| 546 |
+
wandb.init(
|
| 547 |
+
project="liquid-counselor-hackathon",
|
| 548 |
+
name=run_name,
|
| 549 |
+
config={
|
| 550 |
+
"model_name": model_name,
|
| 551 |
+
"use_4bit_quantization": use_4bit,
|
| 552 |
+
"device": str(self.device),
|
| 553 |
+
"gpu": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU",
|
| 554 |
+
"gpu_memory_gb": gpu_memory,
|
| 555 |
+
"framework": "transformers",
|
| 556 |
+
"peft_method": "LoRA",
|
| 557 |
+
"task": "japanese_counseling",
|
| 558 |
+
"dataset": "KokoroChat"
|
| 559 |
+
},
|
| 560 |
+
tags=["counseling", "japanese", "lfm", "finetune", "hackathon"]
|
| 561 |
+
)
|
| 562 |
+
print(f"✅ WandB initialized: {wandb.run.name}")
|
| 563 |
+
print(f"📊 View run at: {wandb.run.get_url()}")
|
| 564 |
+
self.wandb_enabled = True
|
| 565 |
+
except Exception as e:
|
| 566 |
+
print(f"⚠️ WandB initialization failed: {e}")
|
| 567 |
+
print("Continuing without WandB logging...")
|
| 568 |
+
self.wandb_enabled = False
|
| 569 |
+
os.environ["WANDB_DISABLED"] = "true"
|
| 570 |
+
|
| 571 |
+
def setup_model_and_tokenizer(self):
|
| 572 |
+
"""Setup model with quantization and LoRA"""
|
| 573 |
+
|
| 574 |
+
print("Loading tokenizer...")
|
| 575 |
+
# Tokenizer setup
|
| 576 |
+
try:
|
| 577 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
| 578 |
+
except:
|
| 579 |
+
# Fallback to a known working tokenizer if model-specific one fails
|
| 580 |
+
print("Using fallback tokenizer...")
|
| 581 |
+
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 582 |
+
|
| 583 |
+
# Add padding token if it doesn't exist
|
| 584 |
+
if self.tokenizer.pad_token is None:
|
| 585 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 586 |
+
if self.tokenizer.eos_token is None:
|
| 587 |
+
self.tokenizer.eos_token = "</s>"
|
| 588 |
+
self.tokenizer.pad_token = "</s>"
|
| 589 |
+
|
| 590 |
+
self.tokenizer.padding_side = "right"
|
| 591 |
+
|
| 592 |
+
# Quantization config for memory efficiency
|
| 593 |
+
if self.use_4bit:
|
| 594 |
+
print("Setting up 4-bit quantization...")
|
| 595 |
+
bnb_config = BitsAndBytesConfig(
|
| 596 |
+
load_in_4bit=True,
|
| 597 |
+
bnb_4bit_quant_type="nf4",
|
| 598 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 599 |
+
bnb_4bit_use_double_quant=True
|
| 600 |
+
)
|
| 601 |
+
else:
|
| 602 |
+
bnb_config = None
|
| 603 |
+
|
| 604 |
+
# Load model
|
| 605 |
+
print(f"Loading model: {self.model_name}...")
|
| 606 |
+
try:
|
| 607 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 608 |
+
self.model_name,
|
| 609 |
+
quantization_config=bnb_config,
|
| 610 |
+
device_map="auto",
|
| 611 |
+
trust_remote_code=True,
|
| 612 |
+
torch_dtype=torch.float16
|
| 613 |
+
)
|
| 614 |
+
except Exception as e:
|
| 615 |
+
print(f"Error loading model: {e}")
|
| 616 |
+
print("Attempting to load without quantization...")
|
| 617 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 618 |
+
self.model_name,
|
| 619 |
+
device_map="auto",
|
| 620 |
+
trust_remote_code=True,
|
| 621 |
+
torch_dtype=torch.float16,
|
| 622 |
+
low_cpu_mem_usage=True
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
# Enable gradient checkpointing to save memory
|
| 626 |
+
if hasattr(self.model, 'gradient_checkpointing_enable'):
|
| 627 |
+
self.model.gradient_checkpointing_enable()
|
| 628 |
+
|
| 629 |
+
# Prepare model for k-bit training
|
| 630 |
+
if self.use_4bit:
|
| 631 |
+
print("Preparing model for 4-bit training...")
|
| 632 |
+
self.model = prepare_model_for_kbit_training(self.model)
|
| 633 |
+
|
| 634 |
+
# LoRA configuration - optimized for counseling task
|
| 635 |
+
print("Applying LoRA configuration...")
|
| 636 |
+
|
| 637 |
+
# Find the target modules dynamically
|
| 638 |
+
target_modules = self.find_target_modules()
|
| 639 |
+
|
| 640 |
+
lora_config = LoraConfig(
|
| 641 |
+
r=16, # Reduced rank for stability
|
| 642 |
+
lora_alpha=32, # Alpha parameter for LoRA scaling
|
| 643 |
+
target_modules=target_modules,
|
| 644 |
+
lora_dropout=0.05,
|
| 645 |
+
bias="none",
|
| 646 |
+
task_type=TaskType.CAUSAL_LM,
|
| 647 |
+
inference_mode=False
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
# Apply LoRA
|
| 651 |
+
self.model = get_peft_model(self.model, lora_config)
|
| 652 |
+
|
| 653 |
+
# Get trainable parameters info
|
| 654 |
+
trainable_params = 0
|
| 655 |
+
all_params = 0
|
| 656 |
+
for _, param in self.model.named_parameters():
|
| 657 |
+
all_params += param.numel()
|
| 658 |
+
if param.requires_grad:
|
| 659 |
+
trainable_params += param.numel()
|
| 660 |
+
|
| 661 |
+
trainable_percentage = 100 * trainable_params / all_params if all_params > 0 else 0
|
| 662 |
+
|
| 663 |
+
print(f"Trainable parameters: {trainable_params:,} / {all_params:,} ({trainable_percentage:.2f}%)")
|
| 664 |
+
|
| 665 |
+
# Log model architecture to WandB
|
| 666 |
+
if self.wandb_enabled:
|
| 667 |
+
wandb.config.update({
|
| 668 |
+
"lora_r": lora_config.r,
|
| 669 |
+
"lora_alpha": lora_config.lora_alpha,
|
| 670 |
+
"lora_dropout": lora_config.lora_dropout,
|
| 671 |
+
"lora_target_modules": target_modules,
|
| 672 |
+
"total_parameters": all_params,
|
| 673 |
+
"trainable_parameters": trainable_params,
|
| 674 |
+
"trainable_percentage": trainable_percentage
|
| 675 |
+
})
|
| 676 |
+
|
| 677 |
+
self.model.print_trainable_parameters()
|
| 678 |
+
|
| 679 |
+
def find_target_modules(self):
|
| 680 |
+
"""Find linear modules to apply LoRA to"""
|
| 681 |
+
target_modules = []
|
| 682 |
+
for name, module in self.model.named_modules():
|
| 683 |
+
if isinstance(module, torch.nn.Linear):
|
| 684 |
+
# Extract the module name
|
| 685 |
+
names = name.split('.')
|
| 686 |
+
if len(names) > 0:
|
| 687 |
+
target_modules.append(names[-1])
|
| 688 |
+
|
| 689 |
+
# Remove duplicates and filter common patterns
|
| 690 |
+
target_modules = list(set(target_modules))
|
| 691 |
+
|
| 692 |
+
# Common patterns for transformer models
|
| 693 |
+
common_targets = ["q_proj", "v_proj", "k_proj", "o_proj",
|
| 694 |
+
"gate_proj", "up_proj", "down_proj",
|
| 695 |
+
"fc1", "fc2", "query", "key", "value", "dense"]
|
| 696 |
+
|
| 697 |
+
# Filter to only include common targets if they exist
|
| 698 |
+
final_targets = [t for t in target_modules if any(ct in t.lower() for ct in common_targets)]
|
| 699 |
+
|
| 700 |
+
# If no common targets found, use all linear layers
|
| 701 |
+
if not final_targets:
|
| 702 |
+
final_targets = target_modules[:6] # Limit to prevent too many parameters
|
| 703 |
+
|
| 704 |
+
print(f"LoRA target modules: {final_targets}")
|
| 705 |
+
return final_targets if final_targets else ["q_proj", "v_proj"] # Fallback
|
| 706 |
+
|
| 707 |
+
def load_and_process_datasets(self, data_path: str):
|
| 708 |
+
"""Load and process datasets without multiprocessing issues"""
|
| 709 |
+
|
| 710 |
+
print(f"Loading datasets from {data_path}...")
|
| 711 |
+
|
| 712 |
+
# Load train dataset
|
| 713 |
+
train_texts = []
|
| 714 |
+
train_scores = []
|
| 715 |
+
train_topics = []
|
| 716 |
+
|
| 717 |
+
with open(f'{data_path}/train.jsonl', 'r', encoding='utf-8') as f:
|
| 718 |
+
for line in tqdm(f, desc="Loading training data"):
|
| 719 |
+
data = json.loads(line)
|
| 720 |
+
train_texts.append(data['text'])
|
| 721 |
+
train_scores.append(data.get('score', 0))
|
| 722 |
+
train_topics.append(data.get('topic', 'Unknown'))
|
| 723 |
+
|
| 724 |
+
# Load validation dataset
|
| 725 |
+
val_texts = []
|
| 726 |
+
val_scores = []
|
| 727 |
+
val_topics = []
|
| 728 |
+
|
| 729 |
+
with open(f'{data_path}/validation.jsonl', 'r', encoding='utf-8') as f:
|
| 730 |
+
for line in tqdm(f, desc="Loading validation data"):
|
| 731 |
+
data = json.loads(line)
|
| 732 |
+
val_texts.append(data['text'])
|
| 733 |
+
val_scores.append(data.get('score', 0))
|
| 734 |
+
val_topics.append(data.get('topic', 'Unknown'))
|
| 735 |
+
|
| 736 |
+
print(f"Loaded {len(train_texts)} training examples")
|
| 737 |
+
print(f"Loaded {len(val_texts)} validation examples")
|
| 738 |
+
|
| 739 |
+
# Log dataset statistics to WandB
|
| 740 |
+
if self.wandb_enabled:
|
| 741 |
+
# Calculate score statistics
|
| 742 |
+
train_score_stats = {
|
| 743 |
+
"train_examples": len(train_texts),
|
| 744 |
+
"train_avg_score": float(np.mean(train_scores)),
|
| 745 |
+
"train_min_score": float(np.min(train_scores)),
|
| 746 |
+
"train_max_score": float(np.max(train_scores)),
|
| 747 |
+
"train_std_score": float(np.std(train_scores))
|
| 748 |
+
}
|
| 749 |
+
|
| 750 |
+
val_score_stats = {
|
| 751 |
+
"val_examples": len(val_texts),
|
| 752 |
+
"val_avg_score": float(np.mean(val_scores)),
|
| 753 |
+
"val_min_score": float(np.min(val_scores)),
|
| 754 |
+
"val_max_score": float(np.max(val_scores)),
|
| 755 |
+
"val_std_score": float(np.std(val_scores))
|
| 756 |
+
}
|
| 757 |
+
|
| 758 |
+
wandb.config.update(train_score_stats)
|
| 759 |
+
wandb.config.update(val_score_stats)
|
| 760 |
+
|
| 761 |
+
# Log score distribution histogram
|
| 762 |
+
wandb.log({
|
| 763 |
+
"train_score_distribution": wandb.Histogram(train_scores),
|
| 764 |
+
"val_score_distribution": wandb.Histogram(val_scores)
|
| 765 |
+
})
|
| 766 |
+
|
| 767 |
+
# Log topic distribution
|
| 768 |
+
train_topic_counts = {}
|
| 769 |
+
for topic in train_topics:
|
| 770 |
+
train_topic_counts[topic] = train_topic_counts.get(topic, 0) + 1
|
| 771 |
+
|
| 772 |
+
# Create a bar chart for topics (top 20)
|
| 773 |
+
if len(train_topic_counts) > 0:
|
| 774 |
+
top_topics = sorted(train_topic_counts.items(), key=lambda x: x[1], reverse=True)[:20]
|
| 775 |
+
wandb.log({
|
| 776 |
+
"topic_distribution": wandb.plot.bar(
|
| 777 |
+
wandb.Table(data=[[k, v] for k, v in top_topics],
|
| 778 |
+
columns=["Topic", "Count"]),
|
| 779 |
+
"Topic", "Count", title="Training Topic Distribution (Top 20)"
|
| 780 |
+
)
|
| 781 |
+
})
|
| 782 |
+
|
| 783 |
+
# Tokenize datasets in batches (avoiding multiprocessing)
|
| 784 |
+
print("Tokenizing training dataset...")
|
| 785 |
+
train_encodings = self.tokenize_texts(train_texts)
|
| 786 |
+
|
| 787 |
+
print("Tokenizing validation dataset...")
|
| 788 |
+
val_encodings = self.tokenize_texts(val_texts)
|
| 789 |
+
|
| 790 |
+
# Create datasets
|
| 791 |
+
self.train_dataset = Dataset.from_dict(train_encodings)
|
| 792 |
+
self.val_dataset = Dataset.from_dict(val_encodings)
|
| 793 |
+
|
| 794 |
+
# Set format for PyTorch
|
| 795 |
+
self.train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
|
| 796 |
+
self.val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
|
| 797 |
+
|
| 798 |
+
# Clean up memory
|
| 799 |
+
del train_texts, val_texts, train_encodings, val_encodings
|
| 800 |
+
gc.collect()
|
| 801 |
+
|
| 802 |
+
def tokenize_texts(self, texts: List[str], batch_size: int = 100):
|
| 803 |
+
"""Tokenize texts in batches to avoid memory issues"""
|
| 804 |
+
all_input_ids = []
|
| 805 |
+
all_attention_masks = []
|
| 806 |
+
|
| 807 |
+
for i in tqdm(range(0, len(texts), batch_size), desc="Tokenizing"):
|
| 808 |
+
batch_texts = texts[i:i + batch_size]
|
| 809 |
+
|
| 810 |
+
# Tokenize batch
|
| 811 |
+
encodings = self.tokenizer(
|
| 812 |
+
batch_texts,
|
| 813 |
+
truncation=True,
|
| 814 |
+
padding='max_length',
|
| 815 |
+
max_length=512,
|
| 816 |
+
return_tensors='pt'
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
# Convert to lists
|
| 820 |
+
all_input_ids.extend(encodings['input_ids'].tolist())
|
| 821 |
+
all_attention_masks.extend(encodings['attention_mask'].tolist())
|
| 822 |
+
|
| 823 |
+
# Create labels (same as input_ids for language modeling)
|
| 824 |
+
labels = all_input_ids.copy()
|
| 825 |
+
|
| 826 |
+
return {
|
| 827 |
+
'input_ids': all_input_ids,
|
| 828 |
+
'attention_mask': all_attention_masks,
|
| 829 |
+
'labels': labels
|
| 830 |
+
}
|
| 831 |
+
|
| 832 |
+
def setup_training_args(self, output_dir: str = "./counselor_model_2b"):
|
| 833 |
+
"""Setup training arguments optimized for counseling task"""
|
| 834 |
+
|
| 835 |
+
print("Setting up training arguments...")
|
| 836 |
+
|
| 837 |
+
# Calculate batch sizes based on available memory
|
| 838 |
+
if torch.cuda.is_available():
|
| 839 |
+
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
|
| 840 |
+
if gpu_memory < 16: # Less than 16GB
|
| 841 |
+
batch_size = 1
|
| 842 |
+
gradient_accumulation = 16
|
| 843 |
+
elif gpu_memory < 24: # Less than 24GB
|
| 844 |
+
batch_size = 2
|
| 845 |
+
gradient_accumulation = 8
|
| 846 |
+
else: # 24GB or more
|
| 847 |
+
batch_size = 4
|
| 848 |
+
gradient_accumulation = 4
|
| 849 |
+
else:
|
| 850 |
+
batch_size = 1
|
| 851 |
+
gradient_accumulation = 16
|
| 852 |
+
|
| 853 |
+
print(f"Using batch_size={batch_size}, gradient_accumulation={gradient_accumulation}")
|
| 854 |
+
|
| 855 |
+
# Update WandB config with training hyperparameters
|
| 856 |
+
if self.wandb_enabled:
|
| 857 |
+
wandb.config.update({
|
| 858 |
+
"batch_size": batch_size,
|
| 859 |
+
"gradient_accumulation_steps": gradient_accumulation,
|
| 860 |
+
"effective_batch_size": batch_size * gradient_accumulation,
|
| 861 |
+
"num_epochs": 3,
|
| 862 |
+
"learning_rate": 5e-5,
|
| 863 |
+
"warmup_steps": 100,
|
| 864 |
+
"weight_decay": 0.01,
|
| 865 |
+
"max_grad_norm": 1.0,
|
| 866 |
+
"lr_scheduler": "linear",
|
| 867 |
+
"optimizer": "adamw_torch",
|
| 868 |
+
"fp16": True,
|
| 869 |
+
"max_length": 512
|
| 870 |
+
})
|
| 871 |
+
|
| 872 |
+
# Set report_to based on wandb availability
|
| 873 |
+
report_to = "wandb" if self.wandb_enabled else "none"
|
| 874 |
+
|
| 875 |
+
self.training_args = TrainingArguments(
|
| 876 |
+
output_dir=output_dir,
|
| 877 |
+
num_train_epochs=3,
|
| 878 |
+
per_device_train_batch_size=batch_size,
|
| 879 |
+
per_device_eval_batch_size=batch_size,
|
| 880 |
+
gradient_accumulation_steps=gradient_accumulation,
|
| 881 |
+
gradient_checkpointing=True,
|
| 882 |
+
warmup_steps=100,
|
| 883 |
+
learning_rate=5e-5,
|
| 884 |
+
fp16=True,
|
| 885 |
+
logging_steps=50,
|
| 886 |
+
logging_first_step=True,
|
| 887 |
+
eval_strategy="steps",
|
| 888 |
+
eval_steps=200,
|
| 889 |
+
save_strategy="steps",
|
| 890 |
+
save_steps=400,
|
| 891 |
+
save_total_limit=2,
|
| 892 |
+
load_best_model_at_end=True,
|
| 893 |
+
metric_for_best_model="eval_loss",
|
| 894 |
+
greater_is_better=False,
|
| 895 |
+
report_to=report_to,
|
| 896 |
+
run_name=wandb.run.name if self.wandb_enabled and wandb.run else "local_run",
|
| 897 |
+
push_to_hub=False,
|
| 898 |
+
optim="adamw_torch",
|
| 899 |
+
lr_scheduler_type="linear",
|
| 900 |
+
weight_decay=0.01,
|
| 901 |
+
max_grad_norm=1.0,
|
| 902 |
+
remove_unused_columns=False,
|
| 903 |
+
label_names=["labels"],
|
| 904 |
+
dataloader_num_workers=0,
|
| 905 |
+
dataloader_pin_memory=False,
|
| 906 |
+
)
|
| 907 |
+
|
| 908 |
+
def train(self):
|
| 909 |
+
"""Execute training"""
|
| 910 |
+
|
| 911 |
+
print("Initializing trainer...")
|
| 912 |
+
|
| 913 |
+
# Data collator for language modeling
|
| 914 |
+
data_collator = DataCollatorForLanguageModeling(
|
| 915 |
+
tokenizer=self.tokenizer,
|
| 916 |
+
mlm=False,
|
| 917 |
+
pad_to_multiple_of=8
|
| 918 |
+
)
|
| 919 |
+
|
| 920 |
+
# Custom callback for additional metrics (properly inheriting from TrainerCallback)
|
| 921 |
+
class CustomMetricsCallback(TrainerCallback):
|
| 922 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 923 |
+
if logs and self.wandb_enabled:
|
| 924 |
+
# Add perplexity metrics
|
| 925 |
+
if "loss" in logs:
|
| 926 |
+
logs["perplexity"] = np.exp(logs["loss"])
|
| 927 |
+
if "eval_loss" in logs:
|
| 928 |
+
logs["eval_perplexity"] = np.exp(logs["eval_loss"])
|
| 929 |
+
return control
|
| 930 |
+
|
| 931 |
+
# Create callback instance with wandb_enabled flag
|
| 932 |
+
custom_callback = CustomMetricsCallback()
|
| 933 |
+
custom_callback.wandb_enabled = self.wandb_enabled
|
| 934 |
+
|
| 935 |
+
# Custom training to handle potential issues
|
| 936 |
+
try:
|
| 937 |
+
# Initialize trainer with callbacks
|
| 938 |
+
trainer = Trainer(
|
| 939 |
+
model=self.model,
|
| 940 |
+
args=self.training_args,
|
| 941 |
+
train_dataset=self.train_dataset,
|
| 942 |
+
eval_dataset=self.val_dataset,
|
| 943 |
+
data_collator=data_collator,
|
| 944 |
+
tokenizer=self.tokenizer,
|
| 945 |
+
callbacks=[custom_callback] if self.wandb_enabled else [],
|
| 946 |
+
)
|
| 947 |
+
|
| 948 |
+
# Calculate total training steps
|
| 949 |
+
total_steps = len(self.train_dataset) // (self.training_args.per_device_train_batch_size * self.training_args.gradient_accumulation_steps) * self.training_args.num_train_epochs
|
| 950 |
+
|
| 951 |
+
# Start training
|
| 952 |
+
print("="*50)
|
| 953 |
+
print("Starting fine-tuning...")
|
| 954 |
+
print(f"Total training samples: {len(self.train_dataset)}")
|
| 955 |
+
print(f"Total validation samples: {len(self.val_dataset)}")
|
| 956 |
+
print(f"Total training steps: {total_steps}")
|
| 957 |
+
print("="*50)
|
| 958 |
+
|
| 959 |
+
# Log training start
|
| 960 |
+
if self.wandb_enabled:
|
| 961 |
+
wandb.log({"training_status": "started", "total_steps": total_steps})
|
| 962 |
+
|
| 963 |
+
# Train with error handling
|
| 964 |
+
train_result = trainer.train()
|
| 965 |
+
|
| 966 |
+
# Save the final model
|
| 967 |
+
print("\nSaving fine-tuned model...")
|
| 968 |
+
trainer.save_model(f"{self.training_args.output_dir}/final_model_2b")
|
| 969 |
+
self.tokenizer.save_pretrained(f"{self.training_args.output_dir}/final_model_2b")
|
| 970 |
+
|
| 971 |
+
# Save training metrics
|
| 972 |
+
with open(f"{self.training_args.output_dir}/training_metrics.json", 'w') as f:
|
| 973 |
+
json.dump(train_result.metrics, f, indent=2)
|
| 974 |
+
|
| 975 |
+
# Final evaluation
|
| 976 |
+
print("\nRunning final evaluation...")
|
| 977 |
+
eval_results = trainer.evaluate()
|
| 978 |
+
|
| 979 |
+
# Save evaluation metrics
|
| 980 |
+
with open(f"{self.training_args.output_dir}/eval_metrics.json", 'w') as f:
|
| 981 |
+
json.dump(eval_results, f, indent=2)
|
| 982 |
+
|
| 983 |
+
# Log final metrics to WandB
|
| 984 |
+
if self.wandb_enabled:
|
| 985 |
+
# Log final metrics
|
| 986 |
+
wandb.run.summary.update({
|
| 987 |
+
"final_train_loss": train_result.metrics.get("train_loss", 0),
|
| 988 |
+
"final_eval_loss": eval_results.get("eval_loss", 0),
|
| 989 |
+
"final_eval_perplexity": np.exp(eval_results.get("eval_loss", 0)),
|
| 990 |
+
"total_training_time": train_result.metrics.get("train_runtime", 0),
|
| 991 |
+
"training_samples_per_second": train_result.metrics.get("train_samples_per_second", 0),
|
| 992 |
+
"training_status": "completed"
|
| 993 |
+
})
|
| 994 |
+
|
| 995 |
+
# Create a summary table
|
| 996 |
+
summary_table = wandb.Table(
|
| 997 |
+
columns=["Metric", "Value"],
|
| 998 |
+
data=[
|
| 999 |
+
["Final Training Loss", f"{train_result.metrics.get('train_loss', 0):.4f}"],
|
| 1000 |
+
["Final Eval Loss", f"{eval_results.get('eval_loss', 0):.4f}"],
|
| 1001 |
+
["Final Perplexity", f"{np.exp(eval_results.get('eval_loss', 0)):.2f}"],
|
| 1002 |
+
["Training Time (seconds)", f"{train_result.metrics.get('train_runtime', 0):.0f}"],
|
| 1003 |
+
["Training Samples/Second", f"{train_result.metrics.get('train_samples_per_second', 0):.2f}"]
|
| 1004 |
+
]
|
| 1005 |
+
)
|
| 1006 |
+
wandb.log({"training_summary": summary_table})
|
| 1007 |
+
|
| 1008 |
+
# Save model artifact
|
| 1009 |
+
try:
|
| 1010 |
+
artifact = wandb.Artifact(
|
| 1011 |
+
name=f"counselor-model-{wandb.run.id}",
|
| 1012 |
+
type="model",
|
| 1013 |
+
description="Fine-tuned Japanese counseling model",
|
| 1014 |
+
metadata={
|
| 1015 |
+
"base_model": self.model_name,
|
| 1016 |
+
"final_loss": float(eval_results.get("eval_loss", 0)),
|
| 1017 |
+
"final_perplexity": float(np.exp(eval_results.get("eval_loss", 0))),
|
| 1018 |
+
"dataset": "KokoroChat"
|
| 1019 |
+
}
|
| 1020 |
+
)
|
| 1021 |
+
artifact.add_dir(f"{self.training_args.output_dir}/final_model_2b")
|
| 1022 |
+
wandb.log_artifact(artifact)
|
| 1023 |
+
except Exception as e:
|
| 1024 |
+
print(f"Warning: Could not save model artifact: {e}")
|
| 1025 |
+
|
| 1026 |
+
print("\n" + "="*50)
|
| 1027 |
+
print("✅ Training completed successfully!")
|
| 1028 |
+
print(f"📁 Model saved to: {self.training_args.output_dir}/final_model_2b")
|
| 1029 |
+
print(f"📉 Final eval loss: {eval_results.get('eval_loss', 0):.4f}")
|
| 1030 |
+
print(f"📊 Final perplexity: {np.exp(eval_results.get('eval_loss', 0)):.2f}")
|
| 1031 |
+
if self.wandb_enabled and wandb.run:
|
| 1032 |
+
print(f"🔗 View results at: {wandb.run.get_url()}")
|
| 1033 |
+
print("="*50)
|
| 1034 |
+
|
| 1035 |
+
return trainer
|
| 1036 |
+
|
| 1037 |
+
except Exception as e:
|
| 1038 |
+
print(f"❌ Error during training: {e}")
|
| 1039 |
+
|
| 1040 |
+
# Log error to WandB
|
| 1041 |
+
if self.wandb_enabled:
|
| 1042 |
+
wandb.run.summary["training_status"] = "failed"
|
| 1043 |
+
wandb.run.summary["error"] = str(e)
|
| 1044 |
+
|
| 1045 |
+
print("Attempting to save checkpoint...")
|
| 1046 |
+
|
| 1047 |
+
# Try to save whatever we have
|
| 1048 |
+
try:
|
| 1049 |
+
self.model.save_pretrained(f"{self.training_args.output_dir}/checkpoint_emergency")
|
| 1050 |
+
self.tokenizer.save_pretrained(f"{self.training_args.output_dir}/checkpoint_emergency")
|
| 1051 |
+
print(f"💾 Emergency checkpoint saved to: {self.training_args.output_dir}/checkpoint_emergency")
|
| 1052 |
+
except:
|
| 1053 |
+
print("❌ Could not save emergency checkpoint")
|
| 1054 |
+
|
| 1055 |
+
raise e
|
| 1056 |
+
finally:
|
| 1057 |
+
# Ensure WandB run is finished
|
| 1058 |
+
if self.wandb_enabled:
|
| 1059 |
+
wandb.finish()
|
| 1060 |
+
|
| 1061 |
+
# def test_model(model_path: str, tokenizer_path: str):
|
| 1062 |
+
# """Test the fine-tuned model with sample inputs"""
|
| 1063 |
+
|
| 1064 |
+
# print("\n" + "="*50)
|
| 1065 |
+
# print("Testing fine-tuned model...")
|
| 1066 |
+
# print("="*50)
|
| 1067 |
+
|
| 1068 |
+
# # Load model and tokenizer
|
| 1069 |
+
# from peft import PeftModel, PeftConfig
|
| 1070 |
+
|
| 1071 |
+
# tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
| 1072 |
+
# if tokenizer.pad_token is None:
|
| 1073 |
+
# tokenizer.pad_token = tokenizer.eos_token
|
| 1074 |
+
|
| 1075 |
+
# # Try to load as PEFT model
|
| 1076 |
+
# try:
|
| 1077 |
+
# config = PeftConfig.from_pretrained(model_path)
|
| 1078 |
+
# model = AutoModelForCausalLM.from_pretrained(
|
| 1079 |
+
# config.base_model_name_or_path,
|
| 1080 |
+
# torch_dtype=torch.float16,
|
| 1081 |
+
# device_map="auto"
|
| 1082 |
+
# )
|
| 1083 |
+
# model = PeftModel.from_pretrained(model, model_path)
|
| 1084 |
+
# except:
|
| 1085 |
+
# # Load as regular model
|
| 1086 |
+
# model = AutoModelForCausalLM.from_pretrained(
|
| 1087 |
+
# model_path,
|
| 1088 |
+
# torch_dtype=torch.float16,
|
| 1089 |
+
# device_map="auto"
|
| 1090 |
+
# )
|
| 1091 |
+
|
| 1092 |
+
# model.eval()
|
| 1093 |
+
|
| 1094 |
+
# # Test inputs
|
| 1095 |
+
# test_cases = [
|
| 1096 |
+
# "こんにちは。最近ストレスを感じています。",
|
| 1097 |
+
# "仕事がうまくいかなくて悩んでいます。",
|
| 1098 |
+
# "人間関係で困っています。どうすればいいでしょうか。"
|
| 1099 |
+
# ]
|
| 1100 |
+
|
| 1101 |
+
# print("Sample conversations:")
|
| 1102 |
+
# print("-" * 50)
|
| 1103 |
+
|
| 1104 |
+
def test_model(model_path: str, tokenizer_path: str):
|
| 1105 |
+
"""Test the fine-tuned model with sample inputs"""
|
| 1106 |
+
|
| 1107 |
+
print("\n" + "="*50)
|
| 1108 |
+
print("Testing fine-tuned model...")
|
| 1109 |
+
print("="*50)
|
| 1110 |
+
|
| 1111 |
+
# Load model and tokenizer with proper local path handling
|
| 1112 |
+
from peft import PeftModel, PeftConfig
|
| 1113 |
+
import os
|
| 1114 |
+
|
| 1115 |
+
# Fix tokenizer loading for local paths
|
| 1116 |
+
try:
|
| 1117 |
+
# Check if tokenizer files exist in the path
|
| 1118 |
+
if os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")):
|
| 1119 |
+
print(f"Loading tokenizer from {tokenizer_path}")
|
| 1120 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
|
| 1121 |
+
else:
|
| 1122 |
+
print(f"Tokenizer not found at {tokenizer_path}, using base model tokenizer")
|
| 1123 |
+
# Fallback to base model tokenizer
|
| 1124 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 1125 |
+
except Exception as e:
|
| 1126 |
+
print(f"Error loading tokenizer: {e}")
|
| 1127 |
+
print("Using fallback GPT-2 tokenizer")
|
| 1128 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 1129 |
+
|
| 1130 |
+
if tokenizer.pad_token is None:
|
| 1131 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 1132 |
+
|
| 1133 |
+
# Try to load model
|
| 1134 |
+
try:
|
| 1135 |
+
# Check if it's a PEFT model
|
| 1136 |
+
adapter_config_path = os.path.join(model_path, "adapter_config.json")
|
| 1137 |
+
if os.path.exists(adapter_config_path):
|
| 1138 |
+
print("Loading as PEFT model...")
|
| 1139 |
+
config = PeftConfig.from_pretrained(model_path)
|
| 1140 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 1141 |
+
config.base_model_name_or_path,
|
| 1142 |
+
torch_dtype=torch.float16,
|
| 1143 |
+
device_map="auto",
|
| 1144 |
+
trust_remote_code=True
|
| 1145 |
+
)
|
| 1146 |
+
model = PeftModel.from_pretrained(base_model, model_path)
|
| 1147 |
+
else:
|
| 1148 |
+
# Load as regular model
|
| 1149 |
+
print("Loading as regular model...")
|
| 1150 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 1151 |
+
model_path,
|
| 1152 |
+
torch_dtype=torch.float16,
|
| 1153 |
+
device_map="auto",
|
| 1154 |
+
local_files_only=True,
|
| 1155 |
+
trust_remote_code=True
|
| 1156 |
+
)
|
| 1157 |
+
except Exception as e:
|
| 1158 |
+
print(f"Error loading model: {e}")
|
| 1159 |
+
raise
|
| 1160 |
+
|
| 1161 |
+
model.eval()
|
| 1162 |
+
|
| 1163 |
+
# Test inputs
|
| 1164 |
+
test_cases = [
|
| 1165 |
+
"こんにちは。最近ストレスを感じています。",
|
| 1166 |
+
"仕事がうまくいかなくて悩んでいます。",
|
| 1167 |
+
"人間関係で困っています。どうすればいいでしょうか。"
|
| 1168 |
+
]
|
| 1169 |
+
|
| 1170 |
+
print("Sample conversations:")
|
| 1171 |
+
print("-" * 50)
|
| 1172 |
+
|
| 1173 |
+
for test_input in test_cases:
|
| 1174 |
+
# Generate response
|
| 1175 |
+
inputs = tokenizer(test_input, return_tensors="pt", truncation=True, max_length=512)
|
| 1176 |
+
inputs = {k: v.cuda() if torch.cuda.is_available() else v for k, v in inputs.items()}
|
| 1177 |
+
|
| 1178 |
+
with torch.no_grad():
|
| 1179 |
+
outputs = model.generate(
|
| 1180 |
+
**inputs,
|
| 1181 |
+
max_new_tokens=150,
|
| 1182 |
+
temperature=0.1,
|
| 1183 |
+
do_sample=True,
|
| 1184 |
+
top_p=0.9,
|
| 1185 |
+
pad_token_id=tokenizer.pad_token_id
|
| 1186 |
+
)
|
| 1187 |
+
|
| 1188 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 1189 |
+
response = response[len(test_input):].strip() # Remove input from response
|
| 1190 |
+
|
| 1191 |
+
print(f"Client: {test_input}")
|
| 1192 |
+
print(f"Counselor: {response[:200]}...")
|
| 1193 |
+
print("-" * 50)
|
| 1194 |
+
|
| 1195 |
+
print("="*50)
|
| 1196 |
+
|
| 1197 |
+
for test_input in test_cases:
|
| 1198 |
+
# Generate response
|
| 1199 |
+
inputs = tokenizer(test_input, return_tensors="pt", truncation=True, max_length=512)
|
| 1200 |
+
inputs = {k: v.cuda() if torch.cuda.is_available() else v for k, v in inputs.items()}
|
| 1201 |
+
|
| 1202 |
+
with torch.no_grad():
|
| 1203 |
+
outputs = model.generate(
|
| 1204 |
+
**inputs,
|
| 1205 |
+
max_new_tokens=150,
|
| 1206 |
+
temperature=0.1,
|
| 1207 |
+
do_sample=True,
|
| 1208 |
+
top_p=0.9,
|
| 1209 |
+
pad_token_id=tokenizer.pad_token_id
|
| 1210 |
+
)
|
| 1211 |
+
|
| 1212 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 1213 |
+
response = response[len(test_input):].strip() # Remove input from response
|
| 1214 |
+
|
| 1215 |
+
print(f"Client: {test_input}")
|
| 1216 |
+
print(f"Counselor: {response[:200]}...")
|
| 1217 |
+
print("-" * 50)
|
| 1218 |
+
|
| 1219 |
+
print("="*50)
|
| 1220 |
+
|
| 1221 |
+
# Main training script
|
| 1222 |
+
if __name__ == "__main__":
|
| 1223 |
+
import argparse
|
| 1224 |
+
|
| 1225 |
+
parser = argparse.ArgumentParser(description='Fine-tune LFM model for counseling')
|
| 1226 |
+
parser.add_argument('--model_name', type=str, default='LiquidAI/LFM2-2.6B',
|
| 1227 |
+
help='Base model name')
|
| 1228 |
+
parser.add_argument('--data_path', type=str, default='./processed_data_score80',
|
| 1229 |
+
help='Path to processed data')
|
| 1230 |
+
parser.add_argument('--output_dir', type=str, default='./counselor_model_2b',
|
| 1231 |
+
help='Output directory for fine-tuned model')
|
| 1232 |
+
parser.add_argument('--use_4bit', action='store_true', default=False,
|
| 1233 |
+
help='Use 4-bit quantization')
|
| 1234 |
+
parser.add_argument('--wandb_api_key', type=str, default=None,
|
| 1235 |
+
help='WandB API key (optional, can use wandb login instead)')
|
| 1236 |
+
parser.add_argument('--test_only', action='store_true',
|
| 1237 |
+
help='Only test existing model')
|
| 1238 |
+
|
| 1239 |
+
args = parser.parse_args()
|
| 1240 |
+
|
| 1241 |
+
# Set WandB API key if provided
|
| 1242 |
+
if args.wandb_api_key:
|
| 1243 |
+
os.environ["WANDB_API_KEY"] = args.wandb_api_key
|
| 1244 |
+
|
| 1245 |
+
if args.test_only:
|
| 1246 |
+
# Test existing model
|
| 1247 |
+
test_model(
|
| 1248 |
+
f"{args.output_dir}/final_model_2b",
|
| 1249 |
+
f"{args.output_dir}/final_model_2b"
|
| 1250 |
+
)
|
| 1251 |
+
else:
|
| 1252 |
+
# Check if CUDA is available
|
| 1253 |
+
if not torch.cuda.is_available():
|
| 1254 |
+
print("⚠️ Warning: CUDA is not available. Training will be very slow on CPU.")
|
| 1255 |
+
print("It's highly recommended to use a GPU for training.")
|
| 1256 |
+
response = input("Do you want to continue anyway? (y/n): ")
|
| 1257 |
+
if response.lower() != 'y':
|
| 1258 |
+
exit()
|
| 1259 |
+
|
| 1260 |
+
try:
|
| 1261 |
+
# Clear GPU cache
|
| 1262 |
+
if torch.cuda.is_available():
|
| 1263 |
+
torch.cuda.empty_cache()
|
| 1264 |
+
|
| 1265 |
+
# Initialize fine-tuner (WandB is enabled by default)
|
| 1266 |
+
print(f"🚀 Initializing fine-tuner with model: {args.model_name}")
|
| 1267 |
+
finetuner = LFMCounselorFineTuner(
|
| 1268 |
+
model_name=args.model_name,
|
| 1269 |
+
use_4bit=args.use_4bit
|
| 1270 |
+
)
|
| 1271 |
+
|
| 1272 |
+
# Setup model
|
| 1273 |
+
print("\n🔧 Setting up model and tokenizer...")
|
| 1274 |
+
finetuner.setup_model_and_tokenizer()
|
| 1275 |
+
|
| 1276 |
+
# Load datasets
|
| 1277 |
+
print("\n📚 Loading and processing datasets...")
|
| 1278 |
+
finetuner.load_and_process_datasets(args.data_path)
|
| 1279 |
+
|
| 1280 |
+
# Setup training arguments
|
| 1281 |
+
print("\n⚙️ Setting up training arguments...")
|
| 1282 |
+
finetuner.setup_training_args(args.output_dir)
|
| 1283 |
+
|
| 1284 |
+
# Train
|
| 1285 |
+
trainer = finetuner.train()
|
| 1286 |
+
|
| 1287 |
+
# Test the model
|
| 1288 |
+
print("\n🧪 Testing the fine-tuned model...")
|
| 1289 |
+
test_model(
|
| 1290 |
+
f"{args.output_dir}/final_model_2b_v2",
|
| 1291 |
+
f"{args.output_dir}/final_model_2b_v2"
|
| 1292 |
+
)
|
| 1293 |
+
|
| 1294 |
+
print("\n✅ Fine-tuning completed successfully!")
|
| 1295 |
+
print(f"📁 Model saved to: {args.output_dir}/final_model_2b_v2")
|
| 1296 |
+
print("\n📋 Next steps:")
|
| 1297 |
+
print("1. Test more: python finetune_lfm.py --test_only")
|
| 1298 |
+
print("2. Run benchmarking: python benchmark_model.py")
|
| 1299 |
+
print("3. Optimize for mobile: python optimize_for_mobile.py")
|
| 1300 |
+
|
| 1301 |
+
except KeyboardInterrupt:
|
| 1302 |
+
print("\n\n⚠️ Training interrupted by user.")
|
| 1303 |
+
print("Partial model may be saved in checkpoints.")
|
| 1304 |
+
if wandb.run:
|
| 1305 |
+
wandb.finish()
|
| 1306 |
+
except Exception as e:
|
| 1307 |
+
print(f"\n❌ Error during fine-tuning: {e}")
|
| 1308 |
+
import traceback
|
| 1309 |
+
traceback.print_exc()
|
| 1310 |
+
if wandb.run:
|
| 1311 |
+
wandb.finish()
|
finetune_lfm_complete_history.py
ADDED
|
@@ -0,0 +1,801 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Fine-tuning Script for LFM2-2.6B with Complete Dialogue History
|
| 3 |
+
Following KokoroChat methodology - uses entire conversation context
|
| 4 |
+
Filename: finetune_lfm_complete_history.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from transformers import (
|
| 9 |
+
AutoModelForCausalLM,
|
| 10 |
+
AutoTokenizer,
|
| 11 |
+
TrainingArguments,
|
| 12 |
+
Trainer,
|
| 13 |
+
DataCollatorForLanguageModeling,
|
| 14 |
+
BitsAndBytesConfig,
|
| 15 |
+
TrainerCallback
|
| 16 |
+
)
|
| 17 |
+
from peft import (
|
| 18 |
+
LoraConfig,
|
| 19 |
+
get_peft_model,
|
| 20 |
+
prepare_model_for_kbit_training,
|
| 21 |
+
TaskType,
|
| 22 |
+
PeftModel,
|
| 23 |
+
PeftConfig
|
| 24 |
+
)
|
| 25 |
+
from datasets import load_dataset, Dataset
|
| 26 |
+
import os
|
| 27 |
+
from typing import Dict, List, Optional
|
| 28 |
+
import numpy as np
|
| 29 |
+
from tqdm import tqdm
|
| 30 |
+
import json
|
| 31 |
+
import gc
|
| 32 |
+
import warnings
|
| 33 |
+
import wandb
|
| 34 |
+
from datetime import datetime
|
| 35 |
+
|
| 36 |
+
warnings.filterwarnings('ignore')
|
| 37 |
+
|
| 38 |
+
# Enable TF32 for H100 optimization
|
| 39 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 40 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 41 |
+
|
| 42 |
+
class LFMKokoroChatFineTuner:
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
model_name: str = "LiquidAI/LFM2-2.6B",
|
| 46 |
+
use_4bit: bool = False, # H100 has enough memory
|
| 47 |
+
max_seq_length: int = 2048 # Increased for complete dialogue history
|
| 48 |
+
):
|
| 49 |
+
"""
|
| 50 |
+
Initialize the fine-tuner for LFM models with complete dialogue history support
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
model_name: Name of the base model
|
| 54 |
+
use_4bit: Whether to use 4-bit quantization
|
| 55 |
+
max_seq_length: Maximum sequence length for complete dialogues
|
| 56 |
+
"""
|
| 57 |
+
self.model_name = model_name
|
| 58 |
+
self.use_4bit = use_4bit
|
| 59 |
+
self.max_seq_length = max_seq_length
|
| 60 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 61 |
+
|
| 62 |
+
print("="*80)
|
| 63 |
+
print("🚀 LFM Fine-tuning with Complete Dialogue History (KokoroChat Method)")
|
| 64 |
+
print("="*80)
|
| 65 |
+
print(f"Model: {model_name}")
|
| 66 |
+
print(f"Device: {self.device}")
|
| 67 |
+
print(f"Max sequence length: {max_seq_length}")
|
| 68 |
+
|
| 69 |
+
# GPU information
|
| 70 |
+
if torch.cuda.is_available():
|
| 71 |
+
num_gpus = torch.cuda.device_count()
|
| 72 |
+
print(f"Number of GPUs: {num_gpus}")
|
| 73 |
+
for i in range(num_gpus):
|
| 74 |
+
gpu_name = torch.cuda.get_device_name(i)
|
| 75 |
+
gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1e9
|
| 76 |
+
print(f" GPU {i}: {gpu_name} ({gpu_memory:.2f} GB)")
|
| 77 |
+
|
| 78 |
+
# Initialize WandB
|
| 79 |
+
self.init_wandb()
|
| 80 |
+
|
| 81 |
+
def init_wandb(self):
|
| 82 |
+
"""Initialize WandB for experiment tracking"""
|
| 83 |
+
try:
|
| 84 |
+
run_name = f"lfm-kokoro-complete-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
|
| 85 |
+
|
| 86 |
+
wandb.init(
|
| 87 |
+
project="lfm-kokoro-complete-history",
|
| 88 |
+
name=run_name,
|
| 89 |
+
config={
|
| 90 |
+
"model_name": self.model_name,
|
| 91 |
+
"use_4bit_quantization": self.use_4bit,
|
| 92 |
+
"max_seq_length": self.max_seq_length,
|
| 93 |
+
"device": str(self.device),
|
| 94 |
+
"num_gpus": torch.cuda.device_count() if torch.cuda.is_available() else 0,
|
| 95 |
+
"methodology": "Complete dialogue history (KokoroChat)",
|
| 96 |
+
"framework": "transformers + peft",
|
| 97 |
+
"task": "japanese_counseling"
|
| 98 |
+
},
|
| 99 |
+
tags=["counseling", "japanese", "lfm", "complete-history", "kokoro"]
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
print(f"✅ WandB initialized: {wandb.run.name}")
|
| 103 |
+
print(f"📊 View run at: {wandb.run.get_url()}")
|
| 104 |
+
self.wandb_enabled = True
|
| 105 |
+
|
| 106 |
+
except Exception as e:
|
| 107 |
+
print(f"⚠️ WandB initialization failed: {e}")
|
| 108 |
+
self.wandb_enabled = False
|
| 109 |
+
os.environ["WANDB_DISABLED"] = "true"
|
| 110 |
+
|
| 111 |
+
def setup_model_and_tokenizer(self):
|
| 112 |
+
"""Setup model with quantization and LoRA"""
|
| 113 |
+
|
| 114 |
+
print("\n📚 Setting up model and tokenizer...")
|
| 115 |
+
|
| 116 |
+
# Load tokenizer
|
| 117 |
+
print("Loading tokenizer...")
|
| 118 |
+
try:
|
| 119 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 120 |
+
self.model_name,
|
| 121 |
+
trust_remote_code=True
|
| 122 |
+
)
|
| 123 |
+
except:
|
| 124 |
+
print("Using fallback tokenizer...")
|
| 125 |
+
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 126 |
+
|
| 127 |
+
# Set special tokens
|
| 128 |
+
if self.tokenizer.pad_token is None:
|
| 129 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 130 |
+
if self.tokenizer.eos_token is None:
|
| 131 |
+
self.tokenizer.eos_token = "</s>"
|
| 132 |
+
self.tokenizer.pad_token = "</s>"
|
| 133 |
+
|
| 134 |
+
self.tokenizer.padding_side = "left" # Important for batch generation
|
| 135 |
+
|
| 136 |
+
# Quantization config
|
| 137 |
+
if self.use_4bit:
|
| 138 |
+
print("Setting up 4-bit quantization...")
|
| 139 |
+
bnb_config = BitsAndBytesConfig(
|
| 140 |
+
load_in_4bit=True,
|
| 141 |
+
bnb_4bit_quant_type="nf4",
|
| 142 |
+
bnb_4bit_compute_dtype=torch.bfloat16, # BF16 for H100
|
| 143 |
+
bnb_4bit_use_double_quant=True
|
| 144 |
+
)
|
| 145 |
+
else:
|
| 146 |
+
bnb_config = None
|
| 147 |
+
|
| 148 |
+
# Load model
|
| 149 |
+
print(f"Loading model: {self.model_name}...")
|
| 150 |
+
model_kwargs = {
|
| 151 |
+
"trust_remote_code": True,
|
| 152 |
+
"torch_dtype": torch.bfloat16, # BF16 for H100
|
| 153 |
+
"device_map": "auto",
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
if bnb_config:
|
| 157 |
+
model_kwargs["quantization_config"] = bnb_config
|
| 158 |
+
|
| 159 |
+
try:
|
| 160 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 161 |
+
self.model_name,
|
| 162 |
+
**model_kwargs
|
| 163 |
+
)
|
| 164 |
+
except Exception as e:
|
| 165 |
+
print(f"Error loading model: {e}")
|
| 166 |
+
print("Attempting without device_map...")
|
| 167 |
+
model_kwargs.pop("device_map", None)
|
| 168 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 169 |
+
self.model_name,
|
| 170 |
+
**model_kwargs
|
| 171 |
+
)
|
| 172 |
+
self.model = self.model.to(self.device)
|
| 173 |
+
|
| 174 |
+
# Enable gradient checkpointing
|
| 175 |
+
if hasattr(self.model, 'gradient_checkpointing_enable'):
|
| 176 |
+
self.model.gradient_checkpointing_enable()
|
| 177 |
+
|
| 178 |
+
# Prepare for k-bit training if using quantization
|
| 179 |
+
if self.use_4bit:
|
| 180 |
+
print("Preparing model for 4-bit training...")
|
| 181 |
+
self.model = prepare_model_for_kbit_training(self.model)
|
| 182 |
+
|
| 183 |
+
# LoRA configuration optimized for dialogue with complete history
|
| 184 |
+
print("Applying LoRA configuration...")
|
| 185 |
+
|
| 186 |
+
# Find target modules
|
| 187 |
+
target_modules = self.find_target_modules()
|
| 188 |
+
|
| 189 |
+
# Higher rank for complex dialogue understanding
|
| 190 |
+
lora_config = LoraConfig(
|
| 191 |
+
r=64, # Increased for better dialogue understanding
|
| 192 |
+
lora_alpha=128,
|
| 193 |
+
target_modules=target_modules,
|
| 194 |
+
lora_dropout=0.05,
|
| 195 |
+
bias="none",
|
| 196 |
+
task_type=TaskType.CAUSAL_LM,
|
| 197 |
+
inference_mode=False
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Apply LoRA
|
| 201 |
+
self.model = get_peft_model(self.model, lora_config)
|
| 202 |
+
|
| 203 |
+
# Print trainable parameters
|
| 204 |
+
trainable_params = 0
|
| 205 |
+
all_params = 0
|
| 206 |
+
for _, param in self.model.named_parameters():
|
| 207 |
+
all_params += param.numel()
|
| 208 |
+
if param.requires_grad:
|
| 209 |
+
trainable_params += param.numel()
|
| 210 |
+
|
| 211 |
+
trainable_percentage = 100 * trainable_params / all_params if all_params > 0 else 0
|
| 212 |
+
|
| 213 |
+
print(f"Trainable parameters: {trainable_params:,} / {all_params:,} ({trainable_percentage:.2f}%)")
|
| 214 |
+
|
| 215 |
+
# Log to WandB
|
| 216 |
+
if self.wandb_enabled:
|
| 217 |
+
wandb.config.update({
|
| 218 |
+
"lora_r": lora_config.r,
|
| 219 |
+
"lora_alpha": lora_config.lora_alpha,
|
| 220 |
+
"lora_dropout": lora_config.lora_dropout,
|
| 221 |
+
"lora_target_modules": target_modules,
|
| 222 |
+
"total_parameters": all_params,
|
| 223 |
+
"trainable_parameters": trainable_params,
|
| 224 |
+
"trainable_percentage": trainable_percentage
|
| 225 |
+
})
|
| 226 |
+
|
| 227 |
+
self.model.print_trainable_parameters()
|
| 228 |
+
|
| 229 |
+
def find_target_modules(self):
|
| 230 |
+
"""Find linear modules to apply LoRA to"""
|
| 231 |
+
target_modules = []
|
| 232 |
+
for name, module in self.model.named_modules():
|
| 233 |
+
if isinstance(module, torch.nn.Linear):
|
| 234 |
+
names = name.split('.')
|
| 235 |
+
if len(names) > 0:
|
| 236 |
+
target_modules.append(names[-1])
|
| 237 |
+
|
| 238 |
+
# Remove duplicates
|
| 239 |
+
target_modules = list(set(target_modules))
|
| 240 |
+
|
| 241 |
+
# Common patterns for transformer models
|
| 242 |
+
common_targets = ["q_proj", "v_proj", "k_proj", "o_proj",
|
| 243 |
+
"gate_proj", "up_proj", "down_proj",
|
| 244 |
+
"fc1", "fc2", "query", "key", "value", "dense"]
|
| 245 |
+
|
| 246 |
+
# Filter to common targets
|
| 247 |
+
final_targets = [t for t in target_modules if any(ct in t.lower() for ct in common_targets)]
|
| 248 |
+
|
| 249 |
+
if not final_targets:
|
| 250 |
+
# Fallback to specific modules for LFM
|
| 251 |
+
final_targets = ["q_proj", "v_proj", "k_proj", "o_proj"]
|
| 252 |
+
|
| 253 |
+
print(f"LoRA target modules: {final_targets}")
|
| 254 |
+
return final_targets
|
| 255 |
+
|
| 256 |
+
def load_and_process_datasets(self, data_path: str):
|
| 257 |
+
"""
|
| 258 |
+
Load and process datasets with complete dialogue history
|
| 259 |
+
Handles the new data format with full conversation context
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
print(f"\n📚 Loading datasets from {data_path}...")
|
| 263 |
+
|
| 264 |
+
# Check for dataset statistics
|
| 265 |
+
stats_file = os.path.join(data_path, 'dataset_stats.json')
|
| 266 |
+
if os.path.exists(stats_file):
|
| 267 |
+
with open(stats_file, 'r') as f:
|
| 268 |
+
stats = json.load(f)
|
| 269 |
+
print("Dataset statistics:")
|
| 270 |
+
print(f" Average dialogue history: {stats['dialogue_history_stats']['mean_length']:.1f} turns")
|
| 271 |
+
print(f" Max dialogue history: {stats['dialogue_history_stats']['max_length']} turns")
|
| 272 |
+
print(f" Median dialogue history: {stats['dialogue_history_stats']['median_length']:.1f} turns")
|
| 273 |
+
|
| 274 |
+
# Load datasets
|
| 275 |
+
train_data = []
|
| 276 |
+
val_data = []
|
| 277 |
+
|
| 278 |
+
# Load training data
|
| 279 |
+
train_file = os.path.join(data_path, 'train.jsonl')
|
| 280 |
+
with open(train_file, 'r', encoding='utf-8') as f:
|
| 281 |
+
for line in tqdm(f, desc="Loading training data"):
|
| 282 |
+
item = json.loads(line)
|
| 283 |
+
train_data.append({
|
| 284 |
+
'text': item['text'],
|
| 285 |
+
'history_length': item.get('history_length', 0),
|
| 286 |
+
'score': item.get('score', 100),
|
| 287 |
+
'topic': item.get('topic', 'general')
|
| 288 |
+
})
|
| 289 |
+
|
| 290 |
+
# Load validation data
|
| 291 |
+
val_file = os.path.join(data_path, 'val.jsonl')
|
| 292 |
+
with open(val_file, 'r', encoding='utf-8') as f:
|
| 293 |
+
for line in tqdm(f, desc="Loading validation data"):
|
| 294 |
+
item = json.loads(line)
|
| 295 |
+
val_data.append({
|
| 296 |
+
'text': item['text'],
|
| 297 |
+
'history_length': item.get('history_length', 0),
|
| 298 |
+
'score': item.get('score', 100),
|
| 299 |
+
'topic': item.get('topic', 'general')
|
| 300 |
+
})
|
| 301 |
+
|
| 302 |
+
print(f"Loaded {len(train_data)} training examples")
|
| 303 |
+
print(f"Loaded {len(val_data)} validation examples")
|
| 304 |
+
|
| 305 |
+
# Analyze dialogue history lengths
|
| 306 |
+
train_history_lengths = [d['history_length'] for d in train_data]
|
| 307 |
+
val_history_lengths = [d['history_length'] for d in val_data]
|
| 308 |
+
|
| 309 |
+
print(f"\nDialogue history length distribution:")
|
| 310 |
+
print(f" Training - Mean: {np.mean(train_history_lengths):.1f}, Max: {max(train_history_lengths)}")
|
| 311 |
+
print(f" Validation - Mean: {np.mean(val_history_lengths):.1f}, Max: {max(val_history_lengths)}")
|
| 312 |
+
|
| 313 |
+
# Log to WandB
|
| 314 |
+
if self.wandb_enabled:
|
| 315 |
+
wandb.config.update({
|
| 316 |
+
"train_examples": len(train_data),
|
| 317 |
+
"val_examples": len(val_data),
|
| 318 |
+
"avg_train_history_length": float(np.mean(train_history_lengths)),
|
| 319 |
+
"max_train_history_length": int(max(train_history_lengths)),
|
| 320 |
+
"avg_val_history_length": float(np.mean(val_history_lengths)),
|
| 321 |
+
"max_val_history_length": int(max(val_history_lengths))
|
| 322 |
+
})
|
| 323 |
+
|
| 324 |
+
# Log history length distribution
|
| 325 |
+
wandb.log({
|
| 326 |
+
"train_history_distribution": wandb.Histogram(train_history_lengths),
|
| 327 |
+
"val_history_distribution": wandb.Histogram(val_history_lengths)
|
| 328 |
+
})
|
| 329 |
+
|
| 330 |
+
# Tokenize datasets
|
| 331 |
+
print("\nTokenizing datasets with complete dialogue history...")
|
| 332 |
+
print(f"Using max sequence length: {self.max_seq_length}")
|
| 333 |
+
|
| 334 |
+
# Extract texts for tokenization
|
| 335 |
+
train_texts = [d['text'] for d in train_data]
|
| 336 |
+
val_texts = [d['text'] for d in val_data]
|
| 337 |
+
|
| 338 |
+
# Tokenize with longer context for complete history
|
| 339 |
+
train_encodings = self.tokenize_texts(train_texts, desc="Tokenizing training data")
|
| 340 |
+
val_encodings = self.tokenize_texts(val_texts, desc="Tokenizing validation data")
|
| 341 |
+
|
| 342 |
+
# Create datasets
|
| 343 |
+
self.train_dataset = Dataset.from_dict(train_encodings)
|
| 344 |
+
self.val_dataset = Dataset.from_dict(val_encodings)
|
| 345 |
+
|
| 346 |
+
# Set format for PyTorch
|
| 347 |
+
self.train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
|
| 348 |
+
self.val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
|
| 349 |
+
|
| 350 |
+
# Clean up memory
|
| 351 |
+
del train_texts, val_texts, train_encodings, val_encodings, train_data, val_data
|
| 352 |
+
gc.collect()
|
| 353 |
+
|
| 354 |
+
print("✅ Datasets loaded and tokenized")
|
| 355 |
+
|
| 356 |
+
def tokenize_texts(self, texts: List[str], batch_size: int = 50, desc: str = "Tokenizing"):
|
| 357 |
+
"""
|
| 358 |
+
Tokenize texts in batches with support for longer sequences
|
| 359 |
+
"""
|
| 360 |
+
all_input_ids = []
|
| 361 |
+
all_attention_masks = []
|
| 362 |
+
|
| 363 |
+
# Process in smaller batches for long sequences
|
| 364 |
+
for i in tqdm(range(0, len(texts), batch_size), desc=desc):
|
| 365 |
+
batch_texts = texts[i:i + batch_size]
|
| 366 |
+
|
| 367 |
+
# Tokenize batch with longer max length
|
| 368 |
+
encodings = self.tokenizer(
|
| 369 |
+
batch_texts,
|
| 370 |
+
truncation=True,
|
| 371 |
+
padding='max_length',
|
| 372 |
+
max_length=self.max_seq_length,
|
| 373 |
+
return_tensors='pt'
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
# Convert to lists
|
| 377 |
+
all_input_ids.extend(encodings['input_ids'].tolist())
|
| 378 |
+
all_attention_masks.extend(encodings['attention_mask'].tolist())
|
| 379 |
+
|
| 380 |
+
# Create labels (same as input_ids for causal LM)
|
| 381 |
+
labels = all_input_ids.copy()
|
| 382 |
+
|
| 383 |
+
return {
|
| 384 |
+
'input_ids': all_input_ids,
|
| 385 |
+
'attention_mask': all_attention_masks,
|
| 386 |
+
'labels': labels
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
def setup_training_args(self, output_dir: str = "./lfm_kokoro_complete"):
|
| 390 |
+
"""Setup training arguments optimized for complete dialogue history"""
|
| 391 |
+
|
| 392 |
+
print("\n⚙️ Setting up training arguments...")
|
| 393 |
+
|
| 394 |
+
# Calculate batch sizes based on sequence length and GPU memory
|
| 395 |
+
if torch.cuda.is_available():
|
| 396 |
+
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
|
| 397 |
+
num_gpus = torch.cuda.device_count()
|
| 398 |
+
|
| 399 |
+
# Adjust batch size based on sequence length and GPU memory
|
| 400 |
+
if self.max_seq_length >= 2048:
|
| 401 |
+
if gpu_memory >= 80: # H100 80GB
|
| 402 |
+
batch_size = 4
|
| 403 |
+
gradient_accumulation = 4
|
| 404 |
+
elif gpu_memory >= 40:
|
| 405 |
+
batch_size = 2
|
| 406 |
+
gradient_accumulation = 8
|
| 407 |
+
else:
|
| 408 |
+
batch_size = 1
|
| 409 |
+
gradient_accumulation = 16
|
| 410 |
+
else:
|
| 411 |
+
batch_size = 8
|
| 412 |
+
gradient_accumulation = 2
|
| 413 |
+
|
| 414 |
+
# Adjust for multiple GPUs
|
| 415 |
+
if num_gpus > 1:
|
| 416 |
+
batch_size = batch_size * num_gpus
|
| 417 |
+
gradient_accumulation = max(1, gradient_accumulation // num_gpus)
|
| 418 |
+
else:
|
| 419 |
+
batch_size = 1
|
| 420 |
+
gradient_accumulation = 32
|
| 421 |
+
|
| 422 |
+
print(f"Batch configuration:")
|
| 423 |
+
print(f" Per device batch size: {batch_size}")
|
| 424 |
+
print(f" Gradient accumulation steps: {gradient_accumulation}")
|
| 425 |
+
print(f" Effective batch size: {batch_size * gradient_accumulation}")
|
| 426 |
+
|
| 427 |
+
# Update WandB config
|
| 428 |
+
if self.wandb_enabled:
|
| 429 |
+
wandb.config.update({
|
| 430 |
+
"batch_size": batch_size,
|
| 431 |
+
"gradient_accumulation_steps": gradient_accumulation,
|
| 432 |
+
"effective_batch_size": batch_size * gradient_accumulation,
|
| 433 |
+
"num_epochs": 3,
|
| 434 |
+
"learning_rate": 2e-4,
|
| 435 |
+
"warmup_ratio": 0.1,
|
| 436 |
+
"weight_decay": 0.01,
|
| 437 |
+
"max_grad_norm": 1.0,
|
| 438 |
+
"lr_scheduler": "cosine",
|
| 439 |
+
"optimizer": "adamw_torch"
|
| 440 |
+
})
|
| 441 |
+
|
| 442 |
+
self.training_args = TrainingArguments(
|
| 443 |
+
output_dir=output_dir,
|
| 444 |
+
num_train_epochs=3,
|
| 445 |
+
per_device_train_batch_size=batch_size,
|
| 446 |
+
per_device_eval_batch_size=batch_size,
|
| 447 |
+
gradient_accumulation_steps=gradient_accumulation,
|
| 448 |
+
gradient_checkpointing=True,
|
| 449 |
+
warmup_ratio=0.1,
|
| 450 |
+
learning_rate=2e-4,
|
| 451 |
+
bf16=True, # Use BF16 for H100
|
| 452 |
+
tf32=True, # Enable TF32 for H100
|
| 453 |
+
logging_steps=10,
|
| 454 |
+
logging_first_step=True,
|
| 455 |
+
eval_strategy="steps",
|
| 456 |
+
eval_steps=100,
|
| 457 |
+
save_strategy="steps",
|
| 458 |
+
save_steps=200,
|
| 459 |
+
save_total_limit=3,
|
| 460 |
+
load_best_model_at_end=True,
|
| 461 |
+
metric_for_best_model="eval_loss",
|
| 462 |
+
greater_is_better=False,
|
| 463 |
+
report_to="wandb" if self.wandb_enabled else "none",
|
| 464 |
+
run_name=wandb.run.name if self.wandb_enabled and wandb.run else "local_run",
|
| 465 |
+
optim="adamw_torch",
|
| 466 |
+
lr_scheduler_type="cosine",
|
| 467 |
+
weight_decay=0.01,
|
| 468 |
+
max_grad_norm=1.0,
|
| 469 |
+
remove_unused_columns=False,
|
| 470 |
+
label_names=["labels"],
|
| 471 |
+
dataloader_num_workers=4,
|
| 472 |
+
dataloader_pin_memory=True,
|
| 473 |
+
ddp_find_unused_parameters=False if torch.cuda.device_count() > 1 else None,
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
def train(self):
|
| 477 |
+
"""Execute training with complete dialogue history"""
|
| 478 |
+
|
| 479 |
+
print("\n🎯 Starting training with complete dialogue history...")
|
| 480 |
+
|
| 481 |
+
# Data collator
|
| 482 |
+
data_collator = DataCollatorForLanguageModeling(
|
| 483 |
+
tokenizer=self.tokenizer,
|
| 484 |
+
mlm=False,
|
| 485 |
+
pad_to_multiple_of=8
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
# Custom callback for metrics
|
| 489 |
+
class MetricsCallback(TrainerCallback):
|
| 490 |
+
def __init__(self, wandb_enabled):
|
| 491 |
+
self.wandb_enabled = wandb_enabled
|
| 492 |
+
|
| 493 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 494 |
+
if logs and self.wandb_enabled:
|
| 495 |
+
# Add perplexity
|
| 496 |
+
if "loss" in logs:
|
| 497 |
+
logs["perplexity"] = np.exp(logs["loss"])
|
| 498 |
+
if "eval_loss" in logs:
|
| 499 |
+
logs["eval_perplexity"] = np.exp(logs["eval_loss"])
|
| 500 |
+
|
| 501 |
+
# Log to WandB
|
| 502 |
+
wandb.log(logs, step=state.global_step)
|
| 503 |
+
|
| 504 |
+
return control
|
| 505 |
+
|
| 506 |
+
# Initialize trainer
|
| 507 |
+
trainer = Trainer(
|
| 508 |
+
model=self.model,
|
| 509 |
+
args=self.training_args,
|
| 510 |
+
train_dataset=self.train_dataset,
|
| 511 |
+
eval_dataset=self.val_dataset,
|
| 512 |
+
data_collator=data_collator,
|
| 513 |
+
tokenizer=self.tokenizer,
|
| 514 |
+
callbacks=[MetricsCallback(self.wandb_enabled)] if self.wandb_enabled else [],
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
# Calculate total steps
|
| 518 |
+
total_steps = len(self.train_dataset) // (
|
| 519 |
+
self.training_args.per_device_train_batch_size *
|
| 520 |
+
self.training_args.gradient_accumulation_steps
|
| 521 |
+
) * self.training_args.num_train_epochs
|
| 522 |
+
|
| 523 |
+
print("="*60)
|
| 524 |
+
print("Training Information:")
|
| 525 |
+
print(f" Total training samples: {len(self.train_dataset)}")
|
| 526 |
+
print(f" Total validation samples: {len(self.val_dataset)}")
|
| 527 |
+
print(f" Total training steps: {total_steps}")
|
| 528 |
+
print(f" Max sequence length: {self.max_seq_length}")
|
| 529 |
+
print("="*60)
|
| 530 |
+
|
| 531 |
+
# Log training start
|
| 532 |
+
if self.wandb_enabled:
|
| 533 |
+
wandb.log({
|
| 534 |
+
"training_status": "started",
|
| 535 |
+
"total_steps": total_steps,
|
| 536 |
+
"max_seq_length": self.max_seq_length
|
| 537 |
+
})
|
| 538 |
+
|
| 539 |
+
try:
|
| 540 |
+
# Train
|
| 541 |
+
print("\n🚀 Training started...")
|
| 542 |
+
train_result = trainer.train()
|
| 543 |
+
|
| 544 |
+
# Save model
|
| 545 |
+
print("\n💾 Saving fine-tuned model...")
|
| 546 |
+
final_model_path = os.path.join(self.training_args.output_dir, "final_model")
|
| 547 |
+
trainer.save_model(final_model_path)
|
| 548 |
+
self.tokenizer.save_pretrained(final_model_path)
|
| 549 |
+
|
| 550 |
+
# Save training metrics
|
| 551 |
+
with open(os.path.join(self.training_args.output_dir, "training_metrics.json"), 'w') as f:
|
| 552 |
+
json.dump(train_result.metrics, f, indent=2)
|
| 553 |
+
|
| 554 |
+
# Final evaluation
|
| 555 |
+
print("\n📊 Running final evaluation...")
|
| 556 |
+
eval_results = trainer.evaluate()
|
| 557 |
+
|
| 558 |
+
# Save evaluation metrics
|
| 559 |
+
with open(os.path.join(self.training_args.output_dir, "eval_metrics.json"), 'w') as f:
|
| 560 |
+
json.dump(eval_results, f, indent=2)
|
| 561 |
+
|
| 562 |
+
# Log final metrics
|
| 563 |
+
if self.wandb_enabled:
|
| 564 |
+
wandb.run.summary.update({
|
| 565 |
+
"final_train_loss": train_result.metrics.get("train_loss", 0),
|
| 566 |
+
"final_eval_loss": eval_results.get("eval_loss", 0),
|
| 567 |
+
"final_eval_perplexity": np.exp(eval_results.get("eval_loss", 0)),
|
| 568 |
+
"total_training_time": train_result.metrics.get("train_runtime", 0),
|
| 569 |
+
"training_samples_per_second": train_result.metrics.get("train_samples_per_second", 0),
|
| 570 |
+
"training_status": "completed"
|
| 571 |
+
})
|
| 572 |
+
|
| 573 |
+
# Save model artifact
|
| 574 |
+
artifact = wandb.Artifact(
|
| 575 |
+
name=f"kokoro-model-complete-{wandb.run.id}",
|
| 576 |
+
type="model",
|
| 577 |
+
description="LFM model fine-tuned with complete dialogue history",
|
| 578 |
+
metadata={
|
| 579 |
+
"base_model": self.model_name,
|
| 580 |
+
"final_loss": float(eval_results.get("eval_loss", 0)),
|
| 581 |
+
"final_perplexity": float(np.exp(eval_results.get("eval_loss", 0))),
|
| 582 |
+
"max_seq_length": self.max_seq_length,
|
| 583 |
+
"methodology": "Complete dialogue history (KokoroChat)"
|
| 584 |
+
}
|
| 585 |
+
)
|
| 586 |
+
artifact.add_dir(final_model_path)
|
| 587 |
+
wandb.log_artifact(artifact)
|
| 588 |
+
|
| 589 |
+
print("\n" + "="*60)
|
| 590 |
+
print("✅ Training completed successfully!")
|
| 591 |
+
print(f"📁 Model saved to: {final_model_path}")
|
| 592 |
+
print(f"📉 Final eval loss: {eval_results.get('eval_loss', 0):.4f}")
|
| 593 |
+
print(f"📊 Final perplexity: {np.exp(eval_results.get('eval_loss', 0)):.2f}")
|
| 594 |
+
if self.wandb_enabled and wandb.run:
|
| 595 |
+
print(f"🔗 View results at: {wandb.run.get_url()}")
|
| 596 |
+
print("="*60)
|
| 597 |
+
|
| 598 |
+
return trainer
|
| 599 |
+
|
| 600 |
+
except Exception as e:
|
| 601 |
+
print(f"❌ Error during training: {e}")
|
| 602 |
+
|
| 603 |
+
if self.wandb_enabled:
|
| 604 |
+
wandb.run.summary["training_status"] = "failed"
|
| 605 |
+
wandb.run.summary["error"] = str(e)
|
| 606 |
+
|
| 607 |
+
# Save emergency checkpoint
|
| 608 |
+
try:
|
| 609 |
+
emergency_path = os.path.join(self.training_args.output_dir, "emergency_checkpoint")
|
| 610 |
+
self.model.save_pretrained(emergency_path)
|
| 611 |
+
self.tokenizer.save_pretrained(emergency_path)
|
| 612 |
+
print(f"💾 Emergency checkpoint saved to: {emergency_path}")
|
| 613 |
+
except:
|
| 614 |
+
print("❌ Could not save emergency checkpoint")
|
| 615 |
+
|
| 616 |
+
raise e
|
| 617 |
+
|
| 618 |
+
finally:
|
| 619 |
+
if self.wandb_enabled:
|
| 620 |
+
wandb.finish()
|
| 621 |
+
|
| 622 |
+
def test_model_with_complete_history(model_path: str):
|
| 623 |
+
"""Test the fine-tuned model with complete dialogue history examples"""
|
| 624 |
+
|
| 625 |
+
print("\n" + "="*60)
|
| 626 |
+
print("🧪 Testing model with complete dialogue history")
|
| 627 |
+
print("="*60)
|
| 628 |
+
|
| 629 |
+
# Load tokenizer and model
|
| 630 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
|
| 631 |
+
|
| 632 |
+
# Check if it's a PEFT model
|
| 633 |
+
adapter_config_path = os.path.join(model_path, "adapter_config.json")
|
| 634 |
+
if os.path.exists(adapter_config_path):
|
| 635 |
+
print("Loading as PEFT model...")
|
| 636 |
+
config = PeftConfig.from_pretrained(model_path)
|
| 637 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 638 |
+
config.base_model_name_or_path,
|
| 639 |
+
torch_dtype=torch.bfloat16,
|
| 640 |
+
device_map="auto",
|
| 641 |
+
trust_remote_code=True
|
| 642 |
+
)
|
| 643 |
+
model = PeftModel.from_pretrained(base_model, model_path)
|
| 644 |
+
else:
|
| 645 |
+
print("Loading as regular model...")
|
| 646 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 647 |
+
model_path,
|
| 648 |
+
torch_dtype=torch.bfloat16,
|
| 649 |
+
device_map="auto",
|
| 650 |
+
local_files_only=True,
|
| 651 |
+
trust_remote_code=True
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
model.eval()
|
| 655 |
+
|
| 656 |
+
# Test with dialogue history examples
|
| 657 |
+
test_cases = [
|
| 658 |
+
{
|
| 659 |
+
"history": "クライアント: こんにちは。最近ストレスを感じています。\nカウンセラー: こんにちは。ストレスを感じていらっしゃるのですね。どのような状況でストレスを感じることが多いですか?\n",
|
| 660 |
+
"current": "クライアント: 仕事が忙しくて、休む時間がありません。"
|
| 661 |
+
},
|
| 662 |
+
{
|
| 663 |
+
"history": "",
|
| 664 |
+
"current": "クライアント: 人間関係で悩んでいます。"
|
| 665 |
+
}
|
| 666 |
+
]
|
| 667 |
+
|
| 668 |
+
print("Testing with complete dialogue history:\n")
|
| 669 |
+
|
| 670 |
+
for i, test_case in enumerate(test_cases, 1):
|
| 671 |
+
print(f"Test Case {i}:")
|
| 672 |
+
print("-" * 40)
|
| 673 |
+
|
| 674 |
+
# Format input with complete history
|
| 675 |
+
if test_case["history"]:
|
| 676 |
+
prompt = f"""### Instruction:
|
| 677 |
+
あなたは専門的な訓練を受けた心理カウンセラーです。
|
| 678 |
+
以下の完全な対話履歴を踏まえて、カウンセラーとして適切な応答を生成してください。
|
| 679 |
+
|
| 680 |
+
### Dialogue History:
|
| 681 |
+
{test_case["history"]}{test_case["current"]}
|
| 682 |
+
|
| 683 |
+
### Response:
|
| 684 |
+
"""
|
| 685 |
+
else:
|
| 686 |
+
prompt = f"""### Instruction:
|
| 687 |
+
あなたは専門的な訓練を受けた心理カウンセラーです。
|
| 688 |
+
|
| 689 |
+
### Dialogue History:
|
| 690 |
+
{test_case["current"]}
|
| 691 |
+
|
| 692 |
+
### Response:
|
| 693 |
+
"""
|
| 694 |
+
|
| 695 |
+
# Generate response
|
| 696 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
|
| 697 |
+
inputs = {k: v.cuda() if torch.cuda.is_available() else v for k, v in inputs.items()}
|
| 698 |
+
|
| 699 |
+
with torch.no_grad():
|
| 700 |
+
outputs = model.generate(
|
| 701 |
+
**inputs,
|
| 702 |
+
max_new_tokens=150,
|
| 703 |
+
temperature=0,
|
| 704 |
+
do_sample=True,
|
| 705 |
+
top_p=0.9,
|
| 706 |
+
pad_token_id=tokenizer.pad_token_id
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 710 |
+
response = response.split("### Response:")[-1].strip() if "### Response:" in response else response
|
| 711 |
+
|
| 712 |
+
# print(f"History Length: {len(test_case['history'].split('\\n')) if test_case['history'] else 0} turns")
|
| 713 |
+
print("History Length: {} turns".format(len(test_case['history'].split('\\n')) if test_case['history'] else 0))
|
| 714 |
+
|
| 715 |
+
print(f"Current Input: {test_case['current']}")
|
| 716 |
+
print(f"Generated Response: {response[:300]}...")
|
| 717 |
+
print()
|
| 718 |
+
|
| 719 |
+
print("="*60)
|
| 720 |
+
|
| 721 |
+
# Main execution
|
| 722 |
+
if __name__ == "__main__":
|
| 723 |
+
import argparse
|
| 724 |
+
|
| 725 |
+
parser = argparse.ArgumentParser(description='Fine-tune LFM model with complete dialogue history')
|
| 726 |
+
parser.add_argument('--model_name', type=str, default='LiquidAI/LFM2-2.6B',
|
| 727 |
+
help='Base model name')
|
| 728 |
+
parser.add_argument('--data_path', type=str, default='./kokoro_processed_data',
|
| 729 |
+
help='Path to processed data with complete dialogue history')
|
| 730 |
+
parser.add_argument('--output_dir', type=str, default='./lfm_kokoro_complete',
|
| 731 |
+
help='Output directory for fine-tuned model')
|
| 732 |
+
parser.add_argument('--max_seq_length', type=int, default=2048,
|
| 733 |
+
help='Maximum sequence length for complete dialogues')
|
| 734 |
+
parser.add_argument('--use_4bit', action='store_true',
|
| 735 |
+
help='Use 4-bit quantization')
|
| 736 |
+
parser.add_argument('--test_only', action='store_true',
|
| 737 |
+
help='Only test existing model')
|
| 738 |
+
|
| 739 |
+
args = parser.parse_args()
|
| 740 |
+
|
| 741 |
+
if args.test_only:
|
| 742 |
+
# Test existing model
|
| 743 |
+
test_model_with_complete_history(
|
| 744 |
+
os.path.join(args.output_dir, "final_model")
|
| 745 |
+
)
|
| 746 |
+
else:
|
| 747 |
+
# Check CUDA availability
|
| 748 |
+
if not torch.cuda.is_available():
|
| 749 |
+
print("⚠️ Warning: CUDA is not available. Training will be slow.")
|
| 750 |
+
response = input("Continue? (y/n): ")
|
| 751 |
+
if response.lower() != 'y':
|
| 752 |
+
exit()
|
| 753 |
+
|
| 754 |
+
try:
|
| 755 |
+
# Clear GPU cache
|
| 756 |
+
if torch.cuda.is_available():
|
| 757 |
+
torch.cuda.empty_cache()
|
| 758 |
+
|
| 759 |
+
# Initialize fine-tuner
|
| 760 |
+
print(f"🚀 Initializing fine-tuner for complete dialogue history")
|
| 761 |
+
finetuner = LFMKokoroChatFineTuner(
|
| 762 |
+
model_name=args.model_name,
|
| 763 |
+
use_4bit=args.use_4bit,
|
| 764 |
+
max_seq_length=args.max_seq_length
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
+
# Setup model
|
| 768 |
+
finetuner.setup_model_and_tokenizer()
|
| 769 |
+
|
| 770 |
+
# Load datasets
|
| 771 |
+
finetuner.load_and_process_datasets(args.data_path)
|
| 772 |
+
|
| 773 |
+
# Setup training arguments
|
| 774 |
+
finetuner.setup_training_args(args.output_dir)
|
| 775 |
+
|
| 776 |
+
# Train
|
| 777 |
+
trainer = finetuner.train()
|
| 778 |
+
|
| 779 |
+
# Test the model
|
| 780 |
+
print("\n🧪 Testing the fine-tuned model...")
|
| 781 |
+
test_model_with_complete_history(
|
| 782 |
+
os.path.join(args.output_dir, "final_model")
|
| 783 |
+
)
|
| 784 |
+
|
| 785 |
+
print("\n✅ Fine-tuning with complete dialogue history completed!")
|
| 786 |
+
print(f"📁 Model saved to: {args.output_dir}/final_model")
|
| 787 |
+
print("\n📋 Next steps:")
|
| 788 |
+
print(f"1. Test more: python {__file__} --test_only --output_dir {args.output_dir}")
|
| 789 |
+
print("2. Run benchmarking with complete history support")
|
| 790 |
+
print("3. Deploy for production use")
|
| 791 |
+
|
| 792 |
+
except KeyboardInterrupt:
|
| 793 |
+
print("\n\n⚠️ Training interrupted by user.")
|
| 794 |
+
if wandb.run:
|
| 795 |
+
wandb.finish()
|
| 796 |
+
except Exception as e:
|
| 797 |
+
print(f"\n❌ Error: {e}")
|
| 798 |
+
import traceback
|
| 799 |
+
traceback.print_exc()
|
| 800 |
+
if wandb.run:
|
| 801 |
+
wandb.finish()
|
finetune_trl_supervised.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Minimal Working Fine-tuning Script - No Complex Dependencies
|
| 3 |
+
Filename: finetune_minimal.py
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
from torch.utils.data import Dataset, DataLoader
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
# Fix the import issues by reinstalling
|
| 14 |
+
import subprocess
|
| 15 |
+
import sys
|
| 16 |
+
|
| 17 |
+
def fix_environment():
|
| 18 |
+
"""Fix the broken environment"""
|
| 19 |
+
print("Fixing environment...")
|
| 20 |
+
subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", "torchvision"], check=False)
|
| 21 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "--no-deps", "transformers==4.36.0"], check=False)
|
| 22 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "peft==0.7.0", "accelerate==0.25.0"], check=False)
|
| 23 |
+
|
| 24 |
+
# Uncomment if needed
|
| 25 |
+
# fix_environment()
|
| 26 |
+
|
| 27 |
+
# Now import after fixing
|
| 28 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 29 |
+
from peft import LoraConfig, get_peft_model, TaskType
|
| 30 |
+
|
| 31 |
+
class SimpleDataset(Dataset):
|
| 32 |
+
def __init__(self, data_path, tokenizer, max_length=1024):
|
| 33 |
+
self.data = []
|
| 34 |
+
with open(data_path, 'r') as f:
|
| 35 |
+
for line in f:
|
| 36 |
+
item = json.loads(line)
|
| 37 |
+
self.data.append(item['text'])
|
| 38 |
+
|
| 39 |
+
self.tokenizer = tokenizer
|
| 40 |
+
self.max_length = max_length
|
| 41 |
+
|
| 42 |
+
def __len__(self):
|
| 43 |
+
return len(self.data)
|
| 44 |
+
|
| 45 |
+
def __getitem__(self, idx):
|
| 46 |
+
text = self.data[idx]
|
| 47 |
+
encoded = self.tokenizer(
|
| 48 |
+
text,
|
| 49 |
+
truncation=True,
|
| 50 |
+
padding='max_length',
|
| 51 |
+
max_length=self.max_length,
|
| 52 |
+
return_tensors='pt'
|
| 53 |
+
)
|
| 54 |
+
return {
|
| 55 |
+
'input_ids': encoded['input_ids'].squeeze(),
|
| 56 |
+
'attention_mask': encoded['attention_mask'].squeeze()
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
def train_simple():
|
| 60 |
+
"""Simple training without complex dependencies"""
|
| 61 |
+
|
| 62 |
+
# Configuration
|
| 63 |
+
model_name = "LiquidAI/LFM2-2.6B"
|
| 64 |
+
data_dir = "./kokoro_processed_data"
|
| 65 |
+
output_dir = "./lfm_minimal_output"
|
| 66 |
+
batch_size = 4
|
| 67 |
+
learning_rate = 2e-4
|
| 68 |
+
num_epochs = 2
|
| 69 |
+
max_length = 1024
|
| 70 |
+
|
| 71 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 72 |
+
|
| 73 |
+
print("="*60)
|
| 74 |
+
print("Minimal Fine-tuning Script")
|
| 75 |
+
print("="*60)
|
| 76 |
+
|
| 77 |
+
# Device
|
| 78 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 79 |
+
print(f"Device: {device}")
|
| 80 |
+
|
| 81 |
+
# Load tokenizer
|
| 82 |
+
print("Loading tokenizer...")
|
| 83 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 84 |
+
if tokenizer.pad_token is None:
|
| 85 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 86 |
+
|
| 87 |
+
# Load model
|
| 88 |
+
print("Loading model...")
|
| 89 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 90 |
+
model_name,
|
| 91 |
+
torch_dtype=torch.bfloat16,
|
| 92 |
+
device_map="auto",
|
| 93 |
+
trust_remote_code=True
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Apply LoRA
|
| 97 |
+
print("Applying LoRA...")
|
| 98 |
+
peft_config = LoraConfig(
|
| 99 |
+
r=32,
|
| 100 |
+
lora_alpha=64,
|
| 101 |
+
target_modules=["q_proj", "v_proj"],
|
| 102 |
+
lora_dropout=0.05,
|
| 103 |
+
bias="none",
|
| 104 |
+
task_type=TaskType.CAUSAL_LM
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
model = get_peft_model(model, peft_config)
|
| 108 |
+
model.print_trainable_parameters()
|
| 109 |
+
|
| 110 |
+
# Load dataset
|
| 111 |
+
print("Loading dataset...")
|
| 112 |
+
train_dataset = SimpleDataset(
|
| 113 |
+
os.path.join(data_dir, "train.jsonl"),
|
| 114 |
+
tokenizer,
|
| 115 |
+
max_length
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
train_loader = DataLoader(
|
| 119 |
+
train_dataset,
|
| 120 |
+
batch_size=batch_size,
|
| 121 |
+
shuffle=True
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Optimizer
|
| 125 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
| 126 |
+
|
| 127 |
+
# Training loop
|
| 128 |
+
print(f"\nStarting training for {num_epochs} epochs...")
|
| 129 |
+
model.train()
|
| 130 |
+
|
| 131 |
+
global_step = 0
|
| 132 |
+
for epoch in range(num_epochs):
|
| 133 |
+
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
| 134 |
+
|
| 135 |
+
total_loss = 0
|
| 136 |
+
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
|
| 137 |
+
|
| 138 |
+
for batch in progress_bar:
|
| 139 |
+
global_step += 1
|
| 140 |
+
|
| 141 |
+
# Move to device
|
| 142 |
+
input_ids = batch['input_ids'].to(device)
|
| 143 |
+
attention_mask = batch['attention_mask'].to(device)
|
| 144 |
+
|
| 145 |
+
# Forward pass
|
| 146 |
+
outputs = model(
|
| 147 |
+
input_ids=input_ids,
|
| 148 |
+
attention_mask=attention_mask,
|
| 149 |
+
labels=input_ids
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
loss = outputs.loss
|
| 153 |
+
total_loss += loss.item()
|
| 154 |
+
|
| 155 |
+
# Backward pass
|
| 156 |
+
loss.backward()
|
| 157 |
+
|
| 158 |
+
# Update weights every 4 steps (gradient accumulation)
|
| 159 |
+
if global_step % 4 == 0:
|
| 160 |
+
optimizer.step()
|
| 161 |
+
optimizer.zero_grad()
|
| 162 |
+
|
| 163 |
+
# Update progress bar
|
| 164 |
+
progress_bar.set_postfix({'loss': loss.item()})
|
| 165 |
+
|
| 166 |
+
# Save checkpoint
|
| 167 |
+
if global_step % 500 == 0:
|
| 168 |
+
print(f"\nSaving checkpoint at step {global_step}...")
|
| 169 |
+
model.save_pretrained(os.path.join(output_dir, f"checkpoint-{global_step}"))
|
| 170 |
+
tokenizer.save_pretrained(os.path.join(output_dir, f"checkpoint-{global_step}"))
|
| 171 |
+
|
| 172 |
+
avg_loss = total_loss / len(train_loader)
|
| 173 |
+
print(f"Epoch {epoch+1} - Average Loss: {avg_loss:.4f}")
|
| 174 |
+
|
| 175 |
+
# Save final model
|
| 176 |
+
print("\nSaving final model...")
|
| 177 |
+
model.save_pretrained(os.path.join(output_dir, "final_model"))
|
| 178 |
+
tokenizer.save_pretrained(os.path.join(output_dir, "final_model"))
|
| 179 |
+
|
| 180 |
+
print(f"\n✅ Training complete! Model saved to {output_dir}/final_model")
|
| 181 |
+
|
| 182 |
+
# Test the model
|
| 183 |
+
print("\nTesting model...")
|
| 184 |
+
test_model(os.path.join(output_dir, "final_model"))
|
| 185 |
+
|
| 186 |
+
def test_model(model_path):
|
| 187 |
+
"""Test the fine-tuned model"""
|
| 188 |
+
|
| 189 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 190 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 191 |
+
model_path,
|
| 192 |
+
torch_dtype=torch.bfloat16,
|
| 193 |
+
device_map="auto"
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
test_input = "最近ストレスを感じています。"
|
| 197 |
+
prompt = f"""### Instruction:
|
| 198 |
+
あなたは心理カウンセラーです。
|
| 199 |
+
|
| 200 |
+
### Input:
|
| 201 |
+
{test_input}
|
| 202 |
+
|
| 203 |
+
### Response:
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
| 207 |
+
|
| 208 |
+
with torch.no_grad():
|
| 209 |
+
outputs = model.generate(
|
| 210 |
+
inputs.input_ids.cuda(),
|
| 211 |
+
max_new_tokens=100,
|
| 212 |
+
temperature=0.7,
|
| 213 |
+
do_sample=True
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 217 |
+
print(f"\nTest Input: {test_input}")
|
| 218 |
+
print(f"Response: {response.split('### Response:')[-1].strip()}")
|
| 219 |
+
|
| 220 |
+
if __name__ == "__main__":
|
| 221 |
+
train_simple()
|
merge_model.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 2 |
+
# from peft import PeftModel
|
| 3 |
+
# import torch
|
| 4 |
+
|
| 5 |
+
# print("Loading base model...")
|
| 6 |
+
# base_model = AutoModelForCausalLM.from_pretrained(
|
| 7 |
+
# "./models/LFM2-1.2B",
|
| 8 |
+
# torch_dtype=torch.bfloat16,
|
| 9 |
+
# device_map="auto",
|
| 10 |
+
# trust_remote_code=True
|
| 11 |
+
# )
|
| 12 |
+
|
| 13 |
+
# print("Loading LoRA adapters...")
|
| 14 |
+
# model = PeftModel.from_pretrained(base_model, "./counselor_model/final_model")
|
| 15 |
+
|
| 16 |
+
# print("Merging adapters with base model...")
|
| 17 |
+
# merged_model = model.merge_and_unload()
|
| 18 |
+
|
| 19 |
+
# print("Saving merged model...")
|
| 20 |
+
# merged_model.save_pretrained("./counselor_model-merged", safe_serialization=True)
|
| 21 |
+
|
| 22 |
+
# tokenizer = AutoTokenizer.from_pretrained("./models/LFM2-1.2B")
|
| 23 |
+
# tokenizer.save_pretrained("./counselor_model-merged")
|
| 24 |
+
|
| 25 |
+
# print("Model merge complete!")
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 29 |
+
from peft import PeftModel, PeftConfig
|
| 30 |
+
import os
|
| 31 |
+
|
| 32 |
+
def merge_and_save_model(
|
| 33 |
+
base_model_name: str = "LiquidAI/LFM2-2.6B",
|
| 34 |
+
adapter_path: str = "./lfm_minimal_output/final_model",
|
| 35 |
+
output_path: str = "./merged_counselor_minimal_2b"
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Properly merge LoRA weights with base model
|
| 39 |
+
"""
|
| 40 |
+
print("Loading base model...")
|
| 41 |
+
# Load the base model
|
| 42 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 43 |
+
base_model_name,
|
| 44 |
+
torch_dtype=torch.float16,
|
| 45 |
+
device_map="auto",
|
| 46 |
+
trust_remote_code=True
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
print("Loading LoRA adapter...")
|
| 50 |
+
# Load the PEFT model (LoRA adapter)
|
| 51 |
+
model = PeftModel.from_pretrained(
|
| 52 |
+
base_model,
|
| 53 |
+
adapter_path,
|
| 54 |
+
torch_dtype=torch.float16,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
print("Merging weights...")
|
| 58 |
+
# Merge LoRA weights with base model
|
| 59 |
+
model = model.merge_and_unload()
|
| 60 |
+
|
| 61 |
+
print(f"Saving merged model to {output_path}...")
|
| 62 |
+
# Save the merged model
|
| 63 |
+
model.save_pretrained(output_path)
|
| 64 |
+
|
| 65 |
+
# Also save the tokenizer
|
| 66 |
+
tokenizer = AutoTokenizer.from_pretrained(adapter_path)
|
| 67 |
+
tokenizer.save_pretrained(output_path)
|
| 68 |
+
|
| 69 |
+
print("✅ Model merged and saved successfully!")
|
| 70 |
+
return model, tokenizer
|
| 71 |
+
|
| 72 |
+
# Run the merge
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
merge_and_save_model()
|
preprocess_kokoro_method.py
ADDED
|
@@ -0,0 +1,651 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Fixed Data Preprocessing for directory of JSON files with client-counselor dialogues
|
| 3 |
+
Following KokoroChat methodology with COMPLETE dialogue history
|
| 4 |
+
Filename: preprocess_kokoro_directory_fixed.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from typing import List, Dict, Tuple, Optional, Any
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import random
|
| 12 |
+
from collections import defaultdict
|
| 13 |
+
import numpy as np
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
import glob
|
| 16 |
+
|
| 17 |
+
class KokoroChatDirectoryPreprocessor:
|
| 18 |
+
def __init__(self,
|
| 19 |
+
input_dir: str = "./raw_counseling_data",
|
| 20 |
+
output_dir: str = "./kokoro_processed_data",
|
| 21 |
+
min_score: int = 70,
|
| 22 |
+
train_ratio: float = 0.8,
|
| 23 |
+
val_ratio: float = 0.1,
|
| 24 |
+
test_ratio: float = 0.1):
|
| 25 |
+
"""
|
| 26 |
+
Initialize preprocessor for directory of JSON files
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
input_dir: Directory containing JSON files with conversations
|
| 30 |
+
output_dir: Directory to save processed data
|
| 31 |
+
min_score: Minimum score threshold for filtering (if scores exist)
|
| 32 |
+
train_ratio: Ratio for training data
|
| 33 |
+
val_ratio: Ratio for validation data
|
| 34 |
+
test_ratio: Ratio for test data
|
| 35 |
+
"""
|
| 36 |
+
self.input_dir = input_dir
|
| 37 |
+
self.output_dir = output_dir
|
| 38 |
+
self.min_score = min_score
|
| 39 |
+
self.train_ratio = train_ratio
|
| 40 |
+
self.val_ratio = val_ratio
|
| 41 |
+
self.test_ratio = test_ratio
|
| 42 |
+
|
| 43 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 44 |
+
|
| 45 |
+
# Track statistics
|
| 46 |
+
self.total_conversations = 0
|
| 47 |
+
self.total_utterances = 0
|
| 48 |
+
self.skipped_files = 0
|
| 49 |
+
|
| 50 |
+
def load_json_file(self, filepath: str) -> Optional[Dict]:
|
| 51 |
+
"""Load a single JSON file"""
|
| 52 |
+
try:
|
| 53 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 54 |
+
data = json.load(f)
|
| 55 |
+
return data
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(f"⚠️ Error loading {filepath}: {e}")
|
| 58 |
+
self.skipped_files += 1
|
| 59 |
+
return None
|
| 60 |
+
|
| 61 |
+
def safe_get_value(self, obj: Any, default: Any = None) -> Any:
|
| 62 |
+
"""Safely get a value, handling nested dicts and lists"""
|
| 63 |
+
if isinstance(obj, dict):
|
| 64 |
+
# If it's a dict, try to get a meaningful string representation
|
| 65 |
+
if 'name' in obj:
|
| 66 |
+
return str(obj['name'])
|
| 67 |
+
elif 'value' in obj:
|
| 68 |
+
return str(obj['value'])
|
| 69 |
+
elif 'text' in obj:
|
| 70 |
+
return str(obj['text'])
|
| 71 |
+
else:
|
| 72 |
+
# Return first string value found or convert to string
|
| 73 |
+
for v in obj.values():
|
| 74 |
+
if isinstance(v, str):
|
| 75 |
+
return v
|
| 76 |
+
return str(list(obj.values())[0]) if obj else default
|
| 77 |
+
elif isinstance(obj, list):
|
| 78 |
+
# If it's a list, join elements or return first element
|
| 79 |
+
if obj:
|
| 80 |
+
return str(obj[0]) if len(obj) == 1 else ', '.join(str(x) for x in obj)
|
| 81 |
+
return default
|
| 82 |
+
elif obj is None:
|
| 83 |
+
return default
|
| 84 |
+
else:
|
| 85 |
+
return str(obj)
|
| 86 |
+
|
| 87 |
+
def extract_dialogue_from_json(self, data: Dict, filepath: str) -> List[Dict]:
|
| 88 |
+
"""
|
| 89 |
+
Extract dialogue from various JSON formats
|
| 90 |
+
Handles different possible structures
|
| 91 |
+
"""
|
| 92 |
+
conversations = []
|
| 93 |
+
|
| 94 |
+
# Try different possible structures
|
| 95 |
+
if isinstance(data, list):
|
| 96 |
+
# If the JSON is directly a list of utterances
|
| 97 |
+
conversations.append({
|
| 98 |
+
'dialogue': data,
|
| 99 |
+
'id': os.path.basename(filepath).replace('.json', ''),
|
| 100 |
+
'score': 100, # Default score
|
| 101 |
+
'topic': 'general',
|
| 102 |
+
'source_file': filepath
|
| 103 |
+
})
|
| 104 |
+
|
| 105 |
+
elif isinstance(data, dict):
|
| 106 |
+
# Extract score safely
|
| 107 |
+
score = data.get('score', 100)
|
| 108 |
+
if isinstance(score, dict):
|
| 109 |
+
score = score.get('value', 100) if 'value' in score else 100
|
| 110 |
+
try:
|
| 111 |
+
score = float(score)
|
| 112 |
+
except:
|
| 113 |
+
score = 100
|
| 114 |
+
|
| 115 |
+
# Extract topic safely
|
| 116 |
+
topic = self.safe_get_value(data.get('topic', 'general'), 'general')
|
| 117 |
+
|
| 118 |
+
# Check for different possible keys
|
| 119 |
+
if 'dialogue' in data:
|
| 120 |
+
conversations.append({
|
| 121 |
+
'dialogue': data['dialogue'],
|
| 122 |
+
'id': data.get('id', os.path.basename(filepath).replace('.json', '')),
|
| 123 |
+
'score': score,
|
| 124 |
+
'topic': topic,
|
| 125 |
+
'source_file': filepath
|
| 126 |
+
})
|
| 127 |
+
|
| 128 |
+
elif 'messages' in data:
|
| 129 |
+
conversations.append({
|
| 130 |
+
'dialogue': data['messages'],
|
| 131 |
+
'id': data.get('id', os.path.basename(filepath).replace('.json', '')),
|
| 132 |
+
'score': score,
|
| 133 |
+
'topic': topic,
|
| 134 |
+
'source_file': filepath
|
| 135 |
+
})
|
| 136 |
+
|
| 137 |
+
elif 'utterances' in data:
|
| 138 |
+
conversations.append({
|
| 139 |
+
'dialogue': data['utterances'],
|
| 140 |
+
'id': data.get('id', os.path.basename(filepath).replace('.json', '')),
|
| 141 |
+
'score': score,
|
| 142 |
+
'topic': topic,
|
| 143 |
+
'source_file': filepath
|
| 144 |
+
})
|
| 145 |
+
|
| 146 |
+
elif 'conversations' in data:
|
| 147 |
+
# Multiple conversations in one file
|
| 148 |
+
for conv in data['conversations']:
|
| 149 |
+
if isinstance(conv, dict) and any(key in conv for key in ['dialogue', 'messages', 'utterances']):
|
| 150 |
+
dialogue_key = 'dialogue' if 'dialogue' in conv else ('messages' if 'messages' in conv else 'utterances')
|
| 151 |
+
|
| 152 |
+
# Extract score and topic safely for each conversation
|
| 153 |
+
conv_score = conv.get('score', score)
|
| 154 |
+
if isinstance(conv_score, dict):
|
| 155 |
+
conv_score = conv_score.get('value', 100) if 'value' in conv_score else 100
|
| 156 |
+
try:
|
| 157 |
+
conv_score = float(conv_score)
|
| 158 |
+
except:
|
| 159 |
+
conv_score = 100
|
| 160 |
+
|
| 161 |
+
conv_topic = self.safe_get_value(conv.get('topic', topic), 'general')
|
| 162 |
+
|
| 163 |
+
conversations.append({
|
| 164 |
+
'dialogue': conv[dialogue_key],
|
| 165 |
+
'id': conv.get('id', f"{os.path.basename(filepath)}_{len(conversations)}"),
|
| 166 |
+
'score': conv_score,
|
| 167 |
+
'topic': conv_topic,
|
| 168 |
+
'source_file': filepath
|
| 169 |
+
})
|
| 170 |
+
|
| 171 |
+
else:
|
| 172 |
+
# Try to find any list that looks like dialogue
|
| 173 |
+
for key, value in data.items():
|
| 174 |
+
if isinstance(value, list) and len(value) > 0:
|
| 175 |
+
# Check if it looks like dialogue data
|
| 176 |
+
if isinstance(value[0], dict) and any(k in value[0] for k in ['speaker', 'role', 'text', 'content', 'utterance']):
|
| 177 |
+
conversations.append({
|
| 178 |
+
'dialogue': value,
|
| 179 |
+
'id': data.get('id', os.path.basename(filepath).replace('.json', '')),
|
| 180 |
+
'score': score,
|
| 181 |
+
'topic': topic,
|
| 182 |
+
'source_file': filepath
|
| 183 |
+
})
|
| 184 |
+
break
|
| 185 |
+
|
| 186 |
+
return conversations
|
| 187 |
+
|
| 188 |
+
def normalize_utterance(self, utterance: Dict) -> Optional[Dict]:
|
| 189 |
+
"""
|
| 190 |
+
Normalize utterance format from various possible structures
|
| 191 |
+
Returns: {'speaker': str, 'text': str} or None
|
| 192 |
+
"""
|
| 193 |
+
# Determine speaker
|
| 194 |
+
speaker = None
|
| 195 |
+
if 'speaker' in utterance:
|
| 196 |
+
speaker = utterance['speaker']
|
| 197 |
+
elif 'role' in utterance:
|
| 198 |
+
speaker = utterance['role']
|
| 199 |
+
elif 'sender' in utterance:
|
| 200 |
+
speaker = utterance['sender']
|
| 201 |
+
elif 'from' in utterance:
|
| 202 |
+
speaker = utterance['from']
|
| 203 |
+
elif 'type' in utterance:
|
| 204 |
+
speaker = utterance['type']
|
| 205 |
+
|
| 206 |
+
# Determine text content
|
| 207 |
+
text = None
|
| 208 |
+
if 'text' in utterance:
|
| 209 |
+
text = utterance['text']
|
| 210 |
+
elif 'content' in utterance:
|
| 211 |
+
text = utterance['content']
|
| 212 |
+
elif 'message' in utterance:
|
| 213 |
+
text = utterance['message']
|
| 214 |
+
elif 'utterance' in utterance:
|
| 215 |
+
text = utterance['utterance']
|
| 216 |
+
elif 'response' in utterance:
|
| 217 |
+
text = utterance['response']
|
| 218 |
+
|
| 219 |
+
if speaker and text:
|
| 220 |
+
# Normalize speaker labels
|
| 221 |
+
speaker_lower = str(speaker).lower()
|
| 222 |
+
if speaker_lower in ['client', 'user', 'patient', 'クライアント', '相談者', 'c']:
|
| 223 |
+
normalized_speaker = 'client'
|
| 224 |
+
elif speaker_lower in ['counselor', 'therapist', 'assistant', 'カウンセラー', '相談員', 's', 'system']:
|
| 225 |
+
normalized_speaker = 'counselor'
|
| 226 |
+
else:
|
| 227 |
+
# Try to infer from position or content
|
| 228 |
+
normalized_speaker = 'client' if 'client' in speaker_lower else 'counselor'
|
| 229 |
+
|
| 230 |
+
return {
|
| 231 |
+
'speaker': normalized_speaker,
|
| 232 |
+
'text': str(text).strip()
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
return None
|
| 236 |
+
|
| 237 |
+
def merge_consecutive_utterances(self, dialogue: List[Dict]) -> List[Dict]:
|
| 238 |
+
"""
|
| 239 |
+
Merge consecutive utterances from the same speaker
|
| 240 |
+
Following KokoroChat paper methodology
|
| 241 |
+
"""
|
| 242 |
+
if not dialogue:
|
| 243 |
+
return []
|
| 244 |
+
|
| 245 |
+
merged = []
|
| 246 |
+
current_utterance = None
|
| 247 |
+
|
| 248 |
+
for utt in dialogue:
|
| 249 |
+
normalized = self.normalize_utterance(utt)
|
| 250 |
+
if not normalized:
|
| 251 |
+
continue
|
| 252 |
+
|
| 253 |
+
if current_utterance is None:
|
| 254 |
+
current_utterance = normalized
|
| 255 |
+
elif current_utterance['speaker'] == normalized['speaker']:
|
| 256 |
+
# Same speaker - merge utterances
|
| 257 |
+
current_utterance['text'] += ' ' + normalized['text']
|
| 258 |
+
else:
|
| 259 |
+
# Different speaker - save current and start new
|
| 260 |
+
merged.append(current_utterance)
|
| 261 |
+
current_utterance = normalized
|
| 262 |
+
|
| 263 |
+
# Don't forget the last utterance
|
| 264 |
+
if current_utterance:
|
| 265 |
+
merged.append(current_utterance)
|
| 266 |
+
|
| 267 |
+
return merged
|
| 268 |
+
|
| 269 |
+
def create_training_examples(self, conversation: Dict) -> List[Dict]:
|
| 270 |
+
"""
|
| 271 |
+
Create training examples with COMPLETE dialogue history
|
| 272 |
+
Following the paper: Dt = {uC1, uS2, uC3, ..., uCt} -> uSt+1
|
| 273 |
+
"""
|
| 274 |
+
examples = []
|
| 275 |
+
|
| 276 |
+
# Get dialogue
|
| 277 |
+
dialogue = conversation.get('dialogue', [])
|
| 278 |
+
if not dialogue:
|
| 279 |
+
return []
|
| 280 |
+
|
| 281 |
+
# Merge consecutive utterances from same speaker
|
| 282 |
+
merged_dialogue = self.merge_consecutive_utterances(dialogue)
|
| 283 |
+
|
| 284 |
+
if not merged_dialogue:
|
| 285 |
+
return []
|
| 286 |
+
|
| 287 |
+
# Create examples with COMPLETE history
|
| 288 |
+
for i in range(len(merged_dialogue)):
|
| 289 |
+
current = merged_dialogue[i]
|
| 290 |
+
|
| 291 |
+
# Only create examples where counselor responds
|
| 292 |
+
if current['speaker'] == 'counselor':
|
| 293 |
+
# Get COMPLETE dialogue history from beginning
|
| 294 |
+
complete_history = merged_dialogue[:i]
|
| 295 |
+
|
| 296 |
+
# Skip if no history or if history doesn't start with client
|
| 297 |
+
if not complete_history or complete_history[0]['speaker'] != 'client':
|
| 298 |
+
continue
|
| 299 |
+
|
| 300 |
+
# Ensure topic is a string
|
| 301 |
+
topic = conversation.get('topic', 'general')
|
| 302 |
+
if not isinstance(topic, str):
|
| 303 |
+
topic = self.safe_get_value(topic, 'general')
|
| 304 |
+
|
| 305 |
+
# Create training example
|
| 306 |
+
example = {
|
| 307 |
+
'dialogue_history': complete_history,
|
| 308 |
+
'response': current['text'],
|
| 309 |
+
'score': conversation.get('score', 100),
|
| 310 |
+
'topic': topic,
|
| 311 |
+
'conversation_id': conversation.get('id', 'unknown'),
|
| 312 |
+
'source_file': conversation.get('source_file', 'unknown'),
|
| 313 |
+
'turn_number': i,
|
| 314 |
+
'history_length': len(complete_history)
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
examples.append(example)
|
| 318 |
+
|
| 319 |
+
return examples
|
| 320 |
+
|
| 321 |
+
def format_for_training(self, example: Dict, format_type: str = 'simple') -> str:
|
| 322 |
+
"""
|
| 323 |
+
Format example for training
|
| 324 |
+
|
| 325 |
+
Args:
|
| 326 |
+
format_type: 'simple' or 'llama' format
|
| 327 |
+
"""
|
| 328 |
+
# Build complete dialogue history
|
| 329 |
+
history_text = ""
|
| 330 |
+
for turn in example['dialogue_history']:
|
| 331 |
+
speaker = "クライアント" if turn['speaker'] == 'client' else "カウンセラー"
|
| 332 |
+
history_text += f"{speaker}: {turn['text']}\n"
|
| 333 |
+
|
| 334 |
+
if format_type == 'llama':
|
| 335 |
+
# Llama-style format with special tokens
|
| 336 |
+
formatted = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
| 337 |
+
あなたは専門的な訓練を受けた心理カウンセラーです。クライアントの感情に共感し、適切な支援を提供してください。
|
| 338 |
+
これまでの対話履歴全体を考慮して、適切な応答を生成してください。<|eot_id|>
|
| 339 |
+
|
| 340 |
+
<|start_header_id|>user<|end_header_id|>
|
| 341 |
+
以下は、クライアントとカウンセラーの完全な対話履歴です。
|
| 342 |
+
この履歴全体を踏まえて、次のカウンセラーの応答を生成してください。
|
| 343 |
+
|
| 344 |
+
完全な対話履歴:
|
| 345 |
+
{history_text}
|
| 346 |
+
次のカウンセラーの応答を生成してください。<|eot_id|>
|
| 347 |
+
|
| 348 |
+
<|start_header_id|>assistant<|end_header_id|>
|
| 349 |
+
{example['response']}<|eot_id|>"""
|
| 350 |
+
|
| 351 |
+
else:
|
| 352 |
+
# Simple format for models without special tokens
|
| 353 |
+
formatted = f"""### Instruction:
|
| 354 |
+
あなたは専門的な訓練を受けた心理カウンセラーです。
|
| 355 |
+
以下の完全な対話履歴を踏まえて、カウンセラーとして適切な応答を生成してください。
|
| 356 |
+
|
| 357 |
+
### Dialogue History:
|
| 358 |
+
{history_text}
|
| 359 |
+
### Response:
|
| 360 |
+
{example['response']}"""
|
| 361 |
+
|
| 362 |
+
return formatted
|
| 363 |
+
|
| 364 |
+
def process_directory(self, format_type: str = 'simple'):
|
| 365 |
+
"""Process all JSON files in the input directory"""
|
| 366 |
+
print(f"🔍 Scanning directory: {self.input_dir}")
|
| 367 |
+
|
| 368 |
+
# Find all JSON files
|
| 369 |
+
json_files = []
|
| 370 |
+
for pattern in ['*.json', '*.jsonl']:
|
| 371 |
+
json_files.extend(glob.glob(os.path.join(self.input_dir, '**', pattern), recursive=True))
|
| 372 |
+
|
| 373 |
+
print(f"Found {len(json_files)} JSON files")
|
| 374 |
+
|
| 375 |
+
if not json_files:
|
| 376 |
+
print("❌ No JSON files found in the directory!")
|
| 377 |
+
return
|
| 378 |
+
|
| 379 |
+
# Process each file
|
| 380 |
+
all_conversations = []
|
| 381 |
+
|
| 382 |
+
for filepath in tqdm(json_files, desc="Loading JSON files"):
|
| 383 |
+
# Handle both .json and .jsonl files
|
| 384 |
+
if filepath.endswith('.jsonl'):
|
| 385 |
+
# JSONL file - each line is a separate JSON object
|
| 386 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 387 |
+
for line_num, line in enumerate(f):
|
| 388 |
+
try:
|
| 389 |
+
data = json.loads(line)
|
| 390 |
+
conversations = self.extract_dialogue_from_json(data, f"{filepath}_line{line_num}")
|
| 391 |
+
all_conversations.extend(conversations)
|
| 392 |
+
except:
|
| 393 |
+
continue
|
| 394 |
+
else:
|
| 395 |
+
# Regular JSON file
|
| 396 |
+
data = self.load_json_file(filepath)
|
| 397 |
+
if data:
|
| 398 |
+
conversations = self.extract_dialogue_from_json(data, filepath)
|
| 399 |
+
all_conversations.extend(conversations)
|
| 400 |
+
|
| 401 |
+
print(f"✅ Loaded {len(all_conversations)} conversations from {len(json_files) - self.skipped_files} files")
|
| 402 |
+
print(f"⚠️ Skipped {self.skipped_files} files due to errors")
|
| 403 |
+
|
| 404 |
+
# Filter by score
|
| 405 |
+
conversations_before_filter = len(all_conversations)
|
| 406 |
+
filtered_conversations = [
|
| 407 |
+
conv for conv in all_conversations
|
| 408 |
+
if conv.get('score', 100) >= self.min_score
|
| 409 |
+
]
|
| 410 |
+
conversations_after_filter = len(filtered_conversations)
|
| 411 |
+
|
| 412 |
+
print(f"📊 Score filtering (>= {self.min_score}):")
|
| 413 |
+
print(f" Before: {conversations_before_filter} conversations")
|
| 414 |
+
print(f" After: {conversations_after_filter} conversations")
|
| 415 |
+
print(f" Filtered out: {conversations_before_filter - conversations_after_filter} conversations")
|
| 416 |
+
|
| 417 |
+
# Create training examples
|
| 418 |
+
all_examples = []
|
| 419 |
+
history_lengths = []
|
| 420 |
+
|
| 421 |
+
for conv in tqdm(filtered_conversations, desc="Creating training examples"):
|
| 422 |
+
examples = self.create_training_examples(conv)
|
| 423 |
+
all_examples.extend(examples)
|
| 424 |
+
history_lengths.extend([ex['history_length'] for ex in examples])
|
| 425 |
+
|
| 426 |
+
if not all_examples:
|
| 427 |
+
print("❌ No training examples created!")
|
| 428 |
+
return
|
| 429 |
+
|
| 430 |
+
print(f"✅ Created {len(all_examples)} training examples from {len(filtered_conversations)} conversations")
|
| 431 |
+
print(f"📊 Dialogue history statistics:")
|
| 432 |
+
print(f" - Mean length: {np.mean(history_lengths):.1f} turns")
|
| 433 |
+
print(f" - Median length: {np.median(history_lengths):.1f} turns")
|
| 434 |
+
print(f" - Max length: {max(history_lengths)} turns")
|
| 435 |
+
print(f" - Min length: {min(history_lengths)} turns")
|
| 436 |
+
|
| 437 |
+
# Shuffle and split
|
| 438 |
+
random.shuffle(all_examples)
|
| 439 |
+
|
| 440 |
+
train_size = int(self.train_ratio * len(all_examples))
|
| 441 |
+
val_size = int(self.val_ratio * len(all_examples))
|
| 442 |
+
|
| 443 |
+
train_data = all_examples[:train_size]
|
| 444 |
+
val_data = all_examples[train_size:train_size + val_size]
|
| 445 |
+
test_data = all_examples[train_size + val_size:]
|
| 446 |
+
|
| 447 |
+
print(f"\n📂 Split sizes:")
|
| 448 |
+
print(f" Train: {len(train_data)} ({self.train_ratio*100:.0f}%)")
|
| 449 |
+
print(f" Val: {len(val_data)} ({self.val_ratio*100:.0f}%)")
|
| 450 |
+
print(f" Test: {len(test_data)} ({self.test_ratio*100:.0f}%)")
|
| 451 |
+
|
| 452 |
+
# Save splits
|
| 453 |
+
self.save_split(train_data, 'train', format_type)
|
| 454 |
+
self.save_split(val_data, 'val', format_type)
|
| 455 |
+
self.save_split(test_data, 'test', format_type)
|
| 456 |
+
|
| 457 |
+
# Save statistics
|
| 458 |
+
self.save_statistics(
|
| 459 |
+
train_data, val_data, test_data,
|
| 460 |
+
all_conversations, filtered_conversations,
|
| 461 |
+
history_lengths
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
print(f"\n✅ Processing complete! Data saved to {self.output_dir}")
|
| 465 |
+
|
| 466 |
+
def save_split(self, data: List[Dict], split_name: str, format_type: str = 'simple'):
|
| 467 |
+
"""Save processed data split"""
|
| 468 |
+
output_file = os.path.join(self.output_dir, f"{split_name}.jsonl")
|
| 469 |
+
|
| 470 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 471 |
+
for example in tqdm(data, desc=f"Saving {split_name} data"):
|
| 472 |
+
formatted_text = self.format_for_training(example, format_type)
|
| 473 |
+
|
| 474 |
+
# Ensure topic is string
|
| 475 |
+
topic = example.get('topic', 'general')
|
| 476 |
+
if not isinstance(topic, str):
|
| 477 |
+
topic = self.safe_get_value(topic, 'general')
|
| 478 |
+
|
| 479 |
+
output_item = {
|
| 480 |
+
'text': formatted_text,
|
| 481 |
+
'dialogue_history': example['dialogue_history'],
|
| 482 |
+
'response': example['response'],
|
| 483 |
+
'score': example['score'],
|
| 484 |
+
'topic': topic,
|
| 485 |
+
'conversation_id': example['conversation_id'],
|
| 486 |
+
'source_file': example['source_file'],
|
| 487 |
+
'turn_number': example['turn_number'],
|
| 488 |
+
'history_length': example['history_length']
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
f.write(json.dumps(output_item, ensure_ascii=False) + '\n')
|
| 492 |
+
|
| 493 |
+
print(f"✅ Saved {split_name} data to {output_file}")
|
| 494 |
+
|
| 495 |
+
def save_statistics(self, train_data, val_data, test_data,
|
| 496 |
+
all_conversations, filtered_conversations, history_lengths):
|
| 497 |
+
"""Save comprehensive statistics"""
|
| 498 |
+
# Calculate topic distribution (safely)
|
| 499 |
+
topic_counts = defaultdict(int)
|
| 500 |
+
for example in train_data:
|
| 501 |
+
topic = example.get('topic', 'general')
|
| 502 |
+
if not isinstance(topic, str):
|
| 503 |
+
topic = self.safe_get_value(topic, 'general')
|
| 504 |
+
topic_counts[topic] += 1
|
| 505 |
+
|
| 506 |
+
# Calculate source file distribution
|
| 507 |
+
source_counts = defaultdict(int)
|
| 508 |
+
for example in train_data:
|
| 509 |
+
source_file = os.path.basename(example.get('source_file', 'unknown'))
|
| 510 |
+
source_counts[source_file] += 1
|
| 511 |
+
|
| 512 |
+
# Score statistics for filtered conversations
|
| 513 |
+
scores = [conv.get('score', 100) for conv in filtered_conversations]
|
| 514 |
+
|
| 515 |
+
stats = {
|
| 516 |
+
'preprocessing_info': {
|
| 517 |
+
'input_directory': self.input_dir,
|
| 518 |
+
'output_directory': self.output_dir,
|
| 519 |
+
'total_files_processed': len(set(conv.get('source_file', 'unknown') for conv in all_conversations)),
|
| 520 |
+
'total_conversations_loaded': len(all_conversations),
|
| 521 |
+
'conversations_after_filtering': len(filtered_conversations),
|
| 522 |
+
'conversations_filtered_out': len(all_conversations) - len(filtered_conversations),
|
| 523 |
+
'total_training_examples': len(train_data) + len(val_data) + len(test_data),
|
| 524 |
+
'min_score_threshold': self.min_score,
|
| 525 |
+
'methodology': 'KokoroChat paper - complete dialogue history'
|
| 526 |
+
},
|
| 527 |
+
'score_filtering': {
|
| 528 |
+
'threshold': self.min_score,
|
| 529 |
+
'before_filtering': len(all_conversations),
|
| 530 |
+
'after_filtering': len(filtered_conversations),
|
| 531 |
+
'filtered_out': len(all_conversations) - len(filtered_conversations),
|
| 532 |
+
'percentage_kept': (len(filtered_conversations) / len(all_conversations) * 100) if all_conversations else 0
|
| 533 |
+
},
|
| 534 |
+
'score_statistics': {
|
| 535 |
+
'mean': float(np.mean(scores)),
|
| 536 |
+
'std': float(np.std(scores)),
|
| 537 |
+
'min': float(min(scores)),
|
| 538 |
+
'max': float(max(scores)),
|
| 539 |
+
'median': float(np.median(scores)),
|
| 540 |
+
'percentile_25': float(np.percentile(scores, 25)),
|
| 541 |
+
'percentile_75': float(np.percentile(scores, 75))
|
| 542 |
+
},
|
| 543 |
+
'split_sizes': {
|
| 544 |
+
'train': len(train_data),
|
| 545 |
+
'val': len(val_data),
|
| 546 |
+
'test': len(test_data),
|
| 547 |
+
'train_ratio': self.train_ratio,
|
| 548 |
+
'val_ratio': self.val_ratio,
|
| 549 |
+
'test_ratio': self.test_ratio
|
| 550 |
+
},
|
| 551 |
+
'dialogue_history_stats': {
|
| 552 |
+
'mean_length': float(np.mean(history_lengths)),
|
| 553 |
+
'std_length': float(np.std(history_lengths)),
|
| 554 |
+
'min_length': int(min(history_lengths)),
|
| 555 |
+
'max_length': int(max(history_lengths)),
|
| 556 |
+
'median_length': float(np.median(history_lengths)),
|
| 557 |
+
'percentile_25': float(np.percentile(history_lengths, 25)),
|
| 558 |
+
'percentile_75': float(np.percentile(history_lengths, 75)),
|
| 559 |
+
'percentile_95': float(np.percentile(history_lengths, 95))
|
| 560 |
+
},
|
| 561 |
+
'topic_distribution': dict(list(topic_counts.items())[:20]), # Top 20 topics
|
| 562 |
+
'source_file_distribution': dict(list(source_counts.items())[:20]), # Top 20 files
|
| 563 |
+
'history_length_bins': {
|
| 564 |
+
'1-5_turns': sum(1 for l in history_lengths if l <= 5),
|
| 565 |
+
'6-10_turns': sum(1 for l in history_lengths if 5 < l <= 10),
|
| 566 |
+
'11-15_turns': sum(1 for l in history_lengths if 10 < l <= 15),
|
| 567 |
+
'16-20_turns': sum(1 for l in history_lengths if 15 < l <= 20),
|
| 568 |
+
'21-30_turns': sum(1 for l in history_lengths if 20 < l <= 30),
|
| 569 |
+
'31-50_turns': sum(1 for l in history_lengths if 30 < l <= 50),
|
| 570 |
+
'50+_turns': sum(1 for l in history_lengths if l > 50)
|
| 571 |
+
}
|
| 572 |
+
}
|
| 573 |
+
|
| 574 |
+
stats_file = os.path.join(self.output_dir, 'dataset_stats.json')
|
| 575 |
+
with open(stats_file, 'w', encoding='utf-8') as f:
|
| 576 |
+
json.dump(stats, f, ensure_ascii=False, indent=2)
|
| 577 |
+
|
| 578 |
+
print(f"\n📊 Statistics saved to {stats_file}")
|
| 579 |
+
|
| 580 |
+
# Print summary
|
| 581 |
+
print("\n" + "="*70)
|
| 582 |
+
print("📈 DATASET STATISTICS SUMMARY")
|
| 583 |
+
print("="*70)
|
| 584 |
+
print(f"Files processed: {stats['preprocessing_info']['total_files_processed']}")
|
| 585 |
+
print(f"Conversations loaded: {stats['preprocessing_info']['total_conversations_loaded']}")
|
| 586 |
+
print(f"After score filtering (>={self.min_score}): {stats['preprocessing_info']['conversations_after_filtering']}")
|
| 587 |
+
print(f"Training examples created: {stats['preprocessing_info']['total_training_examples']}")
|
| 588 |
+
print(f"\nScore Statistics (after filtering):")
|
| 589 |
+
print(f" Mean: {stats['score_statistics']['mean']:.1f}")
|
| 590 |
+
print(f" Median: {stats['score_statistics']['median']:.1f}")
|
| 591 |
+
print(f" Range: {stats['score_statistics']['min']:.0f} - {stats['score_statistics']['max']:.0f}")
|
| 592 |
+
print(f"\nDialogue History Length Distribution:")
|
| 593 |
+
for bin_name, count in stats['history_length_bins'].items():
|
| 594 |
+
percentage = (count / len(history_lengths)) * 100 if history_lengths else 0
|
| 595 |
+
print(f" {bin_name}: {count} ({percentage:.1f}%)")
|
| 596 |
+
print("="*70)
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
def main():
|
| 600 |
+
import argparse
|
| 601 |
+
|
| 602 |
+
parser = argparse.ArgumentParser(
|
| 603 |
+
description='Preprocess directory of JSON files with counseling dialogues'
|
| 604 |
+
)
|
| 605 |
+
parser.add_argument(
|
| 606 |
+
'--input_dir',
|
| 607 |
+
type=str,
|
| 608 |
+
default='./KokoroChat/kokorochat_dialogues',
|
| 609 |
+
help='Directory containing JSON files with conversations'
|
| 610 |
+
)
|
| 611 |
+
parser.add_argument(
|
| 612 |
+
'--output_dir',
|
| 613 |
+
type=str,
|
| 614 |
+
default='./kokoro_processed_data',
|
| 615 |
+
help='Output directory for processed data'
|
| 616 |
+
)
|
| 617 |
+
parser.add_argument(
|
| 618 |
+
'--min_score',
|
| 619 |
+
type=int,
|
| 620 |
+
default=70,
|
| 621 |
+
help='Minimum score threshold (if scores exist in data)'
|
| 622 |
+
)
|
| 623 |
+
parser.add_argument(
|
| 624 |
+
'--format',
|
| 625 |
+
type=str,
|
| 626 |
+
choices=['simple', 'llama'],
|
| 627 |
+
default='simple',
|
| 628 |
+
help='Output format type'
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
args = parser.parse_args()
|
| 632 |
+
|
| 633 |
+
# Initialize preprocessor
|
| 634 |
+
preprocessor = KokoroChatDirectoryPreprocessor(
|
| 635 |
+
input_dir=args.input_dir,
|
| 636 |
+
output_dir=args.output_dir,
|
| 637 |
+
min_score=args.min_score
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
print("🚀 Starting preprocessing with COMPLETE dialogue history")
|
| 641 |
+
print(" Following KokoroChat paper methodology")
|
| 642 |
+
print("="*70)
|
| 643 |
+
|
| 644 |
+
# Process directory
|
| 645 |
+
preprocessor.process_directory(format_type=args.format)
|
| 646 |
+
|
| 647 |
+
print("\n✅ Preprocessing complete!")
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
if __name__ == "__main__":
|
| 651 |
+
main()
|
score_analysis_threshold_60.png
ADDED
|
Git LFS Details
|
score_distribution.png
ADDED
|
Git LFS Details
|
training_config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_name_or_path": "LiquidAI/LFM2-2.6B",
|
| 3 |
+
"use_lora": true,
|
| 4 |
+
"lora_r": 64,
|
| 5 |
+
"lora_alpha": 128,
|
| 6 |
+
"lora_dropout": 0.05,
|
| 7 |
+
"data_path": "./kokoro_processed_data",
|
| 8 |
+
"max_seq_length": 2048,
|
| 9 |
+
"response_template": "### Response:",
|
| 10 |
+
"output_dir": "./lfm_trl_finetuned",
|
| 11 |
+
"num_train_epochs": 3,
|
| 12 |
+
"per_device_train_batch_size": 4,
|
| 13 |
+
"per_device_eval_batch_size": 4,
|
| 14 |
+
"gradient_accumulation_steps": 4,
|
| 15 |
+
"learning_rate": 2e-4,
|
| 16 |
+
"warmup_ratio": 0.1,
|
| 17 |
+
"logging_steps": 10,
|
| 18 |
+
"save_steps": 100,
|
| 19 |
+
"eval_steps": 100,
|
| 20 |
+
"bf16": true,
|
| 21 |
+
"tf32": true,
|
| 22 |
+
"seed": 42
|
| 23 |
+
}
|