| import json |
| import os |
| from pathlib import Path |
| import pandas as pd |
| from typing import List, Dict, Tuple, Optional |
| import random |
| from tqdm import tqdm |
| import re |
| import matplotlib.pyplot as plt |
| import numpy as np |
|
|
| class KokoroChatPreprocessor: |
| def __init__(self, data_path: str, max_length: int = 2048, min_score: int = 60): |
| """ |
| Initialize the preprocessor for KokoroChat dataset |
| |
| Args: |
| data_path: Path to KokoroChat repository |
| max_length: Maximum sequence length for model input |
| min_score: Minimum score threshold for filtering conversations (default: 60) |
| """ |
| self.data_path = Path(data_path) |
| self.max_length = max_length |
| self.min_score = min_score |
| self.conversations = [] |
| self.score_distribution = [] |
| self.system_prompt = """あなたは思いやりのある心理カウンセラーです。 |
| クライアントの感情を理解し、共感的で支援的な応答を提供してください。 |
| プライバシーを尊重し、判断を下さず、希望と実用的な洞察を提供することに焦点を当ててください。""" |
| |
| def load_json_files(self) -> List[Dict]: |
| """Load all JSON files from the dataset""" |
| json_files = [] |
| |
| data_dir = self.data_path / "kokorochat_dialogues" |
| |
| |
| if not data_dir.exists(): |
| data_dir = self.data_path |
| print(f"Using root directory: {data_dir}") |
| else: |
| print(f"Using data directory: {data_dir}") |
| |
| for root, dirs, files in os.walk(data_dir): |
| for file in tqdm(files, desc="Loading JSON files"): |
| if file.endswith('.json'): |
| file_path = os.path.join(root, file) |
| try: |
| with open(file_path, 'r', encoding='utf-8') as f: |
| data = json.load(f) |
| json_files.append(data) |
| except Exception as e: |
| print(f"Error loading {file_path}: {e}") |
| |
| return json_files |
| |
| def analyze_score_distribution(self, json_files: List[Dict]) -> Dict: |
| """ |
| Analyze the distribution of scores in the dataset |
| |
| Returns: |
| Dictionary with score statistics |
| """ |
| scores = [] |
| for data in json_files: |
| if 'review_by_client_jp' in data: |
| score = data['review_by_client_jp'].get('点数', 0) |
| if score > 0: |
| scores.append(score) |
| self.score_distribution.append(score) |
| |
| if scores: |
| stats = { |
| 'total_conversations': len(json_files), |
| 'conversations_with_scores': len(scores), |
| 'mean_score': float(np.mean(scores)), |
| 'median_score': float(np.median(scores)), |
| 'std_score': float(np.std(scores)), |
| 'min_score': float(np.min(scores)), |
| 'max_score': float(np.max(scores)), |
| 'percentiles': { |
| '25th': float(np.percentile(scores, 25)), |
| '50th': float(np.percentile(scores, 50)), |
| '75th': float(np.percentile(scores, 75)), |
| '90th': float(np.percentile(scores, 90)) |
| }, |
| 'score_ranges': { |
| '0-30': int(sum(1 for s in scores if 0 <= s < 30)), |
| '30-50': int(sum(1 for s in scores if 30 <= s < 50)), |
| '50-60': int(sum(1 for s in scores if 50 <= s < 60)), |
| '60-70': int(sum(1 for s in scores if 60 <= s < 70)), |
| '70-80': int(sum(1 for s in scores if 70 <= s < 80)), |
| '80-90': int(sum(1 for s in scores if 80 <= s < 90)), |
| '90-100': int(sum(1 for s in scores if 90 <= s <= 100)), |
| } |
| } |
| |
| |
| threshold_analysis = {} |
| for threshold in [30, 40, 50, 60, 65, 70, 75, 80]: |
| kept = sum(1 for s in scores if s >= threshold) |
| threshold_analysis[f'threshold_{threshold}'] = { |
| 'conversations_kept': kept, |
| 'percentage_kept': round((kept / len(scores)) * 100, 2) |
| } |
| stats['threshold_analysis'] = threshold_analysis |
| |
| return stats |
| else: |
| return {'error': 'No valid scores found in dataset'} |
| |
| def plot_score_distribution(self, save_path: str = "score_distribution.png"): |
| """ |
| Plot the distribution of scores |
| """ |
| if not self.score_distribution: |
| print("No scores to plot. Run analyze_score_distribution first.") |
| return |
| |
| fig, axes = plt.subplots(2, 2, figsize=(15, 10)) |
| |
| |
| axes[0, 0].hist(self.score_distribution, bins=20, edgecolor='black', alpha=0.7) |
| axes[0, 0].axvline(self.min_score, color='red', linestyle='--', |
| label=f'Current threshold: {self.min_score}') |
| axes[0, 0].set_xlabel('Score') |
| axes[0, 0].set_ylabel('Frequency') |
| axes[0, 0].set_title('Score Distribution') |
| axes[0, 0].legend() |
| axes[0, 0].grid(True, alpha=0.3) |
| |
| |
| axes[0, 1].boxplot(self.score_distribution, vert=True) |
| axes[0, 1].set_ylabel('Score') |
| axes[0, 1].set_title('Score Box Plot') |
| axes[0, 1].grid(True, alpha=0.3) |
| |
| |
| sorted_scores = np.sort(self.score_distribution) |
| cumulative = np.arange(1, len(sorted_scores) + 1) / len(sorted_scores) |
| axes[1, 0].plot(sorted_scores, cumulative) |
| axes[1, 0].axvline(self.min_score, color='red', linestyle='--', |
| label=f'Current threshold: {self.min_score}') |
| axes[1, 0].set_xlabel('Score') |
| axes[1, 0].set_ylabel('Cumulative Probability') |
| axes[1, 0].set_title('Cumulative Distribution') |
| axes[1, 0].legend() |
| axes[1, 0].grid(True, alpha=0.3) |
| |
| |
| thresholds = range(30, 90, 5) |
| kept_percentages = [] |
| for t in thresholds: |
| kept = sum(1 for s in self.score_distribution if s >= t) |
| kept_percentages.append((kept / len(self.score_distribution)) * 100) |
| |
| axes[1, 1].plot(thresholds, kept_percentages, marker='o') |
| axes[1, 1].axvline(self.min_score, color='red', linestyle='--', |
| label=f'Current threshold: {self.min_score}') |
| axes[1, 1].set_xlabel('Score Threshold') |
| axes[1, 1].set_ylabel('% of Conversations Kept') |
| axes[1, 1].set_title('Impact of Score Threshold') |
| axes[1, 1].legend() |
| axes[1, 1].grid(True, alpha=0.3) |
| |
| plt.tight_layout() |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') |
| plt.show() |
| print(f"Score distribution plot saved to {save_path}") |
| |
| def extract_high_quality_conversations(self, data: Dict) -> List[Dict]: |
| """ |
| Extract conversations with high counselor ratings based on min_score |
| Focus on conversations where counselor performed well |
| """ |
| conversations = [] |
| |
| |
| if 'review_by_client_jp' in data: |
| review = data['review_by_client_jp'] |
| score = review.get('点数', 0) |
| |
| |
| if score >= self.min_score: |
| dialogue = data.get('dialogue', []) |
| |
| |
| conversation_text = "" |
| for turn in dialogue: |
| role = turn['role'] |
| utterance = turn['utterance'] |
| |
| if role == 'counselor': |
| conversation_text += f"カウンセラー: {utterance}\n" |
| else: |
| conversation_text += f"クライアント: {utterance}\n" |
| |
| |
| conversations.append({ |
| 'text': conversation_text, |
| 'score': score, |
| 'topic': data.get('topic', {}).get('main_jp', 'Unknown'), |
| 'review_metrics': { |
| 'empathy': review.get('聴いてもらえた、わかってもらえたと感じた', 0), |
| 'respect': review.get('尊重されたと感じた', 0), |
| 'insights': review.get('新しい気づきや体験があった', 0), |
| 'hope': review.get('希望や期待を感じられた', 0), |
| 'concerns_addressed': review.get('取り組みたかったことを扱えた', 0), |
| 'collaboration': review.get('一緒に考えながら取り組めた', 0), |
| 'rhythm': review.get('やりとりのリズムがあっていた', 0), |
| 'comfort': review.get('居心地のよいやりとりだった', 0), |
| 'overall_appropriate': review.get('全体として適切でよかった', 0), |
| 'valuable': review.get('今回の相談は価値があった', 0), |
| 'smooth_start': review.get('相談開始の円滑さ', 0), |
| 'good_ending': review.get('相談終了のタイミング(不必要に聴きすぎていないか)、円滑さ', 0), |
| 'acceptance_empathy': review.get('受容·共感', 0), |
| 'affirmation': review.get('肯定·承認', 0), |
| 'effective_questions': review.get('的確な質問による会話の促進', 0), |
| 'summarization': review.get('要約', 0), |
| 'problem_clarification': review.get('問題の明確化', 0), |
| 'goal_clarification': review.get('この相談での目標の明確化', 0), |
| 'actionable_suggestions': review.get('次の行動につながる提案', 0), |
| 'encouragement': review.get('勇気づけ·希望の喚起', 0) |
| } |
| }) |
| |
| return conversations |
| |
| def create_training_examples(self, conversations: List[Dict], |
| use_weighted_sampling: bool = False) -> List[Dict]: |
| """ |
| Create training examples in instruction-following format |
| |
| Args: |
| conversations: List of conversation dictionaries |
| use_weighted_sampling: If True, create more examples from higher-scored conversations |
| """ |
| training_examples = [] |
| |
| for conv in tqdm(conversations, desc="Creating training examples"): |
| dialogue_lines = conv['text'].split('\n') |
| score = conv['score'] |
| |
| |
| if use_weighted_sampling: |
| |
| weight = max(1, int((score - self.min_score) / 20) + 1) |
| else: |
| weight = 1 |
| |
| |
| for _ in range(weight): |
| for i in range(0, len(dialogue_lines) - 1, 2): |
| if i + 1 < len(dialogue_lines): |
| client_line = dialogue_lines[i] |
| counselor_line = dialogue_lines[i + 1] |
| |
| |
| if 'クライアント:' in client_line and 'カウンセラー:' in counselor_line: |
| client_msg = client_line.replace('クライアント: ', '').replace('クライアント:', '').strip() |
| counselor_msg = counselor_line.replace('カウンセラー: ', '').replace('カウンセラー:', '').strip() |
| |
| |
| if not client_msg or not counselor_msg: |
| continue |
| |
| |
| example = { |
| 'instruction': self.system_prompt, |
| 'input': client_msg, |
| 'output': counselor_msg, |
| 'score': score, |
| 'topic': conv['topic'], |
| 'metrics': conv['review_metrics'] |
| } |
| |
| training_examples.append(example) |
| |
| return training_examples |
| |
| def prepare_dataset(self, test_size: float = 0.1, val_size: float = 0.1, |
| use_weighted_sampling: bool = False, |
| analyze_scores: bool = True): |
| """ |
| Prepare train, validation, and test datasets |
| |
| Args: |
| test_size: Proportion of data for testing |
| val_size: Proportion of data for validation |
| use_weighted_sampling: If True, oversample high-quality conversations |
| analyze_scores: If True, print score distribution analysis |
| """ |
| print("Loading KokoroChat dataset...") |
| json_files = self.load_json_files() |
| print(f"Loaded {len(json_files)} conversation files") |
| |
| |
| if analyze_scores: |
| print("\n" + "="*60) |
| print("SCORE DISTRIBUTION ANALYSIS") |
| print("="*60) |
| stats = self.analyze_score_distribution(json_files) |
| |
| if 'error' not in stats: |
| print(f"Total conversations: {stats['total_conversations']}") |
| print(f"Conversations with scores: {stats['conversations_with_scores']}") |
| print(f"\nScore Statistics:") |
| print(f" Mean: {stats['mean_score']:.2f}") |
| print(f" Median: {stats['median_score']:.2f}") |
| print(f" Std Dev: {stats['std_score']:.2f}") |
| print(f" Range: {stats['min_score']:.0f} - {stats['max_score']:.0f}") |
| |
| print(f"\nScore Distribution:") |
| for range_name, count in stats['score_ranges'].items(): |
| percentage = (count / stats['conversations_with_scores']) * 100 |
| print(f" {range_name}: {count} ({percentage:.1f}%)") |
| |
| print(f"\nThreshold Impact Analysis:") |
| for threshold_name, data in stats['threshold_analysis'].items(): |
| threshold = threshold_name.split('_')[1] |
| print(f" Threshold >= {threshold}: {data['conversations_kept']} conversations ({data['percentage_kept']:.1f}%)") |
| |
| print(f"\nCurrent threshold ({self.min_score}) will keep: ", end="") |
| kept = sum(1 for s in self.score_distribution if s >= self.min_score) |
| print(f"{kept} conversations ({(kept/len(self.score_distribution))*100:.1f}%)") |
| print("="*60 + "\n") |
| |
| |
| self.plot_score_distribution() |
| |
| all_conversations = [] |
| filtered_count = 0 |
| total_count = 0 |
| |
| for data in json_files: |
| if 'review_by_client_jp' in data: |
| total_count += 1 |
| score = data['review_by_client_jp'].get('点数', 0) |
| if score < self.min_score: |
| filtered_count += 1 |
| |
| conversations = self.extract_high_quality_conversations(data) |
| all_conversations.extend(conversations) |
| |
| print(f"Filtered out {filtered_count} conversations with score < {self.min_score}") |
| print(f"Extracted {len(all_conversations)} high-quality conversations (score >= {self.min_score})") |
| |
| |
| training_examples = self.create_training_examples( |
| all_conversations, |
| use_weighted_sampling=use_weighted_sampling |
| ) |
| print(f"Created {len(training_examples)} training examples") |
| |
| if use_weighted_sampling: |
| print("Note: Used weighted sampling - higher scored conversations appear more frequently") |
| |
| |
| random.shuffle(training_examples) |
| |
| total_size = len(training_examples) |
| test_split = int(total_size * test_size) |
| val_split = int(total_size * val_size) |
| |
| test_data = training_examples[:test_split] |
| val_data = training_examples[test_split:test_split + val_split] |
| train_data = training_examples[test_split + val_split:] |
| |
| print(f"\nDataset splits:") |
| print(f" Train: {len(train_data)} examples") |
| print(f" Validation: {len(val_data)} examples") |
| print(f" Test: {len(test_data)} examples") |
| |
| return { |
| 'train': train_data, |
| 'validation': val_data, |
| 'test': test_data |
| } |
| |
| def format_for_lfm(self, example: Dict) -> str: |
| """ |
| Format example for LFM model training |
| """ |
| formatted = f"""### Instruction: |
| {example['instruction']} |
| |
| ### Input: |
| {example['input']} |
| |
| ### Response: |
| {example['output']}""" |
| return formatted |
| |
| def save_datasets(self, datasets: Dict, output_dir: str): |
| """Save processed datasets with proper type conversion for JSON serialization""" |
| output_path = Path(output_dir) |
| output_path.mkdir(parents=True, exist_ok=True) |
| |
| |
| def convert_to_native(obj): |
| if isinstance(obj, np.integer): |
| return int(obj) |
| elif isinstance(obj, np.floating): |
| return float(obj) |
| elif isinstance(obj, np.ndarray): |
| return obj.tolist() |
| else: |
| return obj |
| |
| |
| stats = { |
| 'min_score_threshold': int(self.min_score), |
| 'dataset_sizes': { |
| 'train': len(datasets['train']), |
| 'validation': len(datasets['validation']), |
| 'test': len(datasets['test']) |
| }, |
| 'score_distribution': {} |
| } |
| |
| for split_name, data in datasets.items(): |
| |
| scores = [ex['score'] for ex in data] |
| if scores: |
| stats['score_distribution'][split_name] = { |
| 'mean': float(np.mean(scores)), |
| 'median': float(np.median(scores)), |
| 'min': float(np.min(scores)), |
| 'max': float(np.max(scores)), |
| 'std': float(np.std(scores)) |
| } |
| |
| |
| file_path = output_path / f"{split_name}.jsonl" |
| with open(file_path, 'w', encoding='utf-8') as f: |
| for example in data: |
| formatted_text = self.format_for_lfm(example) |
| |
| json_obj = { |
| 'text': formatted_text, |
| 'score': convert_to_native(example['score']), |
| 'topic': example['topic'] |
| } |
| json_line = json.dumps(json_obj, ensure_ascii=False) |
| f.write(json_line + '\n') |
| |
| print(f"Saved {split_name} dataset with {len(data)} examples to {file_path}") |
| |
| |
| stats_path = output_path / "dataset_stats.json" |
| with open(stats_path, 'w', encoding='utf-8') as f: |
| json.dump(stats, f, ensure_ascii=False, indent=2) |
| print(f"Saved dataset statistics to {stats_path}") |
| |
| |
| print("\n" + "="*60) |
| print("DATASET SUMMARY") |
| print("="*60) |
| print(f"Minimum score threshold: {stats['min_score_threshold']}") |
| print("\nDataset sizes:") |
| for split, size in stats['dataset_sizes'].items(): |
| print(f" {split}: {size} examples") |
| |
| print("\nScore distributions by split:") |
| for split, dist in stats['score_distribution'].items(): |
| print(f" {split}:") |
| print(f" Mean: {dist['mean']:.2f}") |
| print(f" Std: {dist['std']:.2f}") |
| print(f" Range: {dist['min']:.0f} - {dist['max']:.0f}") |
| print("="*60) |
|
|
| |
| if __name__ == "__main__": |
| import argparse |
| |
| parser = argparse.ArgumentParser(description='Preprocess KokoroChat dataset') |
| parser.add_argument('--data_path', type=str, default='./KokoroChat', |
| help='Path to KokoroChat repository') |
| parser.add_argument('--min_score', type=int, default=70, |
| help='Minimum score threshold for filtering (default: 70)') |
| parser.add_argument('--output_dir', type=str, default='./processed_data', |
| help='Output directory for processed data') |
| parser.add_argument('--weighted_sampling', action='store_true', |
| help='Use weighted sampling based on scores') |
| parser.add_argument('--test_size', type=float, default=0.1, |
| help='Test set size (default: 0.1)') |
| parser.add_argument('--val_size', type=float, default=0.1, |
| help='Validation set size (default: 0.1)') |
| parser.add_argument('--analyze_only', action='store_true', |
| help='Only analyze score distribution without processing') |
| |
| args = parser.parse_args() |
| |
| |
| preprocessor = KokoroChatPreprocessor( |
| data_path=args.data_path, |
| min_score=args.min_score |
| ) |
| |
| if args.analyze_only: |
| |
| print("Running score distribution analysis only...") |
| json_files = preprocessor.load_json_files() |
| stats = preprocessor.analyze_score_distribution(json_files) |
| preprocessor.plot_score_distribution(f"score_analysis_threshold_{args.min_score}.png") |
| else: |
| |
| print(f"Processing with minimum score threshold: {args.min_score}") |
| datasets = preprocessor.prepare_dataset( |
| test_size=args.test_size, |
| val_size=args.val_size, |
| use_weighted_sampling=args.weighted_sampling, |
| analyze_scores=True |
| ) |
| |
| |
| output_dir = f"{args.output_dir}_score{args.min_score}" |
| preprocessor.save_datasets(datasets, output_dir) |
| |
| print(f"\nProcessing complete! Data saved to {output_dir}") |
| print("\nNext steps:") |
| print("1. Run fine-tuning: python finetune_lfm.py") |
| print("2. Run benchmarking: python benchmark_model.py") |
| print("3. Optimize for mobile: python optimize_for_mobile.py") |
|
|