Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import yaml | |
| import argparse | |
| import logging | |
| from typing import Dict, Any | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| TrainingArguments, | |
| Trainer, | |
| DataCollatorForLanguageModeling, | |
| EarlyStoppingCallback, | |
| BitsAndBytesConfig | |
| ) | |
| from peft import ( | |
| LoraConfig, | |
| get_peft_model, | |
| prepare_model_for_kbit_training, | |
| TaskType | |
| ) | |
| from datasets import Dataset | |
| from tqdm import tqdm | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class AgriQAFineTuner: | |
| def __init__(self, config_path: str): | |
| self.config = self.load_config(config_path) # load the config file | |
| self.setup_environment() | |
| def load_config(self, config_path: str) -> Dict[str, Any]: | |
| with open(config_path, 'r') as f: | |
| config = yaml.safe_load(f) | |
| return config | |
| def setup_environment(self) -> None: | |
| # Set device | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {self.device}") | |
| # Create output directory | |
| os.makedirs(self.config['training']['output_dir'], exist_ok=True) | |
| def load_model_and_tokenizer(self): | |
| logger.info(f"Loading model: {self.config['model']['base_model']}") | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.config['model']['base_model'], | |
| trust_remote_code=self.config['model']['trust_remote_code'] | |
| ) | |
| # Add padding token if not present | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| # Load model with quantization if specified | |
| if self.config['hardware']['use_4bit']: | |
| logger.info("Loading model with 4-bit quantization") | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_quant_type=self.config['hardware']['bnb_4bit_quant_type'], | |
| bnb_4bit_use_double_quant=self.config['hardware']['bnb_4bit_use_double_quant'], | |
| ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.config['model']['base_model'], | |
| quantization_config=quantization_config, | |
| device_map=self.config['hardware']['device_map'], | |
| trust_remote_code=self.config['model']['trust_remote_code'] | |
| ) | |
| else: | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.config['model']['base_model'], | |
| device_map=self.config['hardware']['device_map'], | |
| trust_remote_code=self.config['model']['trust_remote_code'] | |
| ) | |
| # Prepare model for k-bit training | |
| if self.config['hardware']['use_4bit']: | |
| self.model = prepare_model_for_kbit_training(self.model) | |
| logger.info("Model and tokenizer loaded successfully") | |
| def setup_lora(self): | |
| # Apply LoRA configuration | |
| logger.info("Setting up LoRA configuration") | |
| lora_config = LoraConfig( | |
| r=self.config['lora']['r'], | |
| lora_alpha=self.config['lora']['lora_alpha'], | |
| target_modules=self.config['lora']['target_modules'], | |
| lora_dropout=self.config['lora']['lora_dropout'], | |
| bias=self.config['lora']['bias'], | |
| task_type=self.config['lora']['task_type'], | |
| ) | |
| # Enable gradient checkpointing for memory optimization | |
| if self.config['training']['gradient_checkpointing']: | |
| self.model.gradient_checkpointing_enable() | |
| logger.info("Gradient checkpointing enabled for memory optimization") | |
| # Apply LoRA | |
| self.model = get_peft_model(self.model, lora_config) | |
| self.model.print_trainable_parameters() | |
| logger.info("LoRA configuration applied successfully") | |
| def load_dataset(self): | |
| """Load the tokenized datasets.""" | |
| logger.info("Loading dataset") | |
| # Load pre-tokenized datasets | |
| logger.info("Loading pre-tokenized datasets...") | |
| train_dataset = Dataset.load_from_disk(os.path.join(self.config['data']['tokenized_dir'], "train")) | |
| val_dataset = Dataset.load_from_disk(os.path.join(self.config['data']['tokenized_dir'], "validation")) | |
| # Limit samples if specified | |
| max_samples = self.config['data'].get('max_samples', None) | |
| if max_samples: | |
| logger.info(f"Limiting training samples to {max_samples}") | |
| train_dataset = train_dataset.select(range(min(max_samples, len(train_dataset)))) | |
| val_dataset = val_dataset.select(range(min(max_samples // 10, len(val_dataset)))) # 10% for validation | |
| logger.info(f"Loaded tokenized training samples: {len(train_dataset)}") | |
| logger.info(f"Loaded tokenized validation samples: {len(val_dataset)}") | |
| return train_dataset, val_dataset | |
| def setup_training(self, train_dataset, val_dataset): | |
| logger.info("Setting up training configuration") | |
| # Convert numeric values from config | |
| def convert_numeric(value): | |
| if isinstance(value, str): | |
| try: | |
| return float(value) | |
| except ValueError: | |
| return value | |
| return value | |
| # Training arguments with memory optimizations | |
| training_args = TrainingArguments( | |
| output_dir=self.config['training']['output_dir'], | |
| num_train_epochs=convert_numeric(self.config['training']['num_train_epochs']), | |
| per_device_train_batch_size=convert_numeric(self.config['training']['per_device_train_batch_size']), | |
| per_device_eval_batch_size=convert_numeric(self.config['training']['per_device_eval_batch_size']), | |
| gradient_accumulation_steps=convert_numeric(self.config['training']['gradient_accumulation_steps']), | |
| learning_rate=convert_numeric(self.config['training']['learning_rate']), | |
| weight_decay=convert_numeric(self.config['training']['weight_decay']), | |
| warmup_steps=convert_numeric(self.config['training']['warmup_steps']), | |
| logging_steps=convert_numeric(self.config['training']['logging_steps']), | |
| save_steps=convert_numeric(self.config['training']['save_steps']), | |
| eval_steps=convert_numeric(self.config['training']['eval_steps']), | |
| evaluation_strategy=self.config['training']['evaluation_strategy'], | |
| save_strategy=self.config['training']['save_strategy'], | |
| save_total_limit=convert_numeric(self.config['training']['save_total_limit']), | |
| load_best_model_at_end=self.config['training']['load_best_model_at_end'], | |
| metric_for_best_model=self.config['training']['metric_for_best_model'], | |
| greater_is_better=self.config['training']['greater_is_better'], | |
| fp16=self.config['training']['fp16'], | |
| dataloader_num_workers=convert_numeric(self.config['training']['dataloader_num_workers']), | |
| gradient_checkpointing=self.config['training']['gradient_checkpointing'], | |
| max_grad_norm=convert_numeric(self.config['training']['max_grad_norm']), | |
| report_to=self.config['logging']['report_to'], | |
| run_name=self.config['logging']['run_name'], | |
| log_level=self.config['logging']['log_level'], | |
| # Memory optimization settings | |
| dataloader_drop_last=True, | |
| group_by_length=True, | |
| length_column_name="length", | |
| # Disable features that use more memory | |
| ddp_find_unused_parameters=False, | |
| dataloader_pin_memory=False, | |
| # Additional memory optimizations | |
| optim="adamw_torch_fused", # Use fused optimizer for speed | |
| torch_compile=False, # Disable torch.compile for memory | |
| use_cpu=False, # Keep on GPU but optimize memory | |
| # Reduce memory fragmentation | |
| dataloader_persistent_workers=False, | |
| ) | |
| # Data collator for pre-tokenized data | |
| data_collator = DataCollatorForLanguageModeling( | |
| tokenizer=self.tokenizer, | |
| mlm=False, | |
| ) | |
| # Trainer | |
| self.trainer = Trainer( | |
| model=self.model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=val_dataset, | |
| data_collator=data_collator, | |
| callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] | |
| ) | |
| logger.info("Training setup completed") | |
| def train(self): | |
| logger.info("Starting training...") | |
| try: | |
| # Train the model | |
| train_result = self.trainer.train() | |
| # Save the final model | |
| self.trainer.save_model() | |
| # Save training metrics | |
| metrics = train_result.metrics | |
| self.trainer.log_metrics("train", metrics) | |
| self.trainer.save_metrics("train", metrics) | |
| self.trainer.save_state() | |
| logger.info("Training completed successfully!") | |
| logger.info(f"Training metrics: {metrics}") | |
| except Exception as e: | |
| logger.error(f"Training failed: {e}") | |
| raise | |
| def save_model(self): | |
| logger.info("Saving model...") | |
| output_dir = self.config['training']['output_dir'] | |
| # Save tokenizer | |
| self.tokenizer.save_pretrained(output_dir) | |
| # Save model configuration | |
| model_config = { | |
| 'base_model': self.config['model']['base_model'], | |
| 'lora_config': self.config['lora'], | |
| 'generation_config': self.config['generation'] | |
| } | |
| config_path = os.path.join(output_dir, 'model_config.json') | |
| import json | |
| with open(config_path, 'w') as f: | |
| json.dump(model_config, f, indent=2) | |
| logger.info(f"Model saved to {output_dir}") | |
| def run(self): | |
| logger.info("Starting agriQA fine-tuning pipeline...") | |
| # Load model and tokenizer | |
| self.load_model_and_tokenizer() | |
| # Setup LoRA | |
| self.setup_lora() | |
| # Load and prepare datasets | |
| train_dataset, val_dataset = self.load_dataset() | |
| # Setup training | |
| self.setup_training(train_dataset, val_dataset) | |
| # Train the model | |
| self.train() | |
| # Save the model | |
| self.save_model() | |
| logger.info("Fine-tuning pipeline completed successfully!") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Fine-tune Qwen model on agriQA dataset") | |
| parser.add_argument("--config", type=str, default="configs/training_config.yaml", | |
| help="Path to training configuration file") | |
| args = parser.parse_args() | |
| # Initialize and run fine-tuning | |
| fine_tuner = AgriQAFineTuner(args.config) | |
| fine_tuner.run() | |
| if __name__ == "__main__": | |
| main() |