lfm_complete_code / finetune_lfm.py
Techiiot's picture
Upload folder using huggingface_hub
27c46c6 verified
# import torch
# from transformers import (
# AutoModelForCausalLM,
# AutoTokenizer,
# TrainingArguments,
# Trainer,
# DataCollatorForLanguageModeling,
# BitsAndBytesConfig
# )
# from peft import (
# LoraConfig,
# get_peft_model,
# prepare_model_for_kbit_training,
# TaskType
# )
# from datasets import load_dataset, Dataset
# import os
# from typing import Dict, List, Optional
# import numpy as np
# from tqdm import tqdm
# import json
# import gc
# import warnings
# warnings.filterwarnings('ignore')
# class LFMCounselorFineTuner:
# def __init__(self, model_name: str = "LiquidAI/LFM2-2.6B", use_4bit: bool = True):
# """
# Initialize the fine-tuner for LFM models
# Args:
# model_name: Name of the base model
# use_4bit: Whether to use 4-bit quantization for memory efficiency
# """
# self.model_name = model_name
# self.use_4bit = use_4bit
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {self.device}")
# if torch.cuda.is_available():
# print(f"GPU: {torch.cuda.get_device_name(0)}")
# print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
# # Disable wandb for simplicity
# os.environ["WANDB_DISABLED"] = "true"
# def setup_model_and_tokenizer(self):
# """Setup model with quantization and LoRA"""
# print("Loading tokenizer...")
# # Tokenizer setup
# try:
# self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
# except:
# # Fallback to a known working tokenizer if model-specific one fails
# print("Using fallback tokenizer...")
# self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
# # Add padding token if it doesn't exist
# if self.tokenizer.pad_token is None:
# self.tokenizer.pad_token = self.tokenizer.eos_token
# if self.tokenizer.eos_token is None:
# self.tokenizer.eos_token = "</s>"
# self.tokenizer.pad_token = "</s>"
# self.tokenizer.padding_side = "right"
# # Quantization config for memory efficiency
# if self.use_4bit:
# print("Setting up 4-bit quantization...")
# bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.float16, # Use float16 for better compatibility
# bnb_4bit_use_double_quant=True
# )
# else:
# bnb_config = None
# # Load model
# print(f"Loading model: {self.model_name}...")
# try:
# self.model = AutoModelForCausalLM.from_pretrained(
# self.model_name,
# quantization_config=bnb_config,
# device_map="auto",
# trust_remote_code=True,
# torch_dtype=torch.float16
# )
# except Exception as e:
# print(f"Error loading model: {e}")
# print("Attempting to load without quantization...")
# self.model = AutoModelForCausalLM.from_pretrained(
# self.model_name,
# device_map="auto",
# trust_remote_code=True,
# torch_dtype=torch.float16,
# low_cpu_mem_usage=True
# )
# # Enable gradient checkpointing to save memory
# if hasattr(self.model, 'gradient_checkpointing_enable'):
# self.model.gradient_checkpointing_enable()
# # Prepare model for k-bit training
# if self.use_4bit:
# print("Preparing model for 4-bit training...")
# self.model = prepare_model_for_kbit_training(self.model)
# # LoRA configuration - optimized for counseling task
# print("Applying LoRA configuration...")
# # Find the target modules dynamically
# target_modules = self.find_target_modules()
# lora_config = LoraConfig(
# r=16, # Reduced rank for stability
# lora_alpha=32, # Alpha parameter for LoRA scaling
# target_modules=target_modules,
# lora_dropout=0.05,
# bias="none",
# task_type=TaskType.CAUSAL_LM,
# inference_mode=False
# )
# # Apply LoRA
# self.model = get_peft_model(self.model, lora_config)
# # Print trainable parameters
# self.model.print_trainable_parameters()
# def find_target_modules(self):
# """Find linear modules to apply LoRA to"""
# target_modules = []
# for name, module in self.model.named_modules():
# if isinstance(module, torch.nn.Linear):
# # Extract the module name
# names = name.split('.')
# if len(names) > 0:
# target_modules.append(names[-1])
# # Remove duplicates and filter common patterns
# target_modules = list(set(target_modules))
# # Common patterns for transformer models
# common_targets = ["q_proj", "v_proj", "k_proj", "o_proj",
# "gate_proj", "up_proj", "down_proj",
# "fc1", "fc2", "query", "key", "value", "dense"]
# # Filter to only include common targets if they exist
# final_targets = [t for t in target_modules if any(ct in t.lower() for ct in common_targets)]
# # If no common targets found, use all linear layers
# if not final_targets:
# final_targets = target_modules[:6] # Limit to prevent too many parameters
# print(f"LoRA target modules: {final_targets}")
# return final_targets if final_targets else ["q_proj", "v_proj"] # Fallback
# def load_and_process_datasets(self, data_path: str):
# """Load and process datasets without multiprocessing issues"""
# print(f"Loading datasets from {data_path}...")
# # Load train dataset
# train_texts = []
# with open(f'{data_path}/train.jsonl', 'r', encoding='utf-8') as f:
# for line in tqdm(f, desc="Loading training data"):
# data = json.loads(line)
# train_texts.append(data['text'])
# # Load validation dataset
# val_texts = []
# with open(f'{data_path}/validation.jsonl', 'r', encoding='utf-8') as f:
# for line in tqdm(f, desc="Loading validation data"):
# data = json.loads(line)
# val_texts.append(data['text'])
# print(f"Loaded {len(train_texts)} training examples")
# print(f"Loaded {len(val_texts)} validation examples")
# # Tokenize datasets in batches (avoiding multiprocessing)
# print("Tokenizing training dataset...")
# train_encodings = self.tokenize_texts(train_texts)
# print("Tokenizing validation dataset...")
# val_encodings = self.tokenize_texts(val_texts)
# # Create datasets
# self.train_dataset = Dataset.from_dict(train_encodings)
# self.val_dataset = Dataset.from_dict(val_encodings)
# # Set format for PyTorch
# self.train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
# self.val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
# # Clean up memory
# del train_texts, val_texts, train_encodings, val_encodings
# gc.collect()
# def tokenize_texts(self, texts: List[str], batch_size: int = 100):
# """Tokenize texts in batches to avoid memory issues"""
# all_input_ids = []
# all_attention_masks = []
# for i in tqdm(range(0, len(texts), batch_size), desc="Tokenizing"):
# batch_texts = texts[i:i + batch_size]
# # Tokenize batch
# encodings = self.tokenizer(
# batch_texts,
# truncation=True,
# padding='max_length',
# max_length=512,
# return_tensors='pt'
# )
# # Convert to lists
# all_input_ids.extend(encodings['input_ids'].tolist())
# all_attention_masks.extend(encodings['attention_mask'].tolist())
# # Create labels (same as input_ids for language modeling)
# labels = all_input_ids.copy()
# return {
# 'input_ids': all_input_ids,
# 'attention_mask': all_attention_masks,
# 'labels': labels
# }
# def setup_training_args(self, output_dir: str = "./counselor_model_2b"):
# """Setup training arguments optimized for counseling task"""
# print("Setting up training arguments...")
# # Calculate batch sizes based on available memory
# if torch.cuda.is_available():
# gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
# if gpu_memory < 16: # Less than 16GB
# batch_size = 1
# gradient_accumulation = 16
# elif gpu_memory < 24: # Less than 24GB
# batch_size = 2
# gradient_accumulation = 8
# else: # 24GB or more
# batch_size = 4
# gradient_accumulation = 4
# else:
# batch_size = 1
# gradient_accumulation = 16
# print(f"Using batch_size={batch_size}, gradient_accumulation={gradient_accumulation}")
# self.training_args = TrainingArguments(
# output_dir=output_dir,
# num_train_epochs=3,
# per_device_train_batch_size=batch_size,
# per_device_eval_batch_size=batch_size,
# gradient_accumulation_steps=gradient_accumulation,
# gradient_checkpointing=True,
# warmup_steps=100,
# learning_rate=5e-5, # Conservative learning rate
# fp16=True,
# logging_steps=50,
# eval_strategy="steps",
# eval_steps=200,
# save_strategy="steps",
# save_steps=400,
# save_total_limit=2,
# load_best_model_at_end=True,
# metric_for_best_model="eval_loss",
# greater_is_better=False,
# report_to="none", # Disable all reporting
# push_to_hub=False,
# optim="adamw_torch", # Use standard optimizer
# lr_scheduler_type="linear",
# weight_decay=0.01,
# max_grad_norm=1.0,
# remove_unused_columns=False,
# label_names=["labels"],
# dataloader_num_workers=0, # Disable multiprocessing in dataloader
# dataloader_pin_memory=False, # Disable pinned memory to avoid issues
# )
# def train(self):
# """Execute training"""
# print("Initializing trainer...")
# # Data collator for language modeling
# data_collator = DataCollatorForLanguageModeling(
# tokenizer=self.tokenizer,
# mlm=False,
# pad_to_multiple_of=8
# )
# # Custom training to handle potential issues
# try:
# # Initialize trainer
# trainer = Trainer(
# model=self.model,
# args=self.training_args,
# train_dataset=self.train_dataset,
# eval_dataset=self.val_dataset,
# data_collator=data_collator,
# tokenizer=self.tokenizer,
# )
# # Start training
# print("="*50)
# print("Starting fine-tuning...")
# print("="*50)
# # Train with error handling
# train_result = trainer.train()
# # Save the final model
# print("\nSaving fine-tuned model...")
# trainer.save_model(f"{self.training_args.output_dir}/final_model_2b")
# self.tokenizer.save_pretrained(f"{self.training_args.output_dir}/final_model_2b")
# # Save training metrics
# with open(f"{self.training_args.output_dir}/training_metrics.json", 'w') as f:
# json.dump(train_result.metrics, f, indent=2)
# print("\n" + "="*50)
# print("Training completed successfully!")
# print(f"Model saved to: {self.training_args.output_dir}/final_model_2b")
# print("="*50)
# return trainer
# except Exception as e:
# print(f"Error during training: {e}")
# print("Attempting to save checkpoint...")
# # Try to save whatever we have
# try:
# self.model.save_pretrained(f"{self.training_args.output_dir}/checkpoint_emergency")
# self.tokenizer.save_pretrained(f"{self.training_args.output_dir}/checkpoint_emergency")
# print(f"Emergency checkpoint saved to: {self.training_args.output_dir}/checkpoint_emergency")
# except:
# print("Could not save emergency checkpoint")
# raise e
# def test_model(model_path: str, tokenizer_path: str):
# """Test the fine-tuned model with a sample input"""
# print("\n" + "="*50)
# print("Testing fine-tuned model...")
# print("="*50)
# # Load model and tokenizer
# from peft import PeftModel, PeftConfig
# tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
# # Try to load as PEFT model
# try:
# config = PeftConfig.from_pretrained(model_path)
# model = AutoModelForCausalLM.from_pretrained(
# config.base_model_name_or_path,
# torch_dtype=torch.float16,
# device_map="auto"
# )
# model = PeftModel.from_pretrained(model, model_path)
# except:
# # Load as regular model
# model = AutoModelForCausalLM.from_pretrained(
# model_path,
# torch_dtype=torch.float16,
# device_map="auto"
# )
# model.eval()
# # Test input
# test_input = "こんにちは。最近ストレスを感じています。"
# # Generate response
# inputs = tokenizer(test_input, return_tensors="pt")
# inputs = {k: v.cuda() if torch.cuda.is_available() else v for k, v in inputs.items()}
# with torch.no_grad():
# outputs = model.generate(
# **inputs,
# max_new_tokens=100,
# temperature=0.1,
# do_sample=True,
# top_p=0.9
# )
# response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# print(f"Input: {test_input}")
# print(f"Response: {response}")
# print("="*50)
# # Main training script
# if __name__ == "__main__":
# import argparse
# parser = argparse.ArgumentParser(description='Fine-tune LFM model for counseling')
# parser.add_argument('--model_name', type=str, default='gpt2', # Using GPT2 as fallback
# help='Base model name (use gpt2 if liquid model fails)')
# parser.add_argument('--data_path', type=str, default='./processed_data_score80',
# help='Path to processed data')
# parser.add_argument('--output_dir', type=str, default='./counselor_model_2b',
# help='Output directory for fine-tuned model')
# parser.add_argument('--use_4bit', action='store_true', default=False,
# help='Use 4-bit quantization (set to False for stability)')
# parser.add_argument('--test_only', action='store_true',
# help='Only test existing model')
# args = parser.parse_args()
# if args.test_only:
# # Test existing model
# test_model(
# f"{args.output_dir}/final_model_2b",
# f"{args.output_dir}/final_model_2b"
# )
# else:
# # Check if CUDA is available
# if not torch.cuda.is_available():
# print("Warning: CUDA is not available. Training will be very slow on CPU.")
# print("It's highly recommended to use a GPU for training.")
# response = input("Do you want to continue anyway? (y/n): ")
# if response.lower() != 'y':
# exit()
# try:
# # Clear GPU cache
# if torch.cuda.is_available():
# torch.cuda.empty_cache()
# # Initialize fine-tuner
# print(f"Initializing fine-tuner with model: {args.model_name}")
# finetuner = LFMCounselorFineTuner(
# model_name=args.model_name,
# use_4bit=args.use_4bit
# )
# # Setup model
# print("\nSetting up model and tokenizer...")
# finetuner.setup_model_and_tokenizer()
# # Load datasets (using new method without multiprocessing)
# print("\nLoading and processing datasets...")
# finetuner.load_and_process_datasets(args.data_path)
# # Setup training arguments
# print("\nSetting up training arguments...")
# finetuner.setup_training_args(args.output_dir)
# # Train
# trainer = finetuner.train()
# # Test the model
# print("\nTesting the fine-tuned model...")
# test_model(
# f"{args.output_dir}/final_model_2b",
# f"{args.output_dir}/final_model_2b"
# )
# print("\n✅ Fine-tuning completed successfully!")
# print(f"📁 Model saved to: {args.output_dir}/final_model_2b")
# print("\nNext steps:")
# print("1. Test more: python finetune_lfm.py --test_only")
# print("2. Run benchmarking: python benchmark_model.py")
# print("3. Optimize for mobile: python optimize_for_mobile.py")
# except KeyboardInterrupt:
# print("\n\nTraining interrupted by user.")
# print("Partial model may be saved in checkpoints.")
# except Exception as e:
# print(f"\n❌ Error during fine-tuning: {e}")
# import traceback
# traceback.print_exc()
# print("\nTroubleshooting tips:")
# print("1. Try reducing batch size")
# print("2. Try without 4-bit quantization: remove --use_4bit")
# print("3. Try with a smaller model like gpt2")
# print("4. Ensure you have enough GPU memory")
###### wandb login ######
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
BitsAndBytesConfig,
TrainerCallback
)
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training,
TaskType
)
from datasets import load_dataset, Dataset
import os
from typing import Dict, List, Optional
import numpy as np
from tqdm import tqdm
import json
import gc
import warnings
import wandb
from datetime import datetime
warnings.filterwarnings('ignore')
class LFMCounselorFineTuner:
def __init__(self, model_name: str = "LiquidAI/LFM2-2.6B", use_4bit: bool = True):
"""
Initialize the fine-tuner for LFM models
Args:
model_name: Name of the base model
use_4bit: Whether to use 4-bit quantization for memory efficiency
"""
self.model_name = model_name
self.use_4bit = use_4bit
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
gpu_memory = 0
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
print(f"GPU: {gpu_name}")
print(f"GPU Memory: {gpu_memory:.2f} GB")
# Initialize WandB (always enabled)
try:
# Create a unique run name with timestamp
run_name = f"lfm-counselor-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
# Initialize wandb with comprehensive config
wandb.init(
project="liquid-counselor-hackathon",
name=run_name,
config={
"model_name": model_name,
"use_4bit_quantization": use_4bit,
"device": str(self.device),
"gpu": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU",
"gpu_memory_gb": gpu_memory,
"framework": "transformers",
"peft_method": "LoRA",
"task": "japanese_counseling",
"dataset": "KokoroChat"
},
tags=["counseling", "japanese", "lfm", "finetune", "hackathon"]
)
print(f"✅ WandB initialized: {wandb.run.name}")
print(f"📊 View run at: {wandb.run.get_url()}")
self.wandb_enabled = True
except Exception as e:
print(f"⚠️ WandB initialization failed: {e}")
print("Continuing without WandB logging...")
self.wandb_enabled = False
os.environ["WANDB_DISABLED"] = "true"
def setup_model_and_tokenizer(self):
"""Setup model with quantization and LoRA"""
print("Loading tokenizer...")
# Tokenizer setup
try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
except:
# Fallback to a known working tokenizer if model-specific one fails
print("Using fallback tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Add padding token if it doesn't exist
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
if self.tokenizer.eos_token is None:
self.tokenizer.eos_token = "</s>"
self.tokenizer.pad_token = "</s>"
self.tokenizer.padding_side = "right"
# Quantization config for memory efficiency
if self.use_4bit:
print("Setting up 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)
else:
bnb_config = None
# Load model
print(f"Loading model: {self.model_name}...")
try:
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16
)
except Exception as e:
print(f"Error loading model: {e}")
print("Attempting to load without quantization...")
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
# Enable gradient checkpointing to save memory
if hasattr(self.model, 'gradient_checkpointing_enable'):
self.model.gradient_checkpointing_enable()
# Prepare model for k-bit training
if self.use_4bit:
print("Preparing model for 4-bit training...")
self.model = prepare_model_for_kbit_training(self.model)
# LoRA configuration - optimized for counseling task
print("Applying LoRA configuration...")
# Find the target modules dynamically
target_modules = self.find_target_modules()
lora_config = LoraConfig(
r=16, # Reduced rank for stability
lora_alpha=32, # Alpha parameter for LoRA scaling
target_modules=target_modules,
lora_dropout=0.05,
bias="none",
task_type=TaskType.CAUSAL_LM,
inference_mode=False
)
# Apply LoRA
self.model = get_peft_model(self.model, lora_config)
# Get trainable parameters info
trainable_params = 0
all_params = 0
for _, param in self.model.named_parameters():
all_params += param.numel()
if param.requires_grad:
trainable_params += param.numel()
trainable_percentage = 100 * trainable_params / all_params if all_params > 0 else 0
print(f"Trainable parameters: {trainable_params:,} / {all_params:,} ({trainable_percentage:.2f}%)")
# Log model architecture to WandB
if self.wandb_enabled:
wandb.config.update({
"lora_r": lora_config.r,
"lora_alpha": lora_config.lora_alpha,
"lora_dropout": lora_config.lora_dropout,
"lora_target_modules": target_modules,
"total_parameters": all_params,
"trainable_parameters": trainable_params,
"trainable_percentage": trainable_percentage
})
self.model.print_trainable_parameters()
def find_target_modules(self):
"""Find linear modules to apply LoRA to"""
target_modules = []
for name, module in self.model.named_modules():
if isinstance(module, torch.nn.Linear):
# Extract the module name
names = name.split('.')
if len(names) > 0:
target_modules.append(names[-1])
# Remove duplicates and filter common patterns
target_modules = list(set(target_modules))
# Common patterns for transformer models
common_targets = ["q_proj", "v_proj", "k_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
"fc1", "fc2", "query", "key", "value", "dense"]
# Filter to only include common targets if they exist
final_targets = [t for t in target_modules if any(ct in t.lower() for ct in common_targets)]
# If no common targets found, use all linear layers
if not final_targets:
final_targets = target_modules[:6] # Limit to prevent too many parameters
print(f"LoRA target modules: {final_targets}")
return final_targets if final_targets else ["q_proj", "v_proj"] # Fallback
def load_and_process_datasets(self, data_path: str):
"""Load and process datasets without multiprocessing issues"""
print(f"Loading datasets from {data_path}...")
# Load train dataset
train_texts = []
train_scores = []
train_topics = []
with open(f'{data_path}/train.jsonl', 'r', encoding='utf-8') as f:
for line in tqdm(f, desc="Loading training data"):
data = json.loads(line)
train_texts.append(data['text'])
train_scores.append(data.get('score', 0))
train_topics.append(data.get('topic', 'Unknown'))
# Load validation dataset
val_texts = []
val_scores = []
val_topics = []
with open(f'{data_path}/validation.jsonl', 'r', encoding='utf-8') as f:
for line in tqdm(f, desc="Loading validation data"):
data = json.loads(line)
val_texts.append(data['text'])
val_scores.append(data.get('score', 0))
val_topics.append(data.get('topic', 'Unknown'))
print(f"Loaded {len(train_texts)} training examples")
print(f"Loaded {len(val_texts)} validation examples")
# Log dataset statistics to WandB
if self.wandb_enabled:
# Calculate score statistics
train_score_stats = {
"train_examples": len(train_texts),
"train_avg_score": float(np.mean(train_scores)),
"train_min_score": float(np.min(train_scores)),
"train_max_score": float(np.max(train_scores)),
"train_std_score": float(np.std(train_scores))
}
val_score_stats = {
"val_examples": len(val_texts),
"val_avg_score": float(np.mean(val_scores)),
"val_min_score": float(np.min(val_scores)),
"val_max_score": float(np.max(val_scores)),
"val_std_score": float(np.std(val_scores))
}
wandb.config.update(train_score_stats)
wandb.config.update(val_score_stats)
# Log score distribution histogram
wandb.log({
"train_score_distribution": wandb.Histogram(train_scores),
"val_score_distribution": wandb.Histogram(val_scores)
})
# Log topic distribution
train_topic_counts = {}
for topic in train_topics:
train_topic_counts[topic] = train_topic_counts.get(topic, 0) + 1
# Create a bar chart for topics (top 20)
if len(train_topic_counts) > 0:
top_topics = sorted(train_topic_counts.items(), key=lambda x: x[1], reverse=True)[:20]
wandb.log({
"topic_distribution": wandb.plot.bar(
wandb.Table(data=[[k, v] for k, v in top_topics],
columns=["Topic", "Count"]),
"Topic", "Count", title="Training Topic Distribution (Top 20)"
)
})
# Tokenize datasets in batches (avoiding multiprocessing)
print("Tokenizing training dataset...")
train_encodings = self.tokenize_texts(train_texts)
print("Tokenizing validation dataset...")
val_encodings = self.tokenize_texts(val_texts)
# Create datasets
self.train_dataset = Dataset.from_dict(train_encodings)
self.val_dataset = Dataset.from_dict(val_encodings)
# Set format for PyTorch
self.train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
self.val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
# Clean up memory
del train_texts, val_texts, train_encodings, val_encodings
gc.collect()
def tokenize_texts(self, texts: List[str], batch_size: int = 100):
"""Tokenize texts in batches to avoid memory issues"""
all_input_ids = []
all_attention_masks = []
for i in tqdm(range(0, len(texts), batch_size), desc="Tokenizing"):
batch_texts = texts[i:i + batch_size]
# Tokenize batch
encodings = self.tokenizer(
batch_texts,
truncation=True,
padding='max_length',
max_length=512,
return_tensors='pt'
)
# Convert to lists
all_input_ids.extend(encodings['input_ids'].tolist())
all_attention_masks.extend(encodings['attention_mask'].tolist())
# Create labels (same as input_ids for language modeling)
labels = all_input_ids.copy()
return {
'input_ids': all_input_ids,
'attention_mask': all_attention_masks,
'labels': labels
}
def setup_training_args(self, output_dir: str = "./counselor_model_2b"):
"""Setup training arguments optimized for counseling task"""
print("Setting up training arguments...")
# Calculate batch sizes based on available memory
if torch.cuda.is_available():
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
if gpu_memory < 16: # Less than 16GB
batch_size = 1
gradient_accumulation = 16
elif gpu_memory < 24: # Less than 24GB
batch_size = 2
gradient_accumulation = 8
else: # 24GB or more
batch_size = 4
gradient_accumulation = 4
else:
batch_size = 1
gradient_accumulation = 16
print(f"Using batch_size={batch_size}, gradient_accumulation={gradient_accumulation}")
# Update WandB config with training hyperparameters
if self.wandb_enabled:
wandb.config.update({
"batch_size": batch_size,
"gradient_accumulation_steps": gradient_accumulation,
"effective_batch_size": batch_size * gradient_accumulation,
"num_epochs": 3,
"learning_rate": 5e-5,
"warmup_steps": 100,
"weight_decay": 0.01,
"max_grad_norm": 1.0,
"lr_scheduler": "linear",
"optimizer": "adamw_torch",
"fp16": True,
"max_length": 512
})
# Set report_to based on wandb availability
report_to = "wandb" if self.wandb_enabled else "none"
self.training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=3,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation,
gradient_checkpointing=True,
warmup_steps=100,
learning_rate=5e-5,
fp16=True,
logging_steps=50,
logging_first_step=True,
eval_strategy="steps",
eval_steps=200,
save_strategy="steps",
save_steps=400,
save_total_limit=2,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
report_to=report_to,
run_name=wandb.run.name if self.wandb_enabled and wandb.run else "local_run",
push_to_hub=False,
optim="adamw_torch",
lr_scheduler_type="linear",
weight_decay=0.01,
max_grad_norm=1.0,
remove_unused_columns=False,
label_names=["labels"],
dataloader_num_workers=0,
dataloader_pin_memory=False,
)
def train(self):
"""Execute training"""
print("Initializing trainer...")
# Data collator for language modeling
data_collator = DataCollatorForLanguageModeling(
tokenizer=self.tokenizer,
mlm=False,
pad_to_multiple_of=8
)
# Custom callback for additional metrics (properly inheriting from TrainerCallback)
class CustomMetricsCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
if logs and self.wandb_enabled:
# Add perplexity metrics
if "loss" in logs:
logs["perplexity"] = np.exp(logs["loss"])
if "eval_loss" in logs:
logs["eval_perplexity"] = np.exp(logs["eval_loss"])
return control
# Create callback instance with wandb_enabled flag
custom_callback = CustomMetricsCallback()
custom_callback.wandb_enabled = self.wandb_enabled
# Custom training to handle potential issues
try:
# Initialize trainer with callbacks
trainer = Trainer(
model=self.model,
args=self.training_args,
train_dataset=self.train_dataset,
eval_dataset=self.val_dataset,
data_collator=data_collator,
tokenizer=self.tokenizer,
callbacks=[custom_callback] if self.wandb_enabled else [],
)
# Calculate total training steps
total_steps = len(self.train_dataset) // (self.training_args.per_device_train_batch_size * self.training_args.gradient_accumulation_steps) * self.training_args.num_train_epochs
# Start training
print("="*50)
print("Starting fine-tuning...")
print(f"Total training samples: {len(self.train_dataset)}")
print(f"Total validation samples: {len(self.val_dataset)}")
print(f"Total training steps: {total_steps}")
print("="*50)
# Log training start
if self.wandb_enabled:
wandb.log({"training_status": "started", "total_steps": total_steps})
# Train with error handling
train_result = trainer.train()
# Save the final model
print("\nSaving fine-tuned model...")
trainer.save_model(f"{self.training_args.output_dir}/final_model_2b")
self.tokenizer.save_pretrained(f"{self.training_args.output_dir}/final_model_2b")
# Save training metrics
with open(f"{self.training_args.output_dir}/training_metrics.json", 'w') as f:
json.dump(train_result.metrics, f, indent=2)
# Final evaluation
print("\nRunning final evaluation...")
eval_results = trainer.evaluate()
# Save evaluation metrics
with open(f"{self.training_args.output_dir}/eval_metrics.json", 'w') as f:
json.dump(eval_results, f, indent=2)
# Log final metrics to WandB
if self.wandb_enabled:
# Log final metrics
wandb.run.summary.update({
"final_train_loss": train_result.metrics.get("train_loss", 0),
"final_eval_loss": eval_results.get("eval_loss", 0),
"final_eval_perplexity": np.exp(eval_results.get("eval_loss", 0)),
"total_training_time": train_result.metrics.get("train_runtime", 0),
"training_samples_per_second": train_result.metrics.get("train_samples_per_second", 0),
"training_status": "completed"
})
# Create a summary table
summary_table = wandb.Table(
columns=["Metric", "Value"],
data=[
["Final Training Loss", f"{train_result.metrics.get('train_loss', 0):.4f}"],
["Final Eval Loss", f"{eval_results.get('eval_loss', 0):.4f}"],
["Final Perplexity", f"{np.exp(eval_results.get('eval_loss', 0)):.2f}"],
["Training Time (seconds)", f"{train_result.metrics.get('train_runtime', 0):.0f}"],
["Training Samples/Second", f"{train_result.metrics.get('train_samples_per_second', 0):.2f}"]
]
)
wandb.log({"training_summary": summary_table})
# Save model artifact
try:
artifact = wandb.Artifact(
name=f"counselor-model-{wandb.run.id}",
type="model",
description="Fine-tuned Japanese counseling model",
metadata={
"base_model": self.model_name,
"final_loss": float(eval_results.get("eval_loss", 0)),
"final_perplexity": float(np.exp(eval_results.get("eval_loss", 0))),
"dataset": "KokoroChat"
}
)
artifact.add_dir(f"{self.training_args.output_dir}/final_model_2b")
wandb.log_artifact(artifact)
except Exception as e:
print(f"Warning: Could not save model artifact: {e}")
print("\n" + "="*50)
print("✅ Training completed successfully!")
print(f"📁 Model saved to: {self.training_args.output_dir}/final_model_2b")
print(f"📉 Final eval loss: {eval_results.get('eval_loss', 0):.4f}")
print(f"📊 Final perplexity: {np.exp(eval_results.get('eval_loss', 0)):.2f}")
if self.wandb_enabled and wandb.run:
print(f"🔗 View results at: {wandb.run.get_url()}")
print("="*50)
return trainer
except Exception as e:
print(f"❌ Error during training: {e}")
# Log error to WandB
if self.wandb_enabled:
wandb.run.summary["training_status"] = "failed"
wandb.run.summary["error"] = str(e)
print("Attempting to save checkpoint...")
# Try to save whatever we have
try:
self.model.save_pretrained(f"{self.training_args.output_dir}/checkpoint_emergency")
self.tokenizer.save_pretrained(f"{self.training_args.output_dir}/checkpoint_emergency")
print(f"💾 Emergency checkpoint saved to: {self.training_args.output_dir}/checkpoint_emergency")
except:
print("❌ Could not save emergency checkpoint")
raise e
finally:
# Ensure WandB run is finished
if self.wandb_enabled:
wandb.finish()
# def test_model(model_path: str, tokenizer_path: str):
# """Test the fine-tuned model with sample inputs"""
# print("\n" + "="*50)
# print("Testing fine-tuned model...")
# print("="*50)
# # Load model and tokenizer
# from peft import PeftModel, PeftConfig
# tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
# if tokenizer.pad_token is None:
# tokenizer.pad_token = tokenizer.eos_token
# # Try to load as PEFT model
# try:
# config = PeftConfig.from_pretrained(model_path)
# model = AutoModelForCausalLM.from_pretrained(
# config.base_model_name_or_path,
# torch_dtype=torch.float16,
# device_map="auto"
# )
# model = PeftModel.from_pretrained(model, model_path)
# except:
# # Load as regular model
# model = AutoModelForCausalLM.from_pretrained(
# model_path,
# torch_dtype=torch.float16,
# device_map="auto"
# )
# model.eval()
# # Test inputs
# test_cases = [
# "こんにちは。最近ストレスを感じています。",
# "仕事がうまくいかなくて悩んでいます。",
# "人間関係で困っています。どうすればいいでしょうか。"
# ]
# print("Sample conversations:")
# print("-" * 50)
def test_model(model_path: str, tokenizer_path: str):
"""Test the fine-tuned model with sample inputs"""
print("\n" + "="*50)
print("Testing fine-tuned model...")
print("="*50)
# Load model and tokenizer with proper local path handling
from peft import PeftModel, PeftConfig
import os
# Fix tokenizer loading for local paths
try:
# Check if tokenizer files exist in the path
if os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")):
print(f"Loading tokenizer from {tokenizer_path}")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
else:
print(f"Tokenizer not found at {tokenizer_path}, using base model tokenizer")
# Fallback to base model tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
except Exception as e:
print(f"Error loading tokenizer: {e}")
print("Using fallback GPT-2 tokenizer")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Try to load model
try:
# Check if it's a PEFT model
adapter_config_path = os.path.join(model_path, "adapter_config.json")
if os.path.exists(adapter_config_path):
print("Loading as PEFT model...")
config = PeftConfig.from_pretrained(model_path)
base_model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
model = PeftModel.from_pretrained(base_model, model_path)
else:
# Load as regular model
print("Loading as regular model...")
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="auto",
local_files_only=True,
trust_remote_code=True
)
except Exception as e:
print(f"Error loading model: {e}")
raise
model.eval()
# Test inputs
test_cases = [
"こんにちは。最近ストレスを感じています。",
"仕事がうまくいかなくて悩んでいます。",
"人間関係で困っています。どうすればいいでしょうか。"
]
print("Sample conversations:")
print("-" * 50)
for test_input in test_cases:
# Generate response
inputs = tokenizer(test_input, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.cuda() if torch.cuda.is_available() else v for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
temperature=0.1,
do_sample=True,
top_p=0.9,
pad_token_id=tokenizer.pad_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response[len(test_input):].strip() # Remove input from response
print(f"Client: {test_input}")
print(f"Counselor: {response[:200]}...")
print("-" * 50)
print("="*50)
for test_input in test_cases:
# Generate response
inputs = tokenizer(test_input, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.cuda() if torch.cuda.is_available() else v for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
temperature=0.1,
do_sample=True,
top_p=0.9,
pad_token_id=tokenizer.pad_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response[len(test_input):].strip() # Remove input from response
print(f"Client: {test_input}")
print(f"Counselor: {response[:200]}...")
print("-" * 50)
print("="*50)
# Main training script
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Fine-tune LFM model for counseling')
parser.add_argument('--model_name', type=str, default='LiquidAI/LFM2-2.6B',
help='Base model name')
parser.add_argument('--data_path', type=str, default='./processed_data_score80',
help='Path to processed data')
parser.add_argument('--output_dir', type=str, default='./counselor_model_2b',
help='Output directory for fine-tuned model')
parser.add_argument('--use_4bit', action='store_true', default=False,
help='Use 4-bit quantization')
parser.add_argument('--wandb_api_key', type=str, default=None,
help='WandB API key (optional, can use wandb login instead)')
parser.add_argument('--test_only', action='store_true',
help='Only test existing model')
args = parser.parse_args()
# Set WandB API key if provided
if args.wandb_api_key:
os.environ["WANDB_API_KEY"] = args.wandb_api_key
if args.test_only:
# Test existing model
test_model(
f"{args.output_dir}/final_model_2b",
f"{args.output_dir}/final_model_2b"
)
else:
# Check if CUDA is available
if not torch.cuda.is_available():
print("⚠️ Warning: CUDA is not available. Training will be very slow on CPU.")
print("It's highly recommended to use a GPU for training.")
response = input("Do you want to continue anyway? (y/n): ")
if response.lower() != 'y':
exit()
try:
# Clear GPU cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Initialize fine-tuner (WandB is enabled by default)
print(f"🚀 Initializing fine-tuner with model: {args.model_name}")
finetuner = LFMCounselorFineTuner(
model_name=args.model_name,
use_4bit=args.use_4bit
)
# Setup model
print("\n🔧 Setting up model and tokenizer...")
finetuner.setup_model_and_tokenizer()
# Load datasets
print("\n📚 Loading and processing datasets...")
finetuner.load_and_process_datasets(args.data_path)
# Setup training arguments
print("\n⚙️ Setting up training arguments...")
finetuner.setup_training_args(args.output_dir)
# Train
trainer = finetuner.train()
# Test the model
print("\n🧪 Testing the fine-tuned model...")
test_model(
f"{args.output_dir}/final_model_2b_v2",
f"{args.output_dir}/final_model_2b_v2"
)
print("\n✅ Fine-tuning completed successfully!")
print(f"📁 Model saved to: {args.output_dir}/final_model_2b_v2")
print("\n📋 Next steps:")
print("1. Test more: python finetune_lfm.py --test_only")
print("2. Run benchmarking: python benchmark_model.py")
print("3. Optimize for mobile: python optimize_for_mobile.py")
except KeyboardInterrupt:
print("\n\n⚠️ Training interrupted by user.")
print("Partial model may be saved in checkpoints.")
if wandb.run:
wandb.finish()
except Exception as e:
print(f"\n❌ Error during fine-tuning: {e}")
import traceback
traceback.print_exc()
if wandb.run:
wandb.finish()