Techiiot commited on
Commit
27c46c6
·
verified ·
1 Parent(s): 76d2998

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ benchmark_results.png filter=lfs diff=lfs merge=lfs -text
37
+ score_analysis_threshold_60.png filter=lfs diff=lfs merge=lfs -text
38
+ score_distribution.png filter=lfs diff=lfs merge=lfs -text
benchmark.py ADDED
@@ -0,0 +1,1201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from peft import PeftModel
4
+ import json
5
+ from typing import Dict, List, Tuple
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from sklearn.metrics import accuracy_score, f1_score
9
+ import evaluate
10
+ from datasets import load_dataset
11
+ import pandas as pd
12
+ import matplotlib.pyplot as plt
13
+ import seaborn as sns
14
+
15
+ class CounselorBenchmark:
16
+ def __init__(self, base_model_path: str, finetuned_model_path: str):
17
+ """
18
+ Initialize benchmark suite for counselor models
19
+ """
20
+ self.base_model_path = base_model_path
21
+ self.finetuned_model_path = finetuned_model_path
22
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+
24
+ # Load evaluation metrics
25
+ self.bleu = evaluate.load("sacrebleu")
26
+ self.rouge = evaluate.load("rouge")
27
+ self.bertscore = evaluate.load("bertscore")
28
+
29
+ def load_models(self):
30
+ """Load both base and fine-tuned models for comparison"""
31
+
32
+ # Load base model
33
+ print("Loading base model...")
34
+ self.base_tokenizer = AutoTokenizer.from_pretrained(self.base_model_path)
35
+ self.base_model = AutoModelForCausalLM.from_pretrained(
36
+ self.base_model_path,
37
+ torch_dtype=torch.bfloat16,
38
+ device_map="auto"
39
+ )
40
+
41
+ # Load fine-tuned model
42
+ print("Loading fine-tuned model...")
43
+ self.ft_tokenizer = AutoTokenizer.from_pretrained(self.finetuned_model_path)
44
+ self.ft_model = AutoModelForCausalLM.from_pretrained(
45
+ self.finetuned_model_path,
46
+ torch_dtype=torch.bfloat16,
47
+ device_map="auto"
48
+ )
49
+
50
+ def generate_response(self, model, tokenizer, prompt: str, max_length: int = 256):
51
+ """Generate response from model"""
52
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
53
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
54
+
55
+ with torch.no_grad():
56
+ outputs = model.generate(
57
+ **inputs,
58
+ max_new_tokens=max_length,
59
+ temperature=0.7,
60
+ do_sample=True,
61
+ top_p=0.9,
62
+ repetition_penalty=1.1
63
+ )
64
+
65
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
+ # Extract only the generated part
67
+ response = response[len(prompt):].strip()
68
+ return response
69
+
70
+ def evaluate_empathy_score(self, response: str) -> float:
71
+ """
72
+ Evaluate empathy in counselor response
73
+ Custom metric based on Japanese counseling keywords
74
+ """
75
+ empathy_keywords = [
76
+ 'わかります', '理解', '共感', '気持ち', '感じ',
77
+ 'つらい', '大変', 'お察し', '心配', '支援'
78
+ ]
79
+
80
+ score = sum(1 for keyword in empathy_keywords if keyword in response)
81
+ return min(score / len(empathy_keywords), 1.0)
82
+
83
+ def evaluate_response_quality(self, response: str) -> Dict[str, float]:
84
+ """
85
+ Comprehensive response quality evaluation
86
+ """
87
+ metrics = {}
88
+
89
+ # Length appropriateness (not too short, not too long)
90
+ response_length = len(response)
91
+ if 50 <= response_length <= 300:
92
+ metrics['length_score'] = 1.0
93
+ elif response_length < 50:
94
+ metrics['length_score'] = response_length / 50
95
+ else:
96
+ metrics['length_score'] = max(0, 1 - (response_length - 300) / 500)
97
+
98
+ # Question engagement (does counselor ask clarifying questions?)
99
+ metrics['question_score'] = 1.0 if '?' in response or 'か?' in response else 0.0
100
+
101
+ # Supportive language
102
+ support_phrases = ['大丈夫', '一緒に', '支援', 'サポート', '助け']
103
+ metrics['support_score'] = sum(1 for phrase in support_phrases if phrase in response) / len(support_phrases)
104
+
105
+ # Empathy score
106
+ metrics['empathy_score'] = self.evaluate_empathy_score(response)
107
+
108
+ return metrics
109
+
110
+ def benchmark_on_test_set(self, test_data_path: str, num_samples: int = 100):
111
+ """
112
+ Run comprehensive benchmark on test set
113
+ """
114
+ # Load test data
115
+ test_dataset = load_dataset('json', data_files=test_data_path, split='train')
116
+ test_samples = test_dataset.select(range(min(num_samples, len(test_dataset))))
117
+
118
+ results = {
119
+ 'base_model': {'responses': [], 'metrics': []},
120
+ 'finetuned_model': {'responses': [], 'metrics': []}
121
+ }
122
+
123
+ print(f"Evaluating on {len(test_samples)} test samples...")
124
+
125
+ for sample in tqdm(test_samples):
126
+ prompt = sample['text'].split('### Response:')[0] + '### Response:'
127
+ reference = sample['text'].split('### Response:')[1].strip() if '### Response:' in sample['text'] else ""
128
+
129
+ # Generate responses
130
+ base_response = self.generate_response(self.base_model, self.base_tokenizer, prompt)
131
+ ft_response = self.generate_response(self.ft_model, self.ft_tokenizer, prompt)
132
+
133
+ # Store responses
134
+ results['base_model']['responses'].append(base_response)
135
+ results['finetuned_model']['responses'].append(ft_response)
136
+
137
+ # Evaluate quality
138
+ base_metrics = self.evaluate_response_quality(base_response)
139
+ ft_metrics = self.evaluate_response_quality(ft_response)
140
+
141
+ results['base_model']['metrics'].append(base_metrics)
142
+ results['finetuned_model']['metrics'].append(ft_metrics)
143
+
144
+ return results
145
+
146
+ def calculate_aggregate_metrics(self, results: Dict) -> Dict:
147
+ """Calculate aggregate metrics for comparison"""
148
+ aggregate = {}
149
+
150
+ for model_name in ['base_model', 'finetuned_model']:
151
+ model_metrics = results[model_name]['metrics']
152
+
153
+ aggregate[model_name] = {}
154
+
155
+ # Calculate average for each metric
156
+ metric_names = model_metrics[0].keys() if model_metrics else []
157
+
158
+ for metric in metric_names:
159
+ values = [m[metric] for m in model_metrics]
160
+ aggregate[model_name][metric] = {
161
+ 'mean': np.mean(values),
162
+ 'std': np.std(values),
163
+ 'min': np.min(values),
164
+ 'max': np.max(values)
165
+ }
166
+
167
+ return aggregate
168
+
169
+ def generate_comparison_report(self, results: Dict, aggregate: Dict):
170
+ """Generate detailed comparison report"""
171
+
172
+ report = []
173
+ report.append("=" * 80)
174
+ report.append("COUNSELOR MODEL BENCHMARK REPORT")
175
+ report.append("=" * 80)
176
+ report.append("")
177
+
178
+ # Overall performance comparison
179
+ report.append("PERFORMANCE COMPARISON:")
180
+ report.append("-" * 40)
181
+
182
+ for metric in aggregate['base_model'].keys():
183
+ base_score = aggregate['base_model'][metric]['mean']
184
+ ft_score = aggregate['finetuned_model'][metric]['mean']
185
+ improvement = ((ft_score - base_score) / base_score * 100) if base_score > 0 else 0
186
+
187
+ report.append(f"\n{metric.upper()}:")
188
+ report.append(f" Base Model: {base_score:.3f} (±{aggregate['base_model'][metric]['std']:.3f})")
189
+ report.append(f" Fine-tuned Model: {ft_score:.3f} (±{aggregate['finetuned_model'][metric]['std']:.3f})")
190
+ report.append(f" Improvement: {improvement:+.1f}%")
191
+
192
+ # Calculate overall score
193
+ base_overall = np.mean([aggregate['base_model'][m]['mean'] for m in aggregate['base_model']])
194
+ ft_overall = np.mean([aggregate['finetuned_model'][m]['mean'] for m in aggregate['finetuned_model']])
195
+ overall_improvement = ((ft_overall - base_overall) / base_overall * 100) if base_overall > 0 else 0
196
+
197
+ report.append("\n" + "=" * 40)
198
+ report.append("OVERALL PERFORMANCE:")
199
+ report.append(f" Base Model: {base_overall:.3f}")
200
+ report.append(f" Fine-tuned Model: {ft_overall:.3f}")
201
+ report.append(f" Overall Improvement: {overall_improvement:+.1f}%")
202
+ report.append("=" * 40)
203
+
204
+ return "\n".join(report)
205
+
206
+ def visualize_results(self, aggregate: Dict):
207
+ """Create visualization of benchmark results"""
208
+
209
+ # Prepare data for plotting
210
+ metrics = list(aggregate['base_model'].keys())
211
+ base_scores = [aggregate['base_model'][m]['mean'] for m in metrics]
212
+ ft_scores = [aggregate['finetuned_model'][m]['mean'] for m in metrics]
213
+
214
+ # Create comparison plot
215
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
216
+
217
+ # Bar plot comparison
218
+ x = np.arange(len(metrics))
219
+ width = 0.35
220
+
221
+ ax1.bar(x - width/2, base_scores, width, label='Base Model', color='lightblue')
222
+ ax1.bar(x + width/2, ft_scores, width, label='Fine-tuned Model', color='darkblue')
223
+ ax1.set_xlabel('Metrics')
224
+ ax1.set_ylabel('Score')
225
+ ax1.set_title('Model Performance Comparison')
226
+ ax1.set_xticks(x)
227
+ ax1.set_xticklabels(metrics, rotation=45, ha='right')
228
+ ax1.legend()
229
+ ax1.grid(True, alpha=0.3)
230
+
231
+ # Improvement percentage plot
232
+ improvements = [((ft - base) / base * 100) if base > 0 else 0
233
+ for base, ft in zip(base_scores, ft_scores)]
234
+
235
+ colors = ['green' if imp > 0 else 'red' for imp in improvements]
236
+ ax2.bar(metrics, improvements, color=colors, alpha=0.7)
237
+ ax2.set_xlabel('Metrics')
238
+ ax2.set_ylabel('Improvement (%)')
239
+ ax2.set_title('Fine-tuning Improvement over Base Model')
240
+ ax2.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
241
+ ax2.set_xticklabels(metrics, rotation=45, ha='right')
242
+ ax2.grid(True, alpha=0.3)
243
+
244
+ plt.tight_layout()
245
+ plt.savefig('benchmark_results.png', dpi=300, bbox_inches='tight')
246
+ plt.show()
247
+
248
+ print("Visualization saved as 'benchmark_results.png'")
249
+
250
+ # Run benchmarking
251
+ if __name__ == "__main__":
252
+ # Initialize benchmark
253
+ benchmark = CounselorBenchmark(
254
+ base_model_path="./models/LFM2-2.6B",
255
+ finetuned_model_path="./merged_counselor_mode_2b"
256
+ )
257
+
258
+ # Load models
259
+ benchmark.load_models()
260
+
261
+ # Run benchmark
262
+ print("Running benchmark evaluation...")
263
+ results = benchmark.benchmark_on_test_set("./processed_data_score80/test.jsonl", num_samples=100)
264
+
265
+ # Calculate aggregate metrics
266
+ aggregate = benchmark.calculate_aggregate_metrics(results)
267
+
268
+ # Generate report
269
+ report = benchmark.generate_comparison_report(results, aggregate)
270
+ print(report)
271
+
272
+ # Save report
273
+ with open("benchmark_report_2b.txt", "w") as f:
274
+ f.write(report)
275
+
276
+ # Visualize results
277
+ benchmark.visualize_results(aggregate)
278
+
279
+ print("\nBenchmarking completed! Check 'benchmark_report.txt' for detailed results.")
280
+
281
+
282
+ ####################
283
+
284
+ # import torch
285
+ # from transformers import AutoModelForCausalLM, AutoTokenizer
286
+ # from peft import PeftModel, PeftConfig
287
+ # import numpy as np
288
+ # from typing import List, Dict, Tuple, Optional
289
+ # import json
290
+ # from tqdm import tqdm
291
+ # import os
292
+ # import gc
293
+ # import warnings
294
+ # from datetime import datetime
295
+ # import pandas as pd
296
+ # import matplotlib.pyplot as plt
297
+ # import seaborn as sns
298
+ # from nltk.translate.bleu_score import sentence_bleu, corpus_bleu, SmoothingFunction
299
+ # from rouge_score import rouge_scorer
300
+ # import nltk
301
+ # from collections import defaultdict
302
+
303
+ # # Download required NLTK data
304
+ # try:
305
+ # nltk.download('punkt', quiet=True)
306
+ # except:
307
+ # pass
308
+
309
+ # warnings.filterwarnings('ignore')
310
+
311
+ # class AdvancedCounselorBenchmark:
312
+ # def __init__(self,
313
+ # base_model_name: str = "LiquidAI/LFM2-1.2B",
314
+ # finetuned_model_path: str = "./counselor_model/best_model",
315
+ # merged_model_path: str = "./merged_counselor_model",
316
+ # test_data_path: str = "./processed_data_score70/test.jsonl",
317
+ # device: str = None):
318
+ # """
319
+ # Initialize advanced benchmark suite with BLEU and ROUGE metrics
320
+
321
+ # Args:
322
+ # base_model_name: Name/path of base model
323
+ # finetuned_model_path: Path to fine-tuned LoRA adapter
324
+ # merged_model_path: Path to save/load merged model
325
+ # test_data_path: Path to test dataset with reference responses
326
+ # device: Device to run on (cuda/cpu)
327
+ # """
328
+ # self.base_model_name = base_model_name
329
+ # self.finetuned_model_path = finetuned_model_path
330
+ # self.merged_model_path = merged_model_path
331
+ # self.test_data_path = test_data_path
332
+ # self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
333
+
334
+ # print(f"🔧 Initializing Advanced Benchmark Suite")
335
+ # print(f" Device: {self.device}")
336
+ # if self.device == "cuda":
337
+ # print(f" GPU: {torch.cuda.get_device_name(0)}")
338
+ # print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
339
+
340
+ # # Initialize ROUGE scorer
341
+ # self.rouge_scorer = rouge_scorer.RougeScorer(
342
+ # ['rouge1', 'rouge2', 'rougeL'],
343
+ # use_stemmer=False, # Set to False for Japanese
344
+ # lang='japanese'
345
+ # )
346
+
347
+ # # Smoothing function for BLEU scores
348
+ # self.smoothing = SmoothingFunction().method1
349
+
350
+ # self.results = {}
351
+
352
+ # def load_test_data(self) -> List[Dict]:
353
+ # """Load test dataset with reference responses"""
354
+ # print(f"\n📚 Loading test data from {self.test_data_path}")
355
+
356
+ # test_data = []
357
+ # if os.path.exists(self.test_data_path):
358
+ # with open(self.test_data_path, 'r', encoding='utf-8') as f:
359
+ # for line in f:
360
+ # data = json.loads(line)
361
+ # test_data.append(data)
362
+ # print(f" Loaded {len(test_data)} test examples")
363
+ # else:
364
+ # print(f"⚠️ Test data not found. Creating synthetic test data...")
365
+ # test_data = self.create_synthetic_test_data()
366
+
367
+ # return test_data
368
+
369
+ # def create_synthetic_test_data(self) -> List[Dict]:
370
+ # """Create synthetic test data if real data is not available"""
371
+ # synthetic_data = [
372
+ # {
373
+ # "text": "### Input:\n最近ストレスを感じています。\n\n### Response:\nストレスを感じているのですね。それは大変つらいことだと思います。どのような状況でストレスを感じることが��いですか?お話を聞かせていただければ、一緒に対処法を考えることができます。",
374
+ # "input": "最近ストレスを感じています。",
375
+ # "reference": "ストレスを感じているのですね。それは大変つらいことだと思います。どのような状況でストレスを感じることが多いですか?お話を聞かせていただければ、一緒に対処法を考えることができます。"
376
+ # },
377
+ # {
378
+ # "text": "### Input:\n仕事がうまくいかなくて悩んでいます。\n\n### Response:\n仕事でお悩みなのですね。うまくいかないと感じると、本当に辛いですよね。具体的にどのような点で困難を感じていらっしゃいますか?一緒に整理してみましょう。",
379
+ # "input": "仕事がうまくいかなくて悩んでいます。",
380
+ # "reference": "仕事でお悩みなのですね。うまくいかないと感じると、本当に辛いですよね。具体的にどのような点で困難を感じていらっしゃいますか?一緒に整理してみましょう。"
381
+ # },
382
+ # {
383
+ # "text": "### Input:\n人間関係で困っています。\n\n### Response:\n人間関係の悩みは本当に心が疲れますよね。お気持ちお察しします。どのような関係性でお困りでしょうか?職場、家族、友人関係など、もう少し詳しくお聞かせいただけますか?",
384
+ # "input": "人間関係で困っています。",
385
+ # "reference": "人間関係の悩みは本当に心が疲れますよね。お気持ちお察しします。どのような関係性でお困りでしょうか?職場、家族、友人関係など、もう少し詳しくお聞かせいただけますか?"
386
+ # },
387
+ # {
388
+ # "text": "### Input:\n将来が不安です。\n\n### Response:\n将来への不安を抱えていらっしゃるのですね。先が見えない不安は、とても重く感じられることと思います。特にどのような点について不安を感じていらっしゃいますか?",
389
+ # "input": "将来が不安です。",
390
+ # "reference": "将来への不安を抱えていらっしゃるのですね。先が見えない不安は、とても重く感じられることと思います。特にどのような点について不安を感じていらっしゃいますか?"
391
+ # },
392
+ # {
393
+ # "text": "### Input:\n自信が持てません。\n\n### Response:\n自信が持てないというお気持ち、よくわかります。多くの方が同じような悩みを抱えています。どのような場面で特に自信が持てないと感じますか?あなたの強みも一緒に見つけていきましょう。",
394
+ # "input": "自信が持てません。",
395
+ # "reference": "自信が持てないというお気持ち、よくわかります。多くの方が同じような悩みを抱えています。どのような場面で特に自信が持てないと感じますか?あなたの強みも一緒に見つけていきましょう。"
396
+ # }
397
+ # ]
398
+ # return synthetic_data
399
+
400
+ # def merge_and_save_model(self, force_merge: bool = False):
401
+ # """Merge LoRA weights with base model and save"""
402
+ # if os.path.exists(self.merged_model_path) and not force_merge:
403
+ # print(f"✅ Merged model already exists at {self.merged_model_path}")
404
+ # return
405
+
406
+ # print("\n🔄 Merging LoRA adapter with base model...")
407
+
408
+ # try:
409
+ # # Load base model
410
+ # print(" Loading base model...")
411
+ # base_model = AutoModelForCausalLM.from_pretrained(
412
+ # self.base_model_name,
413
+ # torch_dtype=torch.float16,
414
+ # device_map="auto" if self.device == "cuda" else None,
415
+ # trust_remote_code=True,
416
+ # low_cpu_mem_usage=True
417
+ # )
418
+
419
+ # # Check if adapter exists
420
+ # adapter_config_path = os.path.join(self.finetuned_model_path, "adapter_config.json")
421
+ # if not os.path.exists(adapter_config_path):
422
+ # print(f"⚠️ No LoRA adapter found at {self.finetuned_model_path}")
423
+ # model = base_model
424
+ # else:
425
+ # # Load LoRA adapter
426
+ # print(" Loading LoRA adapter...")
427
+ # model = PeftModel.from_pretrained(
428
+ # base_model,
429
+ # self.finetuned_model_path,
430
+ # torch_dtype=torch.float16
431
+ # )
432
+
433
+ # # Merge weights
434
+ # print(" Merging weights...")
435
+ # model = model.merge_and_unload()
436
+
437
+ # # Save merged model
438
+ # print(f" Saving merged model to {self.merged_model_path}...")
439
+ # model.save_pretrained(self.merged_model_path)
440
+
441
+ # # Save tokenizer
442
+ # tokenizer = AutoTokenizer.from_pretrained(
443
+ # self.finetuned_model_path
444
+ # if os.path.exists(os.path.join(self.finetuned_model_path, "tokenizer_config.json"))
445
+ # else self.base_model_name
446
+ # )
447
+ # tokenizer.save_pretrained(self.merged_model_path)
448
+
449
+ # print("✅ Model merged and saved successfully!")
450
+
451
+ # # Clean up memory
452
+ # del base_model, model
453
+ # gc.collect()
454
+ # torch.cuda.empty_cache()
455
+
456
+ # except Exception as e:
457
+ # print(f"❌ Error during merging: {e}")
458
+ # raise
459
+
460
+ # def load_models(self):
461
+ # """Load base and fine-tuned models for comparison"""
462
+ # print("\n📚 Loading models for benchmarking...")
463
+
464
+ # # Load tokenizer
465
+ # self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name)
466
+ # if self.tokenizer.pad_token is None:
467
+ # self.tokenizer.pad_token = self.tokenizer.eos_token
468
+
469
+ # # Load base model
470
+ # print(" Loading base model...")
471
+ # self.base_model = AutoModelForCausalLM.from_pretrained(
472
+ # self.base_model_name,
473
+ # torch_dtype=torch.float16,
474
+ # device_map="auto" if self.device == "cuda" else None,
475
+ # trust_remote_code=True,
476
+ # low_cpu_mem_usage=True
477
+ # )
478
+ # self.base_model.eval()
479
+
480
+ # # Load merged fine-tuned model
481
+ # if os.path.exists(self.merged_model_path):
482
+ # print(" Loading merged fine-tuned model...")
483
+ # self.finetuned_model = AutoModelForCausalLM.from_pretrained(
484
+ # self.merged_model_path,
485
+ # torch_dtype=torch.float16,
486
+ # device_map="auto" if self.device == "cuda" else None,
487
+ # trust_remote_code=True,
488
+ # low_cpu_mem_usage=True
489
+ # )
490
+ # else:
491
+ # print(" Loading fine-tuned model (attempting PEFT)...")
492
+ # try:
493
+ # base_for_peft = AutoModelForCausalLM.from_pretrained(
494
+ # self.base_model_name,
495
+ # torch_dtype=torch.float16,
496
+ # device_map="auto" if self.device == "cuda" else None,
497
+ # trust_remote_code=True,
498
+ # low_cpu_mem_usage=True
499
+ # )
500
+ # self.finetuned_model = PeftModel.from_pretrained(
501
+ # base_for_peft,
502
+ # self.finetuned_model_path,
503
+ # torch_dtype=torch.float16
504
+ # )
505
+ # except:
506
+ # self.finetuned_model = AutoModelForCausalLM.from_pretrained(
507
+ # self.finetuned_model_path,
508
+ # torch_dtype=torch.float16,
509
+ # device_map="auto" if self.device == "cuda" else None,
510
+ # trust_remote_code=True,
511
+ # low_cpu_mem_usage=True
512
+ # )
513
+
514
+ # self.finetuned_model.eval()
515
+ # print("✅ Models loaded successfully!")
516
+
517
+ # def generate_response(self, model, prompt: str, max_length: int = 150) -> str:
518
+ # """Generate response from model"""
519
+ # inputs = self.tokenizer(
520
+ # prompt,
521
+ # return_tensors="pt",
522
+ # truncation=True,
523
+ # max_length=512
524
+ # )
525
+
526
+ # if self.device == "cuda":
527
+ # inputs = {k: v.cuda() for k, v in inputs.items()}
528
+
529
+ # with torch.no_grad():
530
+ # outputs = model.generate(
531
+ # **inputs,
532
+ # max_new_tokens=max_length,
533
+ # temperature=0.7,
534
+ # do_sample=True,
535
+ # top_p=0.9,
536
+ # pad_token_id=self.tokenizer.pad_token_id,
537
+ # eos_token_id=self.tokenizer.eos_token_id
538
+ # )
539
+
540
+ # response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
541
+ # # Extract only the generated response
542
+ # if "### Response:" in response:
543
+ # response = response.split("### Response:")[-1].strip()
544
+ # elif "Response:" in response:
545
+ # response = response.split("Response:")[-1].strip()
546
+ # else:
547
+ # # Remove the input prompt from response
548
+ # response = response[len(prompt):].strip()
549
+
550
+ # return response
551
+
552
+ # def tokenize_japanese(self, text: str) -> List[str]:
553
+ # """Tokenize Japanese text for BLEU calculation"""
554
+ # # Simple character-based tokenization for Japanese
555
+ # # In production, use MeCab or similar for better tokenization
556
+ # import re
557
+
558
+ # # Remove special characters and split
559
+ # text = re.sub(r'[。、!?\n]', ' ', text)
560
+ # tokens = text.strip().split()
561
+
562
+ # # Character-level tokenization as fallback
563
+ # if not tokens:
564
+ # tokens = list(text.strip())
565
+
566
+ # return tokens
567
+
568
+ # def calculate_bleu_scores(self, reference: str, hypothesis: str) -> Dict[str, float]:
569
+ # """Calculate BLEU-1, BLEU-2, BLEU-3, BLEU-4 scores"""
570
+ # # Tokenize texts
571
+ # ref_tokens = self.tokenize_japanese(reference)
572
+ # hyp_tokens = self.tokenize_japanese(hypothesis)
573
+
574
+ # # Calculate BLEU scores with different n-grams
575
+ # scores = {}
576
+
577
+ # # BLEU-1 (unigram)
578
+ # scores['BLEU-1'] = sentence_bleu(
579
+ # [ref_tokens], hyp_tokens,
580
+ # weights=(1.0, 0, 0, 0),
581
+ # smoothing_function=self.smoothing
582
+ # )
583
+
584
+ # # BLEU-2 (bigram)
585
+ # scores['BLEU-2'] = sentence_bleu(
586
+ # [ref_tokens], hyp_tokens,
587
+ # weights=(0.5, 0.5, 0, 0),
588
+ # smoothing_function=self.smoothing
589
+ # )
590
+
591
+ # # BLEU-3 (trigram)
592
+ # scores['BLEU-3'] = sentence_bleu(
593
+ # [ref_tokens], hyp_tokens,
594
+ # weights=(0.33, 0.33, 0.34, 0),
595
+ # smoothing_function=self.smoothing
596
+ # )
597
+
598
+ # # BLEU-4 (4-gram)
599
+ # scores['BLEU-4'] = sentence_bleu(
600
+ # [ref_tokens], hyp_tokens,
601
+ # weights=(0.25, 0.25, 0.25, 0.25),
602
+ # smoothing_function=self.smoothing
603
+ # )
604
+
605
+ # return scores
606
+
607
+ # def calculate_rouge_scores(self, reference: str, hypothesis: str) -> Dict[str, float]:
608
+ # """Calculate ROUGE-1, ROUGE-2, ROUGE-L scores"""
609
+ # scores = self.rouge_scorer.score(reference, hypothesis)
610
+
611
+ # return {
612
+ # 'ROUGE-1': scores['rouge1'].fmeasure,
613
+ # 'ROUGE-2': scores['rouge2'].fmeasure,
614
+ # 'ROUGE-L': scores['rougeL'].fmeasure
615
+ # }
616
+
617
+ # def run_bleu_rouge_benchmark(self, num_samples: int = None):
618
+ # """Run comprehensive BLEU and ROUGE benchmark"""
619
+ # print("\n" + "="*70)
620
+ # print("🏃 RUNNING BLEU & ROUGE BENCHMARK")
621
+ # print("="*70)
622
+
623
+ # # Load test data
624
+ # test_data = self.load_test_data()
625
+
626
+ # if num_samples:
627
+ # test_data = test_data[:num_samples]
628
+ # print(f" Using {num_samples} samples for benchmarking")
629
+
630
+ # # Initialize score collectors
631
+ # base_scores = defaultdict(list)
632
+ # finetuned_scores = defaultdict(list)
633
+
634
+ # # Metrics to calculate
635
+ # metrics = ['BLEU-1', 'BLEU-2', 'BLEU-3', 'BLEU-4',
636
+ # 'ROUGE-1', 'ROUGE-2', 'ROUGE-L']
637
+
638
+ # print(f"\n📊 Evaluating {len(test_data)} test examples...")
639
+ # print("-" * 70)
640
+
641
+ # detailed_results = []
642
+
643
+ # for i, example in enumerate(tqdm(test_data, desc="Evaluating")):
644
+ # # Extract input and reference
645
+ # if 'input' in example:
646
+ # input_text = example['input']
647
+ # else:
648
+ # # Try to extract from text field
649
+ # if "### Input:" in example['text']:
650
+ # input_text = example['text'].split("### Input:")[1].split("### Response:")[0].strip()
651
+ # else:
652
+ # input_text = example['text'].split("\n")[0].strip()
653
+
654
+ # if 'reference' in example:
655
+ # reference = example['reference']
656
+ # else:
657
+ # # Try to extract from text field
658
+ # if "### Response:" in example['text']:
659
+ # reference = example['text'].split("### Response:")[1].strip()
660
+ # else:
661
+ # parts = example['text'].split("\n")
662
+ # reference = parts[1] if len(parts) > 1 else parts[0]
663
+
664
+ # # Format input for models
665
+ # formatted_input = f"### Instruction:\nあなたは思いやりのある心理カウンセラーです。\n\n### Input:\n{input_text}\n\n### Response:\n"
666
+
667
+ # # Generate responses
668
+ # base_response = self.generate_response(self.base_model, formatted_input)
669
+ # finetuned_response = self.generate_response(self.finetuned_model, formatted_input)
670
+
671
+ # # Calculate BLEU scores
672
+ # base_bleu = self.calculate_bleu_scores(reference, base_response)
673
+ # finetuned_bleu = self.calculate_bleu_scores(reference, finetuned_response)
674
+
675
+ # # Calculate ROUGE scores
676
+ # base_rouge = self.calculate_rouge_scores(reference, base_response)
677
+ # finetuned_rouge = self.calculate_rouge_scores(reference, finetuned_response)
678
+
679
+ # # Combine scores
680
+ # base_all_scores = {**base_bleu, **base_rouge}
681
+ # finetuned_all_scores = {**finetuned_bleu, **finetuned_rouge}
682
+
683
+ # # Collect scores
684
+ # for metric in metrics:
685
+ # base_scores[metric].append(base_all_scores[metric])
686
+ # finetuned_scores[metric].append(finetuned_all_scores[metric])
687
+
688
+ # # Store detailed results
689
+ # detailed_results.append({
690
+ # 'input': input_text,
691
+ # 'reference': reference,
692
+ # 'base_response': base_response,
693
+ # 'finetuned_response': finetuned_response,
694
+ # 'base_scores': base_all_scores,
695
+ # 'finetuned_scores': finetuned_all_scores
696
+ # })
697
+
698
+ # # Print sample results
699
+ # if i < 3: # Show first 3 examples
700
+ # print(f"\n📝 Example {i+1}:")
701
+ # print(f" Input: {input_text[:50]}...")
702
+ # print(f" Reference: {reference[:50]}...")
703
+ # print(f" Base response: {base_response[:50]}...")
704
+ # print(f" Fine-tuned response: {finetuned_response[:50]}...")
705
+ # print(f" Base BLEU-4: {base_bleu['BLEU-4']:.3f}")
706
+ # print(f" Fine-tuned BLEU-4: {finetuned_bleu['BLEU-4']:.3f}")
707
+
708
+ # # Calculate aggregate statistics
709
+ # print("\n" + "="*70)
710
+ # print("📈 BENCHMARK RESULTS")
711
+ # print("="*70)
712
+
713
+ # self.results = {
714
+ # 'detailed_results': detailed_results,
715
+ # 'aggregate_scores': {},
716
+ # 'improvements': {}
717
+ # }
718
+
719
+ # # Print and store results
720
+ # print("\n" + "-"*70)
721
+ # print(f"{'Metric':<12} {'Base Model':<20} {'Fine-tuned Model':<20} {'Improvement':<15}")
722
+ # print("-"*70)
723
+
724
+ # for metric in metrics:
725
+ # base_mean = np.mean(base_scores[metric])
726
+ # base_std = np.std(base_scores[metric])
727
+ # finetuned_mean = np.mean(finetuned_scores[metric])
728
+ # finetuned_std = np.std(finetuned_scores[metric])
729
+
730
+ # # Calculate improvement
731
+ # if base_mean > 0:
732
+ # improvement = ((finetuned_mean - base_mean) / base_mean) * 100
733
+ # else:
734
+ # improvement = 0
735
+
736
+ # # Store results
737
+ # self.results['aggregate_scores'][metric] = {
738
+ # 'base_mean': base_mean,
739
+ # 'base_std': base_std,
740
+ # 'finetuned_mean': finetuned_mean,
741
+ # 'finetuned_std': finetuned_std
742
+ # }
743
+ # self.results['improvements'][metric] = improvement
744
+
745
+ # # Print results
746
+ # base_str = f"{base_mean:.3f} (±{base_std:.3f})"
747
+ # finetuned_str = f"{finetuned_mean:.3f} (±{finetuned_std:.3f})"
748
+ # imp_str = f"{improvement:+.1f}%"
749
+
750
+ # # Color code improvement
751
+ # if improvement > 0:
752
+ # imp_str = f"✅ {imp_str}"
753
+ # elif improvement < 0:
754
+ # imp_str = f"⚠️ {imp_str}"
755
+ # else:
756
+ # imp_str = f"➖ {imp_str}"
757
+
758
+ # print(f"{metric:<12} {base_str:<20} {finetuned_str:<20} {imp_str:<15}")
759
+
760
+ # # Calculate overall scores
761
+ # print("\n" + "="*70)
762
+ # print("🎯 OVERALL PERFORMANCE")
763
+ # print("="*70)
764
+
765
+ # # Average BLEU score
766
+ # bleu_metrics = ['BLEU-1', 'BLEU-2', 'BLEU-3', 'BLEU-4']
767
+ # base_bleu_avg = np.mean([np.mean(base_scores[m]) for m in bleu_metrics])
768
+ # finetuned_bleu_avg = np.mean([np.mean(finetuned_scores[m]) for m in bleu_metrics])
769
+ # bleu_improvement = ((finetuned_bleu_avg - base_bleu_avg) / base_bleu_avg) * 100 if base_bleu_avg > 0 else 0
770
+
771
+ # # Average ROUGE score
772
+ # rouge_metrics = ['ROUGE-1', 'ROUGE-2', 'ROUGE-L']
773
+ # base_rouge_avg = np.mean([np.mean(base_scores[m]) for m in rouge_metrics])
774
+ # finetuned_rouge_avg = np.mean([np.mean(finetuned_scores[m]) for m in rouge_metrics])
775
+ # rouge_improvement = ((finetuned_rouge_avg - base_rouge_avg) / base_rouge_avg) * 100 if base_rouge_avg > 0 else 0
776
+
777
+ # # Overall average
778
+ # base_overall = np.mean([np.mean(base_scores[m]) for m in metrics])
779
+ # finetuned_overall = np.mean([np.mean(finetuned_scores[m]) for m in metrics])
780
+ # overall_improvement = ((finetuned_overall - base_overall) / base_overall) * 100 if base_overall > 0 else 0
781
+
782
+ # self.results['summary'] = {
783
+ # 'bleu_average': {
784
+ # 'base': base_bleu_avg,
785
+ # 'finetuned': finetuned_bleu_avg,
786
+ # 'improvement': bleu_improvement
787
+ # },
788
+ # 'rouge_average': {
789
+ # 'base': base_rouge_avg,
790
+ # 'finetuned': finetuned_rouge_avg,
791
+ # 'improvement': rouge_improvement
792
+ # },
793
+ # 'overall': {
794
+ # 'base': base_overall,
795
+ # 'finetuned': finetuned_overall,
796
+ # 'improvement': overall_improvement
797
+ # }
798
+ # }
799
+
800
+ # print(f"\n📊 Average BLEU Score:")
801
+ # print(f" Base Model: {base_bleu_avg:.3f}")
802
+ # print(f" Fine-tuned Model: {finetuned_bleu_avg:.3f}")
803
+ # print(f" Improvement: {bleu_improvement:+.1f}%")
804
+
805
+ # print(f"\n📊 Average ROUGE Score:")
806
+ # print(f" Base Model: {base_rouge_avg:.3f}")
807
+ # print(f" Fine-tuned Model: {finetuned_rouge_avg:.3f}")
808
+ # print(f" Improvement: {rouge_improvement:+.1f}%")
809
+
810
+ # print(f"\n🎯 Overall Average:")
811
+ # print(f" Base Model: {base_overall:.3f}")
812
+ # print(f" Fine-tuned Model: {finetuned_overall:.3f}")
813
+ # print(f" Improvement: {overall_improvement:+.1f}%")
814
+
815
+ # print("="*70)
816
+
817
+ # return self.results
818
+
819
+ # def visualize_results(self, save_path: str = "bleu_rouge_benchmark.png"):
820
+ # """Create comprehensive visualization of BLEU and ROUGE results"""
821
+ # if 'aggregate_scores' not in self.results:
822
+ # print("❌ No results to visualize. Run benchmark first.")
823
+ # return
824
+
825
+ # print("\n📊 Creating visualizations...")
826
+
827
+ # fig, axes = plt.subplots(2, 3, figsize=(18, 12))
828
+
829
+ # # Color scheme
830
+ # base_color = '#3498db'
831
+ # finetuned_color = '#e74c3c'
832
+ # improvement_positive = '#27ae60'
833
+ # improvement_negative = '#c0392b'
834
+
835
+ # # 1. BLEU Scores Comparison
836
+ # bleu_metrics = ['BLEU-1', 'BLEU-2', 'BLEU-3', 'BLEU-4']
837
+ # bleu_base = [self.results['aggregate_scores'][m]['base_mean'] for m in bleu_metrics]
838
+ # bleu_finetuned = [self.results['aggregate_scores'][m]['finetuned_mean'] for m in bleu_metrics]
839
+
840
+ # x = np.arange(len(bleu_metrics))
841
+ # width = 0.35
842
+
843
+ # axes[0, 0].bar(x - width/2, bleu_base, width, label='Base Model',
844
+ # color=base_color, alpha=0.8)
845
+ # axes[0, 0].bar(x + width/2, bleu_finetuned, width, label='Fine-tuned Model',
846
+ # color=finetuned_color, alpha=0.8)
847
+ # axes[0, 0].set_xlabel('BLEU Metrics')
848
+ # axes[0, 0].set_ylabel('Score')
849
+ # axes[0, 0].set_title('BLEU Score Comparison')
850
+ # axes[0, 0].set_xticks(x)
851
+ # axes[0, 0].set_xticklabels(bleu_metrics)
852
+ # axes[0, 0].legend()
853
+ # axes[0, 0].grid(True, alpha=0.3)
854
+ # axes[0, 0].set_ylim([0, max(max(bleu_base), max(bleu_finetuned)) * 1.2])
855
+
856
+ # # 2. ROUGE Scores Comparison
857
+ # rouge_metrics = ['ROUGE-1', 'ROUGE-2', 'ROUGE-L']
858
+ # rouge_base = [self.results['aggregate_scores'][m]['base_mean'] for m in rouge_metrics]
859
+ # rouge_finetuned = [self.results['aggregate_scores'][m]['finetuned_mean'] for m in rouge_metrics]
860
+
861
+ # x = np.arange(len(rouge_metrics))
862
+
863
+ # axes[0, 1].bar(x - width/2, rouge_base, width, label='Base Model',
864
+ # color=base_color, alpha=0.8)
865
+ # axes[0, 1].bar(x + width/2, rouge_finetuned, width, label='Fine-tuned Model',
866
+ # color=finetuned_color, alpha=0.8)
867
+ # axes[0, 1].set_xlabel('ROUGE Metrics')
868
+ # axes[0, 1].set_ylabel('Score')
869
+ # axes[0, 1].set_title('ROUGE Score Comparison')
870
+ # axes[0, 1].set_xticks(x)
871
+ # axes[0, 1].set_xticklabels(rouge_metrics)
872
+ # axes[0, 1].legend()
873
+ # axes[0, 1].grid(True, alpha=0.3)
874
+ # axes[0, 1].set_ylim([0, max(max(rouge_base), max(rouge_finetuned)) * 1.2])
875
+
876
+ # # 3. Improvement Percentages
877
+ # all_metrics = bleu_metrics + rouge_metrics
878
+ # improvements = [self.results['improvements'][m] for m in all_metrics]
879
+ # colors = [improvement_positive if imp > 0 else improvement_negative for imp in improvements]
880
+
881
+ # axes[0, 2].barh(range(len(all_metrics)), improvements, color=colors, alpha=0.7)
882
+ # axes[0, 2].set_yticks(range(len(all_metrics)))
883
+ # axes[0, 2].set_yticklabels(all_metrics)
884
+ # axes[0, 2].set_xlabel('Improvement (%)')
885
+ # axes[0, 2].set_title('Performance Improvement by Metric')
886
+ # axes[0, 2].axvline(x=0, color='black', linestyle='-', linewidth=0.5)
887
+ # axes[0, 2].grid(True, alpha=0.3, axis='x')
888
+
889
+ # # 4. Line plot showing progression
890
+ # axes[1, 0].plot(bleu_metrics, bleu_base, 'o-', label='Base Model',
891
+ # color=base_color, linewidth=2, markersize=8)
892
+ # axes[1, 0].plot(bleu_metrics, bleu_finetuned, 's-', label='Fine-tuned Model',
893
+ # color=finetuned_color, linewidth=2, markersize=8)
894
+ # axes[1, 0].set_xlabel('BLEU N-gram')
895
+ # axes[1, 0].set_ylabel('Score')
896
+ # axes[1, 0].set_title('BLEU Score Progression')
897
+ # axes[1, 0].legend()
898
+ # axes[1, 0].grid(True, alpha=0.3)
899
+
900
+ # # 5. Summary Statistics
901
+ # ax5 = axes[1, 1]
902
+ # ax5.axis('off')
903
+
904
+ # summary_text = f"""
905
+ # BENCHMARK SUMMARY
906
+ # {'='*30}
907
+
908
+ # BLEU Average:
909
+ # Base: {self.results['summary']['bleu_average']['base']:.3f}
910
+ # Fine-tuned: {self.results['summary']['bleu_average']['finetuned']:.3f}
911
+ # Improvement: {self.results['summary']['bleu_average']['improvement']:+.1f}%
912
+
913
+ # ROUGE Average:
914
+ # Base: {self.results['summary']['rouge_average']['base']:.3f}
915
+ # Fine-tuned: {self.results['summary']['rouge_average']['finetuned']:.3f}
916
+ # Improvement: {self.results['summary']['rouge_average']['improvement']:+.1f}%
917
+
918
+ # Overall Performance:
919
+ # Base: {self.results['summary']['overall']['base']:.3f}
920
+ # Fine-tuned: {self.results['summary']['overall']['finetuned']:.3f}
921
+ # Improvement: {self.results['summary']['overall']['improvement']:+.1f}%
922
+
923
+ # Best Improvements:
924
+ # """
925
+
926
+ # # Find best improvements
927
+ # sorted_metrics = sorted(all_metrics,
928
+ # key=lambda m: self.results['improvements'][m],
929
+ # reverse=True)
930
+
931
+ # for m in sorted_metrics[:2]:
932
+ # summary_text += f" • {m}: {self.results['improvements'][m]:+.1f}%\n"
933
+
934
+ # if any(self.results['improvements'][m] < 0 for m in all_metrics):
935
+ # summary_text += f"\nNeeds Attention:\n"
936
+ # for m in sorted_metrics[-2:]:
937
+ # if self.results['improvements'][m] < 0:
938
+ # summary_text += f" • {m}: {self.results['improvements'][m]:+.1f}%\n"
939
+
940
+ # ax5.text(0.1, 0.9, summary_text, transform=ax5.transAxes,
941
+ # fontsize=10, verticalalignment='top', fontfamily='monospace')
942
+
943
+ # # 6. Heatmap of all scores
944
+ # metrics_for_heatmap = all_metrics
945
+ # models = ['Base', 'Fine-tuned']
946
+
947
+ # heatmap_data = []
948
+ # for metric in metrics_for_heatmap:
949
+ # heatmap_data.append([
950
+ # self.results['aggregate_scores'][metric]['base_mean'],
951
+ # self.results['aggregate_scores'][metric]['finetuned_mean']
952
+ # ])
953
+
954
+ # im = axes[1, 2].imshow(heatmap_data, cmap='YlOrRd', aspect='auto')
955
+ # axes[1, 2].set_xticks(np.arange(len(models)))
956
+ # axes[1, 2].set_yticks(np.arange(len(metrics_for_heatmap)))
957
+ # axes[1, 2].set_xticklabels(models)
958
+ # axes[1, 2].set_yticklabels(metrics_for_heatmap)
959
+ # axes[1, 2].set_title('Score Heatmap')
960
+
961
+ # # Add text annotations
962
+ # for i in range(len(metrics_for_heatmap)):
963
+ # for j in range(len(models)):
964
+ # text = axes[1, 2].text(j, i, f'{heatmap_data[i][j]:.3f}',
965
+ # ha="center", va="center", color="black", fontsize=8)
966
+
967
+ # plt.colorbar(im, ax=axes[1, 2])
968
+
969
+ # plt.suptitle('BLEU & ROUGE Benchmark Results', fontsize=16, fontweight='bold')
970
+ # plt.tight_layout()
971
+ # plt.savefig(save_path, dpi=300, bbox_inches='tight')
972
+ # print(f"✅ Visualization saved to {save_path}")
973
+
974
+ # plt.show()
975
+
976
+ # def save_results(self, output_path: str = "bleu_rouge_results.json"):
977
+ # """Save benchmark results to JSON"""
978
+ # # Convert numpy types to Python native types for JSON serialization
979
+ # def convert_to_native(obj):
980
+ # if isinstance(obj, np.floating):
981
+ # return float(obj)
982
+ # elif isinstance(obj, np.integer):
983
+ # return int(obj)
984
+ # elif isinstance(obj, np.ndarray):
985
+ # return obj.tolist()
986
+ # elif isinstance(obj, dict):
987
+ # return {k: convert_to_native(v) for k, v in obj.items()}
988
+ # elif isinstance(obj, list):
989
+ # return [convert_to_native(item) for item in obj]
990
+ # return obj
991
+
992
+ # results_native = convert_to_native(self.results)
993
+
994
+ # with open(output_path, 'w', encoding='utf-8') as f:
995
+ # json.dump(results_native, f, ensure_ascii=False, indent=2)
996
+ # print(f"✅ Results saved to {output_path}")
997
+
998
+ # def generate_detailed_report(self, output_path: str = "bleu_rouge_report.md"):
999
+ # """Generate detailed markdown report"""
1000
+ # if not self.results:
1001
+ # print("❌ No results to report. Run benchmark first.")
1002
+ # return
1003
+
1004
+ # report = f"""# BLEU & ROUGE Benchmark Report
1005
+ # Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
1006
+
1007
+ # ## Executive Summary
1008
+
1009
+ # Comprehensive evaluation of the fine-tuned counseling model using BLEU and ROUGE metrics.
1010
+
1011
+ # ### Overall Performance
1012
+ # - **Base Model Score**: {self.results['summary']['overall']['base']:.3f}
1013
+ # - **Fine-tuned Model Score**: {self.results['summary']['overall']['finetuned']:.3f}
1014
+ # - **Overall Improvement**: {self.results['summary']['overall']['improvement']:+.1f}%
1015
+
1016
+ # ## Detailed Metrics
1017
+
1018
+ # ### BLEU Scores
1019
+ # | Metric | Base Model | Fine-tuned Model | Improvement |
1020
+ # |--------|------------|------------------|-------------|
1021
+ # """
1022
+
1023
+ # for metric in ['BLEU-1', 'BLEU-2', 'BLEU-3', 'BLEU-4']:
1024
+ # scores = self.results['aggregate_scores'][metric]
1025
+ # report += f"| {metric} | {scores['base_mean']:.3f} (±{scores['base_std']:.3f}) | "
1026
+ # report += f"{scores['finetuned_mean']:.3f} (±{scores['finetuned_std']:.3f}) | "
1027
+ # report += f"{self.results['improvements'][metric]:+.1f}% |\n"
1028
+
1029
+ # report += f"""
1030
+
1031
+ # **BLEU Average**: {self.results['summary']['bleu_average']['improvement']:+.1f}% improvement
1032
+
1033
+ # ### ROUGE Scores
1034
+ # | Metric | Base Model | Fine-tuned Model | Improvement |
1035
+ # |--------|------------|------------------|-------------|
1036
+ # """
1037
+
1038
+ # for metric in ['ROUGE-1', 'ROUGE-2', 'ROUGE-L']:
1039
+ # scores = self.results['aggregate_scores'][metric]
1040
+ # report += f"| {metric} | {scores['base_mean']:.3f} (±{scores['base_std']:.3f}) | "
1041
+ # report += f"{scores['finetuned_mean']:.3f} (±{scores['finetuned_std']:.3f}) | "
1042
+ # report += f"{self.results['improvements'][metric]:+.1f}% |\n"
1043
+
1044
+ # report += f"""
1045
+
1046
+ # **ROUGE Average**: {self.results['summary']['rouge_average']['improvement']:+.1f}% improvement
1047
+
1048
+ # ## Sample Outputs
1049
+
1050
+ # """
1051
+
1052
+ # # Add sample outputs
1053
+ # for i, result in enumerate(self.results['detailed_results'][:3]):
1054
+ # report += f"""### Example {i+1}
1055
+
1056
+ # **Input**: {result['input']}
1057
+
1058
+ # **Reference**: {result['reference'][:200]}...
1059
+
1060
+ # **Base Model Response**: {result['base_response'][:200]}...
1061
+
1062
+ # **Fine-tuned Model Response**: {result['finetuned_response'][:200]}...
1063
+
1064
+ # **Scores**:
1065
+ # - Base BLEU-4: {result['base_scores']['BLEU-4']:.3f}, ROUGE-L: {result['base_scores']['ROUGE-L']:.3f}
1066
+ # - Fine-tuned BLEU-4: {result['finetuned_scores']['BLEU-4']:.3f}, ROUGE-L: {result['finetuned_scores']['ROUGE-L']:.3f}
1067
+
1068
+ # ---
1069
+
1070
+ # """
1071
+
1072
+ # report += """## Analysis & Recommendations
1073
+
1074
+ # """
1075
+
1076
+ # overall_imp = self.results['summary']['overall']['improvement']
1077
+
1078
+ # if overall_imp < -10:
1079
+ # report += """### ⚠️ Significant Performance Degradation
1080
+
1081
+ # The fine-tuned model shows significant degradation in BLEU/ROUGE scores. This indicates:
1082
+
1083
+ # 1. **Catastrophic Forgetting**: The model has lost its language generation capabilities
1084
+ # 2. **Overfitting**: The model memorized training data instead of learning patterns
1085
+ # 3. **Format Mismatch**: Training and inference formats may differ
1086
+
1087
+ # **Immediate Actions Required**:
1088
+ # - ✅ Ensure proper model merging (LoRA weights with base model)
1089
+ # - ✅ Reduce learning rate (try 1e-5 or 2e-5)
1090
+ # - ✅ Use smaller LoRA rank (r=4 or r=8)
1091
+ # - ✅ Mix general conversation data with counseling data (80/20 ratio)
1092
+ # - ✅ Implement regularization (weight decay=0.1, dropout=0.1)
1093
+ # - ✅ Use early stopping with patience=3
1094
+ # """
1095
+ # elif overall_imp < 0:
1096
+ # report += """### ⚠️ Minor Performance Degradation
1097
+
1098
+ # The model shows slight degradation. Common causes:
1099
+
1100
+ # 1. **Aggressive Fine-tuning**: Parameters changed too much
1101
+ # 2. **Limited Training Data**: Not enough diverse examples
1102
+ # 3. **Domain Shift**: Counseling domain too different from base training
1103
+
1104
+ # **Recommended Actions**:
1105
+ # - ✅ Fine-tune for fewer epochs (1-2 instead of 3)
1106
+ # - ✅ Use gradient accumulation for larger effective batch size
1107
+ # - ✅ Implement knowledge distillation from base model
1108
+ # - ✅ Add more diverse training examples
1109
+ # """
1110
+ # elif overall_imp < 10:
1111
+ # report += """### 📊 Modest Improvement
1112
+
1113
+ # The model shows small but positive improvements.
1114
+
1115
+ # **To Further Improve**:
1116
+ # - ✅ Increase training data quality and quantity
1117
+ # - ✅ Experiment with different generation parameters
1118
+ # - ✅ Fine-tune on domain-specific pre-training
1119
+ # - ✅ Use ensemble methods with base model
1120
+ # """
1121
+ # else:
1122
+ # report += """### ✅ Significant Improvement
1123
+
1124
+ # Excellent results! The fine-tuned model shows substantial improvements.
1125
+
1126
+ # **Next Steps**:
1127
+ # - ✅ Deploy for A/B testing with users
1128
+ # - ✅ Monitor performance on edge cases
1129
+ # - ✅ Consider model compression for deployment
1130
+ # - ✅ Collect user feedback for iterative improvement
1131
+ # """
1132
+
1133
+ # with open(output_path, 'w', encoding='utf-8') as f:
1134
+ # f.write(report)
1135
+
1136
+ # print(f"✅ Detailed report saved to {output_path}")
1137
+
1138
+ # # Main execution
1139
+ # if __name__ == "__main__":
1140
+ # import argparse
1141
+
1142
+ # parser = argparse.ArgumentParser(description='Advanced BLEU & ROUGE Benchmark')
1143
+ # parser.add_argument('--base_model', type=str, default='LiquidAI/LFM2-2.6B',
1144
+ # help='Base model name')
1145
+ # parser.add_argument('--finetuned_path', type=str, default='./counselor_model/best_model',
1146
+ # help='Path to fine-tuned model')
1147
+ # parser.add_argument('--merged_path', type=str, default='./merged_counselor_mode_2b',
1148
+ # help='Path to save/load merged model')
1149
+ # parser.add_argument('--test_data', type=str, default='./processed_data_score80/test.jsonl',
1150
+ # help='Path to test data')
1151
+ # parser.add_argument('--num_samples', type=int, default=None,
1152
+ # help='Number of samples to evaluate (None for all)')
1153
+ # parser.add_argument('--force_merge', action='store_true',
1154
+ # help='Force re-merge even if merged model exists')
1155
+ # parser.add_argument('--skip_merge', action='store_true',
1156
+ # help='Skip merging step')
1157
+ # parser.add_argument('--output_dir', type=str, default='./benchmark_results',
1158
+ # help='Directory to save results')
1159
+
1160
+ # args = parser.parse_args()
1161
+
1162
+ # # Create output directory
1163
+ # os.makedirs(args.output_dir, exist_ok=True)
1164
+
1165
+ # try:
1166
+ # # Initialize benchmark
1167
+ # print("🚀 Initializing Advanced BLEU & ROUGE Benchmark")
1168
+ # benchmark = AdvancedCounselorBenchmark(
1169
+ # base_model_name=args.base_model,
1170
+ # finetuned_model_path=args.finetuned_path,
1171
+ # merged_model_path=args.merged_path,
1172
+ # test_data_path=args.test_data
1173
+ # )
1174
+
1175
+ # # Merge models if needed
1176
+ # if not args.skip_merge:
1177
+ # benchmark.merge_and_save_model(force_merge=args.force_merge)
1178
+
1179
+ # # Load models
1180
+ # benchmark.load_models()
1181
+
1182
+ # # Run BLEU & ROUGE benchmark
1183
+ # results = benchmark.run_bleu_rouge_benchmark(num_samples=args.num_samples)
1184
+
1185
+ # # Save results
1186
+ # benchmark.save_results(os.path.join(args.output_dir, "bleu_rouge_results_2b.json"))
1187
+
1188
+ # # Generate visualizations
1189
+ # benchmark.visualize_results(os.path.join(args.output_dir, "bleu_rouge_visualization_2b.png"))
1190
+
1191
+ # # Generate detailed report
1192
+ # benchmark.generate_detailed_report(os.path.join(args.output_dir, "bleu_rouge_report_2b.md"))
1193
+
1194
+ # print("\n✅ BLEU & ROUGE Benchmarking completed successfully!")
1195
+ # print(f"📁 Results saved to {args.output_dir}/")
1196
+
1197
+ # except Exception as e:
1198
+ # print(f"\n❌ Error during benchmarking: {e}")
1199
+ # import traceback
1200
+ # traceback.print_exc()
1201
+
benchmark_report.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ================================================================================
2
+ COUNSELOR MODEL BENCHMARK REPORT
3
+ ================================================================================
4
+
5
+ PERFORMANCE COMPARISON:
6
+ ----------------------------------------
7
+
8
+ LENGTH_SCORE:
9
+ Base Model: 0.809 (±0.153)
10
+ Fine-tuned Model: 0.840 (±0.174)
11
+ Improvement: +3.8%
12
+
13
+ QUESTION_SCORE:
14
+ Base Model: 0.660 (±0.474)
15
+ Fine-tuned Model: 0.850 (±0.357)
16
+ Improvement: +28.8%
17
+
18
+ SUPPORT_SCORE:
19
+ Base Model: 0.248 (±0.184)
20
+ Fine-tuned Model: 0.088 (±0.124)
21
+ Improvement: -64.5%
22
+
23
+ EMPATHY_SCORE:
24
+ Base Model: 0.262 (±0.086)
25
+ Fine-tuned Model: 0.152 (±0.114)
26
+ Improvement: -42.0%
27
+
28
+ ========================================
29
+ OVERALL PERFORMANCE:
30
+ Base Model: 0.495
31
+ Fine-tuned Model: 0.483
32
+ Overall Improvement: -2.5%
33
+ ========================================
benchmark_report_2b-Copy1.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ================================================================================
2
+ COUNSELOR MODEL BENCHMARK REPORT
3
+ ================================================================================
4
+
5
+ PERFORMANCE COMPARISON:
6
+ ----------------------------------------
7
+
8
+ LENGTH_SCORE:
9
+ Base Model: 0.876 (±0.138)
10
+ Fine-tuned Model: 0.956 (±0.135)
11
+ Improvement: +9.2%
12
+
13
+ QUESTION_SCORE:
14
+ Base Model: 0.670 (±0.470)
15
+ Fine-tuned Model: 0.900 (±0.300)
16
+ Improvement: +34.3%
17
+
18
+ ========================================
19
+ OVERALL PERFORMANCE:
20
+ Base Model: 0.773
21
+ Fine-tuned Model: 0.928
22
+ Overall Improvement: +20.1%
23
+ ========================================
benchmark_report_2b.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ================================================================================
2
+ COUNSELOR MODEL BENCHMARK REPORT
3
+ ================================================================================
4
+
5
+ PERFORMANCE COMPARISON:
6
+ ----------------------------------------
7
+
8
+ LENGTH_SCORE:
9
+ Base Model: 0.876 (±0.138)
10
+ Fine-tuned Model: 0.956 (±0.135)
11
+ Improvement: +9.2%
12
+
13
+ QUESTION_SCORE:
14
+ Base Model: 0.670 (±0.470)
15
+ Fine-tuned Model: 0.900 (±0.300)
16
+ Improvement: +34.3%
17
+
18
+ ========================================
19
+ OVERALL PERFORMANCE:
20
+ Base Model: 0.773
21
+ Fine-tuned Model: 0.928
22
+ Overall Improvement: +20.1%
23
+ ========================================
benchmark_report_v2.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ================================================================================
2
+ COUNSELOR MODEL BENCHMARK REPORT
3
+ ================================================================================
4
+
5
+ PERFORMANCE COMPARISON:
6
+ ----------------------------------------
7
+
8
+ LENGTH_SCORE:
9
+ Base Model: 0.785 (±0.146)
10
+ Fine-tuned Model: 0.822 (±0.189)
11
+ Improvement: +4.8%
12
+
13
+ QUESTION_SCORE:
14
+ Base Model: 0.680 (±0.466)
15
+ Fine-tuned Model: 0.870 (±0.336)
16
+ Improvement: +27.9%
17
+
18
+ ========================================
19
+ OVERALL PERFORMANCE:
20
+ Base Model: 0.732
21
+ Fine-tuned Model: 0.846
22
+ Overall Improvement: +15.5%
23
+ ========================================
benchmark_report_wo_merging.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ================================================================================
2
+ COUNSELOR MODEL BENCHMARK REPORT
3
+ ================================================================================
4
+
5
+ PERFORMANCE COMPARISON:
6
+ ----------------------------------------
7
+
8
+ LENGTH_SCORE:
9
+ Base Model: 0.807 (±0.154)
10
+ Fine-tuned Model: 0.808 (±0.202)
11
+ Improvement: +0.1%
12
+
13
+ QUESTION_SCORE:
14
+ Base Model: 0.670 (±0.470)
15
+ Fine-tuned Model: 0.910 (±0.286)
16
+ Improvement: +35.8%
17
+
18
+ SUPPORT_SCORE:
19
+ Base Model: 0.236 (±0.186)
20
+ Fine-tuned Model: 0.082 (±0.120)
21
+ Improvement: -65.3%
22
+
23
+ EMPATHY_SCORE:
24
+ Base Model: 0.267 (±0.099)
25
+ Fine-tuned Model: 0.141 (±0.100)
26
+ Improvement: -47.2%
27
+
28
+ ========================================
29
+ OVERALL PERFORMANCE:
30
+ Base Model: 0.495
31
+ Fine-tuned Model: 0.485
32
+ Overall Improvement: -2.0%
33
+ ========================================
benchmark_results.png ADDED

Git LFS Details

  • SHA256: d7df2f6780b3503d0d5eabe5374989e62691bbdd5528a1aa61cf69d9181550d6
  • Pointer size: 131 Bytes
  • Size of remote file: 155 kB
benchmark_v1.py ADDED
@@ -0,0 +1,803 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Comprehensive Japanese Counseling Model Benchmark Script
3
+ Based on KokoroChat paper evaluation methodology
4
+ """
5
+
6
+ import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ import numpy as np
9
+ from typing import List, Dict, Tuple, Optional, Any
10
+ import json
11
+ from tqdm import tqdm
12
+ import os
13
+ import gc
14
+ import warnings
15
+ from datetime import datetime
16
+ import pandas as pd
17
+ import matplotlib.pyplot as plt
18
+ import seaborn as sns
19
+ from collections import defaultdict
20
+ import MeCab
21
+ from rouge_score import rouge_scorer
22
+ from nltk.translate.bleu_score import sentence_bleu, corpus_bleu, SmoothingFunction
23
+ import sacrebleu
24
+ from bert_score import score as bert_score
25
+ import re
26
+ import statistics
27
+
28
+ warnings.filterwarnings('ignore')
29
+
30
+ # Set style for better visualizations
31
+ plt.style.use('seaborn-v0_8-darkgrid')
32
+ sns.set_palette("husl")
33
+
34
+ class JapaneseCounselingBenchmark:
35
+ """
36
+ Comprehensive benchmark suite for Japanese counseling models
37
+ Following KokoroChat paper evaluation methodology
38
+ """
39
+
40
+ def __init__(self,
41
+ base_model_name: str = "LiquidAI/LFM2-1.2B",
42
+ finetuned_model_path: str = "./merged_counselor_model",
43
+ test_data_path: str = "./processed_data_score70/test.jsonl",
44
+ device: str = None):
45
+ """
46
+ Initialize Japanese counseling benchmark
47
+
48
+ Args:
49
+ base_model_name: Name/path of base model
50
+ finetuned_model_path: Path to fine-tuned merged model
51
+ test_data_path: Path to test dataset
52
+ device: Device to run on (cuda/cpu)
53
+ """
54
+ self.base_model_name = base_model_name
55
+ self.finetuned_model_path = finetuned_model_path
56
+ self.test_data_path = test_data_path
57
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
58
+
59
+ print("="*80)
60
+ print("🎌 Japanese Counseling Model Benchmark Suite")
61
+ print("="*80)
62
+ print(f"📍 Device: {self.device}")
63
+ if self.device == "cuda":
64
+ print(f" GPU: {torch.cuda.get_device_name(0)}")
65
+ print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
66
+
67
+ # Initialize MeCab for Japanese tokenization
68
+ try:
69
+ self.mecab = MeCab.Tagger("-Owakati") # Wakati-gaki mode for word segmentation
70
+ print("✅ MeCab initialized for Japanese tokenization")
71
+ except:
72
+ print("⚠️ MeCab not available. Install with: apt-get install mecab libmecab-dev mecab-ipadic-utf8")
73
+ print(" and: pip install mecab-python3")
74
+ print(" Using fallback character-level tokenization")
75
+ self.mecab = None
76
+
77
+ # Initialize ROUGE scorer (without lang parameter)
78
+ self.rouge_scorer = rouge_scorer.RougeScorer(
79
+ ['rouge1', 'rouge2', 'rougeL'],
80
+ use_stemmer=False # Don't use stemming for Japanese
81
+ )
82
+
83
+ # Smoothing function for BLEU
84
+ self.smoothing = SmoothingFunction().method1
85
+
86
+ # Results storage
87
+ self.results = {}
88
+ self.detailed_results = []
89
+
90
+ def tokenize_japanese(self, text: str) -> List[str]:
91
+ """
92
+ Tokenize Japanese text using MeCab or fallback method
93
+
94
+ Args:
95
+ text: Japanese text to tokenize
96
+
97
+ Returns:
98
+ List of tokens
99
+ """
100
+ if self.mecab:
101
+ try:
102
+ # Use MeCab for proper Japanese tokenization
103
+ tokens = self.mecab.parse(text).strip().split()
104
+ return tokens if tokens else list(text)
105
+ except:
106
+ # Fallback if MeCab fails
107
+ pass
108
+
109
+ # Fallback to character-level tokenization
110
+ # Remove punctuation and split
111
+ text = re.sub(r'[。、!?\n\s]', ' ', text)
112
+ # Split by spaces and then into characters
113
+ words = text.split()
114
+ if words:
115
+ # Try to keep some word boundaries
116
+ tokens = []
117
+ for word in words:
118
+ if len(word) <= 4: # Keep short words together
119
+ tokens.append(word)
120
+ else: # Split longer words into characters
121
+ tokens.extend(list(word))
122
+ return tokens
123
+ else:
124
+ # Pure character-level tokenization
125
+ return list(text.replace(' ', ''))
126
+
127
+ def load_test_data(self, max_samples: Optional[int] = None) -> List[Dict]:
128
+ """
129
+ Load test dataset
130
+
131
+ Args:
132
+ max_samples: Maximum number of samples to load
133
+
134
+ Returns:
135
+ List of test examples
136
+ """
137
+ print(f"\n📚 Loading test data from {self.test_data_path}")
138
+
139
+ test_data = []
140
+
141
+ if not os.path.exists(self.test_data_path):
142
+ print(f"❌ Test data not found at {self.test_data_path}")
143
+ print(" Creating synthetic test data for demonstration...")
144
+ return self.create_synthetic_test_data()
145
+
146
+ with open(self.test_data_path, 'r', encoding='utf-8') as f:
147
+ for i, line in enumerate(f):
148
+ if max_samples and i >= max_samples:
149
+ break
150
+ try:
151
+ data = json.loads(line)
152
+
153
+ # Parse the text field to extract input and response
154
+ text = data.get('text', '')
155
+
156
+ # Extract input and reference response
157
+ if "### Input:" in text and "### Response:" in text:
158
+ parts = text.split("### Input:")
159
+ if len(parts) > 1:
160
+ input_part = parts[1].split("### Response:")[0].strip()
161
+ response_part = text.split("### Response:")[1].strip()
162
+
163
+ test_data.append({
164
+ 'input': input_part,
165
+ 'reference': response_part,
166
+ 'score': data.get('score', 0),
167
+ 'topic': data.get('topic', 'Unknown')
168
+ })
169
+ except Exception as e:
170
+ print(f"⚠️ Error parsing line {i}: {e}")
171
+ continue
172
+
173
+ if not test_data:
174
+ print("⚠️ No valid test data found. Creating synthetic data...")
175
+ return self.create_synthetic_test_data()
176
+
177
+ print(f"✅ Loaded {len(test_data)} test examples")
178
+ return test_data
179
+
180
+ def create_synthetic_test_data(self) -> List[Dict]:
181
+ """Create synthetic test data for demonstration"""
182
+ synthetic_data = [
183
+ {
184
+ 'input': '最近ストレスを感じています。',
185
+ 'reference': 'ストレスを感じているのですね。それは大変つらいことだと思います。どのような状況でストレスを感じることが多いですか?',
186
+ 'score': 75,
187
+ 'topic': 'ストレス'
188
+ },
189
+ {
190
+ 'input': '仕事がうまくいかなくて悩んでいます。',
191
+ 'reference': '仕事でお悩みなのですね。うまくいかないと感じると、本当に辛いですよね。具体的にどのような点で困難を感じていらっしゃいますか?',
192
+ 'score': 78,
193
+ 'topic': '仕事'
194
+ },
195
+ {
196
+ 'input': '人間関係で困っています。',
197
+ 'reference': '人間関係の悩みは本当に心が疲れますよね。お気持ちお察しします。どのような関係性でお困りでしょうか?',
198
+ 'score': 80,
199
+ 'topic': '人間関係'
200
+ },
201
+ {
202
+ 'input': '将来が不安です。',
203
+ 'reference': '将来への不安を抱えていらっしゃるのですね。先が見えない不安は、とても重く感じられることと思います。',
204
+ 'score': 72,
205
+ 'topic': '不安'
206
+ },
207
+ {
208
+ 'input': '自信が持てません。',
209
+ 'reference': '自信が持てないというお気持ち、よくわかります。多くの方が同じような悩みを抱えています。',
210
+ 'score': 76,
211
+ 'topic': '自信'
212
+ }
213
+ ]
214
+ return synthetic_data
215
+
216
+ def load_models(self):
217
+ """Load base and fine-tuned models"""
218
+ print("\n🤖 Loading models for benchmarking...")
219
+
220
+ # Load tokenizer
221
+ print(" Loading tokenizer...")
222
+ try:
223
+ self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name)
224
+ except:
225
+ print(" Using GPT2 tokenizer as fallback...")
226
+ self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
227
+
228
+ if self.tokenizer.pad_token is None:
229
+ self.tokenizer.pad_token = self.tokenizer.eos_token
230
+
231
+ # Load base model
232
+ print(" Loading base model...")
233
+ try:
234
+ self.base_model = AutoModelForCausalLM.from_pretrained(
235
+ self.base_model_name,
236
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
237
+ device_map="auto" if self.device == "cuda" else None,
238
+ trust_remote_code=True,
239
+ low_cpu_mem_usage=True
240
+ )
241
+ except Exception as e:
242
+ print(f" ⚠️ Could not load base model {self.base_model_name}: {e}")
243
+ print(" Using GPT2 as fallback base model...")
244
+ self.base_model = AutoModelForCausalLM.from_pretrained(
245
+ "gpt2",
246
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
247
+ device_map="auto" if self.device == "cuda" else None
248
+ )
249
+ self.base_model.eval()
250
+
251
+ # Load fine-tuned model
252
+ print(f" Loading fine-tuned model from {self.finetuned_model_path}...")
253
+
254
+ # Check if model exists
255
+ if not os.path.exists(self.finetuned_model_path):
256
+ print(f" ⚠️ Fine-tuned model not found at {self.finetuned_model_path}")
257
+ print(" Using base model for both comparisons (for demonstration)")
258
+ self.finetuned_model = self.base_model
259
+ else:
260
+ try:
261
+ self.finetuned_model = AutoModelForCausalLM.from_pretrained(
262
+ self.finetuned_model_path,
263
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
264
+ device_map="auto" if self.device == "cuda" else None,
265
+ trust_remote_code=True,
266
+ low_cpu_mem_usage=True,
267
+ local_files_only=True
268
+ )
269
+ self.finetuned_model.eval()
270
+ except Exception as e:
271
+ print(f" ⚠️ Error loading fine-tuned model: {e}")
272
+ print(" Using base model for comparison")
273
+ self.finetuned_model = self.base_model
274
+
275
+ print("✅ Models loaded successfully!")
276
+
277
+ def generate_response(self, model, prompt: str, max_length: int = 150) -> str:
278
+ """
279
+ Generate response from model
280
+
281
+ Args:
282
+ model: Model to use for generation
283
+ prompt: Input prompt
284
+ max_length: Maximum length of generated response
285
+
286
+ Returns:
287
+ Generated response text
288
+ """
289
+ # Format prompt for counseling
290
+ formatted_prompt = f"""### Instruction:
291
+ あなたは思いやりのある心理カウンセラーです。
292
+ クライアントの感情を理解し、共感的で支援的な応答を提供してください。
293
+
294
+ ### Input:
295
+ {prompt}
296
+
297
+ ### Response:
298
+ """
299
+
300
+ # Tokenize input
301
+ inputs = self.tokenizer(
302
+ formatted_prompt,
303
+ return_tensors="pt",
304
+ truncation=True,
305
+ max_length=512
306
+ )
307
+
308
+ if self.device == "cuda":
309
+ inputs = {k: v.cuda() for k, v in inputs.items()}
310
+
311
+ # Generate response
312
+ try:
313
+ with torch.no_grad():
314
+ outputs = model.generate(
315
+ **inputs,
316
+ max_new_tokens=max_length,
317
+ temperature=0.7,
318
+ do_sample=True,
319
+ top_p=0.9,
320
+ repetition_penalty=1.1,
321
+ pad_token_id=self.tokenizer.pad_token_id,
322
+ eos_token_id=self.tokenizer.eos_token_id
323
+ )
324
+
325
+ # Decode response
326
+ full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
327
+
328
+ # Extract only the generated response
329
+ if "### Response:" in full_response:
330
+ response = full_response.split("### Response:")[-1].strip()
331
+ else:
332
+ response = full_response[len(formatted_prompt):].strip()
333
+ except Exception as e:
334
+ print(f" ⚠️ Generation error: {e}")
335
+ response = "申し訳ございません。応答を生成できませんでした。"
336
+
337
+ return response
338
+
339
+ def calculate_bleu_scores(self, reference: str, hypothesis: str) -> Dict[str, float]:
340
+ """
341
+ Calculate BLEU scores using Japanese tokenization
342
+
343
+ Args:
344
+ reference: Reference text
345
+ hypothesis: Generated text
346
+
347
+ Returns:
348
+ Dictionary of BLEU scores
349
+ """
350
+ # Tokenize using MeCab or fallback
351
+ ref_tokens = self.tokenize_japanese(reference)
352
+ hyp_tokens = self.tokenize_japanese(hypothesis)
353
+
354
+ # Ensure we have tokens
355
+ if not ref_tokens:
356
+ ref_tokens = ['empty']
357
+ if not hyp_tokens:
358
+ hyp_tokens = ['empty']
359
+
360
+ # Calculate BLEU scores
361
+ scores = {}
362
+
363
+ try:
364
+ # BLEU-1 through BLEU-4
365
+ for n in range(1, 5):
366
+ weights = tuple([1/n] * n + [0] * (4-n))
367
+ score = sentence_bleu(
368
+ [ref_tokens],
369
+ hyp_tokens,
370
+ weights=weights,
371
+ smoothing_function=self.smoothing
372
+ )
373
+ scores[f'BLEU-{n}'] = score
374
+ except Exception as e:
375
+ print(f" ⚠️ BLEU calculation error: {e}")
376
+ for n in range(1, 5):
377
+ scores[f'BLEU-{n}'] = 0.0
378
+
379
+ return scores
380
+
381
+ def calculate_rouge_scores(self, reference: str, hypothesis: str) -> Dict[str, float]:
382
+ """
383
+ Calculate ROUGE scores for Japanese text
384
+
385
+ Args:
386
+ reference: Reference text
387
+ hypothesis: Generated text
388
+
389
+ Returns:
390
+ Dictionary of ROUGE scores
391
+ """
392
+ try:
393
+ # For Japanese, we need to add spaces between tokens for ROUGE scorer
394
+ if self.mecab:
395
+ ref_tokenized = ' '.join(self.tokenize_japanese(reference))
396
+ hyp_tokenized = ' '.join(self.tokenize_japanese(hypothesis))
397
+ else:
398
+ # Character-level with spaces
399
+ ref_tokenized = ' '.join(list(reference))
400
+ hyp_tokenized = ' '.join(list(hypothesis))
401
+
402
+ # Calculate ROUGE scores
403
+ scores = self.rouge_scorer.score(ref_tokenized, hyp_tokenized)
404
+
405
+ return {
406
+ 'ROUGE-1': scores['rouge1'].fmeasure,
407
+ 'ROUGE-2': scores['rouge2'].fmeasure,
408
+ 'ROUGE-L': scores['rougeL'].fmeasure
409
+ }
410
+ except Exception as e:
411
+ print(f" ⚠️ ROUGE calculation error: {e}")
412
+ return {
413
+ 'ROUGE-1': 0.0,
414
+ 'ROUGE-2': 0.0,
415
+ 'ROUGE-L': 0.0
416
+ }
417
+
418
+ def calculate_bert_score(self, references: List[str], hypotheses: List[str]) -> Dict[str, float]:
419
+ """
420
+ Calculate BERTScore for semantic similarity
421
+
422
+ Args:
423
+ references: List of reference texts
424
+ hypotheses: List of generated texts
425
+
426
+ Returns:
427
+ Dictionary with BERTScore metrics
428
+ """
429
+ try:
430
+ # Calculate BERTScore
431
+ P, R, F1 = bert_score(
432
+ hypotheses,
433
+ references,
434
+ lang='ja',
435
+ verbose=False,
436
+ device=self.device
437
+ )
438
+
439
+ return {
440
+ 'BERTScore_P': float(P.mean()),
441
+ 'BERTScore_R': float(R.mean()),
442
+ 'BERTScore_F1': float(F1.mean())
443
+ }
444
+ except Exception as e:
445
+ print(f" ⚠️ BERTScore calculation failed: {e}")
446
+ print(" Install with: pip install bert-score")
447
+ return {
448
+ 'BERTScore_P': 0.0,
449
+ 'BERTScore_R': 0.0,
450
+ 'BERTScore_F1': 0.0
451
+ }
452
+
453
+ def evaluate_counseling_quality(self, response: str) -> Dict[str, float]:
454
+ """
455
+ Evaluate counseling-specific qualities
456
+ Based on KokoroChat paper evaluation criteria
457
+
458
+ Args:
459
+ response: Generated counseling response
460
+
461
+ Returns:
462
+ Dictionary of counseling quality scores
463
+ """
464
+ scores = {}
465
+
466
+ # 1. Empathy Score (共感度)
467
+ empathy_keywords = [
468
+ 'わかります', '理解', '共感', 'お気持ち', 'つらい',
469
+ '大変', 'お察し', 'そうですね', 'なるほど', '感じ'
470
+ ]
471
+ empathy_score = sum(1 for keyword in empathy_keywords if keyword in response)
472
+ scores['empathy'] = min(empathy_score / 5.0, 1.0) # Normalize to 0-1
473
+
474
+ # 2. Support Score (支援度)
475
+ support_keywords = [
476
+ 'サポート', '支援', '助け', '一緒に', '協力',
477
+ '応援', 'お手伝い', '力になり', '相談', '話を聞'
478
+ ]
479
+ support_score = sum(1 for keyword in support_keywords if keyword in response)
480
+ scores['support'] = min(support_score / 5.0, 1.0)
481
+
482
+ # 3. Active Listening (傾聴)
483
+ listening_indicators = ['?', 'でしょうか', 'ですか', 'いかがですか', 'どのような']
484
+ scores['active_listening'] = 1.0 if any(ind in response for ind in listening_indicators) else 0.3
485
+
486
+ # 4. Positivity (前向きさ)
487
+ positive_keywords = ['大丈夫', '良い', '素晴らしい', '頑張', '希望', '改善', '解決']
488
+ positive_score = sum(1 for keyword in positive_keywords if keyword in response)
489
+ scores['positivity'] = min(positive_score / 3.0, 1.0)
490
+
491
+ # 5. Response Appropriateness (応答の適切さ)
492
+ response_length = len(response)
493
+ if 30 <= response_length <= 200:
494
+ scores['appropriateness'] = 1.0
495
+ elif 20 <= response_length < 30 or 200 < response_length <= 300:
496
+ scores['appropriateness'] = 0.7
497
+ else:
498
+ scores['appropriateness'] = 0.4
499
+
500
+ return scores
501
+
502
+ def run_comprehensive_benchmark(self, num_samples: Optional[int] = None):
503
+ """
504
+ Run comprehensive benchmark evaluation
505
+
506
+ Args:
507
+ num_samples: Number of samples to evaluate (None for all)
508
+ """
509
+ print("\n" + "="*80)
510
+ print("🚀 Running Comprehensive Benchmark")
511
+ print("="*80)
512
+
513
+ # Load test data
514
+ test_data = self.load_test_data(max_samples=num_samples)
515
+
516
+ if not test_data:
517
+ raise ValueError("No test data available!")
518
+
519
+ # Initialize metric collectors
520
+ base_metrics = defaultdict(list)
521
+ finetuned_metrics = defaultdict(list)
522
+
523
+ # Collect all responses for BERTScore
524
+ all_references = []
525
+ all_base_responses = []
526
+ all_finetuned_responses = []
527
+
528
+ print(f"\n📊 Evaluating {len(test_data)} test examples...")
529
+ print("-"*80)
530
+
531
+ # Process each test example
532
+ for i, example in enumerate(tqdm(test_data, desc="Evaluating")):
533
+ input_text = example['input']
534
+ reference = example['reference']
535
+
536
+ # Generate responses
537
+ base_response = self.generate_response(self.base_model, input_text)
538
+ finetuned_response = self.generate_response(self.finetuned_model, input_text)
539
+
540
+ # Collect for BERTScore
541
+ all_references.append(reference)
542
+ all_base_responses.append(base_response)
543
+ all_finetuned_responses.append(finetuned_response)
544
+
545
+ # Calculate BLEU scores
546
+ base_bleu = self.calculate_bleu_scores(reference, base_response)
547
+ finetuned_bleu = self.calculate_bleu_scores(reference, finetuned_response)
548
+
549
+ for key, value in base_bleu.items():
550
+ base_metrics[key].append(value)
551
+ for key, value in finetuned_bleu.items():
552
+ finetuned_metrics[key].append(value)
553
+
554
+ # Calculate ROUGE scores
555
+ base_rouge = self.calculate_rouge_scores(reference, base_response)
556
+ finetuned_rouge = self.calculate_rouge_scores(reference, finetuned_response)
557
+
558
+ for key, value in base_rouge.items():
559
+ base_metrics[key].append(value)
560
+ for key, value in finetuned_rouge.items():
561
+ finetuned_metrics[key].append(value)
562
+
563
+ # Evaluate counseling quality
564
+ base_quality = self.evaluate_counseling_quality(base_response)
565
+ finetuned_quality = self.evaluate_counseling_quality(finetuned_response)
566
+
567
+ for key, value in base_quality.items():
568
+ base_metrics[f'quality_{key}'].append(value)
569
+ for key, value in finetuned_quality.items():
570
+ finetuned_metrics[f'quality_{key}'].append(value)
571
+
572
+ # Store detailed results
573
+ self.detailed_results.append({
574
+ 'input': input_text,
575
+ 'reference': reference,
576
+ 'base_response': base_response,
577
+ 'finetuned_response': finetuned_response,
578
+ 'base_metrics': {**base_bleu, **base_rouge, **base_quality},
579
+ 'finetuned_metrics': {**finetuned_bleu, **finetuned_rouge, **finetuned_quality}
580
+ })
581
+
582
+ # Show sample outputs
583
+ if i < 3:
584
+ print(f"\n📝 Example {i+1}:")
585
+ print(f"Input: {input_text[:100]}...")
586
+ print(f"Base BLEU-4: {base_bleu['BLEU-4']:.3f}, Fine-tuned BLEU-4: {finetuned_bleu['BLEU-4']:.3f}")
587
+
588
+ # Calculate BERTScore for all examples
589
+ if len(all_references) > 0:
590
+ print("\n🧮 Calculating BERTScore...")
591
+ base_bert = self.calculate_bert_score(all_references, all_base_responses)
592
+ finetuned_bert = self.calculate_bert_score(all_references, all_finetuned_responses)
593
+
594
+ for key, value in base_bert.items():
595
+ base_metrics[key] = [value] * len(test_data)
596
+ for key, value in finetuned_bert.items():
597
+ finetuned_metrics[key] = [value] * len(test_data)
598
+
599
+ # Calculate aggregate statistics
600
+ self.results = self.calculate_aggregate_statistics(base_metrics, finetuned_metrics)
601
+
602
+ # Print results
603
+ self.print_results()
604
+
605
+ return self.results
606
+
607
+ def calculate_aggregate_statistics(self, base_metrics: Dict, finetuned_metrics: Dict) -> Dict:
608
+ """
609
+ Calculate aggregate statistics from collected metrics
610
+
611
+ Args:
612
+ base_metrics: Base model metrics
613
+ finetuned_metrics: Fine-tuned model metrics
614
+
615
+ Returns:
616
+ Dictionary of aggregate results
617
+ """
618
+ results = {
619
+ 'metrics': {},
620
+ 'improvements': {},
621
+ 'summary': {}
622
+ }
623
+
624
+ # Calculate statistics for each metric
625
+ all_metric_names = set(base_metrics.keys()) | set(finetuned_metrics.keys())
626
+
627
+ for metric in all_metric_names:
628
+ base_values = base_metrics.get(metric, [0])
629
+ finetuned_values = finetuned_metrics.get(metric, [0])
630
+
631
+ results['metrics'][metric] = {
632
+ 'base': {
633
+ 'mean': float(np.mean(base_values)),
634
+ 'std': float(np.std(base_values)),
635
+ 'min': float(np.min(base_values)),
636
+ 'max': float(np.max(base_values))
637
+ },
638
+ 'finetuned': {
639
+ 'mean': float(np.mean(finetuned_values)),
640
+ 'std': float(np.std(finetuned_values)),
641
+ 'min': float(np.min(finetuned_values)),
642
+ 'max': float(np.max(finetuned_values))
643
+ }
644
+ }
645
+
646
+ # Calculate improvement
647
+ base_mean = np.mean(base_values)
648
+ finetuned_mean = np.mean(finetuned_values)
649
+ if base_mean > 0:
650
+ improvement = ((finetuned_mean - base_mean) / base_mean) * 100
651
+ else:
652
+ improvement = 0
653
+
654
+ results['improvements'][metric] = improvement
655
+
656
+ # Calculate summary statistics
657
+ bleu_metrics = [m for m in results['metrics'] if 'BLEU' in m]
658
+ rouge_metrics = [m for m in results['metrics'] if 'ROUGE' in m]
659
+ quality_metrics = [m for m in results['metrics'] if 'quality' in m]
660
+
661
+ # Average improvements
662
+ results['summary'] = {
663
+ 'bleu_avg_improvement': np.mean([results['improvements'][m] for m in bleu_metrics]) if bleu_metrics else 0,
664
+ 'rouge_avg_improvement': np.mean([results['improvements'][m] for m in rouge_metrics]) if rouge_metrics else 0,
665
+ 'quality_avg_improvement': np.mean([results['improvements'][m] for m in quality_metrics]) if quality_metrics else 0,
666
+ 'overall_improvement': np.mean(list(results['improvements'].values())) if results['improvements'] else 0
667
+ }
668
+
669
+ return results
670
+
671
+ def print_results(self):
672
+ """Print formatted benchmark results"""
673
+ print("\n" + "="*80)
674
+ print("📊 BENCHMARK RESULTS")
675
+ print("="*80)
676
+
677
+ # Group metrics by category
678
+ bleu_metrics = sorted([m for m in self.results['metrics'] if 'BLEU' in m])
679
+ rouge_metrics = sorted([m for m in self.results['metrics'] if 'ROUGE' in m])
680
+ bert_metrics = sorted([m for m in self.results['metrics'] if 'BERT' in m])
681
+ quality_metrics = sorted([m for m in self.results['metrics'] if 'quality' in m])
682
+
683
+ # Print BLEU scores
684
+ if bleu_metrics:
685
+ print("\n📘 BLEU Scores:")
686
+ print("-"*60)
687
+ print(f"{'Metric':<15} {'Base Model':<20} {'Fine-tuned':<20} {'Improvement':<15}")
688
+ print("-"*60)
689
+ for metric in bleu_metrics:
690
+ base = self.results['metrics'][metric]['base']['mean']
691
+ finetuned = self.results['metrics'][metric]['finetuned']['mean']
692
+ improvement = self.results['improvements'][metric]
693
+ print(f"{metric:<15} {base:.4f}±{self.results['metrics'][metric]['base']['std']:.3f} "
694
+ f"{finetuned:.4f}±{self.results['metrics'][metric]['finetuned']['std']:.3f} "
695
+ f"{improvement:+.1f}%")
696
+
697
+ # Print ROUGE scores
698
+ if rouge_metrics:
699
+ print("\n📕 ROUGE Scores:")
700
+ print("-"*60)
701
+ for metric in rouge_metrics:
702
+ base = self.results['metrics'][metric]['base']['mean']
703
+ finetuned = self.results['metrics'][metric]['finetuned']['mean']
704
+ improvement = self.results['improvements'][metric]
705
+ print(f"{metric:<15} {base:.4f}±{self.results['metrics'][metric]['base']['std']:.3f} "
706
+ f"{finetuned:.4f}±{self.results['metrics'][metric]['finetuned']['std']:.3f} "
707
+ f"{improvement:+.1f}%")
708
+
709
+ # Print BERTScore
710
+ if bert_metrics:
711
+ print("\n📗 BERTScore:")
712
+ print("-"*60)
713
+ for metric in bert_metrics:
714
+ base = self.results['metrics'][metric]['base']['mean']
715
+ finetuned = self.results['metrics'][metric]['finetuned']['mean']
716
+ improvement = self.results['improvements'][metric]
717
+ print(f"{metric:<15} {base:.4f} {finetuned:.4f} {improvement:+.1f}%")
718
+
719
+ # Print Counseling Quality scores
720
+ if quality_metrics:
721
+ print("\n💬 Counseling Quality Metrics:")
722
+ print("-"*60)
723
+ for metric in quality_metrics:
724
+ base = self.results['metrics'][metric]['base']['mean']
725
+ finetuned = self.results['metrics'][metric]['finetuned']['mean']
726
+ improvement = self.results['improvements'][metric]
727
+ metric_name = metric.replace('quality_', '').capitalize()
728
+ print(f"{metric_name:<15} {base:.4f}±{self.results['metrics'][metric]['base']['std']:.3f} "
729
+ f"{finetuned:.4f}±{self.results['metrics'][metric]['finetuned']['std']:.3f} "
730
+ f"{improvement:+.1f}%")
731
+
732
+ # Print summary
733
+ print("\n" + "="*80)
734
+ print("📈 SUMMARY")
735
+ print("="*80)
736
+ print(f"Average BLEU Improvement: {self.results['summary']['bleu_avg_improvement']:+.1f}%")
737
+ print(f"Average ROUGE Improvement: {self.results['summary']['rouge_avg_improvement']:+.1f}%")
738
+ print(f"Average Quality Improvement: {self.results['summary']['quality_avg_improvement']:+.1f}%")
739
+ print(f"Overall Improvement: {self.results['summary']['overall_improvement']:+.1f}%")
740
+ print("="*80)
741
+
742
+ def save_results(self, output_dir: str = "./benchmark_results"):
743
+ """Save all benchmark results"""
744
+ os.makedirs(output_dir, exist_ok=True)
745
+
746
+ # Save detailed results
747
+ with open(os.path.join(output_dir, "detailed_results.json"), 'w', encoding='utf-8') as f:
748
+ json.dump(self.detailed_results, f, ensure_ascii=False, indent=2, default=str)
749
+
750
+ # Save aggregate results
751
+ with open(os.path.join(output_dir, "aggregate_results.json"), 'w', encoding='utf-8') as f:
752
+ json.dump(self.results, f, ensure_ascii=False, indent=2, default=str)
753
+
754
+ print(f"✅ Results saved to {output_dir}/")
755
+
756
+
757
+ def main():
758
+ """Main execution function"""
759
+ import argparse
760
+
761
+ parser = argparse.ArgumentParser(description='Japanese Counseling Model Benchmark')
762
+ parser.add_argument('--base_model', type=str, default='LiquidAI/LFM2-1.2B',
763
+ help='Base model name or path')
764
+ parser.add_argument('--finetuned_model', type=str, default='./merged_counselor_model',
765
+ help='Path to fine-tuned merged model')
766
+ parser.add_argument('--test_data', type=str, default='./processed_data_score70/test.jsonl',
767
+ help='Path to test data')
768
+ parser.add_argument('--num_samples', type=int, default=None,
769
+ help='Number of samples to evaluate (None for all)')
770
+ parser.add_argument('--output_dir', type=str, default='./benchmark_results',
771
+ help='Directory to save results')
772
+
773
+ args = parser.parse_args()
774
+
775
+ try:
776
+ # Initialize benchmark
777
+ print("🎌 Initializing Japanese Counseling Benchmark Suite")
778
+ benchmark = JapaneseCounselingBenchmark(
779
+ base_model_name=args.base_model,
780
+ finetuned_model_path=args.finetuned_model,
781
+ test_data_path=args.test_data
782
+ )
783
+
784
+ # Load models
785
+ benchmark.load_models()
786
+
787
+ # Run benchmark
788
+ results = benchmark.run_comprehensive_benchmark(num_samples=args.num_samples)
789
+
790
+ # Save results
791
+ benchmark.save_results(args.output_dir)
792
+
793
+ print("\n✅ Benchmark completed successfully!")
794
+ print(f"📁 Results saved to {args.output_dir}/")
795
+
796
+ except Exception as e:
797
+ print(f"\n❌ Error during benchmarking: {e}")
798
+ import traceback
799
+ traceback.print_exc()
800
+
801
+
802
+ if __name__ == "__main__":
803
+ main()
benchmarking_v2.py ADDED
@@ -0,0 +1,782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fixed Optimized Japanese Counseling Model Benchmark with proper DataParallel handling
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn.parallel import DataParallel
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ import numpy as np
11
+ from typing import List, Dict, Tuple, Optional, Any
12
+ import json
13
+ from tqdm import tqdm
14
+ import os
15
+ import gc
16
+ import warnings
17
+ from datetime import datetime
18
+ import pandas as pd
19
+ from collections import defaultdict
20
+ import MeCab
21
+ from rouge_score import rouge_scorer
22
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
23
+ import re
24
+ import wandb
25
+ from concurrent.futures import ThreadPoolExecutor
26
+ import time
27
+
28
+ # Suppress warnings
29
+ warnings.filterwarnings('ignore')
30
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
31
+
32
+ # Suppress Pydantic warnings
33
+ import logging
34
+ logging.getLogger('pydantic').setLevel(logging.ERROR)
35
+
36
+ class TestDataset(Dataset):
37
+ """Custom dataset for efficient batch processing"""
38
+
39
+ def __init__(self, data: List[Dict]):
40
+ self.data = data
41
+
42
+ def __len__(self):
43
+ return len(self.data)
44
+
45
+ def __getitem__(self, idx):
46
+ return self.data[idx]
47
+
48
+ def custom_collate_fn(batch):
49
+ """Custom collate function to handle dictionary data properly"""
50
+ return batch
51
+
52
+ class OptimizedJapaneseBenchmark:
53
+ """
54
+ Highly optimized benchmark suite with multi-GPU support and WandB logging
55
+ """
56
+
57
+ def __init__(self,
58
+ base_model_name: str = "LiquidAI/LFM2-1.2B",
59
+ finetuned_model_path: str = "./merged_counselor_model",
60
+ test_data_path: str = "./processed_data_score80/test.jsonl",
61
+ batch_size: int = 16, # Reduced for stability
62
+ num_workers: int = 0,
63
+ use_wandb: bool = True):
64
+ """
65
+ Initialize optimized benchmark with multi-GPU support
66
+ """
67
+ self.base_model_name = base_model_name
68
+ self.finetuned_model_path = finetuned_model_path
69
+ self.test_data_path = test_data_path
70
+ self.batch_size = batch_size
71
+ self.num_workers = num_workers
72
+
73
+ # Setup devices
74
+ self.setup_devices()
75
+
76
+ # Initialize WandB
77
+ if use_wandb:
78
+ self.init_wandb()
79
+ else:
80
+ self.wandb_enabled = False
81
+
82
+ # Initialize tokenizers and scorers
83
+ self.setup_tokenizers_and_scorers()
84
+
85
+ # Results storage
86
+ self.results = {}
87
+ self.detailed_results = []
88
+
89
+ def setup_devices(self):
90
+ """Setup multi-GPU configuration"""
91
+ if torch.cuda.is_available():
92
+ self.num_gpus = torch.cuda.device_count()
93
+ print(f"🚀 Found {self.num_gpus} GPUs")
94
+
95
+ self.device_ids = list(range(self.num_gpus))
96
+ self.device = torch.device("cuda:0")
97
+
98
+ for i in range(self.num_gpus):
99
+ print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
100
+ print(f" Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")
101
+ else:
102
+ self.num_gpus = 0
103
+ self.device = torch.device("cpu")
104
+ print("⚠️ No GPU found, using CPU")
105
+
106
+ def init_wandb(self):
107
+ """Initialize WandB for experiment tracking"""
108
+ try:
109
+ run_name = f"benchmark-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
110
+
111
+ wandb.init(
112
+ project="japanese-counseling-benchmark",
113
+ name=run_name,
114
+ config={
115
+ "base_model": self.base_model_name,
116
+ "finetuned_model": self.finetuned_model_path,
117
+ "batch_size": self.batch_size,
118
+ "num_gpus": self.num_gpus,
119
+ "timestamp": datetime.now().isoformat()
120
+ },
121
+ tags=["benchmark", "japanese", "counseling", "multi-gpu"]
122
+ )
123
+
124
+ self.wandb_enabled = True
125
+ print(f"✅ WandB initialized: {wandb.run.name}")
126
+ print(f"📊 View at: {wandb.run.get_url()}")
127
+ except Exception as e:
128
+ print(f"⚠️ WandB initialization failed: {e}")
129
+ self.wandb_enabled = False
130
+
131
+ def setup_tokenizers_and_scorers(self):
132
+ """Setup tokenizers and scoring functions"""
133
+ # Initialize MeCab for Japanese tokenization
134
+ try:
135
+ self.mecab = MeCab.Tagger("-Owakati")
136
+ print("✅ MeCab initialized")
137
+ except:
138
+ print("⚠️ MeCab not available, using character tokenization")
139
+ self.mecab = None
140
+
141
+ # Initialize ROUGE scorer
142
+ self.rouge_scorer = rouge_scorer.RougeScorer(
143
+ ['rouge1', 'rouge2', 'rougeL'],
144
+ use_stemmer=False
145
+ )
146
+
147
+ # BLEU smoothing
148
+ self.smoothing = SmoothingFunction().method1
149
+
150
+ def load_test_data_fast(self, max_samples: Optional[int] = None) -> List[Dict]:
151
+ """Fast loading of test data"""
152
+ print(f"\n📚 Loading test data from {self.test_data_path}")
153
+
154
+ test_data = []
155
+
156
+ if not os.path.exists(self.test_data_path):
157
+ print("⚠️ Test data not found, using synthetic data")
158
+ return self.create_synthetic_test_data()
159
+
160
+ try:
161
+ with open(self.test_data_path, 'r', encoding='utf-8') as f:
162
+ lines = f.readlines()
163
+
164
+ if max_samples:
165
+ lines = lines[:max_samples]
166
+
167
+ for line in tqdm(lines, desc="Loading data"):
168
+ try:
169
+ data = json.loads(line)
170
+ text = data.get('text', '')
171
+
172
+ if "### Input:" in text and "### Response:" in text:
173
+ input_part = text.split("### Input:")[1].split("### Response:")[0].strip()
174
+ response_part = text.split("### Response:")[1].strip()
175
+
176
+ test_data.append({
177
+ 'input': input_part,
178
+ 'reference': response_part,
179
+ 'score': data.get('score', 0),
180
+ 'topic': data.get('topic', 'Unknown')
181
+ })
182
+ except:
183
+ continue
184
+
185
+ except Exception as e:
186
+ print(f"Error loading data: {e}")
187
+ return self.create_synthetic_test_data()
188
+
189
+ if not test_data:
190
+ print("⚠️ No valid data found, using synthetic data")
191
+ return self.create_synthetic_test_data()
192
+
193
+ print(f"✅ Loaded {len(test_data)} test examples")
194
+
195
+ if self.wandb_enabled:
196
+ wandb.log({"test_data_size": len(test_data)})
197
+
198
+ return test_data
199
+
200
+ def create_synthetic_test_data(self) -> List[Dict]:
201
+ """Create synthetic test data"""
202
+ return [
203
+ {
204
+ 'input': f'ストレスを感じています。',
205
+ 'reference': f'お気持ちわかります。どのような状況でストレスを感じていますか?',
206
+ 'score': 75,
207
+ 'topic': 'stress'
208
+ }
209
+ for i in range(10)
210
+ ]
211
+
212
+ def load_models_optimized(self):
213
+ """Load models with optimization for multi-GPU"""
214
+ print("\n🤖 Loading models with optimization...")
215
+
216
+ # Load tokenizer
217
+ print(" Loading tokenizer...")
218
+ try:
219
+ self.tokenizer = AutoTokenizer.from_pretrained(
220
+ self.base_model_name,
221
+ use_fast=True
222
+ )
223
+ except:
224
+ self.tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
225
+
226
+ if self.tokenizer.pad_token is None:
227
+ self.tokenizer.pad_token = self.tokenizer.eos_token
228
+
229
+ # Load base model
230
+ print(" Loading base model...")
231
+ try:
232
+ base_model = AutoModelForCausalLM.from_pretrained(
233
+ self.base_model_name,
234
+ torch_dtype=torch.float16,
235
+ trust_remote_code=True,
236
+ low_cpu_mem_usage=True
237
+ )
238
+ except Exception as e:
239
+ print(f" Error loading base model: {e}")
240
+ print(" Using GPT2 as fallback...")
241
+ base_model = AutoModelForCausalLM.from_pretrained(
242
+ "gpt2",
243
+ torch_dtype=torch.float16
244
+ )
245
+
246
+ # Load fine-tuned model
247
+ print(" Loading fine-tuned model...")
248
+ if os.path.exists(self.finetuned_model_path):
249
+ try:
250
+ finetuned_model = AutoModelForCausalLM.from_pretrained(
251
+ self.finetuned_model_path,
252
+ torch_dtype=torch.float16,
253
+ trust_remote_code=True,
254
+ low_cpu_mem_usage=True,
255
+ local_files_only=True
256
+ )
257
+ except Exception as e:
258
+ print(f" Error loading fine-tuned model: {e}")
259
+ finetuned_model = base_model
260
+ else:
261
+ print(" Fine-tuned model not found, using base model")
262
+ finetuned_model = base_model
263
+
264
+ # Move models to GPU
265
+ base_model = base_model.to(self.device)
266
+ finetuned_model = finetuned_model.to(self.device)
267
+
268
+ # Setup for multi-GPU if available
269
+ if self.num_gpus > 1:
270
+ print(f" Setting up DataParallel for {self.num_gpus} GPUs...")
271
+ self.base_model = DataParallel(base_model, device_ids=self.device_ids)
272
+ self.finetuned_model = DataParallel(finetuned_model, device_ids=self.device_ids)
273
+ else:
274
+ self.base_model = base_model
275
+ self.finetuned_model = finetuned_model
276
+
277
+ self.base_model.eval()
278
+ self.finetuned_model.eval()
279
+
280
+ print("✅ Models loaded and optimized!")
281
+
282
+ if self.wandb_enabled:
283
+ wandb.log({
284
+ "model_loaded": True,
285
+ "num_gpus_used": self.num_gpus
286
+ })
287
+
288
+ def generate_batch_responses(self, model, prompts: List[str], max_length: int = 150) -> List[str]:
289
+ """Generate responses in batch for efficiency"""
290
+ if len(prompts) == 0:
291
+ return []
292
+
293
+ formatted_prompts = [
294
+ f"""### Instruction:
295
+ あなたは思いやりのある心理カウンセラーです。
296
+
297
+ ### Input:
298
+ {prompt}
299
+
300
+ ### Response:
301
+ """ for prompt in prompts
302
+ ]
303
+
304
+ try:
305
+ # Tokenize all prompts at once
306
+ inputs = self.tokenizer(
307
+ formatted_prompts,
308
+ return_tensors="pt",
309
+ truncation=True,
310
+ max_length=512,
311
+ padding=True,
312
+ padding_side= 'left'
313
+ )
314
+
315
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
316
+
317
+ # Get the actual model from DataParallel if needed
318
+ actual_model = model.module if isinstance(model, DataParallel) else model
319
+
320
+ # Generate in batch
321
+ with torch.no_grad():
322
+ with torch.cuda.amp.autocast():
323
+ outputs = actual_model.generate(
324
+ **inputs,
325
+ max_new_tokens=max_length,
326
+ temperature=0.7,
327
+ do_sample=True,
328
+ top_p=0.9,
329
+ num_beams=1,
330
+ pad_token_id=self.tokenizer.pad_token_id,
331
+ eos_token_id=self.tokenizer.eos_token_id
332
+ )
333
+
334
+ # Decode all at once
335
+ responses = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
336
+
337
+ # Extract only generated parts
338
+ extracted_responses = []
339
+ for i, response in enumerate(responses):
340
+ if "### Response:" in response:
341
+ extracted = response.split("### Response:")[-1].strip()
342
+ else:
343
+ extracted = response[len(formatted_prompts[i]):].strip()
344
+ extracted_responses.append(extracted if extracted else "応答を生成できませんでした。")
345
+
346
+ return extracted_responses
347
+
348
+ except Exception as e:
349
+ print(f"Error in batch generation: {e}")
350
+ # Return default responses
351
+ return ["申し訳ございません。応答を生成できませんでした。"] * len(prompts)
352
+
353
+ def tokenize_japanese(self, text: str) -> List[str]:
354
+ """Tokenize Japanese text"""
355
+ if not text:
356
+ return ['empty']
357
+
358
+ if self.mecab:
359
+ try:
360
+ tokens = self.mecab.parse(text).strip().split()
361
+ return tokens if tokens else list(text)
362
+ except:
363
+ pass
364
+
365
+ # Fallback to character tokenization
366
+ return list(text.replace(' ', ''))
367
+
368
+ def calculate_metrics_batch(self, references: List[str], hypotheses: List[str]) -> Dict:
369
+ """Calculate all metrics in batch"""
370
+ metrics = defaultdict(list)
371
+
372
+ for ref, hyp in zip(references, hypotheses):
373
+ if not ref or not hyp:
374
+ # Add default scores for empty strings
375
+ for n in range(1, 5):
376
+ metrics[f'BLEU-{n}'].append(0.0)
377
+ metrics['ROUGE-1'].append(0.0)
378
+ metrics['ROUGE-2'].append(0.0)
379
+ metrics['ROUGE-L'].append(0.0)
380
+ continue
381
+
382
+ try:
383
+ # Tokenize
384
+ ref_tokens = self.tokenize_japanese(ref)
385
+ hyp_tokens = self.tokenize_japanese(hyp)
386
+
387
+ # BLEU scores
388
+ for n in range(1, 5):
389
+ weights = tuple([1/n] * n + [0] * (4-n))
390
+ try:
391
+ score = sentence_bleu(
392
+ [ref_tokens],
393
+ hyp_tokens,
394
+ weights=weights,
395
+ smoothing_function=self.smoothing
396
+ )
397
+ metrics[f'BLEU-{n}'].append(score)
398
+ except:
399
+ metrics[f'BLEU-{n}'].append(0.0)
400
+
401
+ # ROUGE scores
402
+ try:
403
+ ref_spaced = ' '.join(ref_tokens)
404
+ hyp_spaced = ' '.join(hyp_tokens)
405
+ rouge_scores = self.rouge_scorer.score(ref_spaced, hyp_spaced)
406
+ metrics['ROUGE-1'].append(rouge_scores['rouge1'].fmeasure)
407
+ metrics['ROUGE-2'].append(rouge_scores['rouge2'].fmeasure)
408
+ metrics['ROUGE-L'].append(rouge_scores['rougeL'].fmeasure)
409
+ except:
410
+ metrics['ROUGE-1'].append(0.0)
411
+ metrics['ROUGE-2'].append(0.0)
412
+ metrics['ROUGE-L'].append(0.0)
413
+
414
+ except Exception as e:
415
+ # Add zeros for failed calculations
416
+ for n in range(1, 5):
417
+ metrics[f'BLEU-{n}'].append(0.0)
418
+ metrics['ROUGE-1'].append(0.0)
419
+ metrics['ROUGE-2'].append(0.0)
420
+ metrics['ROUGE-L'].append(0.0)
421
+
422
+ return dict(metrics)
423
+
424
+ def run_fast_benchmark(self, num_samples: Optional[int] = None):
425
+ """Run optimized benchmark with batch processing"""
426
+ print("\n" + "="*80)
427
+ print("🚀 Running Fast Multi-GPU Benchmark")
428
+ print("="*80)
429
+
430
+ start_time = time.time()
431
+
432
+ # Load test data
433
+ test_data = self.load_test_data_fast(max_samples=num_samples)
434
+
435
+ if not test_data:
436
+ raise ValueError("No test data available!")
437
+
438
+ # Create DataLoader
439
+ dataset = TestDataset(test_data)
440
+ dataloader = DataLoader(
441
+ dataset,
442
+ batch_size=self.batch_size,
443
+ shuffle=False,
444
+ num_workers=0,
445
+ collate_fn=custom_collate_fn,
446
+ pin_memory=True if self.device.type == 'cuda' else False
447
+ )
448
+
449
+ # Initialize metric collectors
450
+ all_base_metrics = defaultdict(list)
451
+ all_finetuned_metrics = defaultdict(list)
452
+
453
+ print(f"\n📊 Evaluating {len(test_data)} examples in {len(dataloader)} batches...")
454
+ print(f" Batch size: {self.batch_size}")
455
+ print(f" Using {self.num_gpus} GPU(s)")
456
+
457
+ # Process batches
458
+ successful_batches = 0
459
+ for batch_idx, batch in enumerate(tqdm(dataloader, desc="Processing batches")):
460
+ try:
461
+ # Extract batch data
462
+ inputs = [item['input'] for item in batch]
463
+ references = [item['reference'] for item in batch]
464
+
465
+ # Generate responses in batch
466
+ base_responses = self.generate_batch_responses(self.base_model, inputs)
467
+ finetuned_responses = self.generate_batch_responses(self.finetuned_model, inputs)
468
+
469
+ # Calculate metrics in batch
470
+ base_metrics = self.calculate_metrics_batch(references, base_responses)
471
+ finetuned_metrics = self.calculate_metrics_batch(references, finetuned_responses)
472
+
473
+ # Aggregate metrics
474
+ for key, values in base_metrics.items():
475
+ all_base_metrics[key].extend(values)
476
+ for key, values in finetuned_metrics.items():
477
+ all_finetuned_metrics[key].extend(values)
478
+
479
+ successful_batches += 1
480
+
481
+ # Log progress to WandB
482
+ if self.wandb_enabled and batch_idx % 5 == 0:
483
+ progress = (batch_idx + 1) / len(dataloader) * 100
484
+
485
+ # Calculate current averages
486
+ current_bleu4_base = np.mean(all_base_metrics.get('BLEU-4', [0]))
487
+ current_bleu4_finetuned = np.mean(all_finetuned_metrics.get('BLEU-4', [0]))
488
+ current_rouge_l_base = np.mean(all_base_metrics.get('ROUGE-L', [0]))
489
+ current_rouge_l_finetuned = np.mean(all_finetuned_metrics.get('ROUGE-L', [0]))
490
+
491
+ wandb.log({
492
+ "progress": progress,
493
+ "batches_processed": batch_idx + 1,
494
+ "samples_processed": min((batch_idx + 1) * self.batch_size, len(test_data)),
495
+ "current_bleu4_base": current_bleu4_base,
496
+ "current_bleu4_finetuned": current_bleu4_finetuned,
497
+ "current_rouge_l_base": current_rouge_l_base,
498
+ "current_rouge_l_finetuned": current_rouge_l_finetuned
499
+ })
500
+
501
+ # Store examples for analysis
502
+ if batch_idx == 0 and len(inputs) > 0:
503
+ for i in range(min(3, len(inputs))):
504
+ self.detailed_results.append({
505
+ 'input': inputs[i],
506
+ 'reference': references[i],
507
+ 'base_response': base_responses[i] if i < len(base_responses) else "",
508
+ 'finetuned_response': finetuned_responses[i] if i < len(finetuned_responses) else ""
509
+ })
510
+
511
+ # Print sample
512
+ print(f"\n📝 Sample Example:")
513
+ print(f"Input: {inputs[0][:100]}...")
514
+ print(f"Reference: {references[0][:100]}...")
515
+ print(f"Base response: {base_responses[0][:100]}...")
516
+ print(f"Fine-tuned response: {finetuned_responses[0][:100]}...")
517
+
518
+ except Exception as e:
519
+ print(f"Error processing batch {batch_idx}: {e}")
520
+ continue
521
+
522
+ print(f"\n✅ Successfully processed {successful_batches}/{len(dataloader)} batches")
523
+
524
+ # Calculate final statistics
525
+ self.results = self.calculate_final_statistics(all_base_metrics, all_finetuned_metrics)
526
+
527
+ # Calculate processing time
528
+ total_time = time.time() - start_time
529
+ samples_per_second = len(test_data) / total_time if total_time > 0 else 0
530
+
531
+ print(f"\n⏱️ Benchmark completed in {total_time:.2f} seconds")
532
+ print(f" Processing speed: {samples_per_second:.2f} samples/second")
533
+
534
+ # Log final results to WandB
535
+ if self.wandb_enabled:
536
+ wandb.log({
537
+ "total_time_seconds": total_time,
538
+ "samples_per_second": samples_per_second,
539
+ "total_samples": len(test_data),
540
+ "successful_batches": successful_batches,
541
+ **{f"final_{k}": v for k, v in self.results['summary'].items()}
542
+ })
543
+
544
+ # Log detailed metrics
545
+ for metric_name, improvements in self.results['improvements'].items():
546
+ wandb.log({f"improvement_{metric_name}": improvements})
547
+
548
+ # Create visualization
549
+ if self.results['metrics']:
550
+ self.create_wandb_visualizations()
551
+
552
+ # Print results
553
+ self.print_results()
554
+
555
+ return self.results
556
+
557
+ def create_wandb_visualizations(self):
558
+ """Create WandB visualizations"""
559
+ if not self.wandb_enabled or not self.results.get('metrics'):
560
+ return
561
+
562
+ try:
563
+ # Create comparison table
564
+ data = []
565
+ for metric in self.results['metrics']:
566
+ data.append([
567
+ metric,
568
+ self.results['metrics'][metric]['base']['mean'],
569
+ self.results['metrics'][metric]['finetuned']['mean'],
570
+ self.results['improvements'][metric]
571
+ ])
572
+
573
+ table = wandb.Table(
574
+ columns=["Metric", "Base", "Fine-tuned", "Improvement (%)"],
575
+ data=data
576
+ )
577
+ wandb.log({"results_comparison": table})
578
+
579
+ # Log bar chart of improvements
580
+ wandb.log({
581
+ "improvements_chart": wandb.plot.bar(
582
+ wandb.Table(
583
+ data=[[m, self.results['improvements'][m]] for m in self.results['improvements']],
584
+ columns=["Metric", "Improvement (%)"]
585
+ ),
586
+ "Metric", "Improvement (%)",
587
+ title="Model Improvements"
588
+ )
589
+ })
590
+ except Exception as e:
591
+ print(f"Error creating visualizations: {e}")
592
+
593
+ def calculate_final_statistics(self, base_metrics: Dict, finetuned_metrics: Dict) -> Dict:
594
+ """Calculate final aggregate statistics"""
595
+ results = {
596
+ 'metrics': {},
597
+ 'improvements': {},
598
+ 'summary': {}
599
+ }
600
+
601
+ # Calculate statistics for each metric
602
+ all_metric_names = set(base_metrics.keys()) | set(finetuned_metrics.keys())
603
+
604
+ for metric in all_metric_names:
605
+ base_values = base_metrics.get(metric, [0])
606
+ finetuned_values = finetuned_metrics.get(metric, [0])
607
+
608
+ # Filter out any None values
609
+ base_values = [v for v in base_values if v is not None]
610
+ finetuned_values = [v for v in finetuned_values if v is not None]
611
+
612
+ if not base_values:
613
+ base_values = [0]
614
+ if not finetuned_values:
615
+ finetuned_values = [0]
616
+
617
+ results['metrics'][metric] = {
618
+ 'base': {
619
+ 'mean': float(np.mean(base_values)),
620
+ 'std': float(np.std(base_values)),
621
+ 'min': float(np.min(base_values)),
622
+ 'max': float(np.max(base_values))
623
+ },
624
+ 'finetuned': {
625
+ 'mean': float(np.mean(finetuned_values)),
626
+ 'std': float(np.std(finetuned_values)),
627
+ 'min': float(np.min(finetuned_values)),
628
+ 'max': float(np.max(finetuned_values))
629
+ }
630
+ }
631
+
632
+ # Calculate improvement
633
+ base_mean = np.mean(base_values)
634
+ finetuned_mean = np.mean(finetuned_values)
635
+ if base_mean > 0:
636
+ improvement = ((finetuned_mean - base_mean) / base_mean) * 100
637
+ else:
638
+ improvement = 0 if finetuned_mean == 0 else 100
639
+
640
+ results['improvements'][metric] = improvement
641
+
642
+ # Calculate summary statistics
643
+ bleu_metrics = [m for m in results['metrics'] if 'BLEU' in m]
644
+ rouge_metrics = [m for m in results['metrics'] if 'ROUGE' in m]
645
+
646
+ results['summary'] = {
647
+ 'bleu_avg_improvement': np.mean([results['improvements'][m] for m in bleu_metrics]) if bleu_metrics else 0,
648
+ 'rouge_avg_improvement': np.mean([results['improvements'][m] for m in rouge_metrics]) if rouge_metrics else 0,
649
+ 'overall_improvement': np.mean(list(results['improvements'].values())) if results['improvements'] else 0
650
+ }
651
+
652
+ return results
653
+
654
+ def print_results(self):
655
+ """Print formatted results"""
656
+ print("\n" + "="*80)
657
+ print("📊 BENCHMARK RESULTS")
658
+ print("="*80)
659
+
660
+ if not self.results or 'metrics' not in self.results:
661
+ print("No results to display")
662
+ return
663
+
664
+ # BLEU scores
665
+ print("\n📘 BLEU Scores:")
666
+ print("-"*60)
667
+ print(f"{'Metric':<15} {'Base':<15} {'Fine-tuned':<15} {'Improvement':<15}")
668
+ print("-"*60)
669
+
670
+ for metric in sorted([m for m in self.results['metrics'] if 'BLEU' in m]):
671
+ base = self.results['metrics'][metric]['base']['mean']
672
+ finetuned = self.results['metrics'][metric]['finetuned']['mean']
673
+ improvement = self.results['improvements'][metric]
674
+ print(f"{metric:<15} {base:.4f} {finetuned:.4f} {improvement:+.1f}%")
675
+
676
+ # ROUGE scores
677
+ print("\n📕 ROUGE Scores:")
678
+ print("-"*60)
679
+
680
+ for metric in sorted([m for m in self.results['metrics'] if 'ROUGE' in m]):
681
+ base = self.results['metrics'][metric]['base']['mean']
682
+ finetuned = self.results['metrics'][metric]['finetuned']['mean']
683
+ improvement = self.results['improvements'][metric]
684
+ print(f"{metric:<15} {base:.4f} {finetuned:.4f} {improvement:+.1f}%")
685
+
686
+ # Summary
687
+ print("\n" + "="*80)
688
+ print("📈 SUMMARY")
689
+ print("="*80)
690
+ print(f"BLEU Average Improvement: {self.results['summary']['bleu_avg_improvement']:+.1f}%")
691
+ print(f"ROUGE Average Improvement: {self.results['summary']['rouge_avg_improvement']:+.1f}%")
692
+ print(f"Overall Improvement: {self.results['summary']['overall_improvement']:+.1f}%")
693
+ print("="*80)
694
+
695
+ def save_results(self, output_dir: str = "./benchmark_results"):
696
+ """Save results"""
697
+ os.makedirs(output_dir, exist_ok=True)
698
+
699
+ # Save results
700
+ with open(os.path.join(output_dir, "results.json"), 'w', encoding='utf-8') as f:
701
+ json.dump(self.results, f, ensure_ascii=False, indent=2, default=str)
702
+
703
+ with open(os.path.join(output_dir, "examples.json"), 'w', encoding='utf-8') as f:
704
+ json.dump(self.detailed_results, f, ensure_ascii=False, indent=2)
705
+
706
+ # Save to WandB
707
+ if self.wandb_enabled:
708
+ try:
709
+ artifact = wandb.Artifact(
710
+ name=f"benchmark-results-{wandb.run.id}",
711
+ type="benchmark_results",
712
+ description="Japanese counseling model benchmark results"
713
+ )
714
+ artifact.add_dir(output_dir)
715
+ wandb.log_artifact(artifact)
716
+ except Exception as e:
717
+ print(f"Error saving to WandB: {e}")
718
+
719
+ print(f"✅ Results saved to {output_dir}/")
720
+
721
+ def cleanup(self):
722
+ """Clean up resources"""
723
+ if self.wandb_enabled:
724
+ wandb.finish()
725
+
726
+ if torch.cuda.is_available():
727
+ torch.cuda.empty_cache()
728
+
729
+ gc.collect()
730
+
731
+
732
+ def main():
733
+ """Main execution"""
734
+ import argparse
735
+
736
+ parser = argparse.ArgumentParser(description='Optimized Japanese Counseling Benchmark')
737
+ parser.add_argument('--base_model', type=str, default='LiquidAI/LFM2-1.2B')
738
+ parser.add_argument('--finetuned_model', type=str, default='./merged_counselor_model')
739
+ parser.add_argument('--test_data', type=str, default='./processed_data_score80/test.jsonl')
740
+ parser.add_argument('--batch_size', type=int, default=16, help='Batch size for processing')
741
+ parser.add_argument('--num_samples', type=int, default=None, help='Number of samples to evaluate')
742
+ parser.add_argument('--output_dir', type=str, default='./benchmark_results_fast')
743
+ parser.add_argument('--no_wandb', action='store_true', help='Disable WandB logging')
744
+
745
+ args = parser.parse_args()
746
+
747
+ try:
748
+ # Initialize benchmark
749
+ print("🚀 Initializing Optimized Benchmark Suite")
750
+ benchmark = OptimizedJapaneseBenchmark(
751
+ base_model_name=args.base_model,
752
+ finetuned_model_path=args.finetuned_model,
753
+ test_data_path=args.test_data,
754
+ batch_size=args.batch_size,
755
+ use_wandb=not args.no_wandb
756
+ )
757
+
758
+ # Load models
759
+ benchmark.load_models_optimized()
760
+
761
+ # Run benchmark
762
+ results = benchmark.run_fast_benchmark(num_samples=args.num_samples)
763
+
764
+ # Save results
765
+ benchmark.save_results(args.output_dir)
766
+
767
+ # Cleanup
768
+ benchmark.cleanup()
769
+
770
+ print("\n✅ Benchmark completed successfully!")
771
+
772
+ except Exception as e:
773
+ print(f"\n❌ Error: {e}")
774
+ import traceback
775
+ traceback.print_exc()
776
+
777
+ if 'benchmark' in locals():
778
+ benchmark.cleanup()
779
+
780
+
781
+ if __name__ == "__main__":
782
+ main()
chat.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Interactive Chat Interface for Testing Fine-tuned Japanese Counseling Model
3
+ """
4
+
5
+ import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ import os
8
+ import warnings
9
+ from datetime import datetime
10
+ import json
11
+
12
+ warnings.filterwarnings('ignore')
13
+
14
+ class CounselorChatInterface:
15
+ def __init__(self, model_path: str = "./merged_counselor_model"):
16
+ """
17
+ Initialize the chat interface with the fine-tuned model
18
+
19
+ Args:
20
+ model_path: Path to the fine-tuned model
21
+ """
22
+ self.model_path = model_path
23
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+
25
+ print("="*80)
26
+ print("🎌 Japanese Counseling Model Chat Interface")
27
+ print("="*80)
28
+ print(f"📍 Device: {self.device}")
29
+
30
+ if self.device.type == "cuda":
31
+ print(f" GPU: {torch.cuda.get_device_name(0)}")
32
+ print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
33
+
34
+ self.load_model()
35
+ self.conversation_history = []
36
+
37
+ def load_model(self):
38
+ """Load the fine-tuned model and tokenizer"""
39
+ print(f"\n🤖 Loading model from {self.model_path}...")
40
+
41
+ try:
42
+ # Load tokenizer
43
+ self.tokenizer = AutoTokenizer.from_pretrained(
44
+ self.model_path,
45
+ local_files_only=True
46
+ )
47
+
48
+ # Set padding token if not set
49
+ if self.tokenizer.pad_token is None:
50
+ self.tokenizer.pad_token = self.tokenizer.eos_token
51
+
52
+ # Load model
53
+ self.model = AutoModelForCausalLM.from_pretrained(
54
+ self.model_path,
55
+ torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
56
+ device_map="auto" if self.device.type == "cuda" else None,
57
+ local_files_only=True,
58
+ trust_remote_code=True
59
+ )
60
+
61
+ self.model.eval()
62
+ print("✅ Model loaded successfully!")
63
+
64
+ except Exception as e:
65
+ print(f"❌ Error loading model: {e}")
66
+ print("Trying alternative loading method...")
67
+
68
+ # Try loading with base tokenizer
69
+ try:
70
+ self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
71
+ if self.tokenizer.pad_token is None:
72
+ self.tokenizer.pad_token = self.tokenizer.eos_token
73
+
74
+ self.model = AutoModelForCausalLM.from_pretrained(
75
+ self.model_path,
76
+ torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
77
+ local_files_only=True
78
+ )
79
+ self.model = self.model.to(self.device)
80
+ self.model.eval()
81
+ print("✅ Model loaded with fallback tokenizer!")
82
+ except Exception as e2:
83
+ print(f"❌ Failed to load model: {e2}")
84
+ raise
85
+
86
+ def generate_response(self, user_input: str,
87
+ temperature: float = 0,
88
+ max_length: int = 200,
89
+ use_context: bool = True) -> str:
90
+ """
91
+ Generate a counseling response
92
+
93
+ Args:
94
+ user_input: User's message
95
+ temperature: Generation temperature (0.1-1.0)
96
+ max_length: Maximum response length
97
+ use_context: Whether to use conversation history
98
+
99
+ Returns:
100
+ Generated response
101
+ """
102
+ # Format the prompt
103
+ if use_context and len(self.conversation_history) > 0:
104
+ # Include recent context
105
+ context = "\n".join(self.conversation_history[-4:]) # Last 2 exchanges
106
+ prompt = f"""### Instruction:
107
+ あなたは思いやりのある心理カウンセラーです。
108
+ クライアントの感情を理解し、共感的で支援的な応答を提供してください。
109
+
110
+ ### Context:
111
+ {context}
112
+
113
+ ### Input:
114
+ {user_input}
115
+
116
+ ### Response:
117
+ """
118
+ else:
119
+ prompt = f"""### Instruction:
120
+ あなたは思いやりのある心理カウンセラーです。
121
+ クライアントの感情を理解し、共感的で支援的な応答を提供してください。
122
+
123
+ ### Input:
124
+ {user_input}
125
+
126
+ ### Response:
127
+ """
128
+
129
+ # Tokenize
130
+ inputs = self.tokenizer(
131
+ prompt,
132
+ return_tensors="pt",
133
+ truncation=True,
134
+ max_length=512
135
+ )
136
+
137
+ if self.device.type == "cuda":
138
+ inputs = {k: v.cuda() for k, v in inputs.items()}
139
+
140
+ # Generate
141
+ try:
142
+ with torch.no_grad():
143
+ with torch.cuda.amp.autocast() if self.device.type == "cuda" else torch.autocast("cpu"):
144
+ outputs = self.model.generate(
145
+ **inputs,
146
+ max_new_tokens=max_length,
147
+ temperature=temperature,
148
+ do_sample=True,
149
+ top_p=0.9,
150
+ top_k=50,
151
+ repetition_penalty=1.1,
152
+ pad_token_id=self.tokenizer.pad_token_id,
153
+ eos_token_id=self.tokenizer.eos_token_id
154
+ )
155
+
156
+ # Decode
157
+ full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
158
+
159
+ # Extract only the response part
160
+ if "### Response:" in full_response:
161
+ response = full_response.split("### Response:")[-1].strip()
162
+ else:
163
+ response = full_response[len(prompt):].strip()
164
+
165
+ return response
166
+
167
+ except Exception as e:
168
+ print(f"Error generating response: {e}")
169
+ return "申し訳ございません。応答の生成中にエラーが発生しました。"
170
+
171
+ def chat(self):
172
+ """Start interactive chat session"""
173
+ print("\n" + "="*80)
174
+ print("💬 チャットを開始します (Chat session started)")
175
+ print("="*80)
176
+ print("Commands:")
177
+ print(" /quit or /exit - 終了 (Exit)")
178
+ print(" /clear - 会話履歴をクリア (Clear conversation history)")
179
+ print(" /save - 会話を保存 (Save conversation)")
180
+ print(" /temp <value> - 温度パラメータを設定 (Set temperature, e.g., /temp 0.8)")
181
+ print(" /context on/off - コンテキスト使用の切り替え (Toggle context usage)")
182
+ print("-"*80)
183
+
184
+ temperature = 0.1
185
+ use_context = True
186
+
187
+ while True:
188
+ try:
189
+ # Get user input
190
+ user_input = input("\n👤 You: ").strip()
191
+
192
+ # Check for commands
193
+ if user_input.lower() in ['/quit', '/exit', '/q']:
194
+ print("\n👋 さようなら!(Goodbye!)")
195
+ break
196
+
197
+ elif user_input.lower() == '/clear':
198
+ self.conversation_history = []
199
+ print("✅ 会話履歴をクリアしました (Conversation history cleared)")
200
+ continue
201
+
202
+ elif user_input.lower() == '/save':
203
+ self.save_conversation()
204
+ continue
205
+
206
+ elif user_input.lower().startswith('/temp'):
207
+ try:
208
+ temperature = float(user_input.split()[1])
209
+ temperature = 0.1 # max(0.1, min(, temperature))
210
+ print(f"✅ Temperature set to {temperature}")
211
+ except:
212
+ print("❌ Invalid temperature. Use: /temp 0.7")
213
+ continue
214
+
215
+ elif user_input.lower().startswith('/context'):
216
+ try:
217
+ setting = user_input.split()[1].lower()
218
+ use_context = setting == 'on'
219
+ print(f"✅ Context {'enabled' if use_context else 'disabled'}")
220
+ except:
221
+ print("❌ Use: /context on or /context off")
222
+ continue
223
+
224
+ elif user_input.startswith('/'):
225
+ print("❌ Unknown command")
226
+ continue
227
+
228
+ # Generate response
229
+ print("\n🤖 Counselor: ", end="", flush=True)
230
+ response = self.generate_response(
231
+ user_input,
232
+ temperature=temperature,
233
+ use_context=use_context
234
+ )
235
+ print(response)
236
+
237
+ # Add to history
238
+ self.conversation_history.append(f"Client: {user_input}")
239
+ self.conversation_history.append(f"Counselor: {response}")
240
+
241
+ except KeyboardInterrupt:
242
+ print("\n\n👋 さようなら!(Goodbye!)")
243
+ break
244
+ except Exception as e:
245
+ print(f"\n❌ Error: {e}")
246
+ continue
247
+
248
+ def save_conversation(self):
249
+ """Save the conversation to a file"""
250
+ if not self.conversation_history:
251
+ print("❌ No conversation to save")
252
+ return
253
+
254
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
255
+ filename = f"conversation_{timestamp}.json"
256
+
257
+ conversation_data = {
258
+ "timestamp": timestamp,
259
+ "model_path": self.model_path,
260
+ "conversation": self.conversation_history
261
+ }
262
+
263
+ with open(filename, 'w', encoding='utf-8') as f:
264
+ json.dump(conversation_data, f, ensure_ascii=False, indent=2)
265
+
266
+ print(f"✅ Conversation saved to {filename}")
267
+
268
+ def test_responses(self):
269
+ """Test the model with predefined inputs"""
270
+ print("\n" + "="*80)
271
+ print("🧪 Testing Model Responses")
272
+ print("="*80)
273
+
274
+ test_inputs = [
275
+ "こんにちは。最近ストレスを感じています。",
276
+ "仕事がうまくいかなくて悩んでいます。",
277
+ "人間関係で困っています。どうすればいいでしょうか。",
278
+ "将来が不安で眠れません。",
279
+ "自分に自信が持てません。",
280
+ "家族との関係で悩んでいます。",
281
+ "毎日が辛いです。",
282
+ "誰にも相談できません。"
283
+ ]
284
+
285
+ print("\nTesting with different temperature settings:\n")
286
+
287
+ for temp in [0, 0.1]:
288
+ print(f"\n🌡️ Temperature: {temp}")
289
+ print("-"*60)
290
+
291
+ for i, test_input in enumerate(test_inputs[:3], 1):
292
+ print(f"\n{i}. Input: {test_input}")
293
+ response = self.generate_response(test_input, temperature=temp, use_context=False)
294
+ print(f" Response: {response[:200]}...")
295
+ print()
296
+
297
+ print("="*80)
298
+
299
+
300
+ def main():
301
+ """Main function"""
302
+ import argparse
303
+
304
+ parser = argparse.ArgumentParser(description='Chat with fine-tuned counseling model')
305
+ parser.add_argument('--model_path', type=str, default='./merged_counselor_mode_2b',
306
+ help='Path to the fine-tuned model')
307
+ parser.add_argument('--test_only', action='store_true',
308
+ help='Only run test responses without chat')
309
+
310
+ args = parser.parse_args()
311
+
312
+ # Check if model exists
313
+ if not os.path.exists(args.model_path):
314
+ print(f"❌ Model not found at {args.model_path}")
315
+ print("\nAvailable models:")
316
+ for item in os.listdir('.'):
317
+ if 'model' in item.lower() and os.path.isdir(item):
318
+ print(f" - {item}")
319
+ return
320
+
321
+ try:
322
+ # Initialize chat interface
323
+ chat = CounselorChatInterface(model_path=args.model_path)
324
+
325
+ if args.test_only:
326
+ # Run tests only
327
+ chat.test_responses()
328
+ else:
329
+ # Start interactive chat
330
+ chat.chat()
331
+
332
+ except Exception as e:
333
+ print(f"❌ Error: {e}")
334
+ import traceback
335
+ traceback.print_exc()
336
+
337
+
338
+ if __name__ == "__main__":
339
+ main()
data_preprocessor.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+ import pandas as pd
5
+ from typing import List, Dict, Tuple, Optional
6
+ import random
7
+ from tqdm import tqdm
8
+ import re
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+
12
+ class KokoroChatPreprocessor:
13
+ def __init__(self, data_path: str, max_length: int = 2048, min_score: int = 60):
14
+ """
15
+ Initialize the preprocessor for KokoroChat dataset
16
+
17
+ Args:
18
+ data_path: Path to KokoroChat repository
19
+ max_length: Maximum sequence length for model input
20
+ min_score: Minimum score threshold for filtering conversations (default: 60)
21
+ """
22
+ self.data_path = Path(data_path)
23
+ self.max_length = max_length
24
+ self.min_score = min_score
25
+ self.conversations = []
26
+ self.score_distribution = [] # Track score distribution
27
+ self.system_prompt = """あなたは思いやりのある心理カウンセラーです。
28
+ クライアントの感情を理解し、共感的で支援的な応答を提供してください。
29
+ プライバシーを尊重し、判断を下さず、希望と実用的な洞察を提供することに焦点を当ててください。"""
30
+
31
+ def load_json_files(self) -> List[Dict]:
32
+ """Load all JSON files from the dataset"""
33
+ json_files = []
34
+ # Changed from "data" to "kokorochat_dialogues"
35
+ data_dir = self.data_path / "kokorochat_dialogues"
36
+
37
+ # Check if data directory exists, if not try root directory
38
+ if not data_dir.exists():
39
+ data_dir = self.data_path
40
+ print(f"Using root directory: {data_dir}")
41
+ else:
42
+ print(f"Using data directory: {data_dir}")
43
+
44
+ for root, dirs, files in os.walk(data_dir):
45
+ for file in tqdm(files, desc="Loading JSON files"):
46
+ if file.endswith('.json'):
47
+ file_path = os.path.join(root, file)
48
+ try:
49
+ with open(file_path, 'r', encoding='utf-8') as f:
50
+ data = json.load(f)
51
+ json_files.append(data)
52
+ except Exception as e:
53
+ print(f"Error loading {file_path}: {e}")
54
+
55
+ return json_files
56
+
57
+ def analyze_score_distribution(self, json_files: List[Dict]) -> Dict:
58
+ """
59
+ Analyze the distribution of scores in the dataset
60
+
61
+ Returns:
62
+ Dictionary with score statistics
63
+ """
64
+ scores = []
65
+ for data in json_files:
66
+ if 'review_by_client_jp' in data:
67
+ score = data['review_by_client_jp'].get('点数', 0)
68
+ if score > 0: # Only count valid scores
69
+ scores.append(score)
70
+ self.score_distribution.append(score)
71
+
72
+ if scores:
73
+ stats = {
74
+ 'total_conversations': len(json_files),
75
+ 'conversations_with_scores': len(scores),
76
+ 'mean_score': float(np.mean(scores)),
77
+ 'median_score': float(np.median(scores)),
78
+ 'std_score': float(np.std(scores)),
79
+ 'min_score': float(np.min(scores)),
80
+ 'max_score': float(np.max(scores)),
81
+ 'percentiles': {
82
+ '25th': float(np.percentile(scores, 25)),
83
+ '50th': float(np.percentile(scores, 50)),
84
+ '75th': float(np.percentile(scores, 75)),
85
+ '90th': float(np.percentile(scores, 90))
86
+ },
87
+ 'score_ranges': {
88
+ '0-30': int(sum(1 for s in scores if 0 <= s < 30)),
89
+ '30-50': int(sum(1 for s in scores if 30 <= s < 50)),
90
+ '50-60': int(sum(1 for s in scores if 50 <= s < 60)),
91
+ '60-70': int(sum(1 for s in scores if 60 <= s < 70)),
92
+ '70-80': int(sum(1 for s in scores if 70 <= s < 80)),
93
+ '80-90': int(sum(1 for s in scores if 80 <= s < 90)),
94
+ '90-100': int(sum(1 for s in scores if 90 <= s <= 100)),
95
+ }
96
+ }
97
+
98
+ # Calculate how many conversations would be kept at different thresholds
99
+ threshold_analysis = {}
100
+ for threshold in [30, 40, 50, 60, 65, 70, 75, 80]:
101
+ kept = sum(1 for s in scores if s >= threshold)
102
+ threshold_analysis[f'threshold_{threshold}'] = {
103
+ 'conversations_kept': kept,
104
+ 'percentage_kept': round((kept / len(scores)) * 100, 2)
105
+ }
106
+ stats['threshold_analysis'] = threshold_analysis
107
+
108
+ return stats
109
+ else:
110
+ return {'error': 'No valid scores found in dataset'}
111
+
112
+ def plot_score_distribution(self, save_path: str = "score_distribution.png"):
113
+ """
114
+ Plot the distribution of scores
115
+ """
116
+ if not self.score_distribution:
117
+ print("No scores to plot. Run analyze_score_distribution first.")
118
+ return
119
+
120
+ fig, axes = plt.subplots(2, 2, figsize=(15, 10))
121
+
122
+ # Histogram
123
+ axes[0, 0].hist(self.score_distribution, bins=20, edgecolor='black', alpha=0.7)
124
+ axes[0, 0].axvline(self.min_score, color='red', linestyle='--',
125
+ label=f'Current threshold: {self.min_score}')
126
+ axes[0, 0].set_xlabel('Score')
127
+ axes[0, 0].set_ylabel('Frequency')
128
+ axes[0, 0].set_title('Score Distribution')
129
+ axes[0, 0].legend()
130
+ axes[0, 0].grid(True, alpha=0.3)
131
+
132
+ # Box plot
133
+ axes[0, 1].boxplot(self.score_distribution, vert=True)
134
+ axes[0, 1].set_ylabel('Score')
135
+ axes[0, 1].set_title('Score Box Plot')
136
+ axes[0, 1].grid(True, alpha=0.3)
137
+
138
+ # Cumulative distribution
139
+ sorted_scores = np.sort(self.score_distribution)
140
+ cumulative = np.arange(1, len(sorted_scores) + 1) / len(sorted_scores)
141
+ axes[1, 0].plot(sorted_scores, cumulative)
142
+ axes[1, 0].axvline(self.min_score, color='red', linestyle='--',
143
+ label=f'Current threshold: {self.min_score}')
144
+ axes[1, 0].set_xlabel('Score')
145
+ axes[1, 0].set_ylabel('Cumulative Probability')
146
+ axes[1, 0].set_title('Cumulative Distribution')
147
+ axes[1, 0].legend()
148
+ axes[1, 0].grid(True, alpha=0.3)
149
+
150
+ # Threshold impact analysis
151
+ thresholds = range(30, 90, 5)
152
+ kept_percentages = []
153
+ for t in thresholds:
154
+ kept = sum(1 for s in self.score_distribution if s >= t)
155
+ kept_percentages.append((kept / len(self.score_distribution)) * 100)
156
+
157
+ axes[1, 1].plot(thresholds, kept_percentages, marker='o')
158
+ axes[1, 1].axvline(self.min_score, color='red', linestyle='--',
159
+ label=f'Current threshold: {self.min_score}')
160
+ axes[1, 1].set_xlabel('Score Threshold')
161
+ axes[1, 1].set_ylabel('% of Conversations Kept')
162
+ axes[1, 1].set_title('Impact of Score Threshold')
163
+ axes[1, 1].legend()
164
+ axes[1, 1].grid(True, alpha=0.3)
165
+
166
+ plt.tight_layout()
167
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
168
+ plt.show()
169
+ print(f"Score distribution plot saved to {save_path}")
170
+
171
+ def extract_high_quality_conversations(self, data: Dict) -> List[Dict]:
172
+ """
173
+ Extract conversations with high counselor ratings based on min_score
174
+ Focus on conversations where counselor performed well
175
+ """
176
+ conversations = []
177
+
178
+ # Check if review exists and has good score
179
+ if 'review_by_client_jp' in data:
180
+ review = data['review_by_client_jp']
181
+ score = review.get('点数', 0)
182
+
183
+ # Use configurable min_score threshold
184
+ if score >= self.min_score:
185
+ dialogue = data.get('dialogue', [])
186
+
187
+ # Create conversation pairs
188
+ conversation_text = ""
189
+ for turn in dialogue:
190
+ role = turn['role']
191
+ utterance = turn['utterance']
192
+
193
+ if role == 'counselor':
194
+ conversation_text += f"カウンセラー: {utterance}\n"
195
+ else:
196
+ conversation_text += f"クライアント: {utterance}\n"
197
+
198
+ # Extract detailed metrics for potential weighted training
199
+ conversations.append({
200
+ 'text': conversation_text,
201
+ 'score': score, # Store the score here
202
+ 'topic': data.get('topic', {}).get('main_jp', 'Unknown'),
203
+ 'review_metrics': {
204
+ 'empathy': review.get('聴いてもらえた、わかってもらえたと感じた', 0),
205
+ 'respect': review.get('尊重されたと感じた', 0),
206
+ 'insights': review.get('新しい気づきや体験があった', 0),
207
+ 'hope': review.get('希望や期待を感じられた', 0),
208
+ 'concerns_addressed': review.get('取り組みたかったことを扱えた', 0),
209
+ 'collaboration': review.get('一緒に考えながら取り組めた', 0),
210
+ 'rhythm': review.get('やりとりのリズムがあっていた', 0),
211
+ 'comfort': review.get('居心地のよいやりとりだった', 0),
212
+ 'overall_appropriate': review.get('全体として適切でよかった', 0),
213
+ 'valuable': review.get('今回の相談は価値があった', 0),
214
+ 'smooth_start': review.get('相談開始の円滑さ', 0),
215
+ 'good_ending': review.get('相談終了のタイミング(不必要に聴きすぎていないか)、円滑さ', 0),
216
+ 'acceptance_empathy': review.get('受容·共感', 0),
217
+ 'affirmation': review.get('肯定·承認', 0),
218
+ 'effective_questions': review.get('的確な質問による会話の促進', 0),
219
+ 'summarization': review.get('要約', 0),
220
+ 'problem_clarification': review.get('問題の明確化', 0),
221
+ 'goal_clarification': review.get('この相談での目標の明確化', 0),
222
+ 'actionable_suggestions': review.get('次の行動につながる提案', 0),
223
+ 'encouragement': review.get('勇気づけ·希望の喚起', 0)
224
+ }
225
+ })
226
+
227
+ return conversations
228
+
229
+ def create_training_examples(self, conversations: List[Dict],
230
+ use_weighted_sampling: bool = False) -> List[Dict]:
231
+ """
232
+ Create training examples in instruction-following format
233
+
234
+ Args:
235
+ conversations: List of conversation dictionaries
236
+ use_weighted_sampling: If True, create more examples from higher-scored conversations
237
+ """
238
+ training_examples = []
239
+
240
+ for conv in tqdm(conversations, desc="Creating training examples"):
241
+ dialogue_lines = conv['text'].split('\n')
242
+ score = conv['score'] # Get score from the conversation dict
243
+
244
+ # Calculate sampling weight based on score if enabled
245
+ if use_weighted_sampling:
246
+ # Higher scores get more weight (normalized to 1-3 range)
247
+ weight = max(1, int((score - self.min_score) / 20) + 1)
248
+ else:
249
+ weight = 1
250
+
251
+ # Create multiple training examples from each conversation
252
+ for _ in range(weight): # Repeat based on weight
253
+ for i in range(0, len(dialogue_lines) - 1, 2):
254
+ if i + 1 < len(dialogue_lines):
255
+ client_line = dialogue_lines[i]
256
+ counselor_line = dialogue_lines[i + 1]
257
+
258
+ # Check if lines contain the expected prefixes
259
+ if 'クライアント:' in client_line and 'カウンセラー:' in counselor_line:
260
+ client_msg = client_line.replace('クライアント: ', '').replace('クライアント:', '').strip()
261
+ counselor_msg = counselor_line.replace('カウンセラー: ', '').replace('カウンセラー:', '').strip()
262
+
263
+ # Skip empty messages
264
+ if not client_msg or not counselor_msg:
265
+ continue
266
+
267
+ # Format for instruction tuning
268
+ example = {
269
+ 'instruction': self.system_prompt,
270
+ 'input': client_msg,
271
+ 'output': counselor_msg,
272
+ 'score': score, # Use the score from conversation
273
+ 'topic': conv['topic'],
274
+ 'metrics': conv['review_metrics'] # Include detailed metrics
275
+ }
276
+
277
+ training_examples.append(example)
278
+
279
+ return training_examples
280
+
281
+ def prepare_dataset(self, test_size: float = 0.1, val_size: float = 0.1,
282
+ use_weighted_sampling: bool = False,
283
+ analyze_scores: bool = True):
284
+ """
285
+ Prepare train, validation, and test datasets
286
+
287
+ Args:
288
+ test_size: Proportion of data for testing
289
+ val_size: Proportion of data for validation
290
+ use_weighted_sampling: If True, oversample high-quality conversations
291
+ analyze_scores: If True, print score distribution analysis
292
+ """
293
+ print("Loading KokoroChat dataset...")
294
+ json_files = self.load_json_files()
295
+ print(f"Loaded {len(json_files)} conversation files")
296
+
297
+ # Analyze score distribution if requested
298
+ if analyze_scores:
299
+ print("\n" + "="*60)
300
+ print("SCORE DISTRIBUTION ANALYSIS")
301
+ print("="*60)
302
+ stats = self.analyze_score_distribution(json_files)
303
+
304
+ if 'error' not in stats:
305
+ print(f"Total conversations: {stats['total_conversations']}")
306
+ print(f"Conversations with scores: {stats['conversations_with_scores']}")
307
+ print(f"\nScore Statistics:")
308
+ print(f" Mean: {stats['mean_score']:.2f}")
309
+ print(f" Median: {stats['median_score']:.2f}")
310
+ print(f" Std Dev: {stats['std_score']:.2f}")
311
+ print(f" Range: {stats['min_score']:.0f} - {stats['max_score']:.0f}")
312
+
313
+ print(f"\nScore Distribution:")
314
+ for range_name, count in stats['score_ranges'].items():
315
+ percentage = (count / stats['conversations_with_scores']) * 100
316
+ print(f" {range_name}: {count} ({percentage:.1f}%)")
317
+
318
+ print(f"\nThreshold Impact Analysis:")
319
+ for threshold_name, data in stats['threshold_analysis'].items():
320
+ threshold = threshold_name.split('_')[1]
321
+ print(f" Threshold >= {threshold}: {data['conversations_kept']} conversations ({data['percentage_kept']:.1f}%)")
322
+
323
+ print(f"\nCurrent threshold ({self.min_score}) will keep: ", end="")
324
+ kept = sum(1 for s in self.score_distribution if s >= self.min_score)
325
+ print(f"{kept} conversations ({(kept/len(self.score_distribution))*100:.1f}%)")
326
+ print("="*60 + "\n")
327
+
328
+ # Plot distribution
329
+ self.plot_score_distribution()
330
+
331
+ all_conversations = []
332
+ filtered_count = 0
333
+ total_count = 0
334
+
335
+ for data in json_files:
336
+ if 'review_by_client_jp' in data:
337
+ total_count += 1
338
+ score = data['review_by_client_jp'].get('点数', 0)
339
+ if score < self.min_score:
340
+ filtered_count += 1
341
+
342
+ conversations = self.extract_high_quality_conversations(data)
343
+ all_conversations.extend(conversations)
344
+
345
+ print(f"Filtered out {filtered_count} conversations with score < {self.min_score}")
346
+ print(f"Extracted {len(all_conversations)} high-quality conversations (score >= {self.min_score})")
347
+
348
+ # Create training examples
349
+ training_examples = self.create_training_examples(
350
+ all_conversations,
351
+ use_weighted_sampling=use_weighted_sampling
352
+ )
353
+ print(f"Created {len(training_examples)} training examples")
354
+
355
+ if use_weighted_sampling:
356
+ print("Note: Used weighted sampling - higher scored conversations appear more frequently")
357
+
358
+ # Shuffle and split
359
+ random.shuffle(training_examples)
360
+
361
+ total_size = len(training_examples)
362
+ test_split = int(total_size * test_size)
363
+ val_split = int(total_size * val_size)
364
+
365
+ test_data = training_examples[:test_split]
366
+ val_data = training_examples[test_split:test_split + val_split]
367
+ train_data = training_examples[test_split + val_split:]
368
+
369
+ print(f"\nDataset splits:")
370
+ print(f" Train: {len(train_data)} examples")
371
+ print(f" Validation: {len(val_data)} examples")
372
+ print(f" Test: {len(test_data)} examples")
373
+
374
+ return {
375
+ 'train': train_data,
376
+ 'validation': val_data,
377
+ 'test': test_data
378
+ }
379
+
380
+ def format_for_lfm(self, example: Dict) -> str:
381
+ """
382
+ Format example for LFM model training
383
+ """
384
+ formatted = f"""### Instruction:
385
+ {example['instruction']}
386
+
387
+ ### Input:
388
+ {example['input']}
389
+
390
+ ### Response:
391
+ {example['output']}"""
392
+ return formatted
393
+
394
+ def save_datasets(self, datasets: Dict, output_dir: str):
395
+ """Save processed datasets with proper type conversion for JSON serialization"""
396
+ output_path = Path(output_dir)
397
+ output_path.mkdir(parents=True, exist_ok=True)
398
+
399
+ # Helper function to convert numpy types to Python native types
400
+ def convert_to_native(obj):
401
+ if isinstance(obj, np.integer):
402
+ return int(obj)
403
+ elif isinstance(obj, np.floating):
404
+ return float(obj)
405
+ elif isinstance(obj, np.ndarray):
406
+ return obj.tolist()
407
+ else:
408
+ return obj
409
+
410
+ # Save dataset statistics
411
+ stats = {
412
+ 'min_score_threshold': int(self.min_score),
413
+ 'dataset_sizes': {
414
+ 'train': len(datasets['train']),
415
+ 'validation': len(datasets['validation']),
416
+ 'test': len(datasets['test'])
417
+ },
418
+ 'score_distribution': {}
419
+ }
420
+
421
+ for split_name, data in datasets.items():
422
+ # Calculate score distribution for this split
423
+ scores = [ex['score'] for ex in data]
424
+ if scores:
425
+ stats['score_distribution'][split_name] = {
426
+ 'mean': float(np.mean(scores)),
427
+ 'median': float(np.median(scores)),
428
+ 'min': float(np.min(scores)),
429
+ 'max': float(np.max(scores)),
430
+ 'std': float(np.std(scores))
431
+ }
432
+
433
+ # Save as JSONL for easier streaming
434
+ file_path = output_path / f"{split_name}.jsonl"
435
+ with open(file_path, 'w', encoding='utf-8') as f:
436
+ for example in data:
437
+ formatted_text = self.format_for_lfm(example)
438
+ # Convert all numpy types to native Python types
439
+ json_obj = {
440
+ 'text': formatted_text,
441
+ 'score': convert_to_native(example['score']),
442
+ 'topic': example['topic']
443
+ }
444
+ json_line = json.dumps(json_obj, ensure_ascii=False)
445
+ f.write(json_line + '\n')
446
+
447
+ print(f"Saved {split_name} dataset with {len(data)} examples to {file_path}")
448
+
449
+ # Save statistics
450
+ stats_path = output_path / "dataset_stats.json"
451
+ with open(stats_path, 'w', encoding='utf-8') as f:
452
+ json.dump(stats, f, ensure_ascii=False, indent=2)
453
+ print(f"Saved dataset statistics to {stats_path}")
454
+
455
+ # Print summary statistics
456
+ print("\n" + "="*60)
457
+ print("DATASET SUMMARY")
458
+ print("="*60)
459
+ print(f"Minimum score threshold: {stats['min_score_threshold']}")
460
+ print("\nDataset sizes:")
461
+ for split, size in stats['dataset_sizes'].items():
462
+ print(f" {split}: {size} examples")
463
+
464
+ print("\nScore distributions by split:")
465
+ for split, dist in stats['score_distribution'].items():
466
+ print(f" {split}:")
467
+ print(f" Mean: {dist['mean']:.2f}")
468
+ print(f" Std: {dist['std']:.2f}")
469
+ print(f" Range: {dist['min']:.0f} - {dist['max']:.0f}")
470
+ print("="*60)
471
+
472
+ # Run preprocessing with different score thresholds
473
+ if __name__ == "__main__":
474
+ import argparse
475
+
476
+ parser = argparse.ArgumentParser(description='Preprocess KokoroChat dataset')
477
+ parser.add_argument('--data_path', type=str, default='./KokoroChat',
478
+ help='Path to KokoroChat repository')
479
+ parser.add_argument('--min_score', type=int, default=70,
480
+ help='Minimum score threshold for filtering (default: 70)')
481
+ parser.add_argument('--output_dir', type=str, default='./processed_data',
482
+ help='Output directory for processed data')
483
+ parser.add_argument('--weighted_sampling', action='store_true',
484
+ help='Use weighted sampling based on scores')
485
+ parser.add_argument('--test_size', type=float, default=0.1,
486
+ help='Test set size (default: 0.1)')
487
+ parser.add_argument('--val_size', type=float, default=0.1,
488
+ help='Validation set size (default: 0.1)')
489
+ parser.add_argument('--analyze_only', action='store_true',
490
+ help='Only analyze score distribution without processing')
491
+
492
+ args = parser.parse_args()
493
+
494
+ # Initialize preprocessor with configurable min_score
495
+ preprocessor = KokoroChatPreprocessor(
496
+ data_path=args.data_path,
497
+ min_score=args.min_score
498
+ )
499
+
500
+ if args.analyze_only:
501
+ # Just analyze the score distribution
502
+ print("Running score distribution analysis only...")
503
+ json_files = preprocessor.load_json_files()
504
+ stats = preprocessor.analyze_score_distribution(json_files)
505
+ preprocessor.plot_score_distribution(f"score_analysis_threshold_{args.min_score}.png")
506
+ else:
507
+ # Full preprocessing
508
+ print(f"Processing with minimum score threshold: {args.min_score}")
509
+ datasets = preprocessor.prepare_dataset(
510
+ test_size=args.test_size,
511
+ val_size=args.val_size,
512
+ use_weighted_sampling=args.weighted_sampling,
513
+ analyze_scores=True
514
+ )
515
+
516
+ # Save with threshold in directory name
517
+ output_dir = f"{args.output_dir}_score{args.min_score}"
518
+ preprocessor.save_datasets(datasets, output_dir)
519
+
520
+ print(f"\nProcessing complete! Data saved to {output_dir}")
521
+ print("\nNext steps:")
522
+ print("1. Run fine-tuning: python finetune_lfm.py")
523
+ print("2. Run benchmarking: python benchmark_model.py")
524
+ print("3. Optimize for mobile: python optimize_for_mobile.py")
finalmerged_model.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73fdb12fa8819f3d5160ec5414e55e827d08d1d69874a4168035b7f0c9fb02a4
3
+ size 1806737356
finetune_lfm.py ADDED
@@ -0,0 +1,1311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import torch
2
+ # from transformers import (
3
+ # AutoModelForCausalLM,
4
+ # AutoTokenizer,
5
+ # TrainingArguments,
6
+ # Trainer,
7
+ # DataCollatorForLanguageModeling,
8
+ # BitsAndBytesConfig
9
+ # )
10
+ # from peft import (
11
+ # LoraConfig,
12
+ # get_peft_model,
13
+ # prepare_model_for_kbit_training,
14
+ # TaskType
15
+ # )
16
+ # from datasets import load_dataset, Dataset
17
+ # import os
18
+ # from typing import Dict, List, Optional
19
+ # import numpy as np
20
+ # from tqdm import tqdm
21
+ # import json
22
+ # import gc
23
+ # import warnings
24
+ # warnings.filterwarnings('ignore')
25
+
26
+ # class LFMCounselorFineTuner:
27
+ # def __init__(self, model_name: str = "LiquidAI/LFM2-2.6B", use_4bit: bool = True):
28
+ # """
29
+ # Initialize the fine-tuner for LFM models
30
+
31
+ # Args:
32
+ # model_name: Name of the base model
33
+ # use_4bit: Whether to use 4-bit quantization for memory efficiency
34
+ # """
35
+ # self.model_name = model_name
36
+ # self.use_4bit = use_4bit
37
+ # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+
39
+ # print(f"Using device: {self.device}")
40
+ # if torch.cuda.is_available():
41
+ # print(f"GPU: {torch.cuda.get_device_name(0)}")
42
+ # print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
43
+
44
+ # # Disable wandb for simplicity
45
+ # os.environ["WANDB_DISABLED"] = "true"
46
+
47
+ # def setup_model_and_tokenizer(self):
48
+ # """Setup model with quantization and LoRA"""
49
+
50
+ # print("Loading tokenizer...")
51
+ # # Tokenizer setup
52
+ # try:
53
+ # self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
54
+ # except:
55
+ # # Fallback to a known working tokenizer if model-specific one fails
56
+ # print("Using fallback tokenizer...")
57
+ # self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
58
+
59
+ # # Add padding token if it doesn't exist
60
+ # if self.tokenizer.pad_token is None:
61
+ # self.tokenizer.pad_token = self.tokenizer.eos_token
62
+ # if self.tokenizer.eos_token is None:
63
+ # self.tokenizer.eos_token = "</s>"
64
+ # self.tokenizer.pad_token = "</s>"
65
+
66
+ # self.tokenizer.padding_side = "right"
67
+
68
+ # # Quantization config for memory efficiency
69
+ # if self.use_4bit:
70
+ # print("Setting up 4-bit quantization...")
71
+ # bnb_config = BitsAndBytesConfig(
72
+ # load_in_4bit=True,
73
+ # bnb_4bit_quant_type="nf4",
74
+ # bnb_4bit_compute_dtype=torch.float16, # Use float16 for better compatibility
75
+ # bnb_4bit_use_double_quant=True
76
+ # )
77
+ # else:
78
+ # bnb_config = None
79
+
80
+ # # Load model
81
+ # print(f"Loading model: {self.model_name}...")
82
+ # try:
83
+ # self.model = AutoModelForCausalLM.from_pretrained(
84
+ # self.model_name,
85
+ # quantization_config=bnb_config,
86
+ # device_map="auto",
87
+ # trust_remote_code=True,
88
+ # torch_dtype=torch.float16
89
+ # )
90
+ # except Exception as e:
91
+ # print(f"Error loading model: {e}")
92
+ # print("Attempting to load without quantization...")
93
+ # self.model = AutoModelForCausalLM.from_pretrained(
94
+ # self.model_name,
95
+ # device_map="auto",
96
+ # trust_remote_code=True,
97
+ # torch_dtype=torch.float16,
98
+ # low_cpu_mem_usage=True
99
+ # )
100
+
101
+ # # Enable gradient checkpointing to save memory
102
+ # if hasattr(self.model, 'gradient_checkpointing_enable'):
103
+ # self.model.gradient_checkpointing_enable()
104
+
105
+ # # Prepare model for k-bit training
106
+ # if self.use_4bit:
107
+ # print("Preparing model for 4-bit training...")
108
+ # self.model = prepare_model_for_kbit_training(self.model)
109
+
110
+ # # LoRA configuration - optimized for counseling task
111
+ # print("Applying LoRA configuration...")
112
+
113
+ # # Find the target modules dynamically
114
+ # target_modules = self.find_target_modules()
115
+
116
+ # lora_config = LoraConfig(
117
+ # r=16, # Reduced rank for stability
118
+ # lora_alpha=32, # Alpha parameter for LoRA scaling
119
+ # target_modules=target_modules,
120
+ # lora_dropout=0.05,
121
+ # bias="none",
122
+ # task_type=TaskType.CAUSAL_LM,
123
+ # inference_mode=False
124
+ # )
125
+
126
+ # # Apply LoRA
127
+ # self.model = get_peft_model(self.model, lora_config)
128
+
129
+ # # Print trainable parameters
130
+ # self.model.print_trainable_parameters()
131
+
132
+ # def find_target_modules(self):
133
+ # """Find linear modules to apply LoRA to"""
134
+ # target_modules = []
135
+ # for name, module in self.model.named_modules():
136
+ # if isinstance(module, torch.nn.Linear):
137
+ # # Extract the module name
138
+ # names = name.split('.')
139
+ # if len(names) > 0:
140
+ # target_modules.append(names[-1])
141
+
142
+ # # Remove duplicates and filter common patterns
143
+ # target_modules = list(set(target_modules))
144
+
145
+ # # Common patterns for transformer models
146
+ # common_targets = ["q_proj", "v_proj", "k_proj", "o_proj",
147
+ # "gate_proj", "up_proj", "down_proj",
148
+ # "fc1", "fc2", "query", "key", "value", "dense"]
149
+
150
+ # # Filter to only include common targets if they exist
151
+ # final_targets = [t for t in target_modules if any(ct in t.lower() for ct in common_targets)]
152
+
153
+ # # If no common targets found, use all linear layers
154
+ # if not final_targets:
155
+ # final_targets = target_modules[:6] # Limit to prevent too many parameters
156
+
157
+ # print(f"LoRA target modules: {final_targets}")
158
+ # return final_targets if final_targets else ["q_proj", "v_proj"] # Fallback
159
+
160
+ # def load_and_process_datasets(self, data_path: str):
161
+ # """Load and process datasets without multiprocessing issues"""
162
+
163
+ # print(f"Loading datasets from {data_path}...")
164
+
165
+ # # Load train dataset
166
+ # train_texts = []
167
+ # with open(f'{data_path}/train.jsonl', 'r', encoding='utf-8') as f:
168
+ # for line in tqdm(f, desc="Loading training data"):
169
+ # data = json.loads(line)
170
+ # train_texts.append(data['text'])
171
+
172
+ # # Load validation dataset
173
+ # val_texts = []
174
+ # with open(f'{data_path}/validation.jsonl', 'r', encoding='utf-8') as f:
175
+ # for line in tqdm(f, desc="Loading validation data"):
176
+ # data = json.loads(line)
177
+ # val_texts.append(data['text'])
178
+
179
+ # print(f"Loaded {len(train_texts)} training examples")
180
+ # print(f"Loaded {len(val_texts)} validation examples")
181
+
182
+ # # Tokenize datasets in batches (avoiding multiprocessing)
183
+ # print("Tokenizing training dataset...")
184
+ # train_encodings = self.tokenize_texts(train_texts)
185
+
186
+ # print("Tokenizing validation dataset...")
187
+ # val_encodings = self.tokenize_texts(val_texts)
188
+
189
+ # # Create datasets
190
+ # self.train_dataset = Dataset.from_dict(train_encodings)
191
+ # self.val_dataset = Dataset.from_dict(val_encodings)
192
+
193
+ # # Set format for PyTorch
194
+ # self.train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
195
+ # self.val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
196
+
197
+ # # Clean up memory
198
+ # del train_texts, val_texts, train_encodings, val_encodings
199
+ # gc.collect()
200
+
201
+ # def tokenize_texts(self, texts: List[str], batch_size: int = 100):
202
+ # """Tokenize texts in batches to avoid memory issues"""
203
+ # all_input_ids = []
204
+ # all_attention_masks = []
205
+
206
+ # for i in tqdm(range(0, len(texts), batch_size), desc="Tokenizing"):
207
+ # batch_texts = texts[i:i + batch_size]
208
+
209
+ # # Tokenize batch
210
+ # encodings = self.tokenizer(
211
+ # batch_texts,
212
+ # truncation=True,
213
+ # padding='max_length',
214
+ # max_length=512,
215
+ # return_tensors='pt'
216
+ # )
217
+
218
+ # # Convert to lists
219
+ # all_input_ids.extend(encodings['input_ids'].tolist())
220
+ # all_attention_masks.extend(encodings['attention_mask'].tolist())
221
+
222
+ # # Create labels (same as input_ids for language modeling)
223
+ # labels = all_input_ids.copy()
224
+
225
+ # return {
226
+ # 'input_ids': all_input_ids,
227
+ # 'attention_mask': all_attention_masks,
228
+ # 'labels': labels
229
+ # }
230
+
231
+ # def setup_training_args(self, output_dir: str = "./counselor_model_2b"):
232
+ # """Setup training arguments optimized for counseling task"""
233
+
234
+ # print("Setting up training arguments...")
235
+
236
+ # # Calculate batch sizes based on available memory
237
+ # if torch.cuda.is_available():
238
+ # gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
239
+ # if gpu_memory < 16: # Less than 16GB
240
+ # batch_size = 1
241
+ # gradient_accumulation = 16
242
+ # elif gpu_memory < 24: # Less than 24GB
243
+ # batch_size = 2
244
+ # gradient_accumulation = 8
245
+ # else: # 24GB or more
246
+ # batch_size = 4
247
+ # gradient_accumulation = 4
248
+ # else:
249
+ # batch_size = 1
250
+ # gradient_accumulation = 16
251
+
252
+ # print(f"Using batch_size={batch_size}, gradient_accumulation={gradient_accumulation}")
253
+
254
+ # self.training_args = TrainingArguments(
255
+ # output_dir=output_dir,
256
+ # num_train_epochs=3,
257
+ # per_device_train_batch_size=batch_size,
258
+ # per_device_eval_batch_size=batch_size,
259
+ # gradient_accumulation_steps=gradient_accumulation,
260
+ # gradient_checkpointing=True,
261
+ # warmup_steps=100,
262
+ # learning_rate=5e-5, # Conservative learning rate
263
+ # fp16=True,
264
+ # logging_steps=50,
265
+ # eval_strategy="steps",
266
+ # eval_steps=200,
267
+ # save_strategy="steps",
268
+ # save_steps=400,
269
+ # save_total_limit=2,
270
+ # load_best_model_at_end=True,
271
+ # metric_for_best_model="eval_loss",
272
+ # greater_is_better=False,
273
+ # report_to="none", # Disable all reporting
274
+ # push_to_hub=False,
275
+ # optim="adamw_torch", # Use standard optimizer
276
+ # lr_scheduler_type="linear",
277
+ # weight_decay=0.01,
278
+ # max_grad_norm=1.0,
279
+ # remove_unused_columns=False,
280
+ # label_names=["labels"],
281
+ # dataloader_num_workers=0, # Disable multiprocessing in dataloader
282
+ # dataloader_pin_memory=False, # Disable pinned memory to avoid issues
283
+ # )
284
+
285
+ # def train(self):
286
+ # """Execute training"""
287
+
288
+ # print("Initializing trainer...")
289
+
290
+ # # Data collator for language modeling
291
+ # data_collator = DataCollatorForLanguageModeling(
292
+ # tokenizer=self.tokenizer,
293
+ # mlm=False,
294
+ # pad_to_multiple_of=8
295
+ # )
296
+
297
+ # # Custom training to handle potential issues
298
+ # try:
299
+ # # Initialize trainer
300
+ # trainer = Trainer(
301
+ # model=self.model,
302
+ # args=self.training_args,
303
+ # train_dataset=self.train_dataset,
304
+ # eval_dataset=self.val_dataset,
305
+ # data_collator=data_collator,
306
+ # tokenizer=self.tokenizer,
307
+ # )
308
+
309
+ # # Start training
310
+ # print("="*50)
311
+ # print("Starting fine-tuning...")
312
+ # print("="*50)
313
+
314
+ # # Train with error handling
315
+ # train_result = trainer.train()
316
+
317
+ # # Save the final model
318
+ # print("\nSaving fine-tuned model...")
319
+ # trainer.save_model(f"{self.training_args.output_dir}/final_model_2b")
320
+ # self.tokenizer.save_pretrained(f"{self.training_args.output_dir}/final_model_2b")
321
+
322
+ # # Save training metrics
323
+ # with open(f"{self.training_args.output_dir}/training_metrics.json", 'w') as f:
324
+ # json.dump(train_result.metrics, f, indent=2)
325
+
326
+ # print("\n" + "="*50)
327
+ # print("Training completed successfully!")
328
+ # print(f"Model saved to: {self.training_args.output_dir}/final_model_2b")
329
+ # print("="*50)
330
+
331
+ # return trainer
332
+
333
+ # except Exception as e:
334
+ # print(f"Error during training: {e}")
335
+ # print("Attempting to save checkpoint...")
336
+
337
+ # # Try to save whatever we have
338
+ # try:
339
+ # self.model.save_pretrained(f"{self.training_args.output_dir}/checkpoint_emergency")
340
+ # self.tokenizer.save_pretrained(f"{self.training_args.output_dir}/checkpoint_emergency")
341
+ # print(f"Emergency checkpoint saved to: {self.training_args.output_dir}/checkpoint_emergency")
342
+ # except:
343
+ # print("Could not save emergency checkpoint")
344
+
345
+ # raise e
346
+
347
+ # def test_model(model_path: str, tokenizer_path: str):
348
+ # """Test the fine-tuned model with a sample input"""
349
+
350
+ # print("\n" + "="*50)
351
+ # print("Testing fine-tuned model...")
352
+ # print("="*50)
353
+
354
+ # # Load model and tokenizer
355
+ # from peft import PeftModel, PeftConfig
356
+
357
+ # tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
358
+
359
+ # # Try to load as PEFT model
360
+ # try:
361
+ # config = PeftConfig.from_pretrained(model_path)
362
+ # model = AutoModelForCausalLM.from_pretrained(
363
+ # config.base_model_name_or_path,
364
+ # torch_dtype=torch.float16,
365
+ # device_map="auto"
366
+ # )
367
+ # model = PeftModel.from_pretrained(model, model_path)
368
+ # except:
369
+ # # Load as regular model
370
+ # model = AutoModelForCausalLM.from_pretrained(
371
+ # model_path,
372
+ # torch_dtype=torch.float16,
373
+ # device_map="auto"
374
+ # )
375
+
376
+ # model.eval()
377
+
378
+ # # Test input
379
+ # test_input = "こんにちは。最近ストレスを感じています。"
380
+
381
+ # # Generate response
382
+ # inputs = tokenizer(test_input, return_tensors="pt")
383
+ # inputs = {k: v.cuda() if torch.cuda.is_available() else v for k, v in inputs.items()}
384
+
385
+ # with torch.no_grad():
386
+ # outputs = model.generate(
387
+ # **inputs,
388
+ # max_new_tokens=100,
389
+ # temperature=0.1,
390
+ # do_sample=True,
391
+ # top_p=0.9
392
+ # )
393
+
394
+ # response = tokenizer.decode(outputs[0], skip_special_tokens=True)
395
+ # print(f"Input: {test_input}")
396
+ # print(f"Response: {response}")
397
+ # print("="*50)
398
+
399
+ # # Main training script
400
+ # if __name__ == "__main__":
401
+ # import argparse
402
+
403
+ # parser = argparse.ArgumentParser(description='Fine-tune LFM model for counseling')
404
+ # parser.add_argument('--model_name', type=str, default='gpt2', # Using GPT2 as fallback
405
+ # help='Base model name (use gpt2 if liquid model fails)')
406
+ # parser.add_argument('--data_path', type=str, default='./processed_data_score80',
407
+ # help='Path to processed data')
408
+ # parser.add_argument('--output_dir', type=str, default='./counselor_model_2b',
409
+ # help='Output directory for fine-tuned model')
410
+ # parser.add_argument('--use_4bit', action='store_true', default=False,
411
+ # help='Use 4-bit quantization (set to False for stability)')
412
+ # parser.add_argument('--test_only', action='store_true',
413
+ # help='Only test existing model')
414
+
415
+ # args = parser.parse_args()
416
+
417
+ # if args.test_only:
418
+ # # Test existing model
419
+ # test_model(
420
+ # f"{args.output_dir}/final_model_2b",
421
+ # f"{args.output_dir}/final_model_2b"
422
+ # )
423
+ # else:
424
+ # # Check if CUDA is available
425
+ # if not torch.cuda.is_available():
426
+ # print("Warning: CUDA is not available. Training will be very slow on CPU.")
427
+ # print("It's highly recommended to use a GPU for training.")
428
+ # response = input("Do you want to continue anyway? (y/n): ")
429
+ # if response.lower() != 'y':
430
+ # exit()
431
+
432
+ # try:
433
+ # # Clear GPU cache
434
+ # if torch.cuda.is_available():
435
+ # torch.cuda.empty_cache()
436
+
437
+ # # Initialize fine-tuner
438
+ # print(f"Initializing fine-tuner with model: {args.model_name}")
439
+ # finetuner = LFMCounselorFineTuner(
440
+ # model_name=args.model_name,
441
+ # use_4bit=args.use_4bit
442
+ # )
443
+
444
+ # # Setup model
445
+ # print("\nSetting up model and tokenizer...")
446
+ # finetuner.setup_model_and_tokenizer()
447
+
448
+ # # Load datasets (using new method without multiprocessing)
449
+ # print("\nLoading and processing datasets...")
450
+ # finetuner.load_and_process_datasets(args.data_path)
451
+
452
+ # # Setup training arguments
453
+ # print("\nSetting up training arguments...")
454
+ # finetuner.setup_training_args(args.output_dir)
455
+
456
+ # # Train
457
+ # trainer = finetuner.train()
458
+
459
+ # # Test the model
460
+ # print("\nTesting the fine-tuned model...")
461
+ # test_model(
462
+ # f"{args.output_dir}/final_model_2b",
463
+ # f"{args.output_dir}/final_model_2b"
464
+ # )
465
+
466
+ # print("\n✅ Fine-tuning completed successfully!")
467
+ # print(f"📁 Model saved to: {args.output_dir}/final_model_2b")
468
+ # print("\nNext steps:")
469
+ # print("1. Test more: python finetune_lfm.py --test_only")
470
+ # print("2. Run benchmarking: python benchmark_model.py")
471
+ # print("3. Optimize for mobile: python optimize_for_mobile.py")
472
+
473
+ # except KeyboardInterrupt:
474
+ # print("\n\nTraining interrupted by user.")
475
+ # print("Partial model may be saved in checkpoints.")
476
+ # except Exception as e:
477
+ # print(f"\n❌ Error during fine-tuning: {e}")
478
+ # import traceback
479
+ # traceback.print_exc()
480
+ # print("\nTroubleshooting tips:")
481
+ # print("1. Try reducing batch size")
482
+ # print("2. Try without 4-bit quantization: remove --use_4bit")
483
+ # print("3. Try with a smaller model like gpt2")
484
+ # print("4. Ensure you have enough GPU memory")
485
+
486
+
487
+
488
+ ###### wandb login ######
489
+
490
+ import torch
491
+ from transformers import (
492
+ AutoModelForCausalLM,
493
+ AutoTokenizer,
494
+ TrainingArguments,
495
+ Trainer,
496
+ DataCollatorForLanguageModeling,
497
+ BitsAndBytesConfig,
498
+ TrainerCallback
499
+ )
500
+ from peft import (
501
+ LoraConfig,
502
+ get_peft_model,
503
+ prepare_model_for_kbit_training,
504
+ TaskType
505
+ )
506
+ from datasets import load_dataset, Dataset
507
+ import os
508
+ from typing import Dict, List, Optional
509
+ import numpy as np
510
+ from tqdm import tqdm
511
+ import json
512
+ import gc
513
+ import warnings
514
+ import wandb
515
+ from datetime import datetime
516
+
517
+ warnings.filterwarnings('ignore')
518
+
519
+ class LFMCounselorFineTuner:
520
+ def __init__(self, model_name: str = "LiquidAI/LFM2-2.6B", use_4bit: bool = True):
521
+ """
522
+ Initialize the fine-tuner for LFM models
523
+
524
+ Args:
525
+ model_name: Name of the base model
526
+ use_4bit: Whether to use 4-bit quantization for memory efficiency
527
+ """
528
+ self.model_name = model_name
529
+ self.use_4bit = use_4bit
530
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
531
+
532
+ print(f"Using device: {self.device}")
533
+ gpu_memory = 0
534
+ if torch.cuda.is_available():
535
+ gpu_name = torch.cuda.get_device_name(0)
536
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
537
+ print(f"GPU: {gpu_name}")
538
+ print(f"GPU Memory: {gpu_memory:.2f} GB")
539
+
540
+ # Initialize WandB (always enabled)
541
+ try:
542
+ # Create a unique run name with timestamp
543
+ run_name = f"lfm-counselor-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
544
+
545
+ # Initialize wandb with comprehensive config
546
+ wandb.init(
547
+ project="liquid-counselor-hackathon",
548
+ name=run_name,
549
+ config={
550
+ "model_name": model_name,
551
+ "use_4bit_quantization": use_4bit,
552
+ "device": str(self.device),
553
+ "gpu": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU",
554
+ "gpu_memory_gb": gpu_memory,
555
+ "framework": "transformers",
556
+ "peft_method": "LoRA",
557
+ "task": "japanese_counseling",
558
+ "dataset": "KokoroChat"
559
+ },
560
+ tags=["counseling", "japanese", "lfm", "finetune", "hackathon"]
561
+ )
562
+ print(f"✅ WandB initialized: {wandb.run.name}")
563
+ print(f"📊 View run at: {wandb.run.get_url()}")
564
+ self.wandb_enabled = True
565
+ except Exception as e:
566
+ print(f"⚠️ WandB initialization failed: {e}")
567
+ print("Continuing without WandB logging...")
568
+ self.wandb_enabled = False
569
+ os.environ["WANDB_DISABLED"] = "true"
570
+
571
+ def setup_model_and_tokenizer(self):
572
+ """Setup model with quantization and LoRA"""
573
+
574
+ print("Loading tokenizer...")
575
+ # Tokenizer setup
576
+ try:
577
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
578
+ except:
579
+ # Fallback to a known working tokenizer if model-specific one fails
580
+ print("Using fallback tokenizer...")
581
+ self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
582
+
583
+ # Add padding token if it doesn't exist
584
+ if self.tokenizer.pad_token is None:
585
+ self.tokenizer.pad_token = self.tokenizer.eos_token
586
+ if self.tokenizer.eos_token is None:
587
+ self.tokenizer.eos_token = "</s>"
588
+ self.tokenizer.pad_token = "</s>"
589
+
590
+ self.tokenizer.padding_side = "right"
591
+
592
+ # Quantization config for memory efficiency
593
+ if self.use_4bit:
594
+ print("Setting up 4-bit quantization...")
595
+ bnb_config = BitsAndBytesConfig(
596
+ load_in_4bit=True,
597
+ bnb_4bit_quant_type="nf4",
598
+ bnb_4bit_compute_dtype=torch.float16,
599
+ bnb_4bit_use_double_quant=True
600
+ )
601
+ else:
602
+ bnb_config = None
603
+
604
+ # Load model
605
+ print(f"Loading model: {self.model_name}...")
606
+ try:
607
+ self.model = AutoModelForCausalLM.from_pretrained(
608
+ self.model_name,
609
+ quantization_config=bnb_config,
610
+ device_map="auto",
611
+ trust_remote_code=True,
612
+ torch_dtype=torch.float16
613
+ )
614
+ except Exception as e:
615
+ print(f"Error loading model: {e}")
616
+ print("Attempting to load without quantization...")
617
+ self.model = AutoModelForCausalLM.from_pretrained(
618
+ self.model_name,
619
+ device_map="auto",
620
+ trust_remote_code=True,
621
+ torch_dtype=torch.float16,
622
+ low_cpu_mem_usage=True
623
+ )
624
+
625
+ # Enable gradient checkpointing to save memory
626
+ if hasattr(self.model, 'gradient_checkpointing_enable'):
627
+ self.model.gradient_checkpointing_enable()
628
+
629
+ # Prepare model for k-bit training
630
+ if self.use_4bit:
631
+ print("Preparing model for 4-bit training...")
632
+ self.model = prepare_model_for_kbit_training(self.model)
633
+
634
+ # LoRA configuration - optimized for counseling task
635
+ print("Applying LoRA configuration...")
636
+
637
+ # Find the target modules dynamically
638
+ target_modules = self.find_target_modules()
639
+
640
+ lora_config = LoraConfig(
641
+ r=16, # Reduced rank for stability
642
+ lora_alpha=32, # Alpha parameter for LoRA scaling
643
+ target_modules=target_modules,
644
+ lora_dropout=0.05,
645
+ bias="none",
646
+ task_type=TaskType.CAUSAL_LM,
647
+ inference_mode=False
648
+ )
649
+
650
+ # Apply LoRA
651
+ self.model = get_peft_model(self.model, lora_config)
652
+
653
+ # Get trainable parameters info
654
+ trainable_params = 0
655
+ all_params = 0
656
+ for _, param in self.model.named_parameters():
657
+ all_params += param.numel()
658
+ if param.requires_grad:
659
+ trainable_params += param.numel()
660
+
661
+ trainable_percentage = 100 * trainable_params / all_params if all_params > 0 else 0
662
+
663
+ print(f"Trainable parameters: {trainable_params:,} / {all_params:,} ({trainable_percentage:.2f}%)")
664
+
665
+ # Log model architecture to WandB
666
+ if self.wandb_enabled:
667
+ wandb.config.update({
668
+ "lora_r": lora_config.r,
669
+ "lora_alpha": lora_config.lora_alpha,
670
+ "lora_dropout": lora_config.lora_dropout,
671
+ "lora_target_modules": target_modules,
672
+ "total_parameters": all_params,
673
+ "trainable_parameters": trainable_params,
674
+ "trainable_percentage": trainable_percentage
675
+ })
676
+
677
+ self.model.print_trainable_parameters()
678
+
679
+ def find_target_modules(self):
680
+ """Find linear modules to apply LoRA to"""
681
+ target_modules = []
682
+ for name, module in self.model.named_modules():
683
+ if isinstance(module, torch.nn.Linear):
684
+ # Extract the module name
685
+ names = name.split('.')
686
+ if len(names) > 0:
687
+ target_modules.append(names[-1])
688
+
689
+ # Remove duplicates and filter common patterns
690
+ target_modules = list(set(target_modules))
691
+
692
+ # Common patterns for transformer models
693
+ common_targets = ["q_proj", "v_proj", "k_proj", "o_proj",
694
+ "gate_proj", "up_proj", "down_proj",
695
+ "fc1", "fc2", "query", "key", "value", "dense"]
696
+
697
+ # Filter to only include common targets if they exist
698
+ final_targets = [t for t in target_modules if any(ct in t.lower() for ct in common_targets)]
699
+
700
+ # If no common targets found, use all linear layers
701
+ if not final_targets:
702
+ final_targets = target_modules[:6] # Limit to prevent too many parameters
703
+
704
+ print(f"LoRA target modules: {final_targets}")
705
+ return final_targets if final_targets else ["q_proj", "v_proj"] # Fallback
706
+
707
+ def load_and_process_datasets(self, data_path: str):
708
+ """Load and process datasets without multiprocessing issues"""
709
+
710
+ print(f"Loading datasets from {data_path}...")
711
+
712
+ # Load train dataset
713
+ train_texts = []
714
+ train_scores = []
715
+ train_topics = []
716
+
717
+ with open(f'{data_path}/train.jsonl', 'r', encoding='utf-8') as f:
718
+ for line in tqdm(f, desc="Loading training data"):
719
+ data = json.loads(line)
720
+ train_texts.append(data['text'])
721
+ train_scores.append(data.get('score', 0))
722
+ train_topics.append(data.get('topic', 'Unknown'))
723
+
724
+ # Load validation dataset
725
+ val_texts = []
726
+ val_scores = []
727
+ val_topics = []
728
+
729
+ with open(f'{data_path}/validation.jsonl', 'r', encoding='utf-8') as f:
730
+ for line in tqdm(f, desc="Loading validation data"):
731
+ data = json.loads(line)
732
+ val_texts.append(data['text'])
733
+ val_scores.append(data.get('score', 0))
734
+ val_topics.append(data.get('topic', 'Unknown'))
735
+
736
+ print(f"Loaded {len(train_texts)} training examples")
737
+ print(f"Loaded {len(val_texts)} validation examples")
738
+
739
+ # Log dataset statistics to WandB
740
+ if self.wandb_enabled:
741
+ # Calculate score statistics
742
+ train_score_stats = {
743
+ "train_examples": len(train_texts),
744
+ "train_avg_score": float(np.mean(train_scores)),
745
+ "train_min_score": float(np.min(train_scores)),
746
+ "train_max_score": float(np.max(train_scores)),
747
+ "train_std_score": float(np.std(train_scores))
748
+ }
749
+
750
+ val_score_stats = {
751
+ "val_examples": len(val_texts),
752
+ "val_avg_score": float(np.mean(val_scores)),
753
+ "val_min_score": float(np.min(val_scores)),
754
+ "val_max_score": float(np.max(val_scores)),
755
+ "val_std_score": float(np.std(val_scores))
756
+ }
757
+
758
+ wandb.config.update(train_score_stats)
759
+ wandb.config.update(val_score_stats)
760
+
761
+ # Log score distribution histogram
762
+ wandb.log({
763
+ "train_score_distribution": wandb.Histogram(train_scores),
764
+ "val_score_distribution": wandb.Histogram(val_scores)
765
+ })
766
+
767
+ # Log topic distribution
768
+ train_topic_counts = {}
769
+ for topic in train_topics:
770
+ train_topic_counts[topic] = train_topic_counts.get(topic, 0) + 1
771
+
772
+ # Create a bar chart for topics (top 20)
773
+ if len(train_topic_counts) > 0:
774
+ top_topics = sorted(train_topic_counts.items(), key=lambda x: x[1], reverse=True)[:20]
775
+ wandb.log({
776
+ "topic_distribution": wandb.plot.bar(
777
+ wandb.Table(data=[[k, v] for k, v in top_topics],
778
+ columns=["Topic", "Count"]),
779
+ "Topic", "Count", title="Training Topic Distribution (Top 20)"
780
+ )
781
+ })
782
+
783
+ # Tokenize datasets in batches (avoiding multiprocessing)
784
+ print("Tokenizing training dataset...")
785
+ train_encodings = self.tokenize_texts(train_texts)
786
+
787
+ print("Tokenizing validation dataset...")
788
+ val_encodings = self.tokenize_texts(val_texts)
789
+
790
+ # Create datasets
791
+ self.train_dataset = Dataset.from_dict(train_encodings)
792
+ self.val_dataset = Dataset.from_dict(val_encodings)
793
+
794
+ # Set format for PyTorch
795
+ self.train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
796
+ self.val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
797
+
798
+ # Clean up memory
799
+ del train_texts, val_texts, train_encodings, val_encodings
800
+ gc.collect()
801
+
802
+ def tokenize_texts(self, texts: List[str], batch_size: int = 100):
803
+ """Tokenize texts in batches to avoid memory issues"""
804
+ all_input_ids = []
805
+ all_attention_masks = []
806
+
807
+ for i in tqdm(range(0, len(texts), batch_size), desc="Tokenizing"):
808
+ batch_texts = texts[i:i + batch_size]
809
+
810
+ # Tokenize batch
811
+ encodings = self.tokenizer(
812
+ batch_texts,
813
+ truncation=True,
814
+ padding='max_length',
815
+ max_length=512,
816
+ return_tensors='pt'
817
+ )
818
+
819
+ # Convert to lists
820
+ all_input_ids.extend(encodings['input_ids'].tolist())
821
+ all_attention_masks.extend(encodings['attention_mask'].tolist())
822
+
823
+ # Create labels (same as input_ids for language modeling)
824
+ labels = all_input_ids.copy()
825
+
826
+ return {
827
+ 'input_ids': all_input_ids,
828
+ 'attention_mask': all_attention_masks,
829
+ 'labels': labels
830
+ }
831
+
832
+ def setup_training_args(self, output_dir: str = "./counselor_model_2b"):
833
+ """Setup training arguments optimized for counseling task"""
834
+
835
+ print("Setting up training arguments...")
836
+
837
+ # Calculate batch sizes based on available memory
838
+ if torch.cuda.is_available():
839
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
840
+ if gpu_memory < 16: # Less than 16GB
841
+ batch_size = 1
842
+ gradient_accumulation = 16
843
+ elif gpu_memory < 24: # Less than 24GB
844
+ batch_size = 2
845
+ gradient_accumulation = 8
846
+ else: # 24GB or more
847
+ batch_size = 4
848
+ gradient_accumulation = 4
849
+ else:
850
+ batch_size = 1
851
+ gradient_accumulation = 16
852
+
853
+ print(f"Using batch_size={batch_size}, gradient_accumulation={gradient_accumulation}")
854
+
855
+ # Update WandB config with training hyperparameters
856
+ if self.wandb_enabled:
857
+ wandb.config.update({
858
+ "batch_size": batch_size,
859
+ "gradient_accumulation_steps": gradient_accumulation,
860
+ "effective_batch_size": batch_size * gradient_accumulation,
861
+ "num_epochs": 3,
862
+ "learning_rate": 5e-5,
863
+ "warmup_steps": 100,
864
+ "weight_decay": 0.01,
865
+ "max_grad_norm": 1.0,
866
+ "lr_scheduler": "linear",
867
+ "optimizer": "adamw_torch",
868
+ "fp16": True,
869
+ "max_length": 512
870
+ })
871
+
872
+ # Set report_to based on wandb availability
873
+ report_to = "wandb" if self.wandb_enabled else "none"
874
+
875
+ self.training_args = TrainingArguments(
876
+ output_dir=output_dir,
877
+ num_train_epochs=3,
878
+ per_device_train_batch_size=batch_size,
879
+ per_device_eval_batch_size=batch_size,
880
+ gradient_accumulation_steps=gradient_accumulation,
881
+ gradient_checkpointing=True,
882
+ warmup_steps=100,
883
+ learning_rate=5e-5,
884
+ fp16=True,
885
+ logging_steps=50,
886
+ logging_first_step=True,
887
+ eval_strategy="steps",
888
+ eval_steps=200,
889
+ save_strategy="steps",
890
+ save_steps=400,
891
+ save_total_limit=2,
892
+ load_best_model_at_end=True,
893
+ metric_for_best_model="eval_loss",
894
+ greater_is_better=False,
895
+ report_to=report_to,
896
+ run_name=wandb.run.name if self.wandb_enabled and wandb.run else "local_run",
897
+ push_to_hub=False,
898
+ optim="adamw_torch",
899
+ lr_scheduler_type="linear",
900
+ weight_decay=0.01,
901
+ max_grad_norm=1.0,
902
+ remove_unused_columns=False,
903
+ label_names=["labels"],
904
+ dataloader_num_workers=0,
905
+ dataloader_pin_memory=False,
906
+ )
907
+
908
+ def train(self):
909
+ """Execute training"""
910
+
911
+ print("Initializing trainer...")
912
+
913
+ # Data collator for language modeling
914
+ data_collator = DataCollatorForLanguageModeling(
915
+ tokenizer=self.tokenizer,
916
+ mlm=False,
917
+ pad_to_multiple_of=8
918
+ )
919
+
920
+ # Custom callback for additional metrics (properly inheriting from TrainerCallback)
921
+ class CustomMetricsCallback(TrainerCallback):
922
+ def on_log(self, args, state, control, logs=None, **kwargs):
923
+ if logs and self.wandb_enabled:
924
+ # Add perplexity metrics
925
+ if "loss" in logs:
926
+ logs["perplexity"] = np.exp(logs["loss"])
927
+ if "eval_loss" in logs:
928
+ logs["eval_perplexity"] = np.exp(logs["eval_loss"])
929
+ return control
930
+
931
+ # Create callback instance with wandb_enabled flag
932
+ custom_callback = CustomMetricsCallback()
933
+ custom_callback.wandb_enabled = self.wandb_enabled
934
+
935
+ # Custom training to handle potential issues
936
+ try:
937
+ # Initialize trainer with callbacks
938
+ trainer = Trainer(
939
+ model=self.model,
940
+ args=self.training_args,
941
+ train_dataset=self.train_dataset,
942
+ eval_dataset=self.val_dataset,
943
+ data_collator=data_collator,
944
+ tokenizer=self.tokenizer,
945
+ callbacks=[custom_callback] if self.wandb_enabled else [],
946
+ )
947
+
948
+ # Calculate total training steps
949
+ total_steps = len(self.train_dataset) // (self.training_args.per_device_train_batch_size * self.training_args.gradient_accumulation_steps) * self.training_args.num_train_epochs
950
+
951
+ # Start training
952
+ print("="*50)
953
+ print("Starting fine-tuning...")
954
+ print(f"Total training samples: {len(self.train_dataset)}")
955
+ print(f"Total validation samples: {len(self.val_dataset)}")
956
+ print(f"Total training steps: {total_steps}")
957
+ print("="*50)
958
+
959
+ # Log training start
960
+ if self.wandb_enabled:
961
+ wandb.log({"training_status": "started", "total_steps": total_steps})
962
+
963
+ # Train with error handling
964
+ train_result = trainer.train()
965
+
966
+ # Save the final model
967
+ print("\nSaving fine-tuned model...")
968
+ trainer.save_model(f"{self.training_args.output_dir}/final_model_2b")
969
+ self.tokenizer.save_pretrained(f"{self.training_args.output_dir}/final_model_2b")
970
+
971
+ # Save training metrics
972
+ with open(f"{self.training_args.output_dir}/training_metrics.json", 'w') as f:
973
+ json.dump(train_result.metrics, f, indent=2)
974
+
975
+ # Final evaluation
976
+ print("\nRunning final evaluation...")
977
+ eval_results = trainer.evaluate()
978
+
979
+ # Save evaluation metrics
980
+ with open(f"{self.training_args.output_dir}/eval_metrics.json", 'w') as f:
981
+ json.dump(eval_results, f, indent=2)
982
+
983
+ # Log final metrics to WandB
984
+ if self.wandb_enabled:
985
+ # Log final metrics
986
+ wandb.run.summary.update({
987
+ "final_train_loss": train_result.metrics.get("train_loss", 0),
988
+ "final_eval_loss": eval_results.get("eval_loss", 0),
989
+ "final_eval_perplexity": np.exp(eval_results.get("eval_loss", 0)),
990
+ "total_training_time": train_result.metrics.get("train_runtime", 0),
991
+ "training_samples_per_second": train_result.metrics.get("train_samples_per_second", 0),
992
+ "training_status": "completed"
993
+ })
994
+
995
+ # Create a summary table
996
+ summary_table = wandb.Table(
997
+ columns=["Metric", "Value"],
998
+ data=[
999
+ ["Final Training Loss", f"{train_result.metrics.get('train_loss', 0):.4f}"],
1000
+ ["Final Eval Loss", f"{eval_results.get('eval_loss', 0):.4f}"],
1001
+ ["Final Perplexity", f"{np.exp(eval_results.get('eval_loss', 0)):.2f}"],
1002
+ ["Training Time (seconds)", f"{train_result.metrics.get('train_runtime', 0):.0f}"],
1003
+ ["Training Samples/Second", f"{train_result.metrics.get('train_samples_per_second', 0):.2f}"]
1004
+ ]
1005
+ )
1006
+ wandb.log({"training_summary": summary_table})
1007
+
1008
+ # Save model artifact
1009
+ try:
1010
+ artifact = wandb.Artifact(
1011
+ name=f"counselor-model-{wandb.run.id}",
1012
+ type="model",
1013
+ description="Fine-tuned Japanese counseling model",
1014
+ metadata={
1015
+ "base_model": self.model_name,
1016
+ "final_loss": float(eval_results.get("eval_loss", 0)),
1017
+ "final_perplexity": float(np.exp(eval_results.get("eval_loss", 0))),
1018
+ "dataset": "KokoroChat"
1019
+ }
1020
+ )
1021
+ artifact.add_dir(f"{self.training_args.output_dir}/final_model_2b")
1022
+ wandb.log_artifact(artifact)
1023
+ except Exception as e:
1024
+ print(f"Warning: Could not save model artifact: {e}")
1025
+
1026
+ print("\n" + "="*50)
1027
+ print("✅ Training completed successfully!")
1028
+ print(f"📁 Model saved to: {self.training_args.output_dir}/final_model_2b")
1029
+ print(f"📉 Final eval loss: {eval_results.get('eval_loss', 0):.4f}")
1030
+ print(f"📊 Final perplexity: {np.exp(eval_results.get('eval_loss', 0)):.2f}")
1031
+ if self.wandb_enabled and wandb.run:
1032
+ print(f"🔗 View results at: {wandb.run.get_url()}")
1033
+ print("="*50)
1034
+
1035
+ return trainer
1036
+
1037
+ except Exception as e:
1038
+ print(f"❌ Error during training: {e}")
1039
+
1040
+ # Log error to WandB
1041
+ if self.wandb_enabled:
1042
+ wandb.run.summary["training_status"] = "failed"
1043
+ wandb.run.summary["error"] = str(e)
1044
+
1045
+ print("Attempting to save checkpoint...")
1046
+
1047
+ # Try to save whatever we have
1048
+ try:
1049
+ self.model.save_pretrained(f"{self.training_args.output_dir}/checkpoint_emergency")
1050
+ self.tokenizer.save_pretrained(f"{self.training_args.output_dir}/checkpoint_emergency")
1051
+ print(f"💾 Emergency checkpoint saved to: {self.training_args.output_dir}/checkpoint_emergency")
1052
+ except:
1053
+ print("❌ Could not save emergency checkpoint")
1054
+
1055
+ raise e
1056
+ finally:
1057
+ # Ensure WandB run is finished
1058
+ if self.wandb_enabled:
1059
+ wandb.finish()
1060
+
1061
+ # def test_model(model_path: str, tokenizer_path: str):
1062
+ # """Test the fine-tuned model with sample inputs"""
1063
+
1064
+ # print("\n" + "="*50)
1065
+ # print("Testing fine-tuned model...")
1066
+ # print("="*50)
1067
+
1068
+ # # Load model and tokenizer
1069
+ # from peft import PeftModel, PeftConfig
1070
+
1071
+ # tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
1072
+ # if tokenizer.pad_token is None:
1073
+ # tokenizer.pad_token = tokenizer.eos_token
1074
+
1075
+ # # Try to load as PEFT model
1076
+ # try:
1077
+ # config = PeftConfig.from_pretrained(model_path)
1078
+ # model = AutoModelForCausalLM.from_pretrained(
1079
+ # config.base_model_name_or_path,
1080
+ # torch_dtype=torch.float16,
1081
+ # device_map="auto"
1082
+ # )
1083
+ # model = PeftModel.from_pretrained(model, model_path)
1084
+ # except:
1085
+ # # Load as regular model
1086
+ # model = AutoModelForCausalLM.from_pretrained(
1087
+ # model_path,
1088
+ # torch_dtype=torch.float16,
1089
+ # device_map="auto"
1090
+ # )
1091
+
1092
+ # model.eval()
1093
+
1094
+ # # Test inputs
1095
+ # test_cases = [
1096
+ # "こんにちは。最近ストレスを感じています。",
1097
+ # "仕事がうまくいかなくて悩んでいます。",
1098
+ # "人間関係で困っています。どうすればいいでしょうか。"
1099
+ # ]
1100
+
1101
+ # print("Sample conversations:")
1102
+ # print("-" * 50)
1103
+
1104
+ def test_model(model_path: str, tokenizer_path: str):
1105
+ """Test the fine-tuned model with sample inputs"""
1106
+
1107
+ print("\n" + "="*50)
1108
+ print("Testing fine-tuned model...")
1109
+ print("="*50)
1110
+
1111
+ # Load model and tokenizer with proper local path handling
1112
+ from peft import PeftModel, PeftConfig
1113
+ import os
1114
+
1115
+ # Fix tokenizer loading for local paths
1116
+ try:
1117
+ # Check if tokenizer files exist in the path
1118
+ if os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")):
1119
+ print(f"Loading tokenizer from {tokenizer_path}")
1120
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
1121
+ else:
1122
+ print(f"Tokenizer not found at {tokenizer_path}, using base model tokenizer")
1123
+ # Fallback to base model tokenizer
1124
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
1125
+ except Exception as e:
1126
+ print(f"Error loading tokenizer: {e}")
1127
+ print("Using fallback GPT-2 tokenizer")
1128
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
1129
+
1130
+ if tokenizer.pad_token is None:
1131
+ tokenizer.pad_token = tokenizer.eos_token
1132
+
1133
+ # Try to load model
1134
+ try:
1135
+ # Check if it's a PEFT model
1136
+ adapter_config_path = os.path.join(model_path, "adapter_config.json")
1137
+ if os.path.exists(adapter_config_path):
1138
+ print("Loading as PEFT model...")
1139
+ config = PeftConfig.from_pretrained(model_path)
1140
+ base_model = AutoModelForCausalLM.from_pretrained(
1141
+ config.base_model_name_or_path,
1142
+ torch_dtype=torch.float16,
1143
+ device_map="auto",
1144
+ trust_remote_code=True
1145
+ )
1146
+ model = PeftModel.from_pretrained(base_model, model_path)
1147
+ else:
1148
+ # Load as regular model
1149
+ print("Loading as regular model...")
1150
+ model = AutoModelForCausalLM.from_pretrained(
1151
+ model_path,
1152
+ torch_dtype=torch.float16,
1153
+ device_map="auto",
1154
+ local_files_only=True,
1155
+ trust_remote_code=True
1156
+ )
1157
+ except Exception as e:
1158
+ print(f"Error loading model: {e}")
1159
+ raise
1160
+
1161
+ model.eval()
1162
+
1163
+ # Test inputs
1164
+ test_cases = [
1165
+ "こんにちは。最近ストレスを感じています。",
1166
+ "仕事がうまくいかなくて悩んでいます。",
1167
+ "人間関係で困っています。どうすればいいでしょうか。"
1168
+ ]
1169
+
1170
+ print("Sample conversations:")
1171
+ print("-" * 50)
1172
+
1173
+ for test_input in test_cases:
1174
+ # Generate response
1175
+ inputs = tokenizer(test_input, return_tensors="pt", truncation=True, max_length=512)
1176
+ inputs = {k: v.cuda() if torch.cuda.is_available() else v for k, v in inputs.items()}
1177
+
1178
+ with torch.no_grad():
1179
+ outputs = model.generate(
1180
+ **inputs,
1181
+ max_new_tokens=150,
1182
+ temperature=0.1,
1183
+ do_sample=True,
1184
+ top_p=0.9,
1185
+ pad_token_id=tokenizer.pad_token_id
1186
+ )
1187
+
1188
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
1189
+ response = response[len(test_input):].strip() # Remove input from response
1190
+
1191
+ print(f"Client: {test_input}")
1192
+ print(f"Counselor: {response[:200]}...")
1193
+ print("-" * 50)
1194
+
1195
+ print("="*50)
1196
+
1197
+ for test_input in test_cases:
1198
+ # Generate response
1199
+ inputs = tokenizer(test_input, return_tensors="pt", truncation=True, max_length=512)
1200
+ inputs = {k: v.cuda() if torch.cuda.is_available() else v for k, v in inputs.items()}
1201
+
1202
+ with torch.no_grad():
1203
+ outputs = model.generate(
1204
+ **inputs,
1205
+ max_new_tokens=150,
1206
+ temperature=0.1,
1207
+ do_sample=True,
1208
+ top_p=0.9,
1209
+ pad_token_id=tokenizer.pad_token_id
1210
+ )
1211
+
1212
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
1213
+ response = response[len(test_input):].strip() # Remove input from response
1214
+
1215
+ print(f"Client: {test_input}")
1216
+ print(f"Counselor: {response[:200]}...")
1217
+ print("-" * 50)
1218
+
1219
+ print("="*50)
1220
+
1221
+ # Main training script
1222
+ if __name__ == "__main__":
1223
+ import argparse
1224
+
1225
+ parser = argparse.ArgumentParser(description='Fine-tune LFM model for counseling')
1226
+ parser.add_argument('--model_name', type=str, default='LiquidAI/LFM2-2.6B',
1227
+ help='Base model name')
1228
+ parser.add_argument('--data_path', type=str, default='./processed_data_score80',
1229
+ help='Path to processed data')
1230
+ parser.add_argument('--output_dir', type=str, default='./counselor_model_2b',
1231
+ help='Output directory for fine-tuned model')
1232
+ parser.add_argument('--use_4bit', action='store_true', default=False,
1233
+ help='Use 4-bit quantization')
1234
+ parser.add_argument('--wandb_api_key', type=str, default=None,
1235
+ help='WandB API key (optional, can use wandb login instead)')
1236
+ parser.add_argument('--test_only', action='store_true',
1237
+ help='Only test existing model')
1238
+
1239
+ args = parser.parse_args()
1240
+
1241
+ # Set WandB API key if provided
1242
+ if args.wandb_api_key:
1243
+ os.environ["WANDB_API_KEY"] = args.wandb_api_key
1244
+
1245
+ if args.test_only:
1246
+ # Test existing model
1247
+ test_model(
1248
+ f"{args.output_dir}/final_model_2b",
1249
+ f"{args.output_dir}/final_model_2b"
1250
+ )
1251
+ else:
1252
+ # Check if CUDA is available
1253
+ if not torch.cuda.is_available():
1254
+ print("⚠️ Warning: CUDA is not available. Training will be very slow on CPU.")
1255
+ print("It's highly recommended to use a GPU for training.")
1256
+ response = input("Do you want to continue anyway? (y/n): ")
1257
+ if response.lower() != 'y':
1258
+ exit()
1259
+
1260
+ try:
1261
+ # Clear GPU cache
1262
+ if torch.cuda.is_available():
1263
+ torch.cuda.empty_cache()
1264
+
1265
+ # Initialize fine-tuner (WandB is enabled by default)
1266
+ print(f"🚀 Initializing fine-tuner with model: {args.model_name}")
1267
+ finetuner = LFMCounselorFineTuner(
1268
+ model_name=args.model_name,
1269
+ use_4bit=args.use_4bit
1270
+ )
1271
+
1272
+ # Setup model
1273
+ print("\n🔧 Setting up model and tokenizer...")
1274
+ finetuner.setup_model_and_tokenizer()
1275
+
1276
+ # Load datasets
1277
+ print("\n📚 Loading and processing datasets...")
1278
+ finetuner.load_and_process_datasets(args.data_path)
1279
+
1280
+ # Setup training arguments
1281
+ print("\n⚙️ Setting up training arguments...")
1282
+ finetuner.setup_training_args(args.output_dir)
1283
+
1284
+ # Train
1285
+ trainer = finetuner.train()
1286
+
1287
+ # Test the model
1288
+ print("\n🧪 Testing the fine-tuned model...")
1289
+ test_model(
1290
+ f"{args.output_dir}/final_model_2b_v2",
1291
+ f"{args.output_dir}/final_model_2b_v2"
1292
+ )
1293
+
1294
+ print("\n✅ Fine-tuning completed successfully!")
1295
+ print(f"📁 Model saved to: {args.output_dir}/final_model_2b_v2")
1296
+ print("\n📋 Next steps:")
1297
+ print("1. Test more: python finetune_lfm.py --test_only")
1298
+ print("2. Run benchmarking: python benchmark_model.py")
1299
+ print("3. Optimize for mobile: python optimize_for_mobile.py")
1300
+
1301
+ except KeyboardInterrupt:
1302
+ print("\n\n⚠️ Training interrupted by user.")
1303
+ print("Partial model may be saved in checkpoints.")
1304
+ if wandb.run:
1305
+ wandb.finish()
1306
+ except Exception as e:
1307
+ print(f"\n❌ Error during fine-tuning: {e}")
1308
+ import traceback
1309
+ traceback.print_exc()
1310
+ if wandb.run:
1311
+ wandb.finish()
finetune_lfm_complete_history.py ADDED
@@ -0,0 +1,801 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fine-tuning Script for LFM2-2.6B with Complete Dialogue History
3
+ Following KokoroChat methodology - uses entire conversation context
4
+ Filename: finetune_lfm_complete_history.py
5
+ """
6
+
7
+ import torch
8
+ from transformers import (
9
+ AutoModelForCausalLM,
10
+ AutoTokenizer,
11
+ TrainingArguments,
12
+ Trainer,
13
+ DataCollatorForLanguageModeling,
14
+ BitsAndBytesConfig,
15
+ TrainerCallback
16
+ )
17
+ from peft import (
18
+ LoraConfig,
19
+ get_peft_model,
20
+ prepare_model_for_kbit_training,
21
+ TaskType,
22
+ PeftModel,
23
+ PeftConfig
24
+ )
25
+ from datasets import load_dataset, Dataset
26
+ import os
27
+ from typing import Dict, List, Optional
28
+ import numpy as np
29
+ from tqdm import tqdm
30
+ import json
31
+ import gc
32
+ import warnings
33
+ import wandb
34
+ from datetime import datetime
35
+
36
+ warnings.filterwarnings('ignore')
37
+
38
+ # Enable TF32 for H100 optimization
39
+ torch.backends.cuda.matmul.allow_tf32 = True
40
+ torch.backends.cudnn.allow_tf32 = True
41
+
42
+ class LFMKokoroChatFineTuner:
43
+ def __init__(
44
+ self,
45
+ model_name: str = "LiquidAI/LFM2-2.6B",
46
+ use_4bit: bool = False, # H100 has enough memory
47
+ max_seq_length: int = 2048 # Increased for complete dialogue history
48
+ ):
49
+ """
50
+ Initialize the fine-tuner for LFM models with complete dialogue history support
51
+
52
+ Args:
53
+ model_name: Name of the base model
54
+ use_4bit: Whether to use 4-bit quantization
55
+ max_seq_length: Maximum sequence length for complete dialogues
56
+ """
57
+ self.model_name = model_name
58
+ self.use_4bit = use_4bit
59
+ self.max_seq_length = max_seq_length
60
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+
62
+ print("="*80)
63
+ print("🚀 LFM Fine-tuning with Complete Dialogue History (KokoroChat Method)")
64
+ print("="*80)
65
+ print(f"Model: {model_name}")
66
+ print(f"Device: {self.device}")
67
+ print(f"Max sequence length: {max_seq_length}")
68
+
69
+ # GPU information
70
+ if torch.cuda.is_available():
71
+ num_gpus = torch.cuda.device_count()
72
+ print(f"Number of GPUs: {num_gpus}")
73
+ for i in range(num_gpus):
74
+ gpu_name = torch.cuda.get_device_name(i)
75
+ gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1e9
76
+ print(f" GPU {i}: {gpu_name} ({gpu_memory:.2f} GB)")
77
+
78
+ # Initialize WandB
79
+ self.init_wandb()
80
+
81
+ def init_wandb(self):
82
+ """Initialize WandB for experiment tracking"""
83
+ try:
84
+ run_name = f"lfm-kokoro-complete-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
85
+
86
+ wandb.init(
87
+ project="lfm-kokoro-complete-history",
88
+ name=run_name,
89
+ config={
90
+ "model_name": self.model_name,
91
+ "use_4bit_quantization": self.use_4bit,
92
+ "max_seq_length": self.max_seq_length,
93
+ "device": str(self.device),
94
+ "num_gpus": torch.cuda.device_count() if torch.cuda.is_available() else 0,
95
+ "methodology": "Complete dialogue history (KokoroChat)",
96
+ "framework": "transformers + peft",
97
+ "task": "japanese_counseling"
98
+ },
99
+ tags=["counseling", "japanese", "lfm", "complete-history", "kokoro"]
100
+ )
101
+
102
+ print(f"✅ WandB initialized: {wandb.run.name}")
103
+ print(f"📊 View run at: {wandb.run.get_url()}")
104
+ self.wandb_enabled = True
105
+
106
+ except Exception as e:
107
+ print(f"⚠️ WandB initialization failed: {e}")
108
+ self.wandb_enabled = False
109
+ os.environ["WANDB_DISABLED"] = "true"
110
+
111
+ def setup_model_and_tokenizer(self):
112
+ """Setup model with quantization and LoRA"""
113
+
114
+ print("\n📚 Setting up model and tokenizer...")
115
+
116
+ # Load tokenizer
117
+ print("Loading tokenizer...")
118
+ try:
119
+ self.tokenizer = AutoTokenizer.from_pretrained(
120
+ self.model_name,
121
+ trust_remote_code=True
122
+ )
123
+ except:
124
+ print("Using fallback tokenizer...")
125
+ self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
126
+
127
+ # Set special tokens
128
+ if self.tokenizer.pad_token is None:
129
+ self.tokenizer.pad_token = self.tokenizer.eos_token
130
+ if self.tokenizer.eos_token is None:
131
+ self.tokenizer.eos_token = "</s>"
132
+ self.tokenizer.pad_token = "</s>"
133
+
134
+ self.tokenizer.padding_side = "left" # Important for batch generation
135
+
136
+ # Quantization config
137
+ if self.use_4bit:
138
+ print("Setting up 4-bit quantization...")
139
+ bnb_config = BitsAndBytesConfig(
140
+ load_in_4bit=True,
141
+ bnb_4bit_quant_type="nf4",
142
+ bnb_4bit_compute_dtype=torch.bfloat16, # BF16 for H100
143
+ bnb_4bit_use_double_quant=True
144
+ )
145
+ else:
146
+ bnb_config = None
147
+
148
+ # Load model
149
+ print(f"Loading model: {self.model_name}...")
150
+ model_kwargs = {
151
+ "trust_remote_code": True,
152
+ "torch_dtype": torch.bfloat16, # BF16 for H100
153
+ "device_map": "auto",
154
+ }
155
+
156
+ if bnb_config:
157
+ model_kwargs["quantization_config"] = bnb_config
158
+
159
+ try:
160
+ self.model = AutoModelForCausalLM.from_pretrained(
161
+ self.model_name,
162
+ **model_kwargs
163
+ )
164
+ except Exception as e:
165
+ print(f"Error loading model: {e}")
166
+ print("Attempting without device_map...")
167
+ model_kwargs.pop("device_map", None)
168
+ self.model = AutoModelForCausalLM.from_pretrained(
169
+ self.model_name,
170
+ **model_kwargs
171
+ )
172
+ self.model = self.model.to(self.device)
173
+
174
+ # Enable gradient checkpointing
175
+ if hasattr(self.model, 'gradient_checkpointing_enable'):
176
+ self.model.gradient_checkpointing_enable()
177
+
178
+ # Prepare for k-bit training if using quantization
179
+ if self.use_4bit:
180
+ print("Preparing model for 4-bit training...")
181
+ self.model = prepare_model_for_kbit_training(self.model)
182
+
183
+ # LoRA configuration optimized for dialogue with complete history
184
+ print("Applying LoRA configuration...")
185
+
186
+ # Find target modules
187
+ target_modules = self.find_target_modules()
188
+
189
+ # Higher rank for complex dialogue understanding
190
+ lora_config = LoraConfig(
191
+ r=64, # Increased for better dialogue understanding
192
+ lora_alpha=128,
193
+ target_modules=target_modules,
194
+ lora_dropout=0.05,
195
+ bias="none",
196
+ task_type=TaskType.CAUSAL_LM,
197
+ inference_mode=False
198
+ )
199
+
200
+ # Apply LoRA
201
+ self.model = get_peft_model(self.model, lora_config)
202
+
203
+ # Print trainable parameters
204
+ trainable_params = 0
205
+ all_params = 0
206
+ for _, param in self.model.named_parameters():
207
+ all_params += param.numel()
208
+ if param.requires_grad:
209
+ trainable_params += param.numel()
210
+
211
+ trainable_percentage = 100 * trainable_params / all_params if all_params > 0 else 0
212
+
213
+ print(f"Trainable parameters: {trainable_params:,} / {all_params:,} ({trainable_percentage:.2f}%)")
214
+
215
+ # Log to WandB
216
+ if self.wandb_enabled:
217
+ wandb.config.update({
218
+ "lora_r": lora_config.r,
219
+ "lora_alpha": lora_config.lora_alpha,
220
+ "lora_dropout": lora_config.lora_dropout,
221
+ "lora_target_modules": target_modules,
222
+ "total_parameters": all_params,
223
+ "trainable_parameters": trainable_params,
224
+ "trainable_percentage": trainable_percentage
225
+ })
226
+
227
+ self.model.print_trainable_parameters()
228
+
229
+ def find_target_modules(self):
230
+ """Find linear modules to apply LoRA to"""
231
+ target_modules = []
232
+ for name, module in self.model.named_modules():
233
+ if isinstance(module, torch.nn.Linear):
234
+ names = name.split('.')
235
+ if len(names) > 0:
236
+ target_modules.append(names[-1])
237
+
238
+ # Remove duplicates
239
+ target_modules = list(set(target_modules))
240
+
241
+ # Common patterns for transformer models
242
+ common_targets = ["q_proj", "v_proj", "k_proj", "o_proj",
243
+ "gate_proj", "up_proj", "down_proj",
244
+ "fc1", "fc2", "query", "key", "value", "dense"]
245
+
246
+ # Filter to common targets
247
+ final_targets = [t for t in target_modules if any(ct in t.lower() for ct in common_targets)]
248
+
249
+ if not final_targets:
250
+ # Fallback to specific modules for LFM
251
+ final_targets = ["q_proj", "v_proj", "k_proj", "o_proj"]
252
+
253
+ print(f"LoRA target modules: {final_targets}")
254
+ return final_targets
255
+
256
+ def load_and_process_datasets(self, data_path: str):
257
+ """
258
+ Load and process datasets with complete dialogue history
259
+ Handles the new data format with full conversation context
260
+ """
261
+
262
+ print(f"\n📚 Loading datasets from {data_path}...")
263
+
264
+ # Check for dataset statistics
265
+ stats_file = os.path.join(data_path, 'dataset_stats.json')
266
+ if os.path.exists(stats_file):
267
+ with open(stats_file, 'r') as f:
268
+ stats = json.load(f)
269
+ print("Dataset statistics:")
270
+ print(f" Average dialogue history: {stats['dialogue_history_stats']['mean_length']:.1f} turns")
271
+ print(f" Max dialogue history: {stats['dialogue_history_stats']['max_length']} turns")
272
+ print(f" Median dialogue history: {stats['dialogue_history_stats']['median_length']:.1f} turns")
273
+
274
+ # Load datasets
275
+ train_data = []
276
+ val_data = []
277
+
278
+ # Load training data
279
+ train_file = os.path.join(data_path, 'train.jsonl')
280
+ with open(train_file, 'r', encoding='utf-8') as f:
281
+ for line in tqdm(f, desc="Loading training data"):
282
+ item = json.loads(line)
283
+ train_data.append({
284
+ 'text': item['text'],
285
+ 'history_length': item.get('history_length', 0),
286
+ 'score': item.get('score', 100),
287
+ 'topic': item.get('topic', 'general')
288
+ })
289
+
290
+ # Load validation data
291
+ val_file = os.path.join(data_path, 'val.jsonl')
292
+ with open(val_file, 'r', encoding='utf-8') as f:
293
+ for line in tqdm(f, desc="Loading validation data"):
294
+ item = json.loads(line)
295
+ val_data.append({
296
+ 'text': item['text'],
297
+ 'history_length': item.get('history_length', 0),
298
+ 'score': item.get('score', 100),
299
+ 'topic': item.get('topic', 'general')
300
+ })
301
+
302
+ print(f"Loaded {len(train_data)} training examples")
303
+ print(f"Loaded {len(val_data)} validation examples")
304
+
305
+ # Analyze dialogue history lengths
306
+ train_history_lengths = [d['history_length'] for d in train_data]
307
+ val_history_lengths = [d['history_length'] for d in val_data]
308
+
309
+ print(f"\nDialogue history length distribution:")
310
+ print(f" Training - Mean: {np.mean(train_history_lengths):.1f}, Max: {max(train_history_lengths)}")
311
+ print(f" Validation - Mean: {np.mean(val_history_lengths):.1f}, Max: {max(val_history_lengths)}")
312
+
313
+ # Log to WandB
314
+ if self.wandb_enabled:
315
+ wandb.config.update({
316
+ "train_examples": len(train_data),
317
+ "val_examples": len(val_data),
318
+ "avg_train_history_length": float(np.mean(train_history_lengths)),
319
+ "max_train_history_length": int(max(train_history_lengths)),
320
+ "avg_val_history_length": float(np.mean(val_history_lengths)),
321
+ "max_val_history_length": int(max(val_history_lengths))
322
+ })
323
+
324
+ # Log history length distribution
325
+ wandb.log({
326
+ "train_history_distribution": wandb.Histogram(train_history_lengths),
327
+ "val_history_distribution": wandb.Histogram(val_history_lengths)
328
+ })
329
+
330
+ # Tokenize datasets
331
+ print("\nTokenizing datasets with complete dialogue history...")
332
+ print(f"Using max sequence length: {self.max_seq_length}")
333
+
334
+ # Extract texts for tokenization
335
+ train_texts = [d['text'] for d in train_data]
336
+ val_texts = [d['text'] for d in val_data]
337
+
338
+ # Tokenize with longer context for complete history
339
+ train_encodings = self.tokenize_texts(train_texts, desc="Tokenizing training data")
340
+ val_encodings = self.tokenize_texts(val_texts, desc="Tokenizing validation data")
341
+
342
+ # Create datasets
343
+ self.train_dataset = Dataset.from_dict(train_encodings)
344
+ self.val_dataset = Dataset.from_dict(val_encodings)
345
+
346
+ # Set format for PyTorch
347
+ self.train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
348
+ self.val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
349
+
350
+ # Clean up memory
351
+ del train_texts, val_texts, train_encodings, val_encodings, train_data, val_data
352
+ gc.collect()
353
+
354
+ print("✅ Datasets loaded and tokenized")
355
+
356
+ def tokenize_texts(self, texts: List[str], batch_size: int = 50, desc: str = "Tokenizing"):
357
+ """
358
+ Tokenize texts in batches with support for longer sequences
359
+ """
360
+ all_input_ids = []
361
+ all_attention_masks = []
362
+
363
+ # Process in smaller batches for long sequences
364
+ for i in tqdm(range(0, len(texts), batch_size), desc=desc):
365
+ batch_texts = texts[i:i + batch_size]
366
+
367
+ # Tokenize batch with longer max length
368
+ encodings = self.tokenizer(
369
+ batch_texts,
370
+ truncation=True,
371
+ padding='max_length',
372
+ max_length=self.max_seq_length,
373
+ return_tensors='pt'
374
+ )
375
+
376
+ # Convert to lists
377
+ all_input_ids.extend(encodings['input_ids'].tolist())
378
+ all_attention_masks.extend(encodings['attention_mask'].tolist())
379
+
380
+ # Create labels (same as input_ids for causal LM)
381
+ labels = all_input_ids.copy()
382
+
383
+ return {
384
+ 'input_ids': all_input_ids,
385
+ 'attention_mask': all_attention_masks,
386
+ 'labels': labels
387
+ }
388
+
389
+ def setup_training_args(self, output_dir: str = "./lfm_kokoro_complete"):
390
+ """Setup training arguments optimized for complete dialogue history"""
391
+
392
+ print("\n⚙️ Setting up training arguments...")
393
+
394
+ # Calculate batch sizes based on sequence length and GPU memory
395
+ if torch.cuda.is_available():
396
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
397
+ num_gpus = torch.cuda.device_count()
398
+
399
+ # Adjust batch size based on sequence length and GPU memory
400
+ if self.max_seq_length >= 2048:
401
+ if gpu_memory >= 80: # H100 80GB
402
+ batch_size = 4
403
+ gradient_accumulation = 4
404
+ elif gpu_memory >= 40:
405
+ batch_size = 2
406
+ gradient_accumulation = 8
407
+ else:
408
+ batch_size = 1
409
+ gradient_accumulation = 16
410
+ else:
411
+ batch_size = 8
412
+ gradient_accumulation = 2
413
+
414
+ # Adjust for multiple GPUs
415
+ if num_gpus > 1:
416
+ batch_size = batch_size * num_gpus
417
+ gradient_accumulation = max(1, gradient_accumulation // num_gpus)
418
+ else:
419
+ batch_size = 1
420
+ gradient_accumulation = 32
421
+
422
+ print(f"Batch configuration:")
423
+ print(f" Per device batch size: {batch_size}")
424
+ print(f" Gradient accumulation steps: {gradient_accumulation}")
425
+ print(f" Effective batch size: {batch_size * gradient_accumulation}")
426
+
427
+ # Update WandB config
428
+ if self.wandb_enabled:
429
+ wandb.config.update({
430
+ "batch_size": batch_size,
431
+ "gradient_accumulation_steps": gradient_accumulation,
432
+ "effective_batch_size": batch_size * gradient_accumulation,
433
+ "num_epochs": 3,
434
+ "learning_rate": 2e-4,
435
+ "warmup_ratio": 0.1,
436
+ "weight_decay": 0.01,
437
+ "max_grad_norm": 1.0,
438
+ "lr_scheduler": "cosine",
439
+ "optimizer": "adamw_torch"
440
+ })
441
+
442
+ self.training_args = TrainingArguments(
443
+ output_dir=output_dir,
444
+ num_train_epochs=3,
445
+ per_device_train_batch_size=batch_size,
446
+ per_device_eval_batch_size=batch_size,
447
+ gradient_accumulation_steps=gradient_accumulation,
448
+ gradient_checkpointing=True,
449
+ warmup_ratio=0.1,
450
+ learning_rate=2e-4,
451
+ bf16=True, # Use BF16 for H100
452
+ tf32=True, # Enable TF32 for H100
453
+ logging_steps=10,
454
+ logging_first_step=True,
455
+ eval_strategy="steps",
456
+ eval_steps=100,
457
+ save_strategy="steps",
458
+ save_steps=200,
459
+ save_total_limit=3,
460
+ load_best_model_at_end=True,
461
+ metric_for_best_model="eval_loss",
462
+ greater_is_better=False,
463
+ report_to="wandb" if self.wandb_enabled else "none",
464
+ run_name=wandb.run.name if self.wandb_enabled and wandb.run else "local_run",
465
+ optim="adamw_torch",
466
+ lr_scheduler_type="cosine",
467
+ weight_decay=0.01,
468
+ max_grad_norm=1.0,
469
+ remove_unused_columns=False,
470
+ label_names=["labels"],
471
+ dataloader_num_workers=4,
472
+ dataloader_pin_memory=True,
473
+ ddp_find_unused_parameters=False if torch.cuda.device_count() > 1 else None,
474
+ )
475
+
476
+ def train(self):
477
+ """Execute training with complete dialogue history"""
478
+
479
+ print("\n🎯 Starting training with complete dialogue history...")
480
+
481
+ # Data collator
482
+ data_collator = DataCollatorForLanguageModeling(
483
+ tokenizer=self.tokenizer,
484
+ mlm=False,
485
+ pad_to_multiple_of=8
486
+ )
487
+
488
+ # Custom callback for metrics
489
+ class MetricsCallback(TrainerCallback):
490
+ def __init__(self, wandb_enabled):
491
+ self.wandb_enabled = wandb_enabled
492
+
493
+ def on_log(self, args, state, control, logs=None, **kwargs):
494
+ if logs and self.wandb_enabled:
495
+ # Add perplexity
496
+ if "loss" in logs:
497
+ logs["perplexity"] = np.exp(logs["loss"])
498
+ if "eval_loss" in logs:
499
+ logs["eval_perplexity"] = np.exp(logs["eval_loss"])
500
+
501
+ # Log to WandB
502
+ wandb.log(logs, step=state.global_step)
503
+
504
+ return control
505
+
506
+ # Initialize trainer
507
+ trainer = Trainer(
508
+ model=self.model,
509
+ args=self.training_args,
510
+ train_dataset=self.train_dataset,
511
+ eval_dataset=self.val_dataset,
512
+ data_collator=data_collator,
513
+ tokenizer=self.tokenizer,
514
+ callbacks=[MetricsCallback(self.wandb_enabled)] if self.wandb_enabled else [],
515
+ )
516
+
517
+ # Calculate total steps
518
+ total_steps = len(self.train_dataset) // (
519
+ self.training_args.per_device_train_batch_size *
520
+ self.training_args.gradient_accumulation_steps
521
+ ) * self.training_args.num_train_epochs
522
+
523
+ print("="*60)
524
+ print("Training Information:")
525
+ print(f" Total training samples: {len(self.train_dataset)}")
526
+ print(f" Total validation samples: {len(self.val_dataset)}")
527
+ print(f" Total training steps: {total_steps}")
528
+ print(f" Max sequence length: {self.max_seq_length}")
529
+ print("="*60)
530
+
531
+ # Log training start
532
+ if self.wandb_enabled:
533
+ wandb.log({
534
+ "training_status": "started",
535
+ "total_steps": total_steps,
536
+ "max_seq_length": self.max_seq_length
537
+ })
538
+
539
+ try:
540
+ # Train
541
+ print("\n🚀 Training started...")
542
+ train_result = trainer.train()
543
+
544
+ # Save model
545
+ print("\n💾 Saving fine-tuned model...")
546
+ final_model_path = os.path.join(self.training_args.output_dir, "final_model")
547
+ trainer.save_model(final_model_path)
548
+ self.tokenizer.save_pretrained(final_model_path)
549
+
550
+ # Save training metrics
551
+ with open(os.path.join(self.training_args.output_dir, "training_metrics.json"), 'w') as f:
552
+ json.dump(train_result.metrics, f, indent=2)
553
+
554
+ # Final evaluation
555
+ print("\n📊 Running final evaluation...")
556
+ eval_results = trainer.evaluate()
557
+
558
+ # Save evaluation metrics
559
+ with open(os.path.join(self.training_args.output_dir, "eval_metrics.json"), 'w') as f:
560
+ json.dump(eval_results, f, indent=2)
561
+
562
+ # Log final metrics
563
+ if self.wandb_enabled:
564
+ wandb.run.summary.update({
565
+ "final_train_loss": train_result.metrics.get("train_loss", 0),
566
+ "final_eval_loss": eval_results.get("eval_loss", 0),
567
+ "final_eval_perplexity": np.exp(eval_results.get("eval_loss", 0)),
568
+ "total_training_time": train_result.metrics.get("train_runtime", 0),
569
+ "training_samples_per_second": train_result.metrics.get("train_samples_per_second", 0),
570
+ "training_status": "completed"
571
+ })
572
+
573
+ # Save model artifact
574
+ artifact = wandb.Artifact(
575
+ name=f"kokoro-model-complete-{wandb.run.id}",
576
+ type="model",
577
+ description="LFM model fine-tuned with complete dialogue history",
578
+ metadata={
579
+ "base_model": self.model_name,
580
+ "final_loss": float(eval_results.get("eval_loss", 0)),
581
+ "final_perplexity": float(np.exp(eval_results.get("eval_loss", 0))),
582
+ "max_seq_length": self.max_seq_length,
583
+ "methodology": "Complete dialogue history (KokoroChat)"
584
+ }
585
+ )
586
+ artifact.add_dir(final_model_path)
587
+ wandb.log_artifact(artifact)
588
+
589
+ print("\n" + "="*60)
590
+ print("✅ Training completed successfully!")
591
+ print(f"📁 Model saved to: {final_model_path}")
592
+ print(f"📉 Final eval loss: {eval_results.get('eval_loss', 0):.4f}")
593
+ print(f"📊 Final perplexity: {np.exp(eval_results.get('eval_loss', 0)):.2f}")
594
+ if self.wandb_enabled and wandb.run:
595
+ print(f"🔗 View results at: {wandb.run.get_url()}")
596
+ print("="*60)
597
+
598
+ return trainer
599
+
600
+ except Exception as e:
601
+ print(f"❌ Error during training: {e}")
602
+
603
+ if self.wandb_enabled:
604
+ wandb.run.summary["training_status"] = "failed"
605
+ wandb.run.summary["error"] = str(e)
606
+
607
+ # Save emergency checkpoint
608
+ try:
609
+ emergency_path = os.path.join(self.training_args.output_dir, "emergency_checkpoint")
610
+ self.model.save_pretrained(emergency_path)
611
+ self.tokenizer.save_pretrained(emergency_path)
612
+ print(f"💾 Emergency checkpoint saved to: {emergency_path}")
613
+ except:
614
+ print("❌ Could not save emergency checkpoint")
615
+
616
+ raise e
617
+
618
+ finally:
619
+ if self.wandb_enabled:
620
+ wandb.finish()
621
+
622
+ def test_model_with_complete_history(model_path: str):
623
+ """Test the fine-tuned model with complete dialogue history examples"""
624
+
625
+ print("\n" + "="*60)
626
+ print("🧪 Testing model with complete dialogue history")
627
+ print("="*60)
628
+
629
+ # Load tokenizer and model
630
+ tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
631
+
632
+ # Check if it's a PEFT model
633
+ adapter_config_path = os.path.join(model_path, "adapter_config.json")
634
+ if os.path.exists(adapter_config_path):
635
+ print("Loading as PEFT model...")
636
+ config = PeftConfig.from_pretrained(model_path)
637
+ base_model = AutoModelForCausalLM.from_pretrained(
638
+ config.base_model_name_or_path,
639
+ torch_dtype=torch.bfloat16,
640
+ device_map="auto",
641
+ trust_remote_code=True
642
+ )
643
+ model = PeftModel.from_pretrained(base_model, model_path)
644
+ else:
645
+ print("Loading as regular model...")
646
+ model = AutoModelForCausalLM.from_pretrained(
647
+ model_path,
648
+ torch_dtype=torch.bfloat16,
649
+ device_map="auto",
650
+ local_files_only=True,
651
+ trust_remote_code=True
652
+ )
653
+
654
+ model.eval()
655
+
656
+ # Test with dialogue history examples
657
+ test_cases = [
658
+ {
659
+ "history": "クライアント: こんにちは。最近ストレスを感じています。\nカウンセラー: こんにちは。ストレスを感じていらっしゃるのですね。どのような状況でストレスを感じることが多いですか?\n",
660
+ "current": "クライアント: 仕事が忙しくて、休む時間がありません。"
661
+ },
662
+ {
663
+ "history": "",
664
+ "current": "クライアント: 人間関係で悩んでいます。"
665
+ }
666
+ ]
667
+
668
+ print("Testing with complete dialogue history:\n")
669
+
670
+ for i, test_case in enumerate(test_cases, 1):
671
+ print(f"Test Case {i}:")
672
+ print("-" * 40)
673
+
674
+ # Format input with complete history
675
+ if test_case["history"]:
676
+ prompt = f"""### Instruction:
677
+ あなたは専門的な訓練を受けた心理カウンセラーです。
678
+ 以下の完全な対話履歴を踏まえて、カウンセラーとして適切な応答を生成してください。
679
+
680
+ ### Dialogue History:
681
+ {test_case["history"]}{test_case["current"]}
682
+
683
+ ### Response:
684
+ """
685
+ else:
686
+ prompt = f"""### Instruction:
687
+ あなたは専門的な訓練を受けた心理カウンセラーです。
688
+
689
+ ### Dialogue History:
690
+ {test_case["current"]}
691
+
692
+ ### Response:
693
+ """
694
+
695
+ # Generate response
696
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
697
+ inputs = {k: v.cuda() if torch.cuda.is_available() else v for k, v in inputs.items()}
698
+
699
+ with torch.no_grad():
700
+ outputs = model.generate(
701
+ **inputs,
702
+ max_new_tokens=150,
703
+ temperature=0,
704
+ do_sample=True,
705
+ top_p=0.9,
706
+ pad_token_id=tokenizer.pad_token_id
707
+ )
708
+
709
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
710
+ response = response.split("### Response:")[-1].strip() if "### Response:" in response else response
711
+
712
+ # print(f"History Length: {len(test_case['history'].split('\\n')) if test_case['history'] else 0} turns")
713
+ print("History Length: {} turns".format(len(test_case['history'].split('\\n')) if test_case['history'] else 0))
714
+
715
+ print(f"Current Input: {test_case['current']}")
716
+ print(f"Generated Response: {response[:300]}...")
717
+ print()
718
+
719
+ print("="*60)
720
+
721
+ # Main execution
722
+ if __name__ == "__main__":
723
+ import argparse
724
+
725
+ parser = argparse.ArgumentParser(description='Fine-tune LFM model with complete dialogue history')
726
+ parser.add_argument('--model_name', type=str, default='LiquidAI/LFM2-2.6B',
727
+ help='Base model name')
728
+ parser.add_argument('--data_path', type=str, default='./kokoro_processed_data',
729
+ help='Path to processed data with complete dialogue history')
730
+ parser.add_argument('--output_dir', type=str, default='./lfm_kokoro_complete',
731
+ help='Output directory for fine-tuned model')
732
+ parser.add_argument('--max_seq_length', type=int, default=2048,
733
+ help='Maximum sequence length for complete dialogues')
734
+ parser.add_argument('--use_4bit', action='store_true',
735
+ help='Use 4-bit quantization')
736
+ parser.add_argument('--test_only', action='store_true',
737
+ help='Only test existing model')
738
+
739
+ args = parser.parse_args()
740
+
741
+ if args.test_only:
742
+ # Test existing model
743
+ test_model_with_complete_history(
744
+ os.path.join(args.output_dir, "final_model")
745
+ )
746
+ else:
747
+ # Check CUDA availability
748
+ if not torch.cuda.is_available():
749
+ print("⚠️ Warning: CUDA is not available. Training will be slow.")
750
+ response = input("Continue? (y/n): ")
751
+ if response.lower() != 'y':
752
+ exit()
753
+
754
+ try:
755
+ # Clear GPU cache
756
+ if torch.cuda.is_available():
757
+ torch.cuda.empty_cache()
758
+
759
+ # Initialize fine-tuner
760
+ print(f"🚀 Initializing fine-tuner for complete dialogue history")
761
+ finetuner = LFMKokoroChatFineTuner(
762
+ model_name=args.model_name,
763
+ use_4bit=args.use_4bit,
764
+ max_seq_length=args.max_seq_length
765
+ )
766
+
767
+ # Setup model
768
+ finetuner.setup_model_and_tokenizer()
769
+
770
+ # Load datasets
771
+ finetuner.load_and_process_datasets(args.data_path)
772
+
773
+ # Setup training arguments
774
+ finetuner.setup_training_args(args.output_dir)
775
+
776
+ # Train
777
+ trainer = finetuner.train()
778
+
779
+ # Test the model
780
+ print("\n🧪 Testing the fine-tuned model...")
781
+ test_model_with_complete_history(
782
+ os.path.join(args.output_dir, "final_model")
783
+ )
784
+
785
+ print("\n✅ Fine-tuning with complete dialogue history completed!")
786
+ print(f"📁 Model saved to: {args.output_dir}/final_model")
787
+ print("\n📋 Next steps:")
788
+ print(f"1. Test more: python {__file__} --test_only --output_dir {args.output_dir}")
789
+ print("2. Run benchmarking with complete history support")
790
+ print("3. Deploy for production use")
791
+
792
+ except KeyboardInterrupt:
793
+ print("\n\n⚠️ Training interrupted by user.")
794
+ if wandb.run:
795
+ wandb.finish()
796
+ except Exception as e:
797
+ print(f"\n❌ Error: {e}")
798
+ import traceback
799
+ traceback.print_exc()
800
+ if wandb.run:
801
+ wandb.finish()
finetune_trl_supervised.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Minimal Working Fine-tuning Script - No Complex Dependencies
3
+ Filename: finetune_minimal.py
4
+ """
5
+
6
+ import torch
7
+ import os
8
+ import json
9
+ from torch.utils.data import Dataset, DataLoader
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+
13
+ # Fix the import issues by reinstalling
14
+ import subprocess
15
+ import sys
16
+
17
+ def fix_environment():
18
+ """Fix the broken environment"""
19
+ print("Fixing environment...")
20
+ subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", "torchvision"], check=False)
21
+ subprocess.run([sys.executable, "-m", "pip", "install", "--no-deps", "transformers==4.36.0"], check=False)
22
+ subprocess.run([sys.executable, "-m", "pip", "install", "peft==0.7.0", "accelerate==0.25.0"], check=False)
23
+
24
+ # Uncomment if needed
25
+ # fix_environment()
26
+
27
+ # Now import after fixing
28
+ from transformers import AutoModelForCausalLM, AutoTokenizer
29
+ from peft import LoraConfig, get_peft_model, TaskType
30
+
31
+ class SimpleDataset(Dataset):
32
+ def __init__(self, data_path, tokenizer, max_length=1024):
33
+ self.data = []
34
+ with open(data_path, 'r') as f:
35
+ for line in f:
36
+ item = json.loads(line)
37
+ self.data.append(item['text'])
38
+
39
+ self.tokenizer = tokenizer
40
+ self.max_length = max_length
41
+
42
+ def __len__(self):
43
+ return len(self.data)
44
+
45
+ def __getitem__(self, idx):
46
+ text = self.data[idx]
47
+ encoded = self.tokenizer(
48
+ text,
49
+ truncation=True,
50
+ padding='max_length',
51
+ max_length=self.max_length,
52
+ return_tensors='pt'
53
+ )
54
+ return {
55
+ 'input_ids': encoded['input_ids'].squeeze(),
56
+ 'attention_mask': encoded['attention_mask'].squeeze()
57
+ }
58
+
59
+ def train_simple():
60
+ """Simple training without complex dependencies"""
61
+
62
+ # Configuration
63
+ model_name = "LiquidAI/LFM2-2.6B"
64
+ data_dir = "./kokoro_processed_data"
65
+ output_dir = "./lfm_minimal_output"
66
+ batch_size = 4
67
+ learning_rate = 2e-4
68
+ num_epochs = 2
69
+ max_length = 1024
70
+
71
+ os.makedirs(output_dir, exist_ok=True)
72
+
73
+ print("="*60)
74
+ print("Minimal Fine-tuning Script")
75
+ print("="*60)
76
+
77
+ # Device
78
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
+ print(f"Device: {device}")
80
+
81
+ # Load tokenizer
82
+ print("Loading tokenizer...")
83
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
84
+ if tokenizer.pad_token is None:
85
+ tokenizer.pad_token = tokenizer.eos_token
86
+
87
+ # Load model
88
+ print("Loading model...")
89
+ model = AutoModelForCausalLM.from_pretrained(
90
+ model_name,
91
+ torch_dtype=torch.bfloat16,
92
+ device_map="auto",
93
+ trust_remote_code=True
94
+ )
95
+
96
+ # Apply LoRA
97
+ print("Applying LoRA...")
98
+ peft_config = LoraConfig(
99
+ r=32,
100
+ lora_alpha=64,
101
+ target_modules=["q_proj", "v_proj"],
102
+ lora_dropout=0.05,
103
+ bias="none",
104
+ task_type=TaskType.CAUSAL_LM
105
+ )
106
+
107
+ model = get_peft_model(model, peft_config)
108
+ model.print_trainable_parameters()
109
+
110
+ # Load dataset
111
+ print("Loading dataset...")
112
+ train_dataset = SimpleDataset(
113
+ os.path.join(data_dir, "train.jsonl"),
114
+ tokenizer,
115
+ max_length
116
+ )
117
+
118
+ train_loader = DataLoader(
119
+ train_dataset,
120
+ batch_size=batch_size,
121
+ shuffle=True
122
+ )
123
+
124
+ # Optimizer
125
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
126
+
127
+ # Training loop
128
+ print(f"\nStarting training for {num_epochs} epochs...")
129
+ model.train()
130
+
131
+ global_step = 0
132
+ for epoch in range(num_epochs):
133
+ print(f"\nEpoch {epoch+1}/{num_epochs}")
134
+
135
+ total_loss = 0
136
+ progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
137
+
138
+ for batch in progress_bar:
139
+ global_step += 1
140
+
141
+ # Move to device
142
+ input_ids = batch['input_ids'].to(device)
143
+ attention_mask = batch['attention_mask'].to(device)
144
+
145
+ # Forward pass
146
+ outputs = model(
147
+ input_ids=input_ids,
148
+ attention_mask=attention_mask,
149
+ labels=input_ids
150
+ )
151
+
152
+ loss = outputs.loss
153
+ total_loss += loss.item()
154
+
155
+ # Backward pass
156
+ loss.backward()
157
+
158
+ # Update weights every 4 steps (gradient accumulation)
159
+ if global_step % 4 == 0:
160
+ optimizer.step()
161
+ optimizer.zero_grad()
162
+
163
+ # Update progress bar
164
+ progress_bar.set_postfix({'loss': loss.item()})
165
+
166
+ # Save checkpoint
167
+ if global_step % 500 == 0:
168
+ print(f"\nSaving checkpoint at step {global_step}...")
169
+ model.save_pretrained(os.path.join(output_dir, f"checkpoint-{global_step}"))
170
+ tokenizer.save_pretrained(os.path.join(output_dir, f"checkpoint-{global_step}"))
171
+
172
+ avg_loss = total_loss / len(train_loader)
173
+ print(f"Epoch {epoch+1} - Average Loss: {avg_loss:.4f}")
174
+
175
+ # Save final model
176
+ print("\nSaving final model...")
177
+ model.save_pretrained(os.path.join(output_dir, "final_model"))
178
+ tokenizer.save_pretrained(os.path.join(output_dir, "final_model"))
179
+
180
+ print(f"\n✅ Training complete! Model saved to {output_dir}/final_model")
181
+
182
+ # Test the model
183
+ print("\nTesting model...")
184
+ test_model(os.path.join(output_dir, "final_model"))
185
+
186
+ def test_model(model_path):
187
+ """Test the fine-tuned model"""
188
+
189
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
190
+ model = AutoModelForCausalLM.from_pretrained(
191
+ model_path,
192
+ torch_dtype=torch.bfloat16,
193
+ device_map="auto"
194
+ )
195
+
196
+ test_input = "最近ストレスを感じています。"
197
+ prompt = f"""### Instruction:
198
+ あなたは心理カウンセラーです。
199
+
200
+ ### Input:
201
+ {test_input}
202
+
203
+ ### Response:
204
+ """
205
+
206
+ inputs = tokenizer(prompt, return_tensors="pt")
207
+
208
+ with torch.no_grad():
209
+ outputs = model.generate(
210
+ inputs.input_ids.cuda(),
211
+ max_new_tokens=100,
212
+ temperature=0.7,
213
+ do_sample=True
214
+ )
215
+
216
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
217
+ print(f"\nTest Input: {test_input}")
218
+ print(f"Response: {response.split('### Response:')[-1].strip()}")
219
+
220
+ if __name__ == "__main__":
221
+ train_simple()
merge_model.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ # from peft import PeftModel
3
+ # import torch
4
+
5
+ # print("Loading base model...")
6
+ # base_model = AutoModelForCausalLM.from_pretrained(
7
+ # "./models/LFM2-1.2B",
8
+ # torch_dtype=torch.bfloat16,
9
+ # device_map="auto",
10
+ # trust_remote_code=True
11
+ # )
12
+
13
+ # print("Loading LoRA adapters...")
14
+ # model = PeftModel.from_pretrained(base_model, "./counselor_model/final_model")
15
+
16
+ # print("Merging adapters with base model...")
17
+ # merged_model = model.merge_and_unload()
18
+
19
+ # print("Saving merged model...")
20
+ # merged_model.save_pretrained("./counselor_model-merged", safe_serialization=True)
21
+
22
+ # tokenizer = AutoTokenizer.from_pretrained("./models/LFM2-1.2B")
23
+ # tokenizer.save_pretrained("./counselor_model-merged")
24
+
25
+ # print("Model merge complete!")
26
+
27
+ import torch
28
+ from transformers import AutoModelForCausalLM, AutoTokenizer
29
+ from peft import PeftModel, PeftConfig
30
+ import os
31
+
32
+ def merge_and_save_model(
33
+ base_model_name: str = "LiquidAI/LFM2-2.6B",
34
+ adapter_path: str = "./lfm_minimal_output/final_model",
35
+ output_path: str = "./merged_counselor_minimal_2b"
36
+ ):
37
+ """
38
+ Properly merge LoRA weights with base model
39
+ """
40
+ print("Loading base model...")
41
+ # Load the base model
42
+ base_model = AutoModelForCausalLM.from_pretrained(
43
+ base_model_name,
44
+ torch_dtype=torch.float16,
45
+ device_map="auto",
46
+ trust_remote_code=True
47
+ )
48
+
49
+ print("Loading LoRA adapter...")
50
+ # Load the PEFT model (LoRA adapter)
51
+ model = PeftModel.from_pretrained(
52
+ base_model,
53
+ adapter_path,
54
+ torch_dtype=torch.float16,
55
+ )
56
+
57
+ print("Merging weights...")
58
+ # Merge LoRA weights with base model
59
+ model = model.merge_and_unload()
60
+
61
+ print(f"Saving merged model to {output_path}...")
62
+ # Save the merged model
63
+ model.save_pretrained(output_path)
64
+
65
+ # Also save the tokenizer
66
+ tokenizer = AutoTokenizer.from_pretrained(adapter_path)
67
+ tokenizer.save_pretrained(output_path)
68
+
69
+ print("✅ Model merged and saved successfully!")
70
+ return model, tokenizer
71
+
72
+ # Run the merge
73
+ if __name__ == "__main__":
74
+ merge_and_save_model()
preprocess_kokoro_method.py ADDED
@@ -0,0 +1,651 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fixed Data Preprocessing for directory of JSON files with client-counselor dialogues
3
+ Following KokoroChat methodology with COMPLETE dialogue history
4
+ Filename: preprocess_kokoro_directory_fixed.py
5
+ """
6
+
7
+ import json
8
+ import os
9
+ from typing import List, Dict, Tuple, Optional, Any
10
+ from tqdm import tqdm
11
+ import random
12
+ from collections import defaultdict
13
+ import numpy as np
14
+ from pathlib import Path
15
+ import glob
16
+
17
+ class KokoroChatDirectoryPreprocessor:
18
+ def __init__(self,
19
+ input_dir: str = "./raw_counseling_data",
20
+ output_dir: str = "./kokoro_processed_data",
21
+ min_score: int = 70,
22
+ train_ratio: float = 0.8,
23
+ val_ratio: float = 0.1,
24
+ test_ratio: float = 0.1):
25
+ """
26
+ Initialize preprocessor for directory of JSON files
27
+
28
+ Args:
29
+ input_dir: Directory containing JSON files with conversations
30
+ output_dir: Directory to save processed data
31
+ min_score: Minimum score threshold for filtering (if scores exist)
32
+ train_ratio: Ratio for training data
33
+ val_ratio: Ratio for validation data
34
+ test_ratio: Ratio for test data
35
+ """
36
+ self.input_dir = input_dir
37
+ self.output_dir = output_dir
38
+ self.min_score = min_score
39
+ self.train_ratio = train_ratio
40
+ self.val_ratio = val_ratio
41
+ self.test_ratio = test_ratio
42
+
43
+ os.makedirs(output_dir, exist_ok=True)
44
+
45
+ # Track statistics
46
+ self.total_conversations = 0
47
+ self.total_utterances = 0
48
+ self.skipped_files = 0
49
+
50
+ def load_json_file(self, filepath: str) -> Optional[Dict]:
51
+ """Load a single JSON file"""
52
+ try:
53
+ with open(filepath, 'r', encoding='utf-8') as f:
54
+ data = json.load(f)
55
+ return data
56
+ except Exception as e:
57
+ print(f"⚠️ Error loading {filepath}: {e}")
58
+ self.skipped_files += 1
59
+ return None
60
+
61
+ def safe_get_value(self, obj: Any, default: Any = None) -> Any:
62
+ """Safely get a value, handling nested dicts and lists"""
63
+ if isinstance(obj, dict):
64
+ # If it's a dict, try to get a meaningful string representation
65
+ if 'name' in obj:
66
+ return str(obj['name'])
67
+ elif 'value' in obj:
68
+ return str(obj['value'])
69
+ elif 'text' in obj:
70
+ return str(obj['text'])
71
+ else:
72
+ # Return first string value found or convert to string
73
+ for v in obj.values():
74
+ if isinstance(v, str):
75
+ return v
76
+ return str(list(obj.values())[0]) if obj else default
77
+ elif isinstance(obj, list):
78
+ # If it's a list, join elements or return first element
79
+ if obj:
80
+ return str(obj[0]) if len(obj) == 1 else ', '.join(str(x) for x in obj)
81
+ return default
82
+ elif obj is None:
83
+ return default
84
+ else:
85
+ return str(obj)
86
+
87
+ def extract_dialogue_from_json(self, data: Dict, filepath: str) -> List[Dict]:
88
+ """
89
+ Extract dialogue from various JSON formats
90
+ Handles different possible structures
91
+ """
92
+ conversations = []
93
+
94
+ # Try different possible structures
95
+ if isinstance(data, list):
96
+ # If the JSON is directly a list of utterances
97
+ conversations.append({
98
+ 'dialogue': data,
99
+ 'id': os.path.basename(filepath).replace('.json', ''),
100
+ 'score': 100, # Default score
101
+ 'topic': 'general',
102
+ 'source_file': filepath
103
+ })
104
+
105
+ elif isinstance(data, dict):
106
+ # Extract score safely
107
+ score = data.get('score', 100)
108
+ if isinstance(score, dict):
109
+ score = score.get('value', 100) if 'value' in score else 100
110
+ try:
111
+ score = float(score)
112
+ except:
113
+ score = 100
114
+
115
+ # Extract topic safely
116
+ topic = self.safe_get_value(data.get('topic', 'general'), 'general')
117
+
118
+ # Check for different possible keys
119
+ if 'dialogue' in data:
120
+ conversations.append({
121
+ 'dialogue': data['dialogue'],
122
+ 'id': data.get('id', os.path.basename(filepath).replace('.json', '')),
123
+ 'score': score,
124
+ 'topic': topic,
125
+ 'source_file': filepath
126
+ })
127
+
128
+ elif 'messages' in data:
129
+ conversations.append({
130
+ 'dialogue': data['messages'],
131
+ 'id': data.get('id', os.path.basename(filepath).replace('.json', '')),
132
+ 'score': score,
133
+ 'topic': topic,
134
+ 'source_file': filepath
135
+ })
136
+
137
+ elif 'utterances' in data:
138
+ conversations.append({
139
+ 'dialogue': data['utterances'],
140
+ 'id': data.get('id', os.path.basename(filepath).replace('.json', '')),
141
+ 'score': score,
142
+ 'topic': topic,
143
+ 'source_file': filepath
144
+ })
145
+
146
+ elif 'conversations' in data:
147
+ # Multiple conversations in one file
148
+ for conv in data['conversations']:
149
+ if isinstance(conv, dict) and any(key in conv for key in ['dialogue', 'messages', 'utterances']):
150
+ dialogue_key = 'dialogue' if 'dialogue' in conv else ('messages' if 'messages' in conv else 'utterances')
151
+
152
+ # Extract score and topic safely for each conversation
153
+ conv_score = conv.get('score', score)
154
+ if isinstance(conv_score, dict):
155
+ conv_score = conv_score.get('value', 100) if 'value' in conv_score else 100
156
+ try:
157
+ conv_score = float(conv_score)
158
+ except:
159
+ conv_score = 100
160
+
161
+ conv_topic = self.safe_get_value(conv.get('topic', topic), 'general')
162
+
163
+ conversations.append({
164
+ 'dialogue': conv[dialogue_key],
165
+ 'id': conv.get('id', f"{os.path.basename(filepath)}_{len(conversations)}"),
166
+ 'score': conv_score,
167
+ 'topic': conv_topic,
168
+ 'source_file': filepath
169
+ })
170
+
171
+ else:
172
+ # Try to find any list that looks like dialogue
173
+ for key, value in data.items():
174
+ if isinstance(value, list) and len(value) > 0:
175
+ # Check if it looks like dialogue data
176
+ if isinstance(value[0], dict) and any(k in value[0] for k in ['speaker', 'role', 'text', 'content', 'utterance']):
177
+ conversations.append({
178
+ 'dialogue': value,
179
+ 'id': data.get('id', os.path.basename(filepath).replace('.json', '')),
180
+ 'score': score,
181
+ 'topic': topic,
182
+ 'source_file': filepath
183
+ })
184
+ break
185
+
186
+ return conversations
187
+
188
+ def normalize_utterance(self, utterance: Dict) -> Optional[Dict]:
189
+ """
190
+ Normalize utterance format from various possible structures
191
+ Returns: {'speaker': str, 'text': str} or None
192
+ """
193
+ # Determine speaker
194
+ speaker = None
195
+ if 'speaker' in utterance:
196
+ speaker = utterance['speaker']
197
+ elif 'role' in utterance:
198
+ speaker = utterance['role']
199
+ elif 'sender' in utterance:
200
+ speaker = utterance['sender']
201
+ elif 'from' in utterance:
202
+ speaker = utterance['from']
203
+ elif 'type' in utterance:
204
+ speaker = utterance['type']
205
+
206
+ # Determine text content
207
+ text = None
208
+ if 'text' in utterance:
209
+ text = utterance['text']
210
+ elif 'content' in utterance:
211
+ text = utterance['content']
212
+ elif 'message' in utterance:
213
+ text = utterance['message']
214
+ elif 'utterance' in utterance:
215
+ text = utterance['utterance']
216
+ elif 'response' in utterance:
217
+ text = utterance['response']
218
+
219
+ if speaker and text:
220
+ # Normalize speaker labels
221
+ speaker_lower = str(speaker).lower()
222
+ if speaker_lower in ['client', 'user', 'patient', 'クライアント', '相談者', 'c']:
223
+ normalized_speaker = 'client'
224
+ elif speaker_lower in ['counselor', 'therapist', 'assistant', 'カウンセラー', '相談員', 's', 'system']:
225
+ normalized_speaker = 'counselor'
226
+ else:
227
+ # Try to infer from position or content
228
+ normalized_speaker = 'client' if 'client' in speaker_lower else 'counselor'
229
+
230
+ return {
231
+ 'speaker': normalized_speaker,
232
+ 'text': str(text).strip()
233
+ }
234
+
235
+ return None
236
+
237
+ def merge_consecutive_utterances(self, dialogue: List[Dict]) -> List[Dict]:
238
+ """
239
+ Merge consecutive utterances from the same speaker
240
+ Following KokoroChat paper methodology
241
+ """
242
+ if not dialogue:
243
+ return []
244
+
245
+ merged = []
246
+ current_utterance = None
247
+
248
+ for utt in dialogue:
249
+ normalized = self.normalize_utterance(utt)
250
+ if not normalized:
251
+ continue
252
+
253
+ if current_utterance is None:
254
+ current_utterance = normalized
255
+ elif current_utterance['speaker'] == normalized['speaker']:
256
+ # Same speaker - merge utterances
257
+ current_utterance['text'] += ' ' + normalized['text']
258
+ else:
259
+ # Different speaker - save current and start new
260
+ merged.append(current_utterance)
261
+ current_utterance = normalized
262
+
263
+ # Don't forget the last utterance
264
+ if current_utterance:
265
+ merged.append(current_utterance)
266
+
267
+ return merged
268
+
269
+ def create_training_examples(self, conversation: Dict) -> List[Dict]:
270
+ """
271
+ Create training examples with COMPLETE dialogue history
272
+ Following the paper: Dt = {uC1, uS2, uC3, ..., uCt} -> uSt+1
273
+ """
274
+ examples = []
275
+
276
+ # Get dialogue
277
+ dialogue = conversation.get('dialogue', [])
278
+ if not dialogue:
279
+ return []
280
+
281
+ # Merge consecutive utterances from same speaker
282
+ merged_dialogue = self.merge_consecutive_utterances(dialogue)
283
+
284
+ if not merged_dialogue:
285
+ return []
286
+
287
+ # Create examples with COMPLETE history
288
+ for i in range(len(merged_dialogue)):
289
+ current = merged_dialogue[i]
290
+
291
+ # Only create examples where counselor responds
292
+ if current['speaker'] == 'counselor':
293
+ # Get COMPLETE dialogue history from beginning
294
+ complete_history = merged_dialogue[:i]
295
+
296
+ # Skip if no history or if history doesn't start with client
297
+ if not complete_history or complete_history[0]['speaker'] != 'client':
298
+ continue
299
+
300
+ # Ensure topic is a string
301
+ topic = conversation.get('topic', 'general')
302
+ if not isinstance(topic, str):
303
+ topic = self.safe_get_value(topic, 'general')
304
+
305
+ # Create training example
306
+ example = {
307
+ 'dialogue_history': complete_history,
308
+ 'response': current['text'],
309
+ 'score': conversation.get('score', 100),
310
+ 'topic': topic,
311
+ 'conversation_id': conversation.get('id', 'unknown'),
312
+ 'source_file': conversation.get('source_file', 'unknown'),
313
+ 'turn_number': i,
314
+ 'history_length': len(complete_history)
315
+ }
316
+
317
+ examples.append(example)
318
+
319
+ return examples
320
+
321
+ def format_for_training(self, example: Dict, format_type: str = 'simple') -> str:
322
+ """
323
+ Format example for training
324
+
325
+ Args:
326
+ format_type: 'simple' or 'llama' format
327
+ """
328
+ # Build complete dialogue history
329
+ history_text = ""
330
+ for turn in example['dialogue_history']:
331
+ speaker = "クライアント" if turn['speaker'] == 'client' else "カウンセラー"
332
+ history_text += f"{speaker}: {turn['text']}\n"
333
+
334
+ if format_type == 'llama':
335
+ # Llama-style format with special tokens
336
+ formatted = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
337
+ あなたは専門的な訓練を受けた心理カウンセラーです。クライアントの感情に共感し、適切な支援を提供してください。
338
+ これまでの対話履歴全体を考慮して、適切な応答を生成してください。<|eot_id|>
339
+
340
+ <|start_header_id|>user<|end_header_id|>
341
+ 以下は、クライアントとカウンセラーの完全な対話履歴です。
342
+ この履歴全体を踏まえて、次のカウンセラーの応答を生成してください。
343
+
344
+ 完全な対話履歴:
345
+ {history_text}
346
+ 次のカウンセラーの応答を生成してください。<|eot_id|>
347
+
348
+ <|start_header_id|>assistant<|end_header_id|>
349
+ {example['response']}<|eot_id|>"""
350
+
351
+ else:
352
+ # Simple format for models without special tokens
353
+ formatted = f"""### Instruction:
354
+ あなたは専門的な訓練を受けた心理カウンセラーです。
355
+ 以下の完全な対話履歴を踏まえて、カウンセラーとして適切な応答を生成してください。
356
+
357
+ ### Dialogue History:
358
+ {history_text}
359
+ ### Response:
360
+ {example['response']}"""
361
+
362
+ return formatted
363
+
364
+ def process_directory(self, format_type: str = 'simple'):
365
+ """Process all JSON files in the input directory"""
366
+ print(f"🔍 Scanning directory: {self.input_dir}")
367
+
368
+ # Find all JSON files
369
+ json_files = []
370
+ for pattern in ['*.json', '*.jsonl']:
371
+ json_files.extend(glob.glob(os.path.join(self.input_dir, '**', pattern), recursive=True))
372
+
373
+ print(f"Found {len(json_files)} JSON files")
374
+
375
+ if not json_files:
376
+ print("❌ No JSON files found in the directory!")
377
+ return
378
+
379
+ # Process each file
380
+ all_conversations = []
381
+
382
+ for filepath in tqdm(json_files, desc="Loading JSON files"):
383
+ # Handle both .json and .jsonl files
384
+ if filepath.endswith('.jsonl'):
385
+ # JSONL file - each line is a separate JSON object
386
+ with open(filepath, 'r', encoding='utf-8') as f:
387
+ for line_num, line in enumerate(f):
388
+ try:
389
+ data = json.loads(line)
390
+ conversations = self.extract_dialogue_from_json(data, f"{filepath}_line{line_num}")
391
+ all_conversations.extend(conversations)
392
+ except:
393
+ continue
394
+ else:
395
+ # Regular JSON file
396
+ data = self.load_json_file(filepath)
397
+ if data:
398
+ conversations = self.extract_dialogue_from_json(data, filepath)
399
+ all_conversations.extend(conversations)
400
+
401
+ print(f"✅ Loaded {len(all_conversations)} conversations from {len(json_files) - self.skipped_files} files")
402
+ print(f"⚠️ Skipped {self.skipped_files} files due to errors")
403
+
404
+ # Filter by score
405
+ conversations_before_filter = len(all_conversations)
406
+ filtered_conversations = [
407
+ conv for conv in all_conversations
408
+ if conv.get('score', 100) >= self.min_score
409
+ ]
410
+ conversations_after_filter = len(filtered_conversations)
411
+
412
+ print(f"📊 Score filtering (>= {self.min_score}):")
413
+ print(f" Before: {conversations_before_filter} conversations")
414
+ print(f" After: {conversations_after_filter} conversations")
415
+ print(f" Filtered out: {conversations_before_filter - conversations_after_filter} conversations")
416
+
417
+ # Create training examples
418
+ all_examples = []
419
+ history_lengths = []
420
+
421
+ for conv in tqdm(filtered_conversations, desc="Creating training examples"):
422
+ examples = self.create_training_examples(conv)
423
+ all_examples.extend(examples)
424
+ history_lengths.extend([ex['history_length'] for ex in examples])
425
+
426
+ if not all_examples:
427
+ print("❌ No training examples created!")
428
+ return
429
+
430
+ print(f"✅ Created {len(all_examples)} training examples from {len(filtered_conversations)} conversations")
431
+ print(f"📊 Dialogue history statistics:")
432
+ print(f" - Mean length: {np.mean(history_lengths):.1f} turns")
433
+ print(f" - Median length: {np.median(history_lengths):.1f} turns")
434
+ print(f" - Max length: {max(history_lengths)} turns")
435
+ print(f" - Min length: {min(history_lengths)} turns")
436
+
437
+ # Shuffle and split
438
+ random.shuffle(all_examples)
439
+
440
+ train_size = int(self.train_ratio * len(all_examples))
441
+ val_size = int(self.val_ratio * len(all_examples))
442
+
443
+ train_data = all_examples[:train_size]
444
+ val_data = all_examples[train_size:train_size + val_size]
445
+ test_data = all_examples[train_size + val_size:]
446
+
447
+ print(f"\n📂 Split sizes:")
448
+ print(f" Train: {len(train_data)} ({self.train_ratio*100:.0f}%)")
449
+ print(f" Val: {len(val_data)} ({self.val_ratio*100:.0f}%)")
450
+ print(f" Test: {len(test_data)} ({self.test_ratio*100:.0f}%)")
451
+
452
+ # Save splits
453
+ self.save_split(train_data, 'train', format_type)
454
+ self.save_split(val_data, 'val', format_type)
455
+ self.save_split(test_data, 'test', format_type)
456
+
457
+ # Save statistics
458
+ self.save_statistics(
459
+ train_data, val_data, test_data,
460
+ all_conversations, filtered_conversations,
461
+ history_lengths
462
+ )
463
+
464
+ print(f"\n✅ Processing complete! Data saved to {self.output_dir}")
465
+
466
+ def save_split(self, data: List[Dict], split_name: str, format_type: str = 'simple'):
467
+ """Save processed data split"""
468
+ output_file = os.path.join(self.output_dir, f"{split_name}.jsonl")
469
+
470
+ with open(output_file, 'w', encoding='utf-8') as f:
471
+ for example in tqdm(data, desc=f"Saving {split_name} data"):
472
+ formatted_text = self.format_for_training(example, format_type)
473
+
474
+ # Ensure topic is string
475
+ topic = example.get('topic', 'general')
476
+ if not isinstance(topic, str):
477
+ topic = self.safe_get_value(topic, 'general')
478
+
479
+ output_item = {
480
+ 'text': formatted_text,
481
+ 'dialogue_history': example['dialogue_history'],
482
+ 'response': example['response'],
483
+ 'score': example['score'],
484
+ 'topic': topic,
485
+ 'conversation_id': example['conversation_id'],
486
+ 'source_file': example['source_file'],
487
+ 'turn_number': example['turn_number'],
488
+ 'history_length': example['history_length']
489
+ }
490
+
491
+ f.write(json.dumps(output_item, ensure_ascii=False) + '\n')
492
+
493
+ print(f"✅ Saved {split_name} data to {output_file}")
494
+
495
+ def save_statistics(self, train_data, val_data, test_data,
496
+ all_conversations, filtered_conversations, history_lengths):
497
+ """Save comprehensive statistics"""
498
+ # Calculate topic distribution (safely)
499
+ topic_counts = defaultdict(int)
500
+ for example in train_data:
501
+ topic = example.get('topic', 'general')
502
+ if not isinstance(topic, str):
503
+ topic = self.safe_get_value(topic, 'general')
504
+ topic_counts[topic] += 1
505
+
506
+ # Calculate source file distribution
507
+ source_counts = defaultdict(int)
508
+ for example in train_data:
509
+ source_file = os.path.basename(example.get('source_file', 'unknown'))
510
+ source_counts[source_file] += 1
511
+
512
+ # Score statistics for filtered conversations
513
+ scores = [conv.get('score', 100) for conv in filtered_conversations]
514
+
515
+ stats = {
516
+ 'preprocessing_info': {
517
+ 'input_directory': self.input_dir,
518
+ 'output_directory': self.output_dir,
519
+ 'total_files_processed': len(set(conv.get('source_file', 'unknown') for conv in all_conversations)),
520
+ 'total_conversations_loaded': len(all_conversations),
521
+ 'conversations_after_filtering': len(filtered_conversations),
522
+ 'conversations_filtered_out': len(all_conversations) - len(filtered_conversations),
523
+ 'total_training_examples': len(train_data) + len(val_data) + len(test_data),
524
+ 'min_score_threshold': self.min_score,
525
+ 'methodology': 'KokoroChat paper - complete dialogue history'
526
+ },
527
+ 'score_filtering': {
528
+ 'threshold': self.min_score,
529
+ 'before_filtering': len(all_conversations),
530
+ 'after_filtering': len(filtered_conversations),
531
+ 'filtered_out': len(all_conversations) - len(filtered_conversations),
532
+ 'percentage_kept': (len(filtered_conversations) / len(all_conversations) * 100) if all_conversations else 0
533
+ },
534
+ 'score_statistics': {
535
+ 'mean': float(np.mean(scores)),
536
+ 'std': float(np.std(scores)),
537
+ 'min': float(min(scores)),
538
+ 'max': float(max(scores)),
539
+ 'median': float(np.median(scores)),
540
+ 'percentile_25': float(np.percentile(scores, 25)),
541
+ 'percentile_75': float(np.percentile(scores, 75))
542
+ },
543
+ 'split_sizes': {
544
+ 'train': len(train_data),
545
+ 'val': len(val_data),
546
+ 'test': len(test_data),
547
+ 'train_ratio': self.train_ratio,
548
+ 'val_ratio': self.val_ratio,
549
+ 'test_ratio': self.test_ratio
550
+ },
551
+ 'dialogue_history_stats': {
552
+ 'mean_length': float(np.mean(history_lengths)),
553
+ 'std_length': float(np.std(history_lengths)),
554
+ 'min_length': int(min(history_lengths)),
555
+ 'max_length': int(max(history_lengths)),
556
+ 'median_length': float(np.median(history_lengths)),
557
+ 'percentile_25': float(np.percentile(history_lengths, 25)),
558
+ 'percentile_75': float(np.percentile(history_lengths, 75)),
559
+ 'percentile_95': float(np.percentile(history_lengths, 95))
560
+ },
561
+ 'topic_distribution': dict(list(topic_counts.items())[:20]), # Top 20 topics
562
+ 'source_file_distribution': dict(list(source_counts.items())[:20]), # Top 20 files
563
+ 'history_length_bins': {
564
+ '1-5_turns': sum(1 for l in history_lengths if l <= 5),
565
+ '6-10_turns': sum(1 for l in history_lengths if 5 < l <= 10),
566
+ '11-15_turns': sum(1 for l in history_lengths if 10 < l <= 15),
567
+ '16-20_turns': sum(1 for l in history_lengths if 15 < l <= 20),
568
+ '21-30_turns': sum(1 for l in history_lengths if 20 < l <= 30),
569
+ '31-50_turns': sum(1 for l in history_lengths if 30 < l <= 50),
570
+ '50+_turns': sum(1 for l in history_lengths if l > 50)
571
+ }
572
+ }
573
+
574
+ stats_file = os.path.join(self.output_dir, 'dataset_stats.json')
575
+ with open(stats_file, 'w', encoding='utf-8') as f:
576
+ json.dump(stats, f, ensure_ascii=False, indent=2)
577
+
578
+ print(f"\n📊 Statistics saved to {stats_file}")
579
+
580
+ # Print summary
581
+ print("\n" + "="*70)
582
+ print("📈 DATASET STATISTICS SUMMARY")
583
+ print("="*70)
584
+ print(f"Files processed: {stats['preprocessing_info']['total_files_processed']}")
585
+ print(f"Conversations loaded: {stats['preprocessing_info']['total_conversations_loaded']}")
586
+ print(f"After score filtering (>={self.min_score}): {stats['preprocessing_info']['conversations_after_filtering']}")
587
+ print(f"Training examples created: {stats['preprocessing_info']['total_training_examples']}")
588
+ print(f"\nScore Statistics (after filtering):")
589
+ print(f" Mean: {stats['score_statistics']['mean']:.1f}")
590
+ print(f" Median: {stats['score_statistics']['median']:.1f}")
591
+ print(f" Range: {stats['score_statistics']['min']:.0f} - {stats['score_statistics']['max']:.0f}")
592
+ print(f"\nDialogue History Length Distribution:")
593
+ for bin_name, count in stats['history_length_bins'].items():
594
+ percentage = (count / len(history_lengths)) * 100 if history_lengths else 0
595
+ print(f" {bin_name}: {count} ({percentage:.1f}%)")
596
+ print("="*70)
597
+
598
+
599
+ def main():
600
+ import argparse
601
+
602
+ parser = argparse.ArgumentParser(
603
+ description='Preprocess directory of JSON files with counseling dialogues'
604
+ )
605
+ parser.add_argument(
606
+ '--input_dir',
607
+ type=str,
608
+ default='./KokoroChat/kokorochat_dialogues',
609
+ help='Directory containing JSON files with conversations'
610
+ )
611
+ parser.add_argument(
612
+ '--output_dir',
613
+ type=str,
614
+ default='./kokoro_processed_data',
615
+ help='Output directory for processed data'
616
+ )
617
+ parser.add_argument(
618
+ '--min_score',
619
+ type=int,
620
+ default=70,
621
+ help='Minimum score threshold (if scores exist in data)'
622
+ )
623
+ parser.add_argument(
624
+ '--format',
625
+ type=str,
626
+ choices=['simple', 'llama'],
627
+ default='simple',
628
+ help='Output format type'
629
+ )
630
+
631
+ args = parser.parse_args()
632
+
633
+ # Initialize preprocessor
634
+ preprocessor = KokoroChatDirectoryPreprocessor(
635
+ input_dir=args.input_dir,
636
+ output_dir=args.output_dir,
637
+ min_score=args.min_score
638
+ )
639
+
640
+ print("🚀 Starting preprocessing with COMPLETE dialogue history")
641
+ print(" Following KokoroChat paper methodology")
642
+ print("="*70)
643
+
644
+ # Process directory
645
+ preprocessor.process_directory(format_type=args.format)
646
+
647
+ print("\n✅ Preprocessing complete!")
648
+
649
+
650
+ if __name__ == "__main__":
651
+ main()
score_analysis_threshold_60.png ADDED

Git LFS Details

  • SHA256: db92bcddc596a139b29fd09421a87035d96547c5ff020732e2662e2c366e7d79
  • Pointer size: 131 Bytes
  • Size of remote file: 343 kB
score_distribution.png ADDED

Git LFS Details

  • SHA256: 4a943a076098bbe516b6cb0cb71a39c8a5d704b6c7b60398b89747065665fd15
  • Pointer size: 131 Bytes
  • Size of remote file: 353 kB
training_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name_or_path": "LiquidAI/LFM2-2.6B",
3
+ "use_lora": true,
4
+ "lora_r": 64,
5
+ "lora_alpha": 128,
6
+ "lora_dropout": 0.05,
7
+ "data_path": "./kokoro_processed_data",
8
+ "max_seq_length": 2048,
9
+ "response_template": "### Response:",
10
+ "output_dir": "./lfm_trl_finetuned",
11
+ "num_train_epochs": 3,
12
+ "per_device_train_batch_size": 4,
13
+ "per_device_eval_batch_size": 4,
14
+ "gradient_accumulation_steps": 4,
15
+ "learning_rate": 2e-4,
16
+ "warmup_ratio": 0.1,
17
+ "logging_steps": 10,
18
+ "save_steps": 100,
19
+ "eval_steps": 100,
20
+ "bf16": true,
21
+ "tf32": true,
22
+ "seed": 42
23
+ }