import torch from tqdm import tqdm import logging from pt_variety_identifier.src.bert.model import LanguageIdentfier from pt_variety_identifier.src.bert.tester import Tester import math import os class Trainer: def __init__(self, train_dataset, params, validation_dataset_dict=None) -> None: self.train_dataset = train_dataset self.model = LanguageIdentfier(params['model_name']) self.epochs = params['epochs'] self.lr = 1e-5 self.loss_fn = torch.nn.BCELoss() self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr) self.early_stoping = params['early_stoping'] self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, patience=self.early_stoping//2, verbose=True) self.device = params['device'] self.CURRENT_PATH = params['CURRENT_PATH'] self.CURRENT_TIME = params['CURRENT_TIME'] self.training_domain = params['training_domain'] if 'training_domain' in params else 'all' self.validator = None print(f"Using {self.device} device") if validation_dataset_dict: self.validator = Tester( test_dataset_dict=validation_dataset_dict, model=self.model, train_domain=self.training_domain, ) def _epoch_iter(self): self.model.train() self.model.to(self.device) self.optimizer.zero_grad() with torch.enable_grad(): total_loss = 0 for batch in tqdm(self.train_dataset): input_ids = batch['input_ids'].to(self.device) attention_mask = batch['attention_mask'].to(self.device) labels = batch['label'].to(self.device, dtype=torch.float) outputs = self.model( input_ids, attention_mask=attention_mask).squeeze(dim=1) loss = self.loss_fn(outputs, labels) loss.backward() self.optimizer.step() self.optimizer.zero_grad() total_loss += loss.item() self.scheduler.step(total_loss) return total_loss / len(self.train_dataset) def train(self): logging.info(f"Training model in {self.device}...") best_results = { 'f1': -math.inf, 'accuracy': -math.inf, 'precision': -math.inf, 'recall': -math.inf, 'loss': math.inf } for epoch in tqdm(range(self.epochs)): training_loss = self._epoch_iter() if self.validator: results = self.validator.validate() logging.info(f"Results for {self.training_domain} domain: {results} Epoch: {epoch}") if results['loss'] < best_results['loss'] and results['f1'] > best_results['f1']: logging.info( f"Saving best model... Domain:{self.training_domain} F1:{results['f1']} and Test Loss:{results['loss']}") best_results['loss'] = results['loss'] best_results['accuracy'] = results['accuracy'] best_results['f1'] = results['f1'] best_results['recall'] = results['recall'] best_results['precision'] = results['precision'] torch.save(self.model.state_dict(), os.path.join(self.CURRENT_PATH, "out", str(self.CURRENT_TIME), "models", f'{self.training_domain}.pt')) else: logging.info(f"Not saving model... F1:{results['f1']} and Test Loss:{results['loss']}") logging.info(f"Epoch {epoch} Training Loss: {training_loss}") if training_loss < 0.1: logging.info(f"Training Loss is too low, stoping training...") break return best_results