import sys from peft import get_peft_model, LoraConfig from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM from transformers import EarlyStoppingCallback from codeInsight.logger import logging from codeInsight.exception import ExceptionHandle class ModelTrainer: def __init__(self, model, tokenizer, datasets: dict, config: dict): self.model = model self.tokenizer = tokenizer self.datasets = datasets self.lora_config = config['lora'] self.training_config = config['training'] self.paths_config = config['paths'] self.trainer = self._setup_trainer() logging.info("ModelTrainer initialized.") def _get_target_module(self, model) -> list: try: logging.info('Start Finding LoRA target module') candidates = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] present = set() for name, module in model.named_modules(): for cand in candidates: if name.endswith(cand): present.add(cand) return list(present) if present else ["q_proj", "v_proj"] except Exception as e: logging.error(f"Something is wrong here") raise ExceptionHandle(e, sys) def _peft_model_setup(self): try: logging.info('Setting up PEFT LoRA model') lora_config = LoraConfig( r=self.lora_config['r'], lora_alpha=self.lora_config['lora_alpha'], target_modules=self._get_target_module(self.model), lora_dropout=self.lora_config['lora_dropout'], bias=self.lora_config['bias'], task_type=self.lora_config['task_type'], use_rslora=self.lora_config['use_rslora'] ) peft_model = get_peft_model(self.model, lora_config) logging.info("PEFT model created successfully.") peft_model.print_trainable_parameters() return peft_model except Exception as e: logging.error("Failed to setup PEFT model") raise ExceptionHandle(e, sys) def _get_training_args(self) -> SFTConfig: try: return SFTConfig( output_dir=self.paths_config['output_dir'], per_device_train_batch_size=self.training_config['per_device_train_batch_size'], per_device_eval_batch_siz=self.training_config['per_device_eval_batch_size'], gradient_accumulation_steps=self.training_config['gradient_accumulation_steps'], num_train_epochs=self.training_config['num_train_epochs'], learning_rate=self.training_config['learning_rate'], warmup_ratio=self.training_config['warmup_ratio'], warmup_steps=self.training_config['warmup_steps'], bf16=self.training_config['bf16'], tf32=self.training_config['tf32'], fp16=self.training_config['fp16'], lr_scheduler_type=self.training_config['lr_scheduler_type'], optim=self.training_config['optim'], gradient_checkpointing=self.training_config['gradient_checkpointing'], gradient_checkpointing_kwargs=self.training_config['gradient_checkpointing_kwargs'], max_grad_norm=self.training_config['max_grad_norm'], weight_decay=self.training_config['weight_decay'], logging_steps=self.training_config['logging_steps'], eval_steps=self.training_config['eval_steps'], save_steps=self.training_config['save_steps'], evaluation_strategy=self.training_config['eval_strategy'], save_strategy=self.training_config['save_strategy'], save_total_limit=self.training_config['save_total_limit'], load_best_model_at_end=self.training_config['load_best_model_at_end'], metric_for_best_model=self.training_config['metric_for_best_model'], greater_is_better=self.training_config['greater_is_better'], prediction_loss_only=self.training_config['prediction_loss_only'], report_to=self.training_config['report_to'], dataloader_num_workers=self.training_config['dataloader_num_workers'], max_seq_length=self.training_config['max_seq_length'], dataset_text_field=self.training_config['dataset_text_field'], label_names=self.training_config['label_names'], neftune_noise_alpha=self.training_config['neftune_noise_alpha'] ) except Exception as e: logging.error("Failed to create TrainingArguments") raise ExceptionHandle(e, sys) def _data_collator(self): try: return DataCollatorForCompletionOnlyLM( response_template="<|assistant|>", tokenizer=self.tokenizer ) except Exception as e: logging.error("Failed to create Data Collator") raise ExceptionHandle(e, sys) def _setup_trainer(self) -> SFTTrainer: logging.info("Initializing SFTTrainer") peft_model = self._peft_model_setup() training_args = self._get_training_args() trainer = SFTTrainer( model=peft_model, train_dataset=self.datasets['train'], eval_dataset=self.datasets['val'], args=training_args, data_collator=self._data_collator(), callbacks=[EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.001)], ) logging.info("SFTTrainer initialized successfully.") return trainer def save_apater(self): try: adapter_path = self.paths_config['adapter_save_dir'] self.trainer.model.save_pretrained(adapter_path) logging.info(f"LoRA adapter saved successfully to {adapter_path}") except Exception as e: logging.error("Failed to save LoRA adapter") raise ExceptionHandle(e, sys)