LCA-PORVID's picture
Upload 34 files
ebdb5af verified
import torch
from tqdm import tqdm
import logging
from pt_variety_identifier.src.bert.model import LanguageIdentfier
from pt_variety_identifier.src.bert.tester import Tester
import math
import os
class Trainer:
def __init__(self, train_dataset, params, validation_dataset_dict=None) -> None:
self.train_dataset = train_dataset
self.model = LanguageIdentfier(params['model_name'])
self.epochs = params['epochs']
self.lr = 1e-5
self.loss_fn = torch.nn.BCELoss()
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr)
self.early_stoping = params['early_stoping']
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, patience=self.early_stoping//2, verbose=True)
self.device = params['device']
self.CURRENT_PATH = params['CURRENT_PATH']
self.CURRENT_TIME = params['CURRENT_TIME']
self.training_domain = params['training_domain'] if 'training_domain' in params else 'all'
self.validator = None
print(f"Using {self.device} device")
if validation_dataset_dict:
self.validator = Tester(
test_dataset_dict=validation_dataset_dict,
model=self.model,
train_domain=self.training_domain,
)
def _epoch_iter(self):
self.model.train()
self.model.to(self.device)
self.optimizer.zero_grad()
with torch.enable_grad():
total_loss = 0
for batch in tqdm(self.train_dataset):
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
labels = batch['label'].to(self.device, dtype=torch.float)
outputs = self.model(
input_ids, attention_mask=attention_mask).squeeze(dim=1)
loss = self.loss_fn(outputs, labels)
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
total_loss += loss.item()
self.scheduler.step(total_loss)
return total_loss / len(self.train_dataset)
def train(self):
logging.info(f"Training model in {self.device}...")
best_results = {
'f1': -math.inf,
'accuracy': -math.inf,
'precision': -math.inf,
'recall': -math.inf,
'loss': math.inf
}
for epoch in tqdm(range(self.epochs)):
training_loss = self._epoch_iter()
if self.validator:
results = self.validator.validate()
logging.info(f"Results for {self.training_domain} domain: {results} Epoch: {epoch}")
if results['loss'] < best_results['loss'] and results['f1'] > best_results['f1']:
logging.info(
f"Saving best model... Domain:{self.training_domain} F1:{results['f1']} and Test Loss:{results['loss']}")
best_results['loss'] = results['loss']
best_results['accuracy'] = results['accuracy']
best_results['f1'] = results['f1']
best_results['recall'] = results['recall']
best_results['precision'] = results['precision']
torch.save(self.model.state_dict(), os.path.join(self.CURRENT_PATH, "out", str(self.CURRENT_TIME), "models", f'{self.training_domain}.pt'))
else:
logging.info(f"Not saving model... F1:{results['f1']} and Test Loss:{results['loss']}")
logging.info(f"Epoch {epoch} Training Loss: {training_loss}")
if training_loss < 0.1:
logging.info(f"Training Loss is too low, stoping training...")
break
return best_results