codellama-fine-tuning / scripts /training /finetune_mistral7b.py
Prithvik-1's picture
Upload scripts/training/finetune_mistral7b.py with huggingface_hub
8514fc9 verified
#!/usr/bin/env python3
"""
Fine-tuning script for Mistral models (7B, 3B, etc.) using LoRA (Low-Rank Adaptation)
This script uses Hugging Face Transformers, PEFT, and BitsAndBytes for efficient training.
"""
import os
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
BitsAndBytesConfig,
Trainer,
DataCollatorForLanguageModeling
)
from peft import (
LoraConfig,
PeftModel,
get_peft_model,
prepare_model_for_kbit_training,
TaskType,
)
import json
def get_device_info():
"""Detect and return available compute device"""
device_info = {
"device": "cpu",
"device_type": "cpu",
"use_quantization": False,
"dtype": torch.float32
}
if torch.cuda.is_available():
device_info["device"] = "cuda"
device_info["device_type"] = "cuda"
device_info["use_quantization"] = True
device_info["dtype"] = torch.float16
device_info["device_count"] = torch.cuda.device_count()
device_info["device_name"] = torch.cuda.get_device_name(0)
print(f"✓ CUDA GPU detected: {device_info['device_name']} (Count: {device_info['device_count']})")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device_info["device"] = "mps"
device_info["device_type"] = "mps"
device_info["use_quantization"] = False # BitsAndBytes doesn't support MPS
device_info["dtype"] = torch.float16
print("✓ Apple Silicon GPU (MPS) detected")
else:
print("⚠ No GPU detected, using CPU (training will be very slow)")
device_info["dtype"] = torch.float32
return device_info
# Defaults
DEFAULT_BASE_MODEL = "mistralai/Mistral-7B-v0.1"
DEFAULT_OUTPUT_DIR = "./mistral-finetuned"
DEFAULT_DATASET_PATH = "./training_data.jsonl" # Path to your training data
# LoRA Configuration - Updated with increased dropout for regularization
LORA_CONFIG = LoraConfig(
r=16, # Rank
lora_alpha=32, # LoRA alpha scaling parameter
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_dropout=0.1, # Increased from 0.05 to 0.1 for better regularization
bias="none",
task_type=TaskType.CAUSAL_LM,
)
# BitsAndBytes Configuration for 4-bit quantization (CUDA only)
def get_bitsandbytes_config():
"""Get BitsAndBytes config if CUDA is available, otherwise None"""
if torch.cuda.is_available():
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
return None
def load_and_prepare_model(model_name: str, adapter_path: str | None = None):
"""Load the specified Mistral model, optionally warm-starting from an existing LoRA adapter."""
device_info = get_device_info()
print(f"\nLoading model: {model_name}")
tokenizer_source = adapter_path if adapter_path and os.path.isdir(adapter_path) else model_name
tokenizer = AutoTokenizer.from_pretrained(tokenizer_source)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
# Get quantization config (CUDA only)
bnb_config = get_bitsandbytes_config()
# Prepare model loading kwargs
model_kwargs = {
"trust_remote_code": True,
}
if bnb_config is not None:
# Use 4-bit quantization on CUDA
print("Using 4-bit quantization (CUDA)")
model_kwargs["quantization_config"] = bnb_config
model_kwargs["device_map"] = "auto"
elif device_info["device_type"] == "mps":
# Use MPS with float16
print(f"Using MPS device with {device_info['dtype']}")
model_kwargs["torch_dtype"] = device_info["dtype"]
model_kwargs["device_map"] = "auto"
else:
# CPU fallback
print("Using CPU (no quantization)")
model_kwargs["torch_dtype"] = torch.float32
model_kwargs["device_map"] = "cpu"
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
# Prepare model for k-bit training (only if using quantization)
if bnb_config is not None:
base_model = prepare_model_for_kbit_training(base_model)
if adapter_path:
print(f"Loading existing LoRA adapter from: {adapter_path}")
model = PeftModel.from_pretrained(base_model, adapter_path, is_trainable=True)
else:
model = get_peft_model(base_model, LORA_CONFIG)
# Enable gradient checkpointing to save memory
model.gradient_checkpointing_enable()
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model loaded successfully!")
print(f" - Device: {device_info['device']}")
print(f" - Trainable parameters: {trainable_params:,}")
print(f" - Total parameters: {total_params:,}")
print(f" - Trainable ratio: {100 * trainable_params / total_params:.2f}%\n")
return model, tokenizer, device_info
def load_training_data(file_path):
"""Load training data from JSONL file"""
print(f"Loading training data from {file_path}")
if not os.path.exists(file_path):
print(f"Warning: {file_path} not found. Creating a sample dataset...")
# Create a sample dataset for demonstration
sample_data = [
{"instruction": "What is AI?", "response": "AI (Artificial Intelligence) is the simulation of human intelligence by machines."},
{"instruction": "Explain machine learning", "response": "Machine learning is a subset of AI that enables systems to learn from data."},
]
with open(file_path, 'w') as f:
for item in sample_data:
f.write(json.dumps(item) + '\n')
print(f"Sample dataset created at {file_path}")
data = []
with open(file_path, 'r') as f:
for line in f:
data.append(json.loads(line))
return data
def clean_completion(completion):
"""Remove format markers from completion"""
if not completion:
return completion
# Remove format markers if present
if "### Strict JSON ###" in completion:
completion = completion.split("### Strict JSON ###")[1]
if "### End ###" in completion:
completion = completion.split("### End ###")[0]
return completion.strip()
def format_prompt(instruction, response=None):
"""Format training examples as prompts"""
# Clean response to remove format markers
if response:
response = clean_completion(response)
prompt = f"### Instruction:\n{instruction}\n\n### Response:\n"
if response:
prompt += f"{response}"
return prompt
def tokenize_function(examples, tokenizer, max_length=512):
"""Tokenize the training examples"""
texts = [format_prompt(inst, resp) for inst, resp in zip(examples["instruction"], examples["response"])]
tokenized = tokenizer(
texts,
truncation=True,
padding="max_length",
max_length=max_length,
return_tensors="pt"
)
tokenized["labels"] = tokenized["input_ids"].clone()
return tokenized
def main():
import argparse
parser = argparse.ArgumentParser(description="Fine-tune Mistral models with LoRA")
parser.add_argument("--base-model", default=DEFAULT_BASE_MODEL, help="HF model id (e.g. mistralai/Mistral-7B-v0.1 or mistralai/Mistral-3B-v0.1)")
parser.add_argument("--adapter-path", default=None, help="Optional path to existing LoRA adapters to continue training")
parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR, help="Where to write the fine-tuned adapters")
parser.add_argument("--dataset", default=DEFAULT_DATASET_PATH, help="Path to training data JSONL")
parser.add_argument("--max-length", type=int, default=512, help="Max sequence length for tokenization")
args = parser.parse_args()
print("Starting Mistral Fine-tuning with LoRA")
print("=" * 50)
print(f"Base model: {args.base_model}")
print(f"Training data: {args.dataset}")
print(f"Output dir: {args.output_dir}\n")
# Load model and tokenizer
model, tokenizer, device_info = load_and_prepare_model(args.base_model, args.adapter_path)
# Load training data
training_data = load_training_data(args.dataset)
# Convert to dataset format
instructions = []
responses = []
for item in training_data:
if "instruction" in item:
instructions.append(item["instruction"])
responses.append(item.get("response", ""))
elif "prompt" in item and "completion" in item:
instructions.append(item["prompt"])
completion_value = item["completion"]
if isinstance(completion_value, (dict, list)):
responses.append(json.dumps(completion_value))
else:
responses.append(str(completion_value))
elif "messages" in item:
messages = item["messages"]
if not isinstance(messages, list) or len(messages) == 0:
raise KeyError("'messages' entries must be non-empty lists.")
prompt_parts = []
assistant_reply = None
for idx, message in enumerate(messages):
role = message.get("role", "user")
content = str(message.get("content", "")).strip()
if idx == len(messages) - 1 and role == "assistant":
assistant_reply = content
else:
role_label = role.upper()
prompt_parts.append(f"{role_label}: {content}")
if assistant_reply is None:
assistant_reply = str(messages[-1].get("content", "")).strip()
prompt_text = "\n\n".join(part for part in prompt_parts if part)
instructions.append(prompt_text)
responses.append(assistant_reply)
else:
raise KeyError("Each training example must include either 'instruction'/'response', 'prompt'/'completion', or 'messages'.")
# Create a simple dataset dict
from datasets import Dataset
dataset = Dataset.from_dict({
"instruction": instructions,
"response": responses
})
# Tokenize dataset
print("Tokenizing dataset...")
tokenized_dataset = dataset.map(
lambda x: tokenize_function(x, tokenizer, max_length=args.max_length),
batched=True,
remove_columns=dataset.column_names
)
# Split dataset into train/validation (80/20)
print("Splitting dataset into train/validation (80/20)...")
train_val_split = tokenized_dataset.train_test_split(test_size=0.2, seed=42)
train_dataset = train_val_split["train"]
val_dataset = train_val_split["test"]
print(f" - Training samples: {len(train_dataset)}")
print(f" - Validation samples: {len(val_dataset)}")
# Training arguments - adjust based on device
use_fp16 = device_info["device_type"] in ["cuda", "mps"]
# Calculate total steps and appropriate warmup
effective_batch_size = (2 if device_info["device_type"] != "cpu" else 1) * 4 # batch_size * gradient_accumulation
total_steps = (len(train_dataset) // effective_batch_size) * 3 # 3 epochs
warmup_steps = max(10, int(0.1 * total_steps)) # 10% warmup, minimum 10 steps
print(f"\nTraining Configuration:")
print(f" - Total training steps: {total_steps}")
print(f" - Warmup steps: {warmup_steps} ({100*warmup_steps/total_steps:.1f}% of training)")
training_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=3,
per_device_train_batch_size=2 if device_info["device_type"] != "cpu" else 1,
gradient_accumulation_steps=4,
warmup_steps=warmup_steps, # Dynamic warmup (10% of total steps)
learning_rate=5e-5, # Reduced from 2e-4 to prevent overfitting
weight_decay=0.01, # Added L2 regularization
fp16=use_fp16, # Only enable on GPU (CUDA/MPS)
bf16=False, # Can enable for newer CUDA GPUs if needed
logging_steps=10,
save_steps=50, # Save more frequently
eval_strategy="steps", # Enable evaluation
eval_steps=50, # Evaluate every 50 steps
save_total_limit=3,
load_best_model_at_end=True, # Load best checkpoint based on validation loss
metric_for_best_model="eval_loss",
greater_is_better=False,
lr_scheduler_type="cosine", # Cosine learning rate decay
max_grad_norm=1.0, # Gradient clipping
report_to="none",
push_to_hub=False,
dataloader_pin_memory=device_info["device_type"] == "cuda", # Only pin memory for CUDA
remove_unused_columns=False,
)
print(f"Training Configuration:")
print(f" - Device: {device_info['device']}")
print(f" - Mixed precision (fp16): {use_fp16}")
print(f" - Batch size: {training_args.per_device_train_batch_size}")
print(f" - Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f" - Learning rate: {training_args.learning_rate}")
print(f" - Weight decay: {training_args.weight_decay}")
print(f" - LR scheduler: {training_args.lr_scheduler_type}")
print(f" - Max grad norm: {training_args.max_grad_norm}")
print("=" * 50)
# Data collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
)
# Add early stopping callback
from transformers import EarlyStoppingCallback
# Create trainer with validation set and early stopping
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset, # Add validation set
data_collator=data_collator,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)], # Stop if no improvement for 3 evals
)
# Train
print("\nStarting training...")
trainer.train()
# Save model
print(f"\nSaving fine-tuned model to {args.output_dir}")
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
# Save LoRA adapters separately
model.save_pretrained(args.output_dir)
print("\nFine-tuning complete!")
print(f"Model saved to: {args.output_dir}")
print(f"To load for inference, use the inference script with: {args.output_dir}")
if __name__ == "__main__":
main()