Techiiot commited on
Commit
0446288
·
verified ·
1 Parent(s): 27c46c6

Upload folder using huggingface_hub

Browse files
data_preparation.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data_preparation.py
2
+ import json
3
+ import os
4
+ from pathlib import Path
5
+ import pandas as pd
6
+ from typing import List, Dict, Tuple
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ from sklearn.model_selection import train_test_split
10
+
11
+ class KokoroChatProcessor:
12
+ def __init__(self, data_path: str):
13
+ self.data_path = Path(data_path)
14
+ self.conversations = []
15
+ self.processed_data = []
16
+
17
+ def load_all_conversations(self) -> List[Dict]:
18
+ """Load all JSON files from KokoroChat dataset"""
19
+ json_files = list(self.data_path.glob("**/*.json"))
20
+ print(f"Found {len(json_files)} conversation files")
21
+
22
+ for json_file in tqdm(json_files, desc="Loading conversations"):
23
+ try:
24
+ with open(json_file, 'r', encoding='utf-8') as f:
25
+ data = json.load(f)
26
+ self.conversations.append(data)
27
+ except Exception as e:
28
+ print(f"Error loading {json_file}: {e}")
29
+
30
+ return self.conversations
31
+
32
+ def create_training_examples(self) -> List[Dict]:
33
+ """Convert conversations to training format"""
34
+
35
+ for conv_data in tqdm(self.conversations, desc="Processing conversations"):
36
+ dialogue = conv_data.get('dialogue', [])
37
+ topic = conv_data.get('topic', {})
38
+ review = conv_data.get('review_by_client_jp', {})
39
+
40
+ # Create conversation context
41
+ conversation_pairs = []
42
+
43
+ for i in range(0, len(dialogue) - 1, 2):
44
+ if i + 1 < len(dialogue):
45
+ counselor_msg = dialogue[i]
46
+ client_msg = dialogue[i + 1] if i + 1 < len(dialogue) else None
47
+
48
+ if counselor_msg['role'] == 'counselor' and client_msg and client_msg['role'] == 'client':
49
+ # Build context from previous messages
50
+ context = self._build_context(dialogue[:i+1])
51
+
52
+ training_example = {
53
+ 'instruction': "あなたは共感的で専門的な心理カウンセラーです。クライアントの悩みに寄り添い、適切なサポートを提供してください。",
54
+ 'input': f"クライアント: {client_msg['utterance']}",
55
+ 'output': counselor_msg['utterance'],
56
+ 'context': context,
57
+ 'topic': topic.get('main_jp', ''),
58
+ 'quality_score': self._calculate_quality_score(review)
59
+ }
60
+
61
+ self.processed_data.append(training_example)
62
+
63
+ return self.processed_data
64
+
65
+ def _build_context(self, dialogue_history: List[Dict], max_turns: int = 5) -> str:
66
+ """Build conversation context from history"""
67
+ context_parts = []
68
+ start_idx = max(0, len(dialogue_history) - max_turns * 2)
69
+
70
+ for msg in dialogue_history[start_idx:]:
71
+ role = "カウンセラー" if msg['role'] == 'counselor' else "クライアント"
72
+ context_parts.append(f"{role}: {msg['utterance']}")
73
+
74
+ return "\n".join(context_parts)
75
+
76
+ def _calculate_quality_score(self, review: Dict) -> float:
77
+ """Calculate quality score from client review"""
78
+ if not review or review.get('点数') is None:
79
+ return 0.5 # Default middle score
80
+
81
+ # Normalize score (assuming max score is 100)
82
+ return review.get('点数', 50) / 100.0
83
+
84
+ def prepare_for_finetuning(self, test_size: float = 0.1, val_size: float = 0.1):
85
+ """Prepare train/val/test splits"""
86
+
87
+ # Filter high-quality examples (score > 0.6)
88
+ high_quality = [ex for ex in self.processed_data if ex['quality_score'] > 0.6]
89
+ print(f"Selected {len(high_quality)} high-quality examples")
90
+
91
+ # Create splits
92
+ train_data, test_data = train_test_split(high_quality, test_size=test_size, random_state=42)
93
+ train_data, val_data = train_test_split(train_data, test_size=val_size, random_state=42)
94
+
95
+ # Format for fine-tuning
96
+ def format_example(ex):
97
+ prompt = f"""### 指示:
98
+ {ex['instruction']}
99
+
100
+ ### コンテキスト:
101
+ {ex['context']}
102
+
103
+ ### 入力:
104
+ {ex['input']}
105
+
106
+ ### 応答:
107
+ {ex['output']}"""
108
+ return {'text': prompt}
109
+
110
+ train_formatted = [format_example(ex) for ex in train_data]
111
+ val_formatted = [format_example(ex) for ex in val_data]
112
+ test_formatted = [format_example(ex) for ex in test_data]
113
+
114
+ return train_formatted, val_formatted, test_formatted
115
+
116
+ # Execute data preparation
117
+ processor = KokoroChatProcessor('KokoroChat/data')
118
+ processor.load_all_conversations()
119
+ processor.create_training_examples()
120
+ train_data, val_data, test_data = processor.prepare_for_finetuning()
121
+
122
+ # Save processed data
123
+ import pickle
124
+ with open('processed_data.pkl', 'wb') as f:
125
+ pickle.dump({
126
+ 'train': train_data,
127
+ 'val': val_data,
128
+ 'test': test_data
129
+ }, f)
130
+
131
+ print(f"Training examples: {len(train_data)}")
132
+ print(f"Validation examples: {len(val_data)}")
133
+ print(f"Test examples: {len(test_data)}")
dataprocessing_multiturn.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # prepare_dataset_multiturn.py
2
+ import json
3
+ import os
4
+ from datasets import Dataset, Features, Value
5
+ import pandas as pd
6
+ from pathlib import Path
7
+
8
+ def parse_kokorochat_with_context(json_file_path, context_window=4, max_history_tokens=1500):
9
+ """
10
+ Parse KokoroChat with conversation history for realistic counseling.
11
+
12
+ Args:
13
+ json_file_path: Path to JSON file
14
+ context_window: Number of previous turns to include (default: 4 = 2 exchanges)
15
+ max_history_tokens: Approximate token limit for history (prevents too long sequences)
16
+ """
17
+ try:
18
+ with open(json_file_path, 'r', encoding='utf-8') as f:
19
+ data = json.load(f)
20
+ except Exception as e:
21
+ return [], 0
22
+
23
+ conversations = []
24
+ dialogue = data.get('dialogue', [])
25
+
26
+ # Get quality score
27
+ review_en = data.get('review_by_client_en', {})
28
+ total_score = review_en.get('score', 0)
29
+
30
+ # Get topic
31
+ topic = data.get('topic', {})
32
+ main_topic = topic.get('main_en', '')
33
+ sub_topic = topic.get('sub', '')
34
+
35
+ # Extract examples with context
36
+ for i in range(len(dialogue) - 1):
37
+ current = dialogue[i]
38
+ next_turn = dialogue[i + 1]
39
+
40
+ # Look for client -> counselor pairs
41
+ if current['role'] == 'client' and next_turn['role'] == 'counselor':
42
+ client_msg = current['utterance'].strip()
43
+ counselor_msg = next_turn['utterance'].strip()
44
+
45
+ if len(client_msg) > 5 and len(counselor_msg) > 5:
46
+ # Get conversation history (previous turns)
47
+ start_idx = max(0, i - context_window)
48
+ history = dialogue[start_idx:i]
49
+
50
+ # Estimate token count (rough: ~3 chars per token for Japanese)
51
+ history_text = ''.join([h['utterance'] for h in history])
52
+ if len(history_text) < max_history_tokens * 3: # Keep reasonable length
53
+ conversations.append({
54
+ 'history': history,
55
+ 'client': client_msg,
56
+ 'counselor': counselor_msg,
57
+ 'quality_score': total_score,
58
+ 'topic_main': main_topic,
59
+ 'topic_sub': sub_topic,
60
+ 'dialogue_id': Path(json_file_path).stem
61
+ })
62
+
63
+ return conversations, total_score
64
+
65
+ def format_conversation_for_lfm2(conversation):
66
+ """
67
+ Format conversation with history into LFM2 ChatML template
68
+ """
69
+ # Start with system prompt
70
+ formatted = "<|im_start|>system\n"
71
+ formatted += "あなたは経験豊富な心理カウンセラーです。クライアントの話を傾聴し、共感的で支援的な応答をしてください。<|im_end|>\n"
72
+
73
+ # Add conversation history
74
+ for turn in conversation['history']:
75
+ if turn['role'] == 'client':
76
+ formatted += f"<|im_start|>user\n{turn['utterance']}<|im_end|>\n"
77
+ elif turn['role'] == 'counselor':
78
+ formatted += f"<|im_start|>assistant\n{turn['utterance']}<|im_end|>\n"
79
+
80
+ # Add current exchange (what we're training on)
81
+ formatted += f"<|im_start|>user\n{conversation['client']}<|im_end|>\n"
82
+ formatted += f"<|im_start|>assistant\n{conversation['counselor']}<|im_end|><|endoftext|>"
83
+
84
+ return formatted
85
+
86
+ def create_training_dataset_multiturn(
87
+ data_dir="./KokoroChat/data",
88
+ min_score=70,
89
+ context_window=4
90
+ ):
91
+ """
92
+ Create training dataset with conversation context.
93
+
94
+ Args:
95
+ data_dir: Directory containing JSON files
96
+ min_score: Minimum quality score (0-100, recommend 85 for top quality)
97
+ context_window: Number of previous turns to include
98
+ """
99
+ json_files = list(Path(data_dir).rglob("*.json"))
100
+ print(f"Found {len(json_files)} JSON files")
101
+
102
+ all_conversations = []
103
+ score_distribution = []
104
+
105
+ print("\nProcessing files with multi-turn context...")
106
+ for idx, json_file in enumerate(json_files):
107
+ if idx % 1000 == 0:
108
+ print(f"Processed {idx}/{len(json_files)} files...")
109
+
110
+ try:
111
+ convs, score = parse_kokorochat_with_context(
112
+ json_file,
113
+ context_window=context_window
114
+ )
115
+ score_distribution.append(score)
116
+
117
+ if score >= min_score:
118
+ all_conversations.extend(convs)
119
+ except Exception as e:
120
+ continue
121
+
122
+ print(f"\n=== Processing Results ===")
123
+ print(f"High-quality files (>= {min_score}): {sum(1 for s in score_distribution if s >= min_score)}")
124
+ print(f"Total conversation examples: {len(all_conversations)}")
125
+
126
+ if len(all_conversations) == 0:
127
+ print(f"❌ No conversations found! Try lowering min_score (current: {min_score})")
128
+ return None
129
+
130
+ # Format for LFM2
131
+ formatted_data = []
132
+ for conv in all_conversations:
133
+ formatted_text = format_conversation_for_lfm2(conv)
134
+
135
+ formatted_data.append({
136
+ 'text': formatted_text,
137
+ 'quality_score': conv['quality_score'],
138
+ 'topic_main': conv['topic_main'],
139
+ 'topic_sub': conv['topic_sub'],
140
+ 'has_context': len(conv['history']) > 0
141
+ })
142
+
143
+ # Create dataset
144
+ features = Features({
145
+ 'text': Value('string'),
146
+ 'quality_score': Value('int64'),
147
+ 'topic_main': Value('string'),
148
+ 'topic_sub': Value('string'),
149
+ 'has_context': Value('bool')
150
+ })
151
+
152
+ df = pd.DataFrame(formatted_data)
153
+ dataset = Dataset.from_pandas(df, features=features)
154
+ dataset = dataset.train_test_split(test_size=0.1, seed=42)
155
+
156
+ print(f"\n=== Final Dataset ===")
157
+ print(f"Training samples: {len(dataset['train'])}")
158
+ print(f"Validation samples: {len(dataset['test'])}")
159
+ print(f"Examples with context: {sum(df['has_context'])}")
160
+
161
+ # Save
162
+ dataset.save_to_disk("./kokorochat_processed_multiturn")
163
+ print("\n✅ Multi-turn dataset saved to ./kokorochat_processed_multiturn")
164
+
165
+ # Show sample
166
+ print("\n=== Sample Training Example (with context) ===")
167
+ sample = dataset['train'][5]['text']
168
+ print(sample[:1000] + "\n..." if len(sample) > 1000 else sample)
169
+
170
+ return dataset
171
+
172
+ if __name__ == "__main__":
173
+ dataset = create_training_dataset_multiturn(
174
+ data_dir="./KokoroChat/kokorochat_dialogues",
175
+ min_score=60, # Top 30% quality
176
+ context_window=4 # Include 4 previous turns
177
+ )
finetune_lfm2.6b.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # finetune_lfm2_2.6b_FIXED.py
2
+ import torch
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ AutoModelForCausalLM,
6
+ TrainingArguments,
7
+ Trainer,
8
+ BitsAndBytesConfig,
9
+ GPT2Tokenizer
10
+ )
11
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
12
+ from datasets import load_from_disk
13
+ from dataclasses import dataclass
14
+ from typing import Any, Dict, List
15
+ import wandb
16
+ import os
17
+ import warnings
18
+ warnings.filterwarnings('ignore')
19
+
20
+ print("=" * 80)
21
+ print("LFM2-2.6B FINE-TUNING - FIXED VERSION")
22
+ print("=" * 80)
23
+ print(f"PyTorch: {torch.__version__}")
24
+ print(f"CUDA: {torch.cuda.is_available()}")
25
+ print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")
26
+
27
+ if torch.cuda.is_available():
28
+ gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
29
+ print(f"GPU Memory: {gpu_memory_gb:.1f} GB")
30
+
31
+ import bitsandbytes as bnb
32
+ print("✅ BitsAndBytes OK")
33
+
34
+ # Initialize W&B
35
+ wandb.init(
36
+ project="liquid-ai-hackathon-kokorochat",
37
+ name="LFM2-2.6B-counselor-FIXED",
38
+ config={
39
+ "model": "LFM2-2.6B",
40
+ "dataset": "KokoroChat-MultiTurn",
41
+ "task": "psychological-counseling"
42
+ }
43
+ )
44
+
45
+ print("\n" + "=" * 80)
46
+ print("LOADING MODEL (WITH FALLBACK)")
47
+ print("=" * 80)
48
+
49
+ LOCAL_MODEL_PATH = "./models/LFM2-2.6B"
50
+ HF_MODEL_NAME = "LiquidAI/LFM2-2.6B"
51
+
52
+ # 1. Load tokenizer with GPT2 fallback
53
+ print("\n1. Loading tokenizer...")
54
+ try:
55
+ tokenizer = AutoTokenizer.from_pretrained(
56
+ LOCAL_MODEL_PATH,
57
+ trust_remote_code=True,
58
+ local_files_only=True
59
+ )
60
+ print(" ✅ LFM2 tokenizer loaded!")
61
+ except Exception as e:
62
+ print(f" ⚠️ LFM2 tokenizer failed")
63
+ print(" 🔄 Using GPT2 tokenizer...")
64
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
65
+ print(" ✅ GPT2 tokenizer loaded!")
66
+
67
+ tokenizer.pad_token = tokenizer.eos_token
68
+ tokenizer.padding_side = "right"
69
+
70
+ # 2. QLoRA config
71
+ print("\n2. Configuring QLoRA...")
72
+ bnb_config = BitsAndBytesConfig(
73
+ load_in_4bit=True,
74
+ bnb_4bit_quant_type="nf4",
75
+ bnb_4bit_compute_dtype=torch.bfloat16,
76
+ bnb_4bit_use_double_quant=True,
77
+ )
78
+
79
+ # 3. Load model with proper fallback
80
+ print("\n3. Loading LFM2-2.6B model...")
81
+
82
+ # First, try to ensure we have the custom model files
83
+ print(" 📥 Checking for custom model files...")
84
+
85
+ # Check if modeling files exist
86
+ custom_files = ["modeling_lfm2.py", "configuration_lfm2.py"]
87
+ has_custom_files = all(
88
+ os.path.exists(os.path.join(LOCAL_MODEL_PATH, f))
89
+ for f in custom_files
90
+ )
91
+
92
+ if not has_custom_files:
93
+ print(" ⚠️ Custom model files missing in local directory")
94
+ print(" 📥 Need to download from HuggingFace with custom code...")
95
+
96
+ # Download with custom code
97
+ from huggingface_hub import snapshot_download
98
+
99
+ print(" ⏳ Downloading model with custom code (one-time)...")
100
+ snapshot_download(
101
+ repo_id=HF_MODEL_NAME,
102
+ local_dir=LOCAL_MODEL_PATH,
103
+ local_dir_use_symlinks=False,
104
+ ignore_patterns=[] # Don't ignore anything
105
+ )
106
+ print(" ✅ Model downloaded with custom code!")
107
+
108
+ # Now load the model
109
+ print(" ⏳ Loading model (~2-4 minutes)...")
110
+
111
+ try:
112
+ # Try local first with trust_remote_code
113
+ model = AutoModelForCausalLM.from_pretrained(
114
+ LOCAL_MODEL_PATH,
115
+ quantization_config=bnb_config,
116
+ device_map="auto",
117
+ trust_remote_code=True, # CRITICAL!
118
+ torch_dtype=torch.bfloat16,
119
+ local_files_only=False # Allow downloading custom code if needed
120
+ )
121
+ print(" ✅ Model loaded from local!")
122
+
123
+ except Exception as e:
124
+ print(f" ⚠️ Local load failed: {str(e)[:100]}")
125
+ print(" 📥 Loading directly from HuggingFace...")
126
+
127
+ # Load from HuggingFace Hub
128
+ model = AutoModelForCausalLM.from_pretrained(
129
+ HF_MODEL_NAME,
130
+ quantization_config=bnb_config,
131
+ device_map="auto",
132
+ trust_remote_code=True,
133
+ torch_dtype=torch.bfloat16
134
+ )
135
+ print(" ✅ Model loaded from HuggingFace!")
136
+
137
+ model = prepare_model_for_kbit_training(model)
138
+ model.config.use_cache = False
139
+ print(" ✅ Model prepared!")
140
+
141
+ # 4. LoRA - 2.6B configuration
142
+ print("\n4. Applying LoRA (2.6B config)...")
143
+ lora_config = LoraConfig(
144
+ r=64, # Higher for 2.6B
145
+ lora_alpha=128,
146
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
147
+ "gate_proj", "up_proj", "down_proj"],
148
+ lora_dropout=0.05,
149
+ bias="none",
150
+ task_type="CAUSAL_LM"
151
+ )
152
+
153
+ model = get_peft_model(model, lora_config)
154
+ print("\n📊 Trainable Parameters:")
155
+ model.print_trainable_parameters()
156
+
157
+ # 5. Load dataset
158
+ print("\n5. Loading dataset...")
159
+ dataset = load_from_disk("./kokorochat_processed_multiturn")
160
+ print(f" ✅ Training: {len(dataset['train']):,}, Val: {len(dataset['test']):,}")
161
+
162
+ # 6. Data Collator (same as 1.2B)
163
+ @dataclass
164
+ class DataCollatorForCausalLM:
165
+ tokenizer: Any
166
+ max_length: int = 2048
167
+
168
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
169
+ texts = [f["text"] for f in features]
170
+ batch = self.tokenizer(
171
+ texts,
172
+ max_length=self.max_length,
173
+ padding=True,
174
+ truncation=True,
175
+ return_tensors="pt"
176
+ )
177
+ batch["labels"] = batch["input_ids"].clone()
178
+ batch["labels"][batch["labels"] == self.tokenizer.pad_token_id] = -100
179
+ return batch
180
+
181
+ data_collator = DataCollatorForCausalLM(tokenizer=tokenizer)
182
+
183
+ # 7. Training Configuration - 2.6B optimized
184
+ print("\n6. Configuring training (2.6B optimized)...")
185
+
186
+ gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
187
+
188
+ if gpu_memory_gb >= 70:
189
+ per_device_batch = 2
190
+ grad_accum = 16
191
+ print(f" 🚀 {gpu_memory_gb:.0f}GB GPU → batch=2, accum=16")
192
+ else:
193
+ per_device_batch = 1
194
+ grad_accum = 32
195
+ print(f" ⚡ {gpu_memory_gb:.0f}GB GPU → batch=1, accum=32")
196
+
197
+ training_args = TrainingArguments(
198
+ output_dir="./lfm2-2.6b-checkpoints-fixed",
199
+
200
+ # Batch (memory-adjusted for 2.6B)
201
+ per_device_train_batch_size=per_device_batch,
202
+ per_device_eval_batch_size=per_device_batch,
203
+ gradient_accumulation_steps=grad_accum,
204
+
205
+ # Learning (optimized for 2.6B)
206
+ num_train_epochs=3, # 2.6B learns faster
207
+ learning_rate=2e-4, # Lower for stability
208
+ warmup_steps=200,
209
+ lr_scheduler_type="cosine",
210
+
211
+ # Optimization
212
+ fp16=False,
213
+ bf16=True,
214
+ logging_steps=10,
215
+ eval_strategy="steps",
216
+ eval_steps=50,
217
+ save_strategy="steps",
218
+ save_steps=100,
219
+ save_total_limit=5,
220
+ load_best_model_at_end=True,
221
+ metric_for_best_model="eval_loss",
222
+ optim="paged_adamw_8bit",
223
+ report_to="wandb",
224
+ gradient_checkpointing=True,
225
+ max_grad_norm=0.3,
226
+ logging_dir="./logs",
227
+ remove_unused_columns=False,
228
+ dataloader_num_workers=4,
229
+ dataloader_pin_memory=True,
230
+ )
231
+
232
+ effective_batch = per_device_batch * grad_accum
233
+ steps_per_epoch = len(dataset['train']) // effective_batch
234
+ total_steps = steps_per_epoch * 3
235
+
236
+ print("\n" + "=" * 80)
237
+ print("📊 2.6B TRAINING CONFIGURATION")
238
+ print("=" * 80)
239
+ print(f"\n✅ Batch Config:")
240
+ print(f" Per-device: {per_device_batch}")
241
+ print(f" Gradient accum: {grad_accum}")
242
+ print(f" → Effective: {effective_batch}")
243
+
244
+ print(f"\n✅ Learning Config:")
245
+ print(f" Learning rate: 2e-4 (vs 3e-4 for 1.2B)")
246
+ print(f" Epochs: 3 (vs 4 for 1.2B)")
247
+ print(f" LoRA rank: 64 (vs 32 for 1.2B)")
248
+
249
+ print(f"\n✅ Training Stats:")
250
+ print(f" Training samples: {len(dataset['train']):,}")
251
+ print(f" Steps per epoch: {steps_per_epoch:,}")
252
+ print(f" Total steps: {total_steps:,}")
253
+
254
+ print(f"\n⏱️ Estimated Time:")
255
+ if gpu_memory_gb >= 80:
256
+ print(f" ~5-8 hours on {gpu_memory_gb:.0f}GB GPU")
257
+ else:
258
+ print(f" ~8-12 hours on {gpu_memory_gb:.0f}GB GPU")
259
+
260
+ # 8. Trainer (same as 1.2B)
261
+ trainer = Trainer(
262
+ model=model,
263
+ args=training_args,
264
+ train_dataset=dataset["train"],
265
+ eval_dataset=dataset["test"],
266
+ data_collator=data_collator,
267
+ )
268
+
269
+ # 9. Start training
270
+ print("\n" + "=" * 80)
271
+ print("🚀 STARTING 2.6B TRAINING")
272
+ print("=" * 80)
273
+ print(f"📊 Monitor: https://wandb.ai/sandeeptechiot-ai/liquid-ai-hackathon-kokorochat\n")
274
+
275
+ try:
276
+ trainer.train()
277
+ print("\n✅ TRAINING COMPLETE!")
278
+
279
+ except KeyboardInterrupt:
280
+ print("\n⚠️ Interrupted - saving...")
281
+ trainer.save_model("./lfm2-2.6b-interrupted")
282
+
283
+ except Exception as e:
284
+ print(f"\n❌ Error: {e}")
285
+ import traceback
286
+ traceback.print_exc()
287
+ raise
288
+
289
+ # 10. Save
290
+ output_dir = "./lfm2-2.6b-counselor-final"
291
+ lora_dir = "./lfm2-2.6b-counselor-lora"
292
+
293
+ trainer.save_model(output_dir)
294
+ tokenizer.save_pretrained(output_dir)
295
+ model.save_pretrained(lora_dir)
296
+
297
+ print(f"\n✅ Model saved to: {output_dir}")
298
+
299
+ wandb.finish()
300
+
301
+ print("\n" + "=" * 80)
302
+ print("🎉 2.6B TRAINING COMPLETE!")
303
+ print("=" * 80)