File size: 24,095 Bytes
27c46c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
import json
import os
from pathlib import Path
import pandas as pd
from typing import List, Dict, Tuple, Optional
import random
from tqdm import tqdm
import re
import matplotlib.pyplot as plt
import numpy as np

class KokoroChatPreprocessor:
    def __init__(self, data_path: str, max_length: int = 2048, min_score: int = 60):
        """
        Initialize the preprocessor for KokoroChat dataset
        
        Args:
            data_path: Path to KokoroChat repository
            max_length: Maximum sequence length for model input
            min_score: Minimum score threshold for filtering conversations (default: 60)
        """
        self.data_path = Path(data_path)
        self.max_length = max_length
        self.min_score = min_score
        self.conversations = []
        self.score_distribution = []  # Track score distribution
        self.system_prompt = """あなたは思いやりのある心理カウンセラーです。
クライアントの感情を理解し、共感的で支援的な応答を提供してください。
プライバシーを尊重し、判断を下さず、希望と実用的な洞察を提供することに焦点を当ててください。"""
        
    def load_json_files(self) -> List[Dict]:
        """Load all JSON files from the dataset"""
        json_files = []
        # Changed from "data" to "kokorochat_dialogues"
        data_dir = self.data_path / "kokorochat_dialogues"
        
        # Check if data directory exists, if not try root directory
        if not data_dir.exists():
            data_dir = self.data_path
            print(f"Using root directory: {data_dir}")
        else:
            print(f"Using data directory: {data_dir}")
            
        for root, dirs, files in os.walk(data_dir):
            for file in tqdm(files, desc="Loading JSON files"):
                if file.endswith('.json'):
                    file_path = os.path.join(root, file)
                    try:
                        with open(file_path, 'r', encoding='utf-8') as f:
                            data = json.load(f)
                            json_files.append(data)
                    except Exception as e:
                        print(f"Error loading {file_path}: {e}")
        
        return json_files
    
    def analyze_score_distribution(self, json_files: List[Dict]) -> Dict:
        """
        Analyze the distribution of scores in the dataset
        
        Returns:
            Dictionary with score statistics
        """
        scores = []
        for data in json_files:
            if 'review_by_client_jp' in data:
                score = data['review_by_client_jp'].get('点数', 0)
                if score > 0:  # Only count valid scores
                    scores.append(score)
                    self.score_distribution.append(score)
        
        if scores:
            stats = {
                'total_conversations': len(json_files),
                'conversations_with_scores': len(scores),
                'mean_score': float(np.mean(scores)),
                'median_score': float(np.median(scores)),
                'std_score': float(np.std(scores)),
                'min_score': float(np.min(scores)),
                'max_score': float(np.max(scores)),
                'percentiles': {
                    '25th': float(np.percentile(scores, 25)),
                    '50th': float(np.percentile(scores, 50)),
                    '75th': float(np.percentile(scores, 75)),
                    '90th': float(np.percentile(scores, 90))
                },
                'score_ranges': {
                    '0-30': int(sum(1 for s in scores if 0 <= s < 30)),
                    '30-50': int(sum(1 for s in scores if 30 <= s < 50)),
                    '50-60': int(sum(1 for s in scores if 50 <= s < 60)),
                    '60-70': int(sum(1 for s in scores if 60 <= s < 70)),
                    '70-80': int(sum(1 for s in scores if 70 <= s < 80)),
                    '80-90': int(sum(1 for s in scores if 80 <= s < 90)),
                    '90-100': int(sum(1 for s in scores if 90 <= s <= 100)),
                }
            }
            
            # Calculate how many conversations would be kept at different thresholds
            threshold_analysis = {}
            for threshold in [30, 40, 50, 60, 65, 70, 75, 80]:
                kept = sum(1 for s in scores if s >= threshold)
                threshold_analysis[f'threshold_{threshold}'] = {
                    'conversations_kept': kept,
                    'percentage_kept': round((kept / len(scores)) * 100, 2)
                }
            stats['threshold_analysis'] = threshold_analysis
            
            return stats
        else:
            return {'error': 'No valid scores found in dataset'}
    
    def plot_score_distribution(self, save_path: str = "score_distribution.png"):
        """
        Plot the distribution of scores
        """
        if not self.score_distribution:
            print("No scores to plot. Run analyze_score_distribution first.")
            return
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # Histogram
        axes[0, 0].hist(self.score_distribution, bins=20, edgecolor='black', alpha=0.7)
        axes[0, 0].axvline(self.min_score, color='red', linestyle='--', 
                          label=f'Current threshold: {self.min_score}')
        axes[0, 0].set_xlabel('Score')
        axes[0, 0].set_ylabel('Frequency')
        axes[0, 0].set_title('Score Distribution')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
        
        # Box plot
        axes[0, 1].boxplot(self.score_distribution, vert=True)
        axes[0, 1].set_ylabel('Score')
        axes[0, 1].set_title('Score Box Plot')
        axes[0, 1].grid(True, alpha=0.3)
        
        # Cumulative distribution
        sorted_scores = np.sort(self.score_distribution)
        cumulative = np.arange(1, len(sorted_scores) + 1) / len(sorted_scores)
        axes[1, 0].plot(sorted_scores, cumulative)
        axes[1, 0].axvline(self.min_score, color='red', linestyle='--', 
                          label=f'Current threshold: {self.min_score}')
        axes[1, 0].set_xlabel('Score')
        axes[1, 0].set_ylabel('Cumulative Probability')
        axes[1, 0].set_title('Cumulative Distribution')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
        
        # Threshold impact analysis
        thresholds = range(30, 90, 5)
        kept_percentages = []
        for t in thresholds:
            kept = sum(1 for s in self.score_distribution if s >= t)
            kept_percentages.append((kept / len(self.score_distribution)) * 100)
        
        axes[1, 1].plot(thresholds, kept_percentages, marker='o')
        axes[1, 1].axvline(self.min_score, color='red', linestyle='--', 
                          label=f'Current threshold: {self.min_score}')
        axes[1, 1].set_xlabel('Score Threshold')
        axes[1, 1].set_ylabel('% of Conversations Kept')
        axes[1, 1].set_title('Impact of Score Threshold')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
        print(f"Score distribution plot saved to {save_path}")
    
    def extract_high_quality_conversations(self, data: Dict) -> List[Dict]:
        """
        Extract conversations with high counselor ratings based on min_score
        Focus on conversations where counselor performed well
        """
        conversations = []
        
        # Check if review exists and has good score
        if 'review_by_client_jp' in data:
            review = data['review_by_client_jp']
            score = review.get('点数', 0)
            
            # Use configurable min_score threshold
            if score >= self.min_score:
                dialogue = data.get('dialogue', [])
                
                # Create conversation pairs
                conversation_text = ""
                for turn in dialogue:
                    role = turn['role']
                    utterance = turn['utterance']
                    
                    if role == 'counselor':
                        conversation_text += f"カウンセラー: {utterance}\n"
                    else:
                        conversation_text += f"クライアント: {utterance}\n"
                
                # Extract detailed metrics for potential weighted training
                conversations.append({
                    'text': conversation_text,
                    'score': score,  # Store the score here
                    'topic': data.get('topic', {}).get('main_jp', 'Unknown'),
                    'review_metrics': {
                        'empathy': review.get('聴いてもらえた、わかってもらえたと感じた', 0),
                        'respect': review.get('尊重されたと感じた', 0),
                        'insights': review.get('新しい気づきや体験があった', 0),
                        'hope': review.get('希望や期待を感じられた', 0),
                        'concerns_addressed': review.get('取り組みたかったことを扱えた', 0),
                        'collaboration': review.get('一緒に考えながら取り組めた', 0),
                        'rhythm': review.get('やりとりのリズムがあっていた', 0),
                        'comfort': review.get('居心地のよいやりとりだった', 0),
                        'overall_appropriate': review.get('全体として適切でよかった', 0),
                        'valuable': review.get('今回の相談は価値があった', 0),
                        'smooth_start': review.get('相談開始の円滑さ', 0),
                        'good_ending': review.get('相談終了のタイミング(不必要に聴きすぎていないか)、円滑さ', 0),
                        'acceptance_empathy': review.get('受容·共感', 0),
                        'affirmation': review.get('肯定·承認', 0),
                        'effective_questions': review.get('的確な質問による会話の促進', 0),
                        'summarization': review.get('要約', 0),
                        'problem_clarification': review.get('問題の明確化', 0),
                        'goal_clarification': review.get('この相談での目標の明確化', 0),
                        'actionable_suggestions': review.get('次の行動につながる提案', 0),
                        'encouragement': review.get('勇気づけ·希望の喚起', 0)
                    }
                })
        
        return conversations
    
    def create_training_examples(self, conversations: List[Dict], 
                                use_weighted_sampling: bool = False) -> List[Dict]:
        """
        Create training examples in instruction-following format
        
        Args:
            conversations: List of conversation dictionaries
            use_weighted_sampling: If True, create more examples from higher-scored conversations
        """
        training_examples = []
        
        for conv in tqdm(conversations, desc="Creating training examples"):
            dialogue_lines = conv['text'].split('\n')
            score = conv['score']  # Get score from the conversation dict
            
            # Calculate sampling weight based on score if enabled
            if use_weighted_sampling:
                # Higher scores get more weight (normalized to 1-3 range)
                weight = max(1, int((score - self.min_score) / 20) + 1)
            else:
                weight = 1
            
            # Create multiple training examples from each conversation
            for _ in range(weight):  # Repeat based on weight
                for i in range(0, len(dialogue_lines) - 1, 2):
                    if i + 1 < len(dialogue_lines):
                        client_line = dialogue_lines[i]
                        counselor_line = dialogue_lines[i + 1]
                        
                        # Check if lines contain the expected prefixes
                        if 'クライアント:' in client_line and 'カウンセラー:' in counselor_line:
                            client_msg = client_line.replace('クライアント: ', '').replace('クライアント:', '').strip()
                            counselor_msg = counselor_line.replace('カウンセラー: ', '').replace('カウンセラー:', '').strip()
                            
                            # Skip empty messages
                            if not client_msg or not counselor_msg:
                                continue
                            
                            # Format for instruction tuning
                            example = {
                                'instruction': self.system_prompt,
                                'input': client_msg,
                                'output': counselor_msg,
                                'score': score,  # Use the score from conversation
                                'topic': conv['topic'],
                                'metrics': conv['review_metrics']  # Include detailed metrics
                            }
                            
                            training_examples.append(example)
        
        return training_examples
    
    def prepare_dataset(self, test_size: float = 0.1, val_size: float = 0.1,
                        use_weighted_sampling: bool = False,
                        analyze_scores: bool = True):
        """
        Prepare train, validation, and test datasets
        
        Args:
            test_size: Proportion of data for testing
            val_size: Proportion of data for validation
            use_weighted_sampling: If True, oversample high-quality conversations
            analyze_scores: If True, print score distribution analysis
        """
        print("Loading KokoroChat dataset...")
        json_files = self.load_json_files()
        print(f"Loaded {len(json_files)} conversation files")
        
        # Analyze score distribution if requested
        if analyze_scores:
            print("\n" + "="*60)
            print("SCORE DISTRIBUTION ANALYSIS")
            print("="*60)
            stats = self.analyze_score_distribution(json_files)
            
            if 'error' not in stats:
                print(f"Total conversations: {stats['total_conversations']}")
                print(f"Conversations with scores: {stats['conversations_with_scores']}")
                print(f"\nScore Statistics:")
                print(f"  Mean: {stats['mean_score']:.2f}")
                print(f"  Median: {stats['median_score']:.2f}")
                print(f"  Std Dev: {stats['std_score']:.2f}")
                print(f"  Range: {stats['min_score']:.0f} - {stats['max_score']:.0f}")
                
                print(f"\nScore Distribution:")
                for range_name, count in stats['score_ranges'].items():
                    percentage = (count / stats['conversations_with_scores']) * 100
                    print(f"  {range_name}: {count} ({percentage:.1f}%)")
                
                print(f"\nThreshold Impact Analysis:")
                for threshold_name, data in stats['threshold_analysis'].items():
                    threshold = threshold_name.split('_')[1]
                    print(f"  Threshold >= {threshold}: {data['conversations_kept']} conversations ({data['percentage_kept']:.1f}%)")
                
                print(f"\nCurrent threshold ({self.min_score}) will keep: ", end="")
                kept = sum(1 for s in self.score_distribution if s >= self.min_score)
                print(f"{kept} conversations ({(kept/len(self.score_distribution))*100:.1f}%)")
                print("="*60 + "\n")
                
                # Plot distribution
                self.plot_score_distribution()
        
        all_conversations = []
        filtered_count = 0
        total_count = 0
        
        for data in json_files:
            if 'review_by_client_jp' in data:
                total_count += 1
                score = data['review_by_client_jp'].get('点数', 0)
                if score < self.min_score:
                    filtered_count += 1
            
            conversations = self.extract_high_quality_conversations(data)
            all_conversations.extend(conversations)
        
        print(f"Filtered out {filtered_count} conversations with score < {self.min_score}")
        print(f"Extracted {len(all_conversations)} high-quality conversations (score >= {self.min_score})")
        
        # Create training examples
        training_examples = self.create_training_examples(
            all_conversations, 
            use_weighted_sampling=use_weighted_sampling
        )
        print(f"Created {len(training_examples)} training examples")
        
        if use_weighted_sampling:
            print("Note: Used weighted sampling - higher scored conversations appear more frequently")
        
        # Shuffle and split
        random.shuffle(training_examples)
        
        total_size = len(training_examples)
        test_split = int(total_size * test_size)
        val_split = int(total_size * val_size)
        
        test_data = training_examples[:test_split]
        val_data = training_examples[test_split:test_split + val_split]
        train_data = training_examples[test_split + val_split:]
        
        print(f"\nDataset splits:")
        print(f"  Train: {len(train_data)} examples")
        print(f"  Validation: {len(val_data)} examples")
        print(f"  Test: {len(test_data)} examples")
        
        return {
            'train': train_data,
            'validation': val_data,
            'test': test_data
        }
    
    def format_for_lfm(self, example: Dict) -> str:
        """
        Format example for LFM model training
        """
        formatted = f"""### Instruction:
{example['instruction']}

### Input:
{example['input']}

### Response:
{example['output']}"""
        return formatted
    
    def save_datasets(self, datasets: Dict, output_dir: str):
        """Save processed datasets with proper type conversion for JSON serialization"""
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)
        
        # Helper function to convert numpy types to Python native types
        def convert_to_native(obj):
            if isinstance(obj, np.integer):
                return int(obj)
            elif isinstance(obj, np.floating):
                return float(obj)
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            else:
                return obj
        
        # Save dataset statistics
        stats = {
            'min_score_threshold': int(self.min_score),
            'dataset_sizes': {
                'train': len(datasets['train']),
                'validation': len(datasets['validation']),
                'test': len(datasets['test'])
            },
            'score_distribution': {}
        }
        
        for split_name, data in datasets.items():
            # Calculate score distribution for this split
            scores = [ex['score'] for ex in data]
            if scores:
                stats['score_distribution'][split_name] = {
                    'mean': float(np.mean(scores)),
                    'median': float(np.median(scores)),
                    'min': float(np.min(scores)),
                    'max': float(np.max(scores)),
                    'std': float(np.std(scores))
                }
            
            # Save as JSONL for easier streaming
            file_path = output_path / f"{split_name}.jsonl"
            with open(file_path, 'w', encoding='utf-8') as f:
                for example in data:
                    formatted_text = self.format_for_lfm(example)
                    # Convert all numpy types to native Python types
                    json_obj = {
                        'text': formatted_text,
                        'score': convert_to_native(example['score']),
                        'topic': example['topic']
                    }
                    json_line = json.dumps(json_obj, ensure_ascii=False)
                    f.write(json_line + '\n')
            
            print(f"Saved {split_name} dataset with {len(data)} examples to {file_path}")
        
        # Save statistics
        stats_path = output_path / "dataset_stats.json"
        with open(stats_path, 'w', encoding='utf-8') as f:
            json.dump(stats, f, ensure_ascii=False, indent=2)
        print(f"Saved dataset statistics to {stats_path}")
        
        # Print summary statistics
        print("\n" + "="*60)
        print("DATASET SUMMARY")
        print("="*60)
        print(f"Minimum score threshold: {stats['min_score_threshold']}")
        print("\nDataset sizes:")
        for split, size in stats['dataset_sizes'].items():
            print(f"  {split}: {size} examples")
        
        print("\nScore distributions by split:")
        for split, dist in stats['score_distribution'].items():
            print(f"  {split}:")
            print(f"    Mean: {dist['mean']:.2f}")
            print(f"    Std:  {dist['std']:.2f}")
            print(f"    Range: {dist['min']:.0f} - {dist['max']:.0f}")
        print("="*60)

# Run preprocessing with different score thresholds
if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description='Preprocess KokoroChat dataset')
    parser.add_argument('--data_path', type=str, default='./KokoroChat',
                       help='Path to KokoroChat repository')
    parser.add_argument('--min_score', type=int, default=70,
                       help='Minimum score threshold for filtering (default: 70)')
    parser.add_argument('--output_dir', type=str, default='./processed_data',
                       help='Output directory for processed data')
    parser.add_argument('--weighted_sampling', action='store_true',
                       help='Use weighted sampling based on scores')
    parser.add_argument('--test_size', type=float, default=0.1,
                       help='Test set size (default: 0.1)')
    parser.add_argument('--val_size', type=float, default=0.1,
                       help='Validation set size (default: 0.1)')
    parser.add_argument('--analyze_only', action='store_true',
                       help='Only analyze score distribution without processing')
    
    args = parser.parse_args()
    
    # Initialize preprocessor with configurable min_score
    preprocessor = KokoroChatPreprocessor(
        data_path=args.data_path,
        min_score=args.min_score
    )
    
    if args.analyze_only:
        # Just analyze the score distribution
        print("Running score distribution analysis only...")
        json_files = preprocessor.load_json_files()
        stats = preprocessor.analyze_score_distribution(json_files)
        preprocessor.plot_score_distribution(f"score_analysis_threshold_{args.min_score}.png")
    else:
        # Full preprocessing
        print(f"Processing with minimum score threshold: {args.min_score}")
        datasets = preprocessor.prepare_dataset(
            test_size=args.test_size,
            val_size=args.val_size,
            use_weighted_sampling=args.weighted_sampling,
            analyze_scores=True
        )
        
        # Save with threshold in directory name
        output_dir = f"{args.output_dir}_score{args.min_score}"
        preprocessor.save_datasets(datasets, output_dir)
        
        print(f"\nProcessing complete! Data saved to {output_dir}")
        print("\nNext steps:")
        print("1. Run fine-tuning: python finetune_lfm.py")
        print("2. Run benchmarking: python benchmark_model.py")
        print("3. Optimize for mobile: python optimize_for_mobile.py")