AgriQA-Assistant / src /training /finetune.py
Nada
intial commit
bf5488d
import os
import sys
import yaml
import argparse
import logging
from typing import Dict, Any
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
EarlyStoppingCallback,
BitsAndBytesConfig
)
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training,
TaskType
)
from datasets import Dataset
from tqdm import tqdm
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class AgriQAFineTuner:
def __init__(self, config_path: str):
self.config = self.load_config(config_path) # load the config file
self.setup_environment()
def load_config(self, config_path: str) -> Dict[str, Any]:
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
return config
def setup_environment(self) -> None:
# Set device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
# Create output directory
os.makedirs(self.config['training']['output_dir'], exist_ok=True)
def load_model_and_tokenizer(self):
logger.info(f"Loading model: {self.config['model']['base_model']}")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.config['model']['base_model'],
trust_remote_code=self.config['model']['trust_remote_code']
)
# Add padding token if not present
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load model with quantization if specified
if self.config['hardware']['use_4bit']:
logger.info("Loading model with 4-bit quantization")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type=self.config['hardware']['bnb_4bit_quant_type'],
bnb_4bit_use_double_quant=self.config['hardware']['bnb_4bit_use_double_quant'],
)
self.model = AutoModelForCausalLM.from_pretrained(
self.config['model']['base_model'],
quantization_config=quantization_config,
device_map=self.config['hardware']['device_map'],
trust_remote_code=self.config['model']['trust_remote_code']
)
else:
self.model = AutoModelForCausalLM.from_pretrained(
self.config['model']['base_model'],
device_map=self.config['hardware']['device_map'],
trust_remote_code=self.config['model']['trust_remote_code']
)
# Prepare model for k-bit training
if self.config['hardware']['use_4bit']:
self.model = prepare_model_for_kbit_training(self.model)
logger.info("Model and tokenizer loaded successfully")
def setup_lora(self):
# Apply LoRA configuration
logger.info("Setting up LoRA configuration")
lora_config = LoraConfig(
r=self.config['lora']['r'],
lora_alpha=self.config['lora']['lora_alpha'],
target_modules=self.config['lora']['target_modules'],
lora_dropout=self.config['lora']['lora_dropout'],
bias=self.config['lora']['bias'],
task_type=self.config['lora']['task_type'],
)
# Enable gradient checkpointing for memory optimization
if self.config['training']['gradient_checkpointing']:
self.model.gradient_checkpointing_enable()
logger.info("Gradient checkpointing enabled for memory optimization")
# Apply LoRA
self.model = get_peft_model(self.model, lora_config)
self.model.print_trainable_parameters()
logger.info("LoRA configuration applied successfully")
def load_dataset(self):
"""Load the tokenized datasets."""
logger.info("Loading dataset")
# Load pre-tokenized datasets
logger.info("Loading pre-tokenized datasets...")
train_dataset = Dataset.load_from_disk(os.path.join(self.config['data']['tokenized_dir'], "train"))
val_dataset = Dataset.load_from_disk(os.path.join(self.config['data']['tokenized_dir'], "validation"))
# Limit samples if specified
max_samples = self.config['data'].get('max_samples', None)
if max_samples:
logger.info(f"Limiting training samples to {max_samples}")
train_dataset = train_dataset.select(range(min(max_samples, len(train_dataset))))
val_dataset = val_dataset.select(range(min(max_samples // 10, len(val_dataset)))) # 10% for validation
logger.info(f"Loaded tokenized training samples: {len(train_dataset)}")
logger.info(f"Loaded tokenized validation samples: {len(val_dataset)}")
return train_dataset, val_dataset
def setup_training(self, train_dataset, val_dataset):
logger.info("Setting up training configuration")
# Convert numeric values from config
def convert_numeric(value):
if isinstance(value, str):
try:
return float(value)
except ValueError:
return value
return value
# Training arguments with memory optimizations
training_args = TrainingArguments(
output_dir=self.config['training']['output_dir'],
num_train_epochs=convert_numeric(self.config['training']['num_train_epochs']),
per_device_train_batch_size=convert_numeric(self.config['training']['per_device_train_batch_size']),
per_device_eval_batch_size=convert_numeric(self.config['training']['per_device_eval_batch_size']),
gradient_accumulation_steps=convert_numeric(self.config['training']['gradient_accumulation_steps']),
learning_rate=convert_numeric(self.config['training']['learning_rate']),
weight_decay=convert_numeric(self.config['training']['weight_decay']),
warmup_steps=convert_numeric(self.config['training']['warmup_steps']),
logging_steps=convert_numeric(self.config['training']['logging_steps']),
save_steps=convert_numeric(self.config['training']['save_steps']),
eval_steps=convert_numeric(self.config['training']['eval_steps']),
evaluation_strategy=self.config['training']['evaluation_strategy'],
save_strategy=self.config['training']['save_strategy'],
save_total_limit=convert_numeric(self.config['training']['save_total_limit']),
load_best_model_at_end=self.config['training']['load_best_model_at_end'],
metric_for_best_model=self.config['training']['metric_for_best_model'],
greater_is_better=self.config['training']['greater_is_better'],
fp16=self.config['training']['fp16'],
dataloader_num_workers=convert_numeric(self.config['training']['dataloader_num_workers']),
gradient_checkpointing=self.config['training']['gradient_checkpointing'],
max_grad_norm=convert_numeric(self.config['training']['max_grad_norm']),
report_to=self.config['logging']['report_to'],
run_name=self.config['logging']['run_name'],
log_level=self.config['logging']['log_level'],
# Memory optimization settings
dataloader_drop_last=True,
group_by_length=True,
length_column_name="length",
# Disable features that use more memory
ddp_find_unused_parameters=False,
dataloader_pin_memory=False,
# Additional memory optimizations
optim="adamw_torch_fused", # Use fused optimizer for speed
torch_compile=False, # Disable torch.compile for memory
use_cpu=False, # Keep on GPU but optimize memory
# Reduce memory fragmentation
dataloader_persistent_workers=False,
)
# Data collator for pre-tokenized data
data_collator = DataCollatorForLanguageModeling(
tokenizer=self.tokenizer,
mlm=False,
)
# Trainer
self.trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=data_collator,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)
logger.info("Training setup completed")
def train(self):
logger.info("Starting training...")
try:
# Train the model
train_result = self.trainer.train()
# Save the final model
self.trainer.save_model()
# Save training metrics
metrics = train_result.metrics
self.trainer.log_metrics("train", metrics)
self.trainer.save_metrics("train", metrics)
self.trainer.save_state()
logger.info("Training completed successfully!")
logger.info(f"Training metrics: {metrics}")
except Exception as e:
logger.error(f"Training failed: {e}")
raise
def save_model(self):
logger.info("Saving model...")
output_dir = self.config['training']['output_dir']
# Save tokenizer
self.tokenizer.save_pretrained(output_dir)
# Save model configuration
model_config = {
'base_model': self.config['model']['base_model'],
'lora_config': self.config['lora'],
'generation_config': self.config['generation']
}
config_path = os.path.join(output_dir, 'model_config.json')
import json
with open(config_path, 'w') as f:
json.dump(model_config, f, indent=2)
logger.info(f"Model saved to {output_dir}")
def run(self):
logger.info("Starting agriQA fine-tuning pipeline...")
# Load model and tokenizer
self.load_model_and_tokenizer()
# Setup LoRA
self.setup_lora()
# Load and prepare datasets
train_dataset, val_dataset = self.load_dataset()
# Setup training
self.setup_training(train_dataset, val_dataset)
# Train the model
self.train()
# Save the model
self.save_model()
logger.info("Fine-tuning pipeline completed successfully!")
def main():
parser = argparse.ArgumentParser(description="Fine-tune Qwen model on agriQA dataset")
parser.add_argument("--config", type=str, default="configs/training_config.yaml",
help="Path to training configuration file")
args = parser.parse_args()
# Initialize and run fine-tuning
fine_tuner = AgriQAFineTuner(args.config)
fine_tuner.run()
if __name__ == "__main__":
main()