| import torch | |
| from tqdm import tqdm | |
| import logging | |
| from pt_variety_identifier.src.bert.model import LanguageIdentfier | |
| from pt_variety_identifier.src.bert.tester import Tester | |
| import math | |
| import os | |
| class Trainer: | |
| def __init__(self, train_dataset, params, validation_dataset_dict=None) -> None: | |
| self.train_dataset = train_dataset | |
| self.model = LanguageIdentfier(params['model_name']) | |
| self.epochs = params['epochs'] | |
| self.lr = 1e-5 | |
| self.loss_fn = torch.nn.BCELoss() | |
| self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr) | |
| self.early_stoping = params['early_stoping'] | |
| self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
| self.optimizer, patience=self.early_stoping//2, verbose=True) | |
| self.device = params['device'] | |
| self.CURRENT_PATH = params['CURRENT_PATH'] | |
| self.CURRENT_TIME = params['CURRENT_TIME'] | |
| self.training_domain = params['training_domain'] if 'training_domain' in params else 'all' | |
| self.validator = None | |
| print(f"Using {self.device} device") | |
| if validation_dataset_dict: | |
| self.validator = Tester( | |
| test_dataset_dict=validation_dataset_dict, | |
| model=self.model, | |
| train_domain=self.training_domain, | |
| ) | |
| def _epoch_iter(self): | |
| self.model.train() | |
| self.model.to(self.device) | |
| self.optimizer.zero_grad() | |
| with torch.enable_grad(): | |
| total_loss = 0 | |
| for batch in tqdm(self.train_dataset): | |
| input_ids = batch['input_ids'].to(self.device) | |
| attention_mask = batch['attention_mask'].to(self.device) | |
| labels = batch['label'].to(self.device, dtype=torch.float) | |
| outputs = self.model( | |
| input_ids, attention_mask=attention_mask).squeeze(dim=1) | |
| loss = self.loss_fn(outputs, labels) | |
| loss.backward() | |
| self.optimizer.step() | |
| self.optimizer.zero_grad() | |
| total_loss += loss.item() | |
| self.scheduler.step(total_loss) | |
| return total_loss / len(self.train_dataset) | |
| def train(self): | |
| logging.info(f"Training model in {self.device}...") | |
| best_results = { | |
| 'f1': -math.inf, | |
| 'accuracy': -math.inf, | |
| 'precision': -math.inf, | |
| 'recall': -math.inf, | |
| 'loss': math.inf | |
| } | |
| for epoch in tqdm(range(self.epochs)): | |
| training_loss = self._epoch_iter() | |
| if self.validator: | |
| results = self.validator.validate() | |
| logging.info(f"Results for {self.training_domain} domain: {results} Epoch: {epoch}") | |
| if results['loss'] < best_results['loss'] and results['f1'] > best_results['f1']: | |
| logging.info( | |
| f"Saving best model... Domain:{self.training_domain} F1:{results['f1']} and Test Loss:{results['loss']}") | |
| best_results['loss'] = results['loss'] | |
| best_results['accuracy'] = results['accuracy'] | |
| best_results['f1'] = results['f1'] | |
| best_results['recall'] = results['recall'] | |
| best_results['precision'] = results['precision'] | |
| torch.save(self.model.state_dict(), os.path.join(self.CURRENT_PATH, "out", str(self.CURRENT_TIME), "models", f'{self.training_domain}.pt')) | |
| else: | |
| logging.info(f"Not saving model... F1:{results['f1']} and Test Loss:{results['loss']}") | |
| logging.info(f"Epoch {epoch} Training Loss: {training_loss}") | |
| if training_loss < 0.1: | |
| logging.info(f"Training Loss is too low, stoping training...") | |
| break | |
| return best_results |