Spaces:
Sleeping
Sleeping
File size: 6,319 Bytes
c2af030 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | import sys
from peft import get_peft_model, LoraConfig
from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM
from transformers import EarlyStoppingCallback
from codeInsight.logger import logging
from codeInsight.exception import ExceptionHandle
class ModelTrainer:
def __init__(self, model, tokenizer, datasets: dict, config: dict):
self.model = model
self.tokenizer = tokenizer
self.datasets = datasets
self.lora_config = config['lora']
self.training_config = config['training']
self.paths_config = config['paths']
self.trainer = self._setup_trainer()
logging.info("ModelTrainer initialized.")
def _get_target_module(self, model) -> list:
try:
logging.info('Start Finding LoRA target module')
candidates = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
present = set()
for name, module in model.named_modules():
for cand in candidates:
if name.endswith(cand):
present.add(cand)
return list(present) if present else ["q_proj", "v_proj"]
except Exception as e:
logging.error(f"Something is wrong here")
raise ExceptionHandle(e, sys)
def _peft_model_setup(self):
try:
logging.info('Setting up PEFT LoRA model')
lora_config = LoraConfig(
r=self.lora_config['r'],
lora_alpha=self.lora_config['lora_alpha'],
target_modules=self._get_target_module(self.model),
lora_dropout=self.lora_config['lora_dropout'],
bias=self.lora_config['bias'],
task_type=self.lora_config['task_type'],
use_rslora=self.lora_config['use_rslora']
)
peft_model = get_peft_model(self.model, lora_config)
logging.info("PEFT model created successfully.")
peft_model.print_trainable_parameters()
return peft_model
except Exception as e:
logging.error("Failed to setup PEFT model")
raise ExceptionHandle(e, sys)
def _get_training_args(self) -> SFTConfig:
try:
return SFTConfig(
output_dir=self.paths_config['output_dir'],
per_device_train_batch_size=self.training_config['per_device_train_batch_size'],
per_device_eval_batch_siz=self.training_config['per_device_eval_batch_size'],
gradient_accumulation_steps=self.training_config['gradient_accumulation_steps'],
num_train_epochs=self.training_config['num_train_epochs'],
learning_rate=self.training_config['learning_rate'],
warmup_ratio=self.training_config['warmup_ratio'],
warmup_steps=self.training_config['warmup_steps'],
bf16=self.training_config['bf16'],
tf32=self.training_config['tf32'],
fp16=self.training_config['fp16'],
lr_scheduler_type=self.training_config['lr_scheduler_type'],
optim=self.training_config['optim'],
gradient_checkpointing=self.training_config['gradient_checkpointing'],
gradient_checkpointing_kwargs=self.training_config['gradient_checkpointing_kwargs'],
max_grad_norm=self.training_config['max_grad_norm'],
weight_decay=self.training_config['weight_decay'],
logging_steps=self.training_config['logging_steps'],
eval_steps=self.training_config['eval_steps'],
save_steps=self.training_config['save_steps'],
evaluation_strategy=self.training_config['eval_strategy'],
save_strategy=self.training_config['save_strategy'],
save_total_limit=self.training_config['save_total_limit'],
load_best_model_at_end=self.training_config['load_best_model_at_end'],
metric_for_best_model=self.training_config['metric_for_best_model'],
greater_is_better=self.training_config['greater_is_better'],
prediction_loss_only=self.training_config['prediction_loss_only'],
report_to=self.training_config['report_to'],
dataloader_num_workers=self.training_config['dataloader_num_workers'],
max_seq_length=self.training_config['max_seq_length'],
dataset_text_field=self.training_config['dataset_text_field'],
label_names=self.training_config['label_names'],
neftune_noise_alpha=self.training_config['neftune_noise_alpha']
)
except Exception as e:
logging.error("Failed to create TrainingArguments")
raise ExceptionHandle(e, sys)
def _data_collator(self):
try:
return DataCollatorForCompletionOnlyLM(
response_template="<|assistant|>",
tokenizer=self.tokenizer
)
except Exception as e:
logging.error("Failed to create Data Collator")
raise ExceptionHandle(e, sys)
def _setup_trainer(self) -> SFTTrainer:
logging.info("Initializing SFTTrainer")
peft_model = self._peft_model_setup()
training_args = self._get_training_args()
trainer = SFTTrainer(
model=peft_model,
train_dataset=self.datasets['train'],
eval_dataset=self.datasets['val'],
args=training_args,
data_collator=self._data_collator(),
callbacks=[EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.001)],
)
logging.info("SFTTrainer initialized successfully.")
return trainer
def save_apater(self):
try:
adapter_path = self.paths_config['adapter_save_dir']
self.trainer.model.save_pretrained(adapter_path)
logging.info(f"LoRA adapter saved successfully to {adapter_path}")
except Exception as e:
logging.error("Failed to save LoRA adapter")
raise ExceptionHandle(e, sys) |