lfm_complete_code / finetune_lfm_complete_history.py
Techiiot's picture
Upload folder using huggingface_hub
27c46c6 verified
"""
Fine-tuning Script for LFM2-2.6B with Complete Dialogue History
Following KokoroChat methodology - uses entire conversation context
Filename: finetune_lfm_complete_history.py
"""
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
BitsAndBytesConfig,
TrainerCallback
)
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training,
TaskType,
PeftModel,
PeftConfig
)
from datasets import load_dataset, Dataset
import os
from typing import Dict, List, Optional
import numpy as np
from tqdm import tqdm
import json
import gc
import warnings
import wandb
from datetime import datetime
warnings.filterwarnings('ignore')
# Enable TF32 for H100 optimization
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
class LFMKokoroChatFineTuner:
def __init__(
self,
model_name: str = "LiquidAI/LFM2-2.6B",
use_4bit: bool = False, # H100 has enough memory
max_seq_length: int = 2048 # Increased for complete dialogue history
):
"""
Initialize the fine-tuner for LFM models with complete dialogue history support
Args:
model_name: Name of the base model
use_4bit: Whether to use 4-bit quantization
max_seq_length: Maximum sequence length for complete dialogues
"""
self.model_name = model_name
self.use_4bit = use_4bit
self.max_seq_length = max_seq_length
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("="*80)
print("🚀 LFM Fine-tuning with Complete Dialogue History (KokoroChat Method)")
print("="*80)
print(f"Model: {model_name}")
print(f"Device: {self.device}")
print(f"Max sequence length: {max_seq_length}")
# GPU information
if torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
print(f"Number of GPUs: {num_gpus}")
for i in range(num_gpus):
gpu_name = torch.cuda.get_device_name(i)
gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1e9
print(f" GPU {i}: {gpu_name} ({gpu_memory:.2f} GB)")
# Initialize WandB
self.init_wandb()
def init_wandb(self):
"""Initialize WandB for experiment tracking"""
try:
run_name = f"lfm-kokoro-complete-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
wandb.init(
project="lfm-kokoro-complete-history",
name=run_name,
config={
"model_name": self.model_name,
"use_4bit_quantization": self.use_4bit,
"max_seq_length": self.max_seq_length,
"device": str(self.device),
"num_gpus": torch.cuda.device_count() if torch.cuda.is_available() else 0,
"methodology": "Complete dialogue history (KokoroChat)",
"framework": "transformers + peft",
"task": "japanese_counseling"
},
tags=["counseling", "japanese", "lfm", "complete-history", "kokoro"]
)
print(f"✅ WandB initialized: {wandb.run.name}")
print(f"📊 View run at: {wandb.run.get_url()}")
self.wandb_enabled = True
except Exception as e:
print(f"⚠️ WandB initialization failed: {e}")
self.wandb_enabled = False
os.environ["WANDB_DISABLED"] = "true"
def setup_model_and_tokenizer(self):
"""Setup model with quantization and LoRA"""
print("\n📚 Setting up model and tokenizer...")
# Load tokenizer
print("Loading tokenizer...")
try:
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True
)
except:
print("Using fallback tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Set special tokens
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
if self.tokenizer.eos_token is None:
self.tokenizer.eos_token = "</s>"
self.tokenizer.pad_token = "</s>"
self.tokenizer.padding_side = "left" # Important for batch generation
# Quantization config
if self.use_4bit:
print("Setting up 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16, # BF16 for H100
bnb_4bit_use_double_quant=True
)
else:
bnb_config = None
# Load model
print(f"Loading model: {self.model_name}...")
model_kwargs = {
"trust_remote_code": True,
"torch_dtype": torch.bfloat16, # BF16 for H100
"device_map": "auto",
}
if bnb_config:
model_kwargs["quantization_config"] = bnb_config
try:
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
**model_kwargs
)
except Exception as e:
print(f"Error loading model: {e}")
print("Attempting without device_map...")
model_kwargs.pop("device_map", None)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
**model_kwargs
)
self.model = self.model.to(self.device)
# Enable gradient checkpointing
if hasattr(self.model, 'gradient_checkpointing_enable'):
self.model.gradient_checkpointing_enable()
# Prepare for k-bit training if using quantization
if self.use_4bit:
print("Preparing model for 4-bit training...")
self.model = prepare_model_for_kbit_training(self.model)
# LoRA configuration optimized for dialogue with complete history
print("Applying LoRA configuration...")
# Find target modules
target_modules = self.find_target_modules()
# Higher rank for complex dialogue understanding
lora_config = LoraConfig(
r=64, # Increased for better dialogue understanding
lora_alpha=128,
target_modules=target_modules,
lora_dropout=0.05,
bias="none",
task_type=TaskType.CAUSAL_LM,
inference_mode=False
)
# Apply LoRA
self.model = get_peft_model(self.model, lora_config)
# Print trainable parameters
trainable_params = 0
all_params = 0
for _, param in self.model.named_parameters():
all_params += param.numel()
if param.requires_grad:
trainable_params += param.numel()
trainable_percentage = 100 * trainable_params / all_params if all_params > 0 else 0
print(f"Trainable parameters: {trainable_params:,} / {all_params:,} ({trainable_percentage:.2f}%)")
# Log to WandB
if self.wandb_enabled:
wandb.config.update({
"lora_r": lora_config.r,
"lora_alpha": lora_config.lora_alpha,
"lora_dropout": lora_config.lora_dropout,
"lora_target_modules": target_modules,
"total_parameters": all_params,
"trainable_parameters": trainable_params,
"trainable_percentage": trainable_percentage
})
self.model.print_trainable_parameters()
def find_target_modules(self):
"""Find linear modules to apply LoRA to"""
target_modules = []
for name, module in self.model.named_modules():
if isinstance(module, torch.nn.Linear):
names = name.split('.')
if len(names) > 0:
target_modules.append(names[-1])
# Remove duplicates
target_modules = list(set(target_modules))
# Common patterns for transformer models
common_targets = ["q_proj", "v_proj", "k_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
"fc1", "fc2", "query", "key", "value", "dense"]
# Filter to common targets
final_targets = [t for t in target_modules if any(ct in t.lower() for ct in common_targets)]
if not final_targets:
# Fallback to specific modules for LFM
final_targets = ["q_proj", "v_proj", "k_proj", "o_proj"]
print(f"LoRA target modules: {final_targets}")
return final_targets
def load_and_process_datasets(self, data_path: str):
"""
Load and process datasets with complete dialogue history
Handles the new data format with full conversation context
"""
print(f"\n📚 Loading datasets from {data_path}...")
# Check for dataset statistics
stats_file = os.path.join(data_path, 'dataset_stats.json')
if os.path.exists(stats_file):
with open(stats_file, 'r') as f:
stats = json.load(f)
print("Dataset statistics:")
print(f" Average dialogue history: {stats['dialogue_history_stats']['mean_length']:.1f} turns")
print(f" Max dialogue history: {stats['dialogue_history_stats']['max_length']} turns")
print(f" Median dialogue history: {stats['dialogue_history_stats']['median_length']:.1f} turns")
# Load datasets
train_data = []
val_data = []
# Load training data
train_file = os.path.join(data_path, 'train.jsonl')
with open(train_file, 'r', encoding='utf-8') as f:
for line in tqdm(f, desc="Loading training data"):
item = json.loads(line)
train_data.append({
'text': item['text'],
'history_length': item.get('history_length', 0),
'score': item.get('score', 100),
'topic': item.get('topic', 'general')
})
# Load validation data
val_file = os.path.join(data_path, 'val.jsonl')
with open(val_file, 'r', encoding='utf-8') as f:
for line in tqdm(f, desc="Loading validation data"):
item = json.loads(line)
val_data.append({
'text': item['text'],
'history_length': item.get('history_length', 0),
'score': item.get('score', 100),
'topic': item.get('topic', 'general')
})
print(f"Loaded {len(train_data)} training examples")
print(f"Loaded {len(val_data)} validation examples")
# Analyze dialogue history lengths
train_history_lengths = [d['history_length'] for d in train_data]
val_history_lengths = [d['history_length'] for d in val_data]
print(f"\nDialogue history length distribution:")
print(f" Training - Mean: {np.mean(train_history_lengths):.1f}, Max: {max(train_history_lengths)}")
print(f" Validation - Mean: {np.mean(val_history_lengths):.1f}, Max: {max(val_history_lengths)}")
# Log to WandB
if self.wandb_enabled:
wandb.config.update({
"train_examples": len(train_data),
"val_examples": len(val_data),
"avg_train_history_length": float(np.mean(train_history_lengths)),
"max_train_history_length": int(max(train_history_lengths)),
"avg_val_history_length": float(np.mean(val_history_lengths)),
"max_val_history_length": int(max(val_history_lengths))
})
# Log history length distribution
wandb.log({
"train_history_distribution": wandb.Histogram(train_history_lengths),
"val_history_distribution": wandb.Histogram(val_history_lengths)
})
# Tokenize datasets
print("\nTokenizing datasets with complete dialogue history...")
print(f"Using max sequence length: {self.max_seq_length}")
# Extract texts for tokenization
train_texts = [d['text'] for d in train_data]
val_texts = [d['text'] for d in val_data]
# Tokenize with longer context for complete history
train_encodings = self.tokenize_texts(train_texts, desc="Tokenizing training data")
val_encodings = self.tokenize_texts(val_texts, desc="Tokenizing validation data")
# Create datasets
self.train_dataset = Dataset.from_dict(train_encodings)
self.val_dataset = Dataset.from_dict(val_encodings)
# Set format for PyTorch
self.train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
self.val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
# Clean up memory
del train_texts, val_texts, train_encodings, val_encodings, train_data, val_data
gc.collect()
print("✅ Datasets loaded and tokenized")
def tokenize_texts(self, texts: List[str], batch_size: int = 50, desc: str = "Tokenizing"):
"""
Tokenize texts in batches with support for longer sequences
"""
all_input_ids = []
all_attention_masks = []
# Process in smaller batches for long sequences
for i in tqdm(range(0, len(texts), batch_size), desc=desc):
batch_texts = texts[i:i + batch_size]
# Tokenize batch with longer max length
encodings = self.tokenizer(
batch_texts,
truncation=True,
padding='max_length',
max_length=self.max_seq_length,
return_tensors='pt'
)
# Convert to lists
all_input_ids.extend(encodings['input_ids'].tolist())
all_attention_masks.extend(encodings['attention_mask'].tolist())
# Create labels (same as input_ids for causal LM)
labels = all_input_ids.copy()
return {
'input_ids': all_input_ids,
'attention_mask': all_attention_masks,
'labels': labels
}
def setup_training_args(self, output_dir: str = "./lfm_kokoro_complete"):
"""Setup training arguments optimized for complete dialogue history"""
print("\n⚙️ Setting up training arguments...")
# Calculate batch sizes based on sequence length and GPU memory
if torch.cuda.is_available():
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
num_gpus = torch.cuda.device_count()
# Adjust batch size based on sequence length and GPU memory
if self.max_seq_length >= 2048:
if gpu_memory >= 80: # H100 80GB
batch_size = 4
gradient_accumulation = 4
elif gpu_memory >= 40:
batch_size = 2
gradient_accumulation = 8
else:
batch_size = 1
gradient_accumulation = 16
else:
batch_size = 8
gradient_accumulation = 2
# Adjust for multiple GPUs
if num_gpus > 1:
batch_size = batch_size * num_gpus
gradient_accumulation = max(1, gradient_accumulation // num_gpus)
else:
batch_size = 1
gradient_accumulation = 32
print(f"Batch configuration:")
print(f" Per device batch size: {batch_size}")
print(f" Gradient accumulation steps: {gradient_accumulation}")
print(f" Effective batch size: {batch_size * gradient_accumulation}")
# Update WandB config
if self.wandb_enabled:
wandb.config.update({
"batch_size": batch_size,
"gradient_accumulation_steps": gradient_accumulation,
"effective_batch_size": batch_size * gradient_accumulation,
"num_epochs": 3,
"learning_rate": 2e-4,
"warmup_ratio": 0.1,
"weight_decay": 0.01,
"max_grad_norm": 1.0,
"lr_scheduler": "cosine",
"optimizer": "adamw_torch"
})
self.training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=3,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation,
gradient_checkpointing=True,
warmup_ratio=0.1,
learning_rate=2e-4,
bf16=True, # Use BF16 for H100
tf32=True, # Enable TF32 for H100
logging_steps=10,
logging_first_step=True,
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=200,
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
report_to="wandb" if self.wandb_enabled else "none",
run_name=wandb.run.name if self.wandb_enabled and wandb.run else "local_run",
optim="adamw_torch",
lr_scheduler_type="cosine",
weight_decay=0.01,
max_grad_norm=1.0,
remove_unused_columns=False,
label_names=["labels"],
dataloader_num_workers=4,
dataloader_pin_memory=True,
ddp_find_unused_parameters=False if torch.cuda.device_count() > 1 else None,
)
def train(self):
"""Execute training with complete dialogue history"""
print("\n🎯 Starting training with complete dialogue history...")
# Data collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=self.tokenizer,
mlm=False,
pad_to_multiple_of=8
)
# Custom callback for metrics
class MetricsCallback(TrainerCallback):
def __init__(self, wandb_enabled):
self.wandb_enabled = wandb_enabled
def on_log(self, args, state, control, logs=None, **kwargs):
if logs and self.wandb_enabled:
# Add perplexity
if "loss" in logs:
logs["perplexity"] = np.exp(logs["loss"])
if "eval_loss" in logs:
logs["eval_perplexity"] = np.exp(logs["eval_loss"])
# Log to WandB
wandb.log(logs, step=state.global_step)
return control
# Initialize trainer
trainer = Trainer(
model=self.model,
args=self.training_args,
train_dataset=self.train_dataset,
eval_dataset=self.val_dataset,
data_collator=data_collator,
tokenizer=self.tokenizer,
callbacks=[MetricsCallback(self.wandb_enabled)] if self.wandb_enabled else [],
)
# Calculate total steps
total_steps = len(self.train_dataset) // (
self.training_args.per_device_train_batch_size *
self.training_args.gradient_accumulation_steps
) * self.training_args.num_train_epochs
print("="*60)
print("Training Information:")
print(f" Total training samples: {len(self.train_dataset)}")
print(f" Total validation samples: {len(self.val_dataset)}")
print(f" Total training steps: {total_steps}")
print(f" Max sequence length: {self.max_seq_length}")
print("="*60)
# Log training start
if self.wandb_enabled:
wandb.log({
"training_status": "started",
"total_steps": total_steps,
"max_seq_length": self.max_seq_length
})
try:
# Train
print("\n🚀 Training started...")
train_result = trainer.train()
# Save model
print("\n💾 Saving fine-tuned model...")
final_model_path = os.path.join(self.training_args.output_dir, "final_model")
trainer.save_model(final_model_path)
self.tokenizer.save_pretrained(final_model_path)
# Save training metrics
with open(os.path.join(self.training_args.output_dir, "training_metrics.json"), 'w') as f:
json.dump(train_result.metrics, f, indent=2)
# Final evaluation
print("\n📊 Running final evaluation...")
eval_results = trainer.evaluate()
# Save evaluation metrics
with open(os.path.join(self.training_args.output_dir, "eval_metrics.json"), 'w') as f:
json.dump(eval_results, f, indent=2)
# Log final metrics
if self.wandb_enabled:
wandb.run.summary.update({
"final_train_loss": train_result.metrics.get("train_loss", 0),
"final_eval_loss": eval_results.get("eval_loss", 0),
"final_eval_perplexity": np.exp(eval_results.get("eval_loss", 0)),
"total_training_time": train_result.metrics.get("train_runtime", 0),
"training_samples_per_second": train_result.metrics.get("train_samples_per_second", 0),
"training_status": "completed"
})
# Save model artifact
artifact = wandb.Artifact(
name=f"kokoro-model-complete-{wandb.run.id}",
type="model",
description="LFM model fine-tuned with complete dialogue history",
metadata={
"base_model": self.model_name,
"final_loss": float(eval_results.get("eval_loss", 0)),
"final_perplexity": float(np.exp(eval_results.get("eval_loss", 0))),
"max_seq_length": self.max_seq_length,
"methodology": "Complete dialogue history (KokoroChat)"
}
)
artifact.add_dir(final_model_path)
wandb.log_artifact(artifact)
print("\n" + "="*60)
print("✅ Training completed successfully!")
print(f"📁 Model saved to: {final_model_path}")
print(f"📉 Final eval loss: {eval_results.get('eval_loss', 0):.4f}")
print(f"📊 Final perplexity: {np.exp(eval_results.get('eval_loss', 0)):.2f}")
if self.wandb_enabled and wandb.run:
print(f"🔗 View results at: {wandb.run.get_url()}")
print("="*60)
return trainer
except Exception as e:
print(f"❌ Error during training: {e}")
if self.wandb_enabled:
wandb.run.summary["training_status"] = "failed"
wandb.run.summary["error"] = str(e)
# Save emergency checkpoint
try:
emergency_path = os.path.join(self.training_args.output_dir, "emergency_checkpoint")
self.model.save_pretrained(emergency_path)
self.tokenizer.save_pretrained(emergency_path)
print(f"💾 Emergency checkpoint saved to: {emergency_path}")
except:
print("❌ Could not save emergency checkpoint")
raise e
finally:
if self.wandb_enabled:
wandb.finish()
def test_model_with_complete_history(model_path: str):
"""Test the fine-tuned model with complete dialogue history examples"""
print("\n" + "="*60)
print("🧪 Testing model with complete dialogue history")
print("="*60)
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
# Check if it's a PEFT model
adapter_config_path = os.path.join(model_path, "adapter_config.json")
if os.path.exists(adapter_config_path):
print("Loading as PEFT model...")
config = PeftConfig.from_pretrained(model_path)
base_model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
model = PeftModel.from_pretrained(base_model, model_path)
else:
print("Loading as regular model...")
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
local_files_only=True,
trust_remote_code=True
)
model.eval()
# Test with dialogue history examples
test_cases = [
{
"history": "クライアント: こんにちは。最近ストレスを感じています。\nカウンセラー: こんにちは。ストレスを感じていらっしゃるのですね。どのような状況でストレスを感じることが多いですか?\n",
"current": "クライアント: 仕事が忙しくて、休む時間がありません。"
},
{
"history": "",
"current": "クライアント: 人間関係で悩んでいます。"
}
]
print("Testing with complete dialogue history:\n")
for i, test_case in enumerate(test_cases, 1):
print(f"Test Case {i}:")
print("-" * 40)
# Format input with complete history
if test_case["history"]:
prompt = f"""### Instruction:
あなたは専門的な訓練を受けた心理カウンセラーです。
以下の完全な対話履歴を踏まえて、カウンセラーとして適切な応答を生成してください。
### Dialogue History:
{test_case["history"]}{test_case["current"]}
### Response:
"""
else:
prompt = f"""### Instruction:
あなたは専門的な訓練を受けた心理カウンセラーです。
### Dialogue History:
{test_case["current"]}
### Response:
"""
# Generate response
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
inputs = {k: v.cuda() if torch.cuda.is_available() else v for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
temperature=0,
do_sample=True,
top_p=0.9,
pad_token_id=tokenizer.pad_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("### Response:")[-1].strip() if "### Response:" in response else response
# print(f"History Length: {len(test_case['history'].split('\\n')) if test_case['history'] else 0} turns")
print("History Length: {} turns".format(len(test_case['history'].split('\\n')) if test_case['history'] else 0))
print(f"Current Input: {test_case['current']}")
print(f"Generated Response: {response[:300]}...")
print()
print("="*60)
# Main execution
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Fine-tune LFM model with complete dialogue history')
parser.add_argument('--model_name', type=str, default='LiquidAI/LFM2-2.6B',
help='Base model name')
parser.add_argument('--data_path', type=str, default='./kokoro_processed_data',
help='Path to processed data with complete dialogue history')
parser.add_argument('--output_dir', type=str, default='./lfm_kokoro_complete',
help='Output directory for fine-tuned model')
parser.add_argument('--max_seq_length', type=int, default=2048,
help='Maximum sequence length for complete dialogues')
parser.add_argument('--use_4bit', action='store_true',
help='Use 4-bit quantization')
parser.add_argument('--test_only', action='store_true',
help='Only test existing model')
args = parser.parse_args()
if args.test_only:
# Test existing model
test_model_with_complete_history(
os.path.join(args.output_dir, "final_model")
)
else:
# Check CUDA availability
if not torch.cuda.is_available():
print("⚠️ Warning: CUDA is not available. Training will be slow.")
response = input("Continue? (y/n): ")
if response.lower() != 'y':
exit()
try:
# Clear GPU cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Initialize fine-tuner
print(f"🚀 Initializing fine-tuner for complete dialogue history")
finetuner = LFMKokoroChatFineTuner(
model_name=args.model_name,
use_4bit=args.use_4bit,
max_seq_length=args.max_seq_length
)
# Setup model
finetuner.setup_model_and_tokenizer()
# Load datasets
finetuner.load_and_process_datasets(args.data_path)
# Setup training arguments
finetuner.setup_training_args(args.output_dir)
# Train
trainer = finetuner.train()
# Test the model
print("\n🧪 Testing the fine-tuned model...")
test_model_with_complete_history(
os.path.join(args.output_dir, "final_model")
)
print("\n✅ Fine-tuning with complete dialogue history completed!")
print(f"📁 Model saved to: {args.output_dir}/final_model")
print("\n📋 Next steps:")
print(f"1. Test more: python {__file__} --test_only --output_dir {args.output_dir}")
print("2. Run benchmarking with complete history support")
print("3. Deploy for production use")
except KeyboardInterrupt:
print("\n\n⚠️ Training interrupted by user.")
if wandb.run:
wandb.finish()
except Exception as e:
print(f"\n❌ Error: {e}")
import traceback
traceback.print_exc()
if wandb.run:
wandb.finish()