Upload folder using huggingface_hub
Browse files- data_preparation.py +133 -0
- dataprocessing_multiturn.py +177 -0
- finetune_lfm2.6b.py +303 -0
data_preparation.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# data_preparation.py
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from typing import List, Dict, Tuple
|
| 7 |
+
import numpy as np
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from sklearn.model_selection import train_test_split
|
| 10 |
+
|
| 11 |
+
class KokoroChatProcessor:
|
| 12 |
+
def __init__(self, data_path: str):
|
| 13 |
+
self.data_path = Path(data_path)
|
| 14 |
+
self.conversations = []
|
| 15 |
+
self.processed_data = []
|
| 16 |
+
|
| 17 |
+
def load_all_conversations(self) -> List[Dict]:
|
| 18 |
+
"""Load all JSON files from KokoroChat dataset"""
|
| 19 |
+
json_files = list(self.data_path.glob("**/*.json"))
|
| 20 |
+
print(f"Found {len(json_files)} conversation files")
|
| 21 |
+
|
| 22 |
+
for json_file in tqdm(json_files, desc="Loading conversations"):
|
| 23 |
+
try:
|
| 24 |
+
with open(json_file, 'r', encoding='utf-8') as f:
|
| 25 |
+
data = json.load(f)
|
| 26 |
+
self.conversations.append(data)
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print(f"Error loading {json_file}: {e}")
|
| 29 |
+
|
| 30 |
+
return self.conversations
|
| 31 |
+
|
| 32 |
+
def create_training_examples(self) -> List[Dict]:
|
| 33 |
+
"""Convert conversations to training format"""
|
| 34 |
+
|
| 35 |
+
for conv_data in tqdm(self.conversations, desc="Processing conversations"):
|
| 36 |
+
dialogue = conv_data.get('dialogue', [])
|
| 37 |
+
topic = conv_data.get('topic', {})
|
| 38 |
+
review = conv_data.get('review_by_client_jp', {})
|
| 39 |
+
|
| 40 |
+
# Create conversation context
|
| 41 |
+
conversation_pairs = []
|
| 42 |
+
|
| 43 |
+
for i in range(0, len(dialogue) - 1, 2):
|
| 44 |
+
if i + 1 < len(dialogue):
|
| 45 |
+
counselor_msg = dialogue[i]
|
| 46 |
+
client_msg = dialogue[i + 1] if i + 1 < len(dialogue) else None
|
| 47 |
+
|
| 48 |
+
if counselor_msg['role'] == 'counselor' and client_msg and client_msg['role'] == 'client':
|
| 49 |
+
# Build context from previous messages
|
| 50 |
+
context = self._build_context(dialogue[:i+1])
|
| 51 |
+
|
| 52 |
+
training_example = {
|
| 53 |
+
'instruction': "あなたは共感的で専門的な心理カウンセラーです。クライアントの悩みに寄り添い、適切なサポートを提供してください。",
|
| 54 |
+
'input': f"クライアント: {client_msg['utterance']}",
|
| 55 |
+
'output': counselor_msg['utterance'],
|
| 56 |
+
'context': context,
|
| 57 |
+
'topic': topic.get('main_jp', ''),
|
| 58 |
+
'quality_score': self._calculate_quality_score(review)
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
self.processed_data.append(training_example)
|
| 62 |
+
|
| 63 |
+
return self.processed_data
|
| 64 |
+
|
| 65 |
+
def _build_context(self, dialogue_history: List[Dict], max_turns: int = 5) -> str:
|
| 66 |
+
"""Build conversation context from history"""
|
| 67 |
+
context_parts = []
|
| 68 |
+
start_idx = max(0, len(dialogue_history) - max_turns * 2)
|
| 69 |
+
|
| 70 |
+
for msg in dialogue_history[start_idx:]:
|
| 71 |
+
role = "カウンセラー" if msg['role'] == 'counselor' else "クライアント"
|
| 72 |
+
context_parts.append(f"{role}: {msg['utterance']}")
|
| 73 |
+
|
| 74 |
+
return "\n".join(context_parts)
|
| 75 |
+
|
| 76 |
+
def _calculate_quality_score(self, review: Dict) -> float:
|
| 77 |
+
"""Calculate quality score from client review"""
|
| 78 |
+
if not review or review.get('点数') is None:
|
| 79 |
+
return 0.5 # Default middle score
|
| 80 |
+
|
| 81 |
+
# Normalize score (assuming max score is 100)
|
| 82 |
+
return review.get('点数', 50) / 100.0
|
| 83 |
+
|
| 84 |
+
def prepare_for_finetuning(self, test_size: float = 0.1, val_size: float = 0.1):
|
| 85 |
+
"""Prepare train/val/test splits"""
|
| 86 |
+
|
| 87 |
+
# Filter high-quality examples (score > 0.6)
|
| 88 |
+
high_quality = [ex for ex in self.processed_data if ex['quality_score'] > 0.6]
|
| 89 |
+
print(f"Selected {len(high_quality)} high-quality examples")
|
| 90 |
+
|
| 91 |
+
# Create splits
|
| 92 |
+
train_data, test_data = train_test_split(high_quality, test_size=test_size, random_state=42)
|
| 93 |
+
train_data, val_data = train_test_split(train_data, test_size=val_size, random_state=42)
|
| 94 |
+
|
| 95 |
+
# Format for fine-tuning
|
| 96 |
+
def format_example(ex):
|
| 97 |
+
prompt = f"""### 指示:
|
| 98 |
+
{ex['instruction']}
|
| 99 |
+
|
| 100 |
+
### コンテキスト:
|
| 101 |
+
{ex['context']}
|
| 102 |
+
|
| 103 |
+
### 入力:
|
| 104 |
+
{ex['input']}
|
| 105 |
+
|
| 106 |
+
### 応答:
|
| 107 |
+
{ex['output']}"""
|
| 108 |
+
return {'text': prompt}
|
| 109 |
+
|
| 110 |
+
train_formatted = [format_example(ex) for ex in train_data]
|
| 111 |
+
val_formatted = [format_example(ex) for ex in val_data]
|
| 112 |
+
test_formatted = [format_example(ex) for ex in test_data]
|
| 113 |
+
|
| 114 |
+
return train_formatted, val_formatted, test_formatted
|
| 115 |
+
|
| 116 |
+
# Execute data preparation
|
| 117 |
+
processor = KokoroChatProcessor('KokoroChat/data')
|
| 118 |
+
processor.load_all_conversations()
|
| 119 |
+
processor.create_training_examples()
|
| 120 |
+
train_data, val_data, test_data = processor.prepare_for_finetuning()
|
| 121 |
+
|
| 122 |
+
# Save processed data
|
| 123 |
+
import pickle
|
| 124 |
+
with open('processed_data.pkl', 'wb') as f:
|
| 125 |
+
pickle.dump({
|
| 126 |
+
'train': train_data,
|
| 127 |
+
'val': val_data,
|
| 128 |
+
'test': test_data
|
| 129 |
+
}, f)
|
| 130 |
+
|
| 131 |
+
print(f"Training examples: {len(train_data)}")
|
| 132 |
+
print(f"Validation examples: {len(val_data)}")
|
| 133 |
+
print(f"Test examples: {len(test_data)}")
|
dataprocessing_multiturn.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# prepare_dataset_multiturn.py
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from datasets import Dataset, Features, Value
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
def parse_kokorochat_with_context(json_file_path, context_window=4, max_history_tokens=1500):
|
| 9 |
+
"""
|
| 10 |
+
Parse KokoroChat with conversation history for realistic counseling.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
json_file_path: Path to JSON file
|
| 14 |
+
context_window: Number of previous turns to include (default: 4 = 2 exchanges)
|
| 15 |
+
max_history_tokens: Approximate token limit for history (prevents too long sequences)
|
| 16 |
+
"""
|
| 17 |
+
try:
|
| 18 |
+
with open(json_file_path, 'r', encoding='utf-8') as f:
|
| 19 |
+
data = json.load(f)
|
| 20 |
+
except Exception as e:
|
| 21 |
+
return [], 0
|
| 22 |
+
|
| 23 |
+
conversations = []
|
| 24 |
+
dialogue = data.get('dialogue', [])
|
| 25 |
+
|
| 26 |
+
# Get quality score
|
| 27 |
+
review_en = data.get('review_by_client_en', {})
|
| 28 |
+
total_score = review_en.get('score', 0)
|
| 29 |
+
|
| 30 |
+
# Get topic
|
| 31 |
+
topic = data.get('topic', {})
|
| 32 |
+
main_topic = topic.get('main_en', '')
|
| 33 |
+
sub_topic = topic.get('sub', '')
|
| 34 |
+
|
| 35 |
+
# Extract examples with context
|
| 36 |
+
for i in range(len(dialogue) - 1):
|
| 37 |
+
current = dialogue[i]
|
| 38 |
+
next_turn = dialogue[i + 1]
|
| 39 |
+
|
| 40 |
+
# Look for client -> counselor pairs
|
| 41 |
+
if current['role'] == 'client' and next_turn['role'] == 'counselor':
|
| 42 |
+
client_msg = current['utterance'].strip()
|
| 43 |
+
counselor_msg = next_turn['utterance'].strip()
|
| 44 |
+
|
| 45 |
+
if len(client_msg) > 5 and len(counselor_msg) > 5:
|
| 46 |
+
# Get conversation history (previous turns)
|
| 47 |
+
start_idx = max(0, i - context_window)
|
| 48 |
+
history = dialogue[start_idx:i]
|
| 49 |
+
|
| 50 |
+
# Estimate token count (rough: ~3 chars per token for Japanese)
|
| 51 |
+
history_text = ''.join([h['utterance'] for h in history])
|
| 52 |
+
if len(history_text) < max_history_tokens * 3: # Keep reasonable length
|
| 53 |
+
conversations.append({
|
| 54 |
+
'history': history,
|
| 55 |
+
'client': client_msg,
|
| 56 |
+
'counselor': counselor_msg,
|
| 57 |
+
'quality_score': total_score,
|
| 58 |
+
'topic_main': main_topic,
|
| 59 |
+
'topic_sub': sub_topic,
|
| 60 |
+
'dialogue_id': Path(json_file_path).stem
|
| 61 |
+
})
|
| 62 |
+
|
| 63 |
+
return conversations, total_score
|
| 64 |
+
|
| 65 |
+
def format_conversation_for_lfm2(conversation):
|
| 66 |
+
"""
|
| 67 |
+
Format conversation with history into LFM2 ChatML template
|
| 68 |
+
"""
|
| 69 |
+
# Start with system prompt
|
| 70 |
+
formatted = "<|im_start|>system\n"
|
| 71 |
+
formatted += "あなたは経験豊富な心理カウンセラーです。クライアントの話を傾聴し、共感的で支援的な応答をしてください。<|im_end|>\n"
|
| 72 |
+
|
| 73 |
+
# Add conversation history
|
| 74 |
+
for turn in conversation['history']:
|
| 75 |
+
if turn['role'] == 'client':
|
| 76 |
+
formatted += f"<|im_start|>user\n{turn['utterance']}<|im_end|>\n"
|
| 77 |
+
elif turn['role'] == 'counselor':
|
| 78 |
+
formatted += f"<|im_start|>assistant\n{turn['utterance']}<|im_end|>\n"
|
| 79 |
+
|
| 80 |
+
# Add current exchange (what we're training on)
|
| 81 |
+
formatted += f"<|im_start|>user\n{conversation['client']}<|im_end|>\n"
|
| 82 |
+
formatted += f"<|im_start|>assistant\n{conversation['counselor']}<|im_end|><|endoftext|>"
|
| 83 |
+
|
| 84 |
+
return formatted
|
| 85 |
+
|
| 86 |
+
def create_training_dataset_multiturn(
|
| 87 |
+
data_dir="./KokoroChat/data",
|
| 88 |
+
min_score=70,
|
| 89 |
+
context_window=4
|
| 90 |
+
):
|
| 91 |
+
"""
|
| 92 |
+
Create training dataset with conversation context.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
data_dir: Directory containing JSON files
|
| 96 |
+
min_score: Minimum quality score (0-100, recommend 85 for top quality)
|
| 97 |
+
context_window: Number of previous turns to include
|
| 98 |
+
"""
|
| 99 |
+
json_files = list(Path(data_dir).rglob("*.json"))
|
| 100 |
+
print(f"Found {len(json_files)} JSON files")
|
| 101 |
+
|
| 102 |
+
all_conversations = []
|
| 103 |
+
score_distribution = []
|
| 104 |
+
|
| 105 |
+
print("\nProcessing files with multi-turn context...")
|
| 106 |
+
for idx, json_file in enumerate(json_files):
|
| 107 |
+
if idx % 1000 == 0:
|
| 108 |
+
print(f"Processed {idx}/{len(json_files)} files...")
|
| 109 |
+
|
| 110 |
+
try:
|
| 111 |
+
convs, score = parse_kokorochat_with_context(
|
| 112 |
+
json_file,
|
| 113 |
+
context_window=context_window
|
| 114 |
+
)
|
| 115 |
+
score_distribution.append(score)
|
| 116 |
+
|
| 117 |
+
if score >= min_score:
|
| 118 |
+
all_conversations.extend(convs)
|
| 119 |
+
except Exception as e:
|
| 120 |
+
continue
|
| 121 |
+
|
| 122 |
+
print(f"\n=== Processing Results ===")
|
| 123 |
+
print(f"High-quality files (>= {min_score}): {sum(1 for s in score_distribution if s >= min_score)}")
|
| 124 |
+
print(f"Total conversation examples: {len(all_conversations)}")
|
| 125 |
+
|
| 126 |
+
if len(all_conversations) == 0:
|
| 127 |
+
print(f"❌ No conversations found! Try lowering min_score (current: {min_score})")
|
| 128 |
+
return None
|
| 129 |
+
|
| 130 |
+
# Format for LFM2
|
| 131 |
+
formatted_data = []
|
| 132 |
+
for conv in all_conversations:
|
| 133 |
+
formatted_text = format_conversation_for_lfm2(conv)
|
| 134 |
+
|
| 135 |
+
formatted_data.append({
|
| 136 |
+
'text': formatted_text,
|
| 137 |
+
'quality_score': conv['quality_score'],
|
| 138 |
+
'topic_main': conv['topic_main'],
|
| 139 |
+
'topic_sub': conv['topic_sub'],
|
| 140 |
+
'has_context': len(conv['history']) > 0
|
| 141 |
+
})
|
| 142 |
+
|
| 143 |
+
# Create dataset
|
| 144 |
+
features = Features({
|
| 145 |
+
'text': Value('string'),
|
| 146 |
+
'quality_score': Value('int64'),
|
| 147 |
+
'topic_main': Value('string'),
|
| 148 |
+
'topic_sub': Value('string'),
|
| 149 |
+
'has_context': Value('bool')
|
| 150 |
+
})
|
| 151 |
+
|
| 152 |
+
df = pd.DataFrame(formatted_data)
|
| 153 |
+
dataset = Dataset.from_pandas(df, features=features)
|
| 154 |
+
dataset = dataset.train_test_split(test_size=0.1, seed=42)
|
| 155 |
+
|
| 156 |
+
print(f"\n=== Final Dataset ===")
|
| 157 |
+
print(f"Training samples: {len(dataset['train'])}")
|
| 158 |
+
print(f"Validation samples: {len(dataset['test'])}")
|
| 159 |
+
print(f"Examples with context: {sum(df['has_context'])}")
|
| 160 |
+
|
| 161 |
+
# Save
|
| 162 |
+
dataset.save_to_disk("./kokorochat_processed_multiturn")
|
| 163 |
+
print("\n✅ Multi-turn dataset saved to ./kokorochat_processed_multiturn")
|
| 164 |
+
|
| 165 |
+
# Show sample
|
| 166 |
+
print("\n=== Sample Training Example (with context) ===")
|
| 167 |
+
sample = dataset['train'][5]['text']
|
| 168 |
+
print(sample[:1000] + "\n..." if len(sample) > 1000 else sample)
|
| 169 |
+
|
| 170 |
+
return dataset
|
| 171 |
+
|
| 172 |
+
if __name__ == "__main__":
|
| 173 |
+
dataset = create_training_dataset_multiturn(
|
| 174 |
+
data_dir="./KokoroChat/kokorochat_dialogues",
|
| 175 |
+
min_score=60, # Top 30% quality
|
| 176 |
+
context_window=4 # Include 4 previous turns
|
| 177 |
+
)
|
finetune_lfm2.6b.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# finetune_lfm2_2.6b_FIXED.py
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import (
|
| 4 |
+
AutoTokenizer,
|
| 5 |
+
AutoModelForCausalLM,
|
| 6 |
+
TrainingArguments,
|
| 7 |
+
Trainer,
|
| 8 |
+
BitsAndBytesConfig,
|
| 9 |
+
GPT2Tokenizer
|
| 10 |
+
)
|
| 11 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
| 12 |
+
from datasets import load_from_disk
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from typing import Any, Dict, List
|
| 15 |
+
import wandb
|
| 16 |
+
import os
|
| 17 |
+
import warnings
|
| 18 |
+
warnings.filterwarnings('ignore')
|
| 19 |
+
|
| 20 |
+
print("=" * 80)
|
| 21 |
+
print("LFM2-2.6B FINE-TUNING - FIXED VERSION")
|
| 22 |
+
print("=" * 80)
|
| 23 |
+
print(f"PyTorch: {torch.__version__}")
|
| 24 |
+
print(f"CUDA: {torch.cuda.is_available()}")
|
| 25 |
+
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")
|
| 26 |
+
|
| 27 |
+
if torch.cuda.is_available():
|
| 28 |
+
gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
| 29 |
+
print(f"GPU Memory: {gpu_memory_gb:.1f} GB")
|
| 30 |
+
|
| 31 |
+
import bitsandbytes as bnb
|
| 32 |
+
print("✅ BitsAndBytes OK")
|
| 33 |
+
|
| 34 |
+
# Initialize W&B
|
| 35 |
+
wandb.init(
|
| 36 |
+
project="liquid-ai-hackathon-kokorochat",
|
| 37 |
+
name="LFM2-2.6B-counselor-FIXED",
|
| 38 |
+
config={
|
| 39 |
+
"model": "LFM2-2.6B",
|
| 40 |
+
"dataset": "KokoroChat-MultiTurn",
|
| 41 |
+
"task": "psychological-counseling"
|
| 42 |
+
}
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
print("\n" + "=" * 80)
|
| 46 |
+
print("LOADING MODEL (WITH FALLBACK)")
|
| 47 |
+
print("=" * 80)
|
| 48 |
+
|
| 49 |
+
LOCAL_MODEL_PATH = "./models/LFM2-2.6B"
|
| 50 |
+
HF_MODEL_NAME = "LiquidAI/LFM2-2.6B"
|
| 51 |
+
|
| 52 |
+
# 1. Load tokenizer with GPT2 fallback
|
| 53 |
+
print("\n1. Loading tokenizer...")
|
| 54 |
+
try:
|
| 55 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 56 |
+
LOCAL_MODEL_PATH,
|
| 57 |
+
trust_remote_code=True,
|
| 58 |
+
local_files_only=True
|
| 59 |
+
)
|
| 60 |
+
print(" ✅ LFM2 tokenizer loaded!")
|
| 61 |
+
except Exception as e:
|
| 62 |
+
print(f" ⚠️ LFM2 tokenizer failed")
|
| 63 |
+
print(" 🔄 Using GPT2 tokenizer...")
|
| 64 |
+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
| 65 |
+
print(" ✅ GPT2 tokenizer loaded!")
|
| 66 |
+
|
| 67 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 68 |
+
tokenizer.padding_side = "right"
|
| 69 |
+
|
| 70 |
+
# 2. QLoRA config
|
| 71 |
+
print("\n2. Configuring QLoRA...")
|
| 72 |
+
bnb_config = BitsAndBytesConfig(
|
| 73 |
+
load_in_4bit=True,
|
| 74 |
+
bnb_4bit_quant_type="nf4",
|
| 75 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 76 |
+
bnb_4bit_use_double_quant=True,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# 3. Load model with proper fallback
|
| 80 |
+
print("\n3. Loading LFM2-2.6B model...")
|
| 81 |
+
|
| 82 |
+
# First, try to ensure we have the custom model files
|
| 83 |
+
print(" 📥 Checking for custom model files...")
|
| 84 |
+
|
| 85 |
+
# Check if modeling files exist
|
| 86 |
+
custom_files = ["modeling_lfm2.py", "configuration_lfm2.py"]
|
| 87 |
+
has_custom_files = all(
|
| 88 |
+
os.path.exists(os.path.join(LOCAL_MODEL_PATH, f))
|
| 89 |
+
for f in custom_files
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
if not has_custom_files:
|
| 93 |
+
print(" ⚠️ Custom model files missing in local directory")
|
| 94 |
+
print(" 📥 Need to download from HuggingFace with custom code...")
|
| 95 |
+
|
| 96 |
+
# Download with custom code
|
| 97 |
+
from huggingface_hub import snapshot_download
|
| 98 |
+
|
| 99 |
+
print(" ⏳ Downloading model with custom code (one-time)...")
|
| 100 |
+
snapshot_download(
|
| 101 |
+
repo_id=HF_MODEL_NAME,
|
| 102 |
+
local_dir=LOCAL_MODEL_PATH,
|
| 103 |
+
local_dir_use_symlinks=False,
|
| 104 |
+
ignore_patterns=[] # Don't ignore anything
|
| 105 |
+
)
|
| 106 |
+
print(" ✅ Model downloaded with custom code!")
|
| 107 |
+
|
| 108 |
+
# Now load the model
|
| 109 |
+
print(" ⏳ Loading model (~2-4 minutes)...")
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
# Try local first with trust_remote_code
|
| 113 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 114 |
+
LOCAL_MODEL_PATH,
|
| 115 |
+
quantization_config=bnb_config,
|
| 116 |
+
device_map="auto",
|
| 117 |
+
trust_remote_code=True, # CRITICAL!
|
| 118 |
+
torch_dtype=torch.bfloat16,
|
| 119 |
+
local_files_only=False # Allow downloading custom code if needed
|
| 120 |
+
)
|
| 121 |
+
print(" ✅ Model loaded from local!")
|
| 122 |
+
|
| 123 |
+
except Exception as e:
|
| 124 |
+
print(f" ⚠️ Local load failed: {str(e)[:100]}")
|
| 125 |
+
print(" 📥 Loading directly from HuggingFace...")
|
| 126 |
+
|
| 127 |
+
# Load from HuggingFace Hub
|
| 128 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 129 |
+
HF_MODEL_NAME,
|
| 130 |
+
quantization_config=bnb_config,
|
| 131 |
+
device_map="auto",
|
| 132 |
+
trust_remote_code=True,
|
| 133 |
+
torch_dtype=torch.bfloat16
|
| 134 |
+
)
|
| 135 |
+
print(" ✅ Model loaded from HuggingFace!")
|
| 136 |
+
|
| 137 |
+
model = prepare_model_for_kbit_training(model)
|
| 138 |
+
model.config.use_cache = False
|
| 139 |
+
print(" ✅ Model prepared!")
|
| 140 |
+
|
| 141 |
+
# 4. LoRA - 2.6B configuration
|
| 142 |
+
print("\n4. Applying LoRA (2.6B config)...")
|
| 143 |
+
lora_config = LoraConfig(
|
| 144 |
+
r=64, # Higher for 2.6B
|
| 145 |
+
lora_alpha=128,
|
| 146 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 147 |
+
"gate_proj", "up_proj", "down_proj"],
|
| 148 |
+
lora_dropout=0.05,
|
| 149 |
+
bias="none",
|
| 150 |
+
task_type="CAUSAL_LM"
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
model = get_peft_model(model, lora_config)
|
| 154 |
+
print("\n📊 Trainable Parameters:")
|
| 155 |
+
model.print_trainable_parameters()
|
| 156 |
+
|
| 157 |
+
# 5. Load dataset
|
| 158 |
+
print("\n5. Loading dataset...")
|
| 159 |
+
dataset = load_from_disk("./kokorochat_processed_multiturn")
|
| 160 |
+
print(f" ✅ Training: {len(dataset['train']):,}, Val: {len(dataset['test']):,}")
|
| 161 |
+
|
| 162 |
+
# 6. Data Collator (same as 1.2B)
|
| 163 |
+
@dataclass
|
| 164 |
+
class DataCollatorForCausalLM:
|
| 165 |
+
tokenizer: Any
|
| 166 |
+
max_length: int = 2048
|
| 167 |
+
|
| 168 |
+
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
| 169 |
+
texts = [f["text"] for f in features]
|
| 170 |
+
batch = self.tokenizer(
|
| 171 |
+
texts,
|
| 172 |
+
max_length=self.max_length,
|
| 173 |
+
padding=True,
|
| 174 |
+
truncation=True,
|
| 175 |
+
return_tensors="pt"
|
| 176 |
+
)
|
| 177 |
+
batch["labels"] = batch["input_ids"].clone()
|
| 178 |
+
batch["labels"][batch["labels"] == self.tokenizer.pad_token_id] = -100
|
| 179 |
+
return batch
|
| 180 |
+
|
| 181 |
+
data_collator = DataCollatorForCausalLM(tokenizer=tokenizer)
|
| 182 |
+
|
| 183 |
+
# 7. Training Configuration - 2.6B optimized
|
| 184 |
+
print("\n6. Configuring training (2.6B optimized)...")
|
| 185 |
+
|
| 186 |
+
gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
| 187 |
+
|
| 188 |
+
if gpu_memory_gb >= 70:
|
| 189 |
+
per_device_batch = 2
|
| 190 |
+
grad_accum = 16
|
| 191 |
+
print(f" 🚀 {gpu_memory_gb:.0f}GB GPU → batch=2, accum=16")
|
| 192 |
+
else:
|
| 193 |
+
per_device_batch = 1
|
| 194 |
+
grad_accum = 32
|
| 195 |
+
print(f" ⚡ {gpu_memory_gb:.0f}GB GPU → batch=1, accum=32")
|
| 196 |
+
|
| 197 |
+
training_args = TrainingArguments(
|
| 198 |
+
output_dir="./lfm2-2.6b-checkpoints-fixed",
|
| 199 |
+
|
| 200 |
+
# Batch (memory-adjusted for 2.6B)
|
| 201 |
+
per_device_train_batch_size=per_device_batch,
|
| 202 |
+
per_device_eval_batch_size=per_device_batch,
|
| 203 |
+
gradient_accumulation_steps=grad_accum,
|
| 204 |
+
|
| 205 |
+
# Learning (optimized for 2.6B)
|
| 206 |
+
num_train_epochs=3, # 2.6B learns faster
|
| 207 |
+
learning_rate=2e-4, # Lower for stability
|
| 208 |
+
warmup_steps=200,
|
| 209 |
+
lr_scheduler_type="cosine",
|
| 210 |
+
|
| 211 |
+
# Optimization
|
| 212 |
+
fp16=False,
|
| 213 |
+
bf16=True,
|
| 214 |
+
logging_steps=10,
|
| 215 |
+
eval_strategy="steps",
|
| 216 |
+
eval_steps=50,
|
| 217 |
+
save_strategy="steps",
|
| 218 |
+
save_steps=100,
|
| 219 |
+
save_total_limit=5,
|
| 220 |
+
load_best_model_at_end=True,
|
| 221 |
+
metric_for_best_model="eval_loss",
|
| 222 |
+
optim="paged_adamw_8bit",
|
| 223 |
+
report_to="wandb",
|
| 224 |
+
gradient_checkpointing=True,
|
| 225 |
+
max_grad_norm=0.3,
|
| 226 |
+
logging_dir="./logs",
|
| 227 |
+
remove_unused_columns=False,
|
| 228 |
+
dataloader_num_workers=4,
|
| 229 |
+
dataloader_pin_memory=True,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
effective_batch = per_device_batch * grad_accum
|
| 233 |
+
steps_per_epoch = len(dataset['train']) // effective_batch
|
| 234 |
+
total_steps = steps_per_epoch * 3
|
| 235 |
+
|
| 236 |
+
print("\n" + "=" * 80)
|
| 237 |
+
print("📊 2.6B TRAINING CONFIGURATION")
|
| 238 |
+
print("=" * 80)
|
| 239 |
+
print(f"\n✅ Batch Config:")
|
| 240 |
+
print(f" Per-device: {per_device_batch}")
|
| 241 |
+
print(f" Gradient accum: {grad_accum}")
|
| 242 |
+
print(f" → Effective: {effective_batch}")
|
| 243 |
+
|
| 244 |
+
print(f"\n✅ Learning Config:")
|
| 245 |
+
print(f" Learning rate: 2e-4 (vs 3e-4 for 1.2B)")
|
| 246 |
+
print(f" Epochs: 3 (vs 4 for 1.2B)")
|
| 247 |
+
print(f" LoRA rank: 64 (vs 32 for 1.2B)")
|
| 248 |
+
|
| 249 |
+
print(f"\n✅ Training Stats:")
|
| 250 |
+
print(f" Training samples: {len(dataset['train']):,}")
|
| 251 |
+
print(f" Steps per epoch: {steps_per_epoch:,}")
|
| 252 |
+
print(f" Total steps: {total_steps:,}")
|
| 253 |
+
|
| 254 |
+
print(f"\n⏱️ Estimated Time:")
|
| 255 |
+
if gpu_memory_gb >= 80:
|
| 256 |
+
print(f" ~5-8 hours on {gpu_memory_gb:.0f}GB GPU")
|
| 257 |
+
else:
|
| 258 |
+
print(f" ~8-12 hours on {gpu_memory_gb:.0f}GB GPU")
|
| 259 |
+
|
| 260 |
+
# 8. Trainer (same as 1.2B)
|
| 261 |
+
trainer = Trainer(
|
| 262 |
+
model=model,
|
| 263 |
+
args=training_args,
|
| 264 |
+
train_dataset=dataset["train"],
|
| 265 |
+
eval_dataset=dataset["test"],
|
| 266 |
+
data_collator=data_collator,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# 9. Start training
|
| 270 |
+
print("\n" + "=" * 80)
|
| 271 |
+
print("🚀 STARTING 2.6B TRAINING")
|
| 272 |
+
print("=" * 80)
|
| 273 |
+
print(f"📊 Monitor: https://wandb.ai/sandeeptechiot-ai/liquid-ai-hackathon-kokorochat\n")
|
| 274 |
+
|
| 275 |
+
try:
|
| 276 |
+
trainer.train()
|
| 277 |
+
print("\n✅ TRAINING COMPLETE!")
|
| 278 |
+
|
| 279 |
+
except KeyboardInterrupt:
|
| 280 |
+
print("\n⚠️ Interrupted - saving...")
|
| 281 |
+
trainer.save_model("./lfm2-2.6b-interrupted")
|
| 282 |
+
|
| 283 |
+
except Exception as e:
|
| 284 |
+
print(f"\n❌ Error: {e}")
|
| 285 |
+
import traceback
|
| 286 |
+
traceback.print_exc()
|
| 287 |
+
raise
|
| 288 |
+
|
| 289 |
+
# 10. Save
|
| 290 |
+
output_dir = "./lfm2-2.6b-counselor-final"
|
| 291 |
+
lora_dir = "./lfm2-2.6b-counselor-lora"
|
| 292 |
+
|
| 293 |
+
trainer.save_model(output_dir)
|
| 294 |
+
tokenizer.save_pretrained(output_dir)
|
| 295 |
+
model.save_pretrained(lora_dir)
|
| 296 |
+
|
| 297 |
+
print(f"\n✅ Model saved to: {output_dir}")
|
| 298 |
+
|
| 299 |
+
wandb.finish()
|
| 300 |
+
|
| 301 |
+
print("\n" + "=" * 80)
|
| 302 |
+
print("🎉 2.6B TRAINING COMPLETE!")
|
| 303 |
+
print("=" * 80)
|