| """ |
| Fixed Data Preprocessing for directory of JSON files with client-counselor dialogues |
| Following KokoroChat methodology with COMPLETE dialogue history |
| Filename: preprocess_kokoro_directory_fixed.py |
| """ |
|
|
| import json |
| import os |
| from typing import List, Dict, Tuple, Optional, Any |
| from tqdm import tqdm |
| import random |
| from collections import defaultdict |
| import numpy as np |
| from pathlib import Path |
| import glob |
|
|
| class KokoroChatDirectoryPreprocessor: |
| def __init__(self, |
| input_dir: str = "./raw_counseling_data", |
| output_dir: str = "./kokoro_processed_data", |
| min_score: int = 70, |
| train_ratio: float = 0.8, |
| val_ratio: float = 0.1, |
| test_ratio: float = 0.1): |
| """ |
| Initialize preprocessor for directory of JSON files |
| |
| Args: |
| input_dir: Directory containing JSON files with conversations |
| output_dir: Directory to save processed data |
| min_score: Minimum score threshold for filtering (if scores exist) |
| train_ratio: Ratio for training data |
| val_ratio: Ratio for validation data |
| test_ratio: Ratio for test data |
| """ |
| self.input_dir = input_dir |
| self.output_dir = output_dir |
| self.min_score = min_score |
| self.train_ratio = train_ratio |
| self.val_ratio = val_ratio |
| self.test_ratio = test_ratio |
| |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| self.total_conversations = 0 |
| self.total_utterances = 0 |
| self.skipped_files = 0 |
| |
| def load_json_file(self, filepath: str) -> Optional[Dict]: |
| """Load a single JSON file""" |
| try: |
| with open(filepath, 'r', encoding='utf-8') as f: |
| data = json.load(f) |
| return data |
| except Exception as e: |
| print(f"⚠️ Error loading {filepath}: {e}") |
| self.skipped_files += 1 |
| return None |
| |
| def safe_get_value(self, obj: Any, default: Any = None) -> Any: |
| """Safely get a value, handling nested dicts and lists""" |
| if isinstance(obj, dict): |
| |
| if 'name' in obj: |
| return str(obj['name']) |
| elif 'value' in obj: |
| return str(obj['value']) |
| elif 'text' in obj: |
| return str(obj['text']) |
| else: |
| |
| for v in obj.values(): |
| if isinstance(v, str): |
| return v |
| return str(list(obj.values())[0]) if obj else default |
| elif isinstance(obj, list): |
| |
| if obj: |
| return str(obj[0]) if len(obj) == 1 else ', '.join(str(x) for x in obj) |
| return default |
| elif obj is None: |
| return default |
| else: |
| return str(obj) |
| |
| def extract_dialogue_from_json(self, data: Dict, filepath: str) -> List[Dict]: |
| """ |
| Extract dialogue from various JSON formats |
| Handles different possible structures |
| """ |
| conversations = [] |
| |
| |
| if isinstance(data, list): |
| |
| conversations.append({ |
| 'dialogue': data, |
| 'id': os.path.basename(filepath).replace('.json', ''), |
| 'score': 100, |
| 'topic': 'general', |
| 'source_file': filepath |
| }) |
| |
| elif isinstance(data, dict): |
| |
| score = data.get('score', 100) |
| if isinstance(score, dict): |
| score = score.get('value', 100) if 'value' in score else 100 |
| try: |
| score = float(score) |
| except: |
| score = 100 |
| |
| |
| topic = self.safe_get_value(data.get('topic', 'general'), 'general') |
| |
| |
| if 'dialogue' in data: |
| conversations.append({ |
| 'dialogue': data['dialogue'], |
| 'id': data.get('id', os.path.basename(filepath).replace('.json', '')), |
| 'score': score, |
| 'topic': topic, |
| 'source_file': filepath |
| }) |
| |
| elif 'messages' in data: |
| conversations.append({ |
| 'dialogue': data['messages'], |
| 'id': data.get('id', os.path.basename(filepath).replace('.json', '')), |
| 'score': score, |
| 'topic': topic, |
| 'source_file': filepath |
| }) |
| |
| elif 'utterances' in data: |
| conversations.append({ |
| 'dialogue': data['utterances'], |
| 'id': data.get('id', os.path.basename(filepath).replace('.json', '')), |
| 'score': score, |
| 'topic': topic, |
| 'source_file': filepath |
| }) |
| |
| elif 'conversations' in data: |
| |
| for conv in data['conversations']: |
| if isinstance(conv, dict) and any(key in conv for key in ['dialogue', 'messages', 'utterances']): |
| dialogue_key = 'dialogue' if 'dialogue' in conv else ('messages' if 'messages' in conv else 'utterances') |
| |
| |
| conv_score = conv.get('score', score) |
| if isinstance(conv_score, dict): |
| conv_score = conv_score.get('value', 100) if 'value' in conv_score else 100 |
| try: |
| conv_score = float(conv_score) |
| except: |
| conv_score = 100 |
| |
| conv_topic = self.safe_get_value(conv.get('topic', topic), 'general') |
| |
| conversations.append({ |
| 'dialogue': conv[dialogue_key], |
| 'id': conv.get('id', f"{os.path.basename(filepath)}_{len(conversations)}"), |
| 'score': conv_score, |
| 'topic': conv_topic, |
| 'source_file': filepath |
| }) |
| |
| else: |
| |
| for key, value in data.items(): |
| if isinstance(value, list) and len(value) > 0: |
| |
| if isinstance(value[0], dict) and any(k in value[0] for k in ['speaker', 'role', 'text', 'content', 'utterance']): |
| conversations.append({ |
| 'dialogue': value, |
| 'id': data.get('id', os.path.basename(filepath).replace('.json', '')), |
| 'score': score, |
| 'topic': topic, |
| 'source_file': filepath |
| }) |
| break |
| |
| return conversations |
| |
| def normalize_utterance(self, utterance: Dict) -> Optional[Dict]: |
| """ |
| Normalize utterance format from various possible structures |
| Returns: {'speaker': str, 'text': str} or None |
| """ |
| |
| speaker = None |
| if 'speaker' in utterance: |
| speaker = utterance['speaker'] |
| elif 'role' in utterance: |
| speaker = utterance['role'] |
| elif 'sender' in utterance: |
| speaker = utterance['sender'] |
| elif 'from' in utterance: |
| speaker = utterance['from'] |
| elif 'type' in utterance: |
| speaker = utterance['type'] |
| |
| |
| text = None |
| if 'text' in utterance: |
| text = utterance['text'] |
| elif 'content' in utterance: |
| text = utterance['content'] |
| elif 'message' in utterance: |
| text = utterance['message'] |
| elif 'utterance' in utterance: |
| text = utterance['utterance'] |
| elif 'response' in utterance: |
| text = utterance['response'] |
| |
| if speaker and text: |
| |
| speaker_lower = str(speaker).lower() |
| if speaker_lower in ['client', 'user', 'patient', 'クライアント', '相談者', 'c']: |
| normalized_speaker = 'client' |
| elif speaker_lower in ['counselor', 'therapist', 'assistant', 'カウンセラー', '相談員', 's', 'system']: |
| normalized_speaker = 'counselor' |
| else: |
| |
| normalized_speaker = 'client' if 'client' in speaker_lower else 'counselor' |
| |
| return { |
| 'speaker': normalized_speaker, |
| 'text': str(text).strip() |
| } |
| |
| return None |
| |
| def merge_consecutive_utterances(self, dialogue: List[Dict]) -> List[Dict]: |
| """ |
| Merge consecutive utterances from the same speaker |
| Following KokoroChat paper methodology |
| """ |
| if not dialogue: |
| return [] |
| |
| merged = [] |
| current_utterance = None |
| |
| for utt in dialogue: |
| normalized = self.normalize_utterance(utt) |
| if not normalized: |
| continue |
| |
| if current_utterance is None: |
| current_utterance = normalized |
| elif current_utterance['speaker'] == normalized['speaker']: |
| |
| current_utterance['text'] += ' ' + normalized['text'] |
| else: |
| |
| merged.append(current_utterance) |
| current_utterance = normalized |
| |
| |
| if current_utterance: |
| merged.append(current_utterance) |
| |
| return merged |
| |
| def create_training_examples(self, conversation: Dict) -> List[Dict]: |
| """ |
| Create training examples with COMPLETE dialogue history |
| Following the paper: Dt = {uC1, uS2, uC3, ..., uCt} -> uSt+1 |
| """ |
| examples = [] |
| |
| |
| dialogue = conversation.get('dialogue', []) |
| if not dialogue: |
| return [] |
| |
| |
| merged_dialogue = self.merge_consecutive_utterances(dialogue) |
| |
| if not merged_dialogue: |
| return [] |
| |
| |
| for i in range(len(merged_dialogue)): |
| current = merged_dialogue[i] |
| |
| |
| if current['speaker'] == 'counselor': |
| |
| complete_history = merged_dialogue[:i] |
| |
| |
| if not complete_history or complete_history[0]['speaker'] != 'client': |
| continue |
| |
| |
| topic = conversation.get('topic', 'general') |
| if not isinstance(topic, str): |
| topic = self.safe_get_value(topic, 'general') |
| |
| |
| example = { |
| 'dialogue_history': complete_history, |
| 'response': current['text'], |
| 'score': conversation.get('score', 100), |
| 'topic': topic, |
| 'conversation_id': conversation.get('id', 'unknown'), |
| 'source_file': conversation.get('source_file', 'unknown'), |
| 'turn_number': i, |
| 'history_length': len(complete_history) |
| } |
| |
| examples.append(example) |
| |
| return examples |
| |
| def format_for_training(self, example: Dict, format_type: str = 'simple') -> str: |
| """ |
| Format example for training |
| |
| Args: |
| format_type: 'simple' or 'llama' format |
| """ |
| |
| history_text = "" |
| for turn in example['dialogue_history']: |
| speaker = "クライアント" if turn['speaker'] == 'client' else "カウンセラー" |
| history_text += f"{speaker}: {turn['text']}\n" |
| |
| if format_type == 'llama': |
| |
| formatted = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> |
| あなたは専門的な訓練を受けた心理カウンセラーです。クライアントの感情に共感し、適切な支援を提供してください。 |
| これまでの対話履歴全体を考慮して、適切な応答を生成してください。<|eot_id|> |
| |
| <|start_header_id|>user<|end_header_id|> |
| 以下は、クライアントとカウンセラーの完全な対話履歴です。 |
| この履歴全体を踏まえて、次のカウンセラーの応答を生成してください。 |
| |
| 完全な対話履歴: |
| {history_text} |
| 次のカウンセラーの応答を生成してください。<|eot_id|> |
| |
| <|start_header_id|>assistant<|end_header_id|> |
| {example['response']}<|eot_id|>""" |
| |
| else: |
| |
| formatted = f"""### Instruction: |
| あなたは専門的な訓練を受けた心理カウンセラーです。 |
| 以下の完全な対話履歴を踏まえて、カウンセラーとして適切な応答を生成してください。 |
| |
| ### Dialogue History: |
| {history_text} |
| ### Response: |
| {example['response']}""" |
| |
| return formatted |
| |
| def process_directory(self, format_type: str = 'simple'): |
| """Process all JSON files in the input directory""" |
| print(f"🔍 Scanning directory: {self.input_dir}") |
| |
| |
| json_files = [] |
| for pattern in ['*.json', '*.jsonl']: |
| json_files.extend(glob.glob(os.path.join(self.input_dir, '**', pattern), recursive=True)) |
| |
| print(f"Found {len(json_files)} JSON files") |
| |
| if not json_files: |
| print("❌ No JSON files found in the directory!") |
| return |
| |
| |
| all_conversations = [] |
| |
| for filepath in tqdm(json_files, desc="Loading JSON files"): |
| |
| if filepath.endswith('.jsonl'): |
| |
| with open(filepath, 'r', encoding='utf-8') as f: |
| for line_num, line in enumerate(f): |
| try: |
| data = json.loads(line) |
| conversations = self.extract_dialogue_from_json(data, f"{filepath}_line{line_num}") |
| all_conversations.extend(conversations) |
| except: |
| continue |
| else: |
| |
| data = self.load_json_file(filepath) |
| if data: |
| conversations = self.extract_dialogue_from_json(data, filepath) |
| all_conversations.extend(conversations) |
| |
| print(f"✅ Loaded {len(all_conversations)} conversations from {len(json_files) - self.skipped_files} files") |
| print(f"⚠️ Skipped {self.skipped_files} files due to errors") |
| |
| |
| conversations_before_filter = len(all_conversations) |
| filtered_conversations = [ |
| conv for conv in all_conversations |
| if conv.get('score', 100) >= self.min_score |
| ] |
| conversations_after_filter = len(filtered_conversations) |
| |
| print(f"📊 Score filtering (>= {self.min_score}):") |
| print(f" Before: {conversations_before_filter} conversations") |
| print(f" After: {conversations_after_filter} conversations") |
| print(f" Filtered out: {conversations_before_filter - conversations_after_filter} conversations") |
| |
| |
| all_examples = [] |
| history_lengths = [] |
| |
| for conv in tqdm(filtered_conversations, desc="Creating training examples"): |
| examples = self.create_training_examples(conv) |
| all_examples.extend(examples) |
| history_lengths.extend([ex['history_length'] for ex in examples]) |
| |
| if not all_examples: |
| print("❌ No training examples created!") |
| return |
| |
| print(f"✅ Created {len(all_examples)} training examples from {len(filtered_conversations)} conversations") |
| print(f"📊 Dialogue history statistics:") |
| print(f" - Mean length: {np.mean(history_lengths):.1f} turns") |
| print(f" - Median length: {np.median(history_lengths):.1f} turns") |
| print(f" - Max length: {max(history_lengths)} turns") |
| print(f" - Min length: {min(history_lengths)} turns") |
| |
| |
| random.shuffle(all_examples) |
| |
| train_size = int(self.train_ratio * len(all_examples)) |
| val_size = int(self.val_ratio * len(all_examples)) |
| |
| train_data = all_examples[:train_size] |
| val_data = all_examples[train_size:train_size + val_size] |
| test_data = all_examples[train_size + val_size:] |
| |
| print(f"\n📂 Split sizes:") |
| print(f" Train: {len(train_data)} ({self.train_ratio*100:.0f}%)") |
| print(f" Val: {len(val_data)} ({self.val_ratio*100:.0f}%)") |
| print(f" Test: {len(test_data)} ({self.test_ratio*100:.0f}%)") |
| |
| |
| self.save_split(train_data, 'train', format_type) |
| self.save_split(val_data, 'val', format_type) |
| self.save_split(test_data, 'test', format_type) |
| |
| |
| self.save_statistics( |
| train_data, val_data, test_data, |
| all_conversations, filtered_conversations, |
| history_lengths |
| ) |
| |
| print(f"\n✅ Processing complete! Data saved to {self.output_dir}") |
| |
| def save_split(self, data: List[Dict], split_name: str, format_type: str = 'simple'): |
| """Save processed data split""" |
| output_file = os.path.join(self.output_dir, f"{split_name}.jsonl") |
| |
| with open(output_file, 'w', encoding='utf-8') as f: |
| for example in tqdm(data, desc=f"Saving {split_name} data"): |
| formatted_text = self.format_for_training(example, format_type) |
| |
| |
| topic = example.get('topic', 'general') |
| if not isinstance(topic, str): |
| topic = self.safe_get_value(topic, 'general') |
| |
| output_item = { |
| 'text': formatted_text, |
| 'dialogue_history': example['dialogue_history'], |
| 'response': example['response'], |
| 'score': example['score'], |
| 'topic': topic, |
| 'conversation_id': example['conversation_id'], |
| 'source_file': example['source_file'], |
| 'turn_number': example['turn_number'], |
| 'history_length': example['history_length'] |
| } |
| |
| f.write(json.dumps(output_item, ensure_ascii=False) + '\n') |
| |
| print(f"✅ Saved {split_name} data to {output_file}") |
| |
| def save_statistics(self, train_data, val_data, test_data, |
| all_conversations, filtered_conversations, history_lengths): |
| """Save comprehensive statistics""" |
| |
| topic_counts = defaultdict(int) |
| for example in train_data: |
| topic = example.get('topic', 'general') |
| if not isinstance(topic, str): |
| topic = self.safe_get_value(topic, 'general') |
| topic_counts[topic] += 1 |
| |
| |
| source_counts = defaultdict(int) |
| for example in train_data: |
| source_file = os.path.basename(example.get('source_file', 'unknown')) |
| source_counts[source_file] += 1 |
| |
| |
| scores = [conv.get('score', 100) for conv in filtered_conversations] |
| |
| stats = { |
| 'preprocessing_info': { |
| 'input_directory': self.input_dir, |
| 'output_directory': self.output_dir, |
| 'total_files_processed': len(set(conv.get('source_file', 'unknown') for conv in all_conversations)), |
| 'total_conversations_loaded': len(all_conversations), |
| 'conversations_after_filtering': len(filtered_conversations), |
| 'conversations_filtered_out': len(all_conversations) - len(filtered_conversations), |
| 'total_training_examples': len(train_data) + len(val_data) + len(test_data), |
| 'min_score_threshold': self.min_score, |
| 'methodology': 'KokoroChat paper - complete dialogue history' |
| }, |
| 'score_filtering': { |
| 'threshold': self.min_score, |
| 'before_filtering': len(all_conversations), |
| 'after_filtering': len(filtered_conversations), |
| 'filtered_out': len(all_conversations) - len(filtered_conversations), |
| 'percentage_kept': (len(filtered_conversations) / len(all_conversations) * 100) if all_conversations else 0 |
| }, |
| 'score_statistics': { |
| 'mean': float(np.mean(scores)), |
| 'std': float(np.std(scores)), |
| 'min': float(min(scores)), |
| 'max': float(max(scores)), |
| 'median': float(np.median(scores)), |
| 'percentile_25': float(np.percentile(scores, 25)), |
| 'percentile_75': float(np.percentile(scores, 75)) |
| }, |
| 'split_sizes': { |
| 'train': len(train_data), |
| 'val': len(val_data), |
| 'test': len(test_data), |
| 'train_ratio': self.train_ratio, |
| 'val_ratio': self.val_ratio, |
| 'test_ratio': self.test_ratio |
| }, |
| 'dialogue_history_stats': { |
| 'mean_length': float(np.mean(history_lengths)), |
| 'std_length': float(np.std(history_lengths)), |
| 'min_length': int(min(history_lengths)), |
| 'max_length': int(max(history_lengths)), |
| 'median_length': float(np.median(history_lengths)), |
| 'percentile_25': float(np.percentile(history_lengths, 25)), |
| 'percentile_75': float(np.percentile(history_lengths, 75)), |
| 'percentile_95': float(np.percentile(history_lengths, 95)) |
| }, |
| 'topic_distribution': dict(list(topic_counts.items())[:20]), |
| 'source_file_distribution': dict(list(source_counts.items())[:20]), |
| 'history_length_bins': { |
| '1-5_turns': sum(1 for l in history_lengths if l <= 5), |
| '6-10_turns': sum(1 for l in history_lengths if 5 < l <= 10), |
| '11-15_turns': sum(1 for l in history_lengths if 10 < l <= 15), |
| '16-20_turns': sum(1 for l in history_lengths if 15 < l <= 20), |
| '21-30_turns': sum(1 for l in history_lengths if 20 < l <= 30), |
| '31-50_turns': sum(1 for l in history_lengths if 30 < l <= 50), |
| '50+_turns': sum(1 for l in history_lengths if l > 50) |
| } |
| } |
| |
| stats_file = os.path.join(self.output_dir, 'dataset_stats.json') |
| with open(stats_file, 'w', encoding='utf-8') as f: |
| json.dump(stats, f, ensure_ascii=False, indent=2) |
| |
| print(f"\n📊 Statistics saved to {stats_file}") |
| |
| |
| print("\n" + "="*70) |
| print("📈 DATASET STATISTICS SUMMARY") |
| print("="*70) |
| print(f"Files processed: {stats['preprocessing_info']['total_files_processed']}") |
| print(f"Conversations loaded: {stats['preprocessing_info']['total_conversations_loaded']}") |
| print(f"After score filtering (>={self.min_score}): {stats['preprocessing_info']['conversations_after_filtering']}") |
| print(f"Training examples created: {stats['preprocessing_info']['total_training_examples']}") |
| print(f"\nScore Statistics (after filtering):") |
| print(f" Mean: {stats['score_statistics']['mean']:.1f}") |
| print(f" Median: {stats['score_statistics']['median']:.1f}") |
| print(f" Range: {stats['score_statistics']['min']:.0f} - {stats['score_statistics']['max']:.0f}") |
| print(f"\nDialogue History Length Distribution:") |
| for bin_name, count in stats['history_length_bins'].items(): |
| percentage = (count / len(history_lengths)) * 100 if history_lengths else 0 |
| print(f" {bin_name}: {count} ({percentage:.1f}%)") |
| print("="*70) |
|
|
|
|
| def main(): |
| import argparse |
| |
| parser = argparse.ArgumentParser( |
| description='Preprocess directory of JSON files with counseling dialogues' |
| ) |
| parser.add_argument( |
| '--input_dir', |
| type=str, |
| default='./KokoroChat/kokorochat_dialogues', |
| help='Directory containing JSON files with conversations' |
| ) |
| parser.add_argument( |
| '--output_dir', |
| type=str, |
| default='./kokoro_processed_data', |
| help='Output directory for processed data' |
| ) |
| parser.add_argument( |
| '--min_score', |
| type=int, |
| default=70, |
| help='Minimum score threshold (if scores exist in data)' |
| ) |
| parser.add_argument( |
| '--format', |
| type=str, |
| choices=['simple', 'llama'], |
| default='simple', |
| help='Output format type' |
| ) |
| |
| args = parser.parse_args() |
| |
| |
| preprocessor = KokoroChatDirectoryPreprocessor( |
| input_dir=args.input_dir, |
| output_dir=args.output_dir, |
| min_score=args.min_score |
| ) |
| |
| print("🚀 Starting preprocessing with COMPLETE dialogue history") |
| print(" Following KokoroChat paper methodology") |
| print("="*70) |
| |
| |
| preprocessor.process_directory(format_type=args.format) |
| |
| print("\n✅ Preprocessing complete!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|