import torch import evaluate from tqdm import tqdm import logging class Tester: def __init__(self, test_dataset_dict, model, train_domain) -> None: self.test_dataset_dict = test_dataset_dict self.model = model self.train_domain = train_domain self.accuracy = evaluate.load("accuracy") self.f1 = evaluate.load("f1") self.precision = evaluate.load("precision") self.recall = evaluate.load("recall") self.loss_fn = torch.nn.BCELoss() self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") def _validate(self, test_dataset): with torch.no_grad(): total_loss = 0 for batch in tqdm(test_dataset): input_ids = batch['input_ids'].to(self.device) attention_mask = batch['attention_mask'].to(self.device) labels = batch['label'].to(self.device) logits = self.model(input_ids, attention_mask=attention_mask).squeeze(dim=1) loss = self.loss_fn(logits, labels.float()) # If logits is bigger than 0.5, it's 1, otherwise it's 0 predictions = (logits > 0.5).long() # Detach from GPU predictions = predictions.cpu() labels = labels.cpu() accuracy = self.accuracy.add_batch( predictions=predictions, references=labels) f1 = self.f1.add_batch( predictions=predictions, references=labels) precision = self.precision.add_batch( predictions=predictions, references=labels) recall = self.recall.add_batch( predictions=predictions, references=labels) total_loss += loss.item() accuracy = self.accuracy.compute()['accuracy'] f1 = self.f1.compute()['f1'] precision = self.precision.compute()['precision'] recall = self.recall.compute()['recall'] total_loss = total_loss / len(test_dataset) return accuracy, f1, precision, recall, total_loss def validate(self): self.model.eval() self.model.to(self.device) results = {} average_results = {} for domain in self.test_dataset_dict.keys(): logging.info(f"Testing {domain} domain...") accuracy, f1, precision, recall, total_loss = self._validate(self.test_dataset_dict[domain]) results[domain] = { 'accuracy': accuracy, 'f1': f1, 'precision': precision, 'recall': recall, 'loss': total_loss } # Remove key for train domain if self.train_domain in results.keys(): results.pop(self.train_domain) if len(results.keys()) == 0: logging.info("Only one domain to test, returning results") return results # Calculate the average of all domains except the train domain for metric in ['accuracy', 'f1', 'precision', 'recall', 'loss']: average_results[metric] = sum([results[domain][metric] for domain in results.keys()]) / len(results.keys()) return results, average_results # Migrate this method to Model def _bagging(self, logits): # Average the logits return torch.mean(logits, dim=0) def _test(self, test_dataset): with torch.no_grad(): total_loss = 0 for batch in tqdm(test_dataset): input_ids = batch['input_ids'].to(self.device) attention_mask = batch['attention_mask'].to(self.device) labels = batch['label'].to(self.device) logits = self.model(input_ids, attention_mask=attention_mask).squeeze(dim=1) logits = self._bagging(logits) loss = self.loss_fn(logits, labels.float()) # If logits is bigger than 0.5, it's 1, otherwise it's 0 predictions = (logits > 0.5).long() # Detach from GPU predictions = predictions.cpu() labels = labels.cpu() accuracy = self.accuracy.add_batch( predictions=predictions, references=labels) f1 = self.f1.add_batch( predictions=predictions, references=labels) precision = self.precision.add_batch( predictions=predictions, references=labels) recall = self.recall.add_batch( predictions=predictions, references=labels) total_loss += loss.item() accuracy = self.accuracy.compute()['accuracy'] f1 = self.f1.compute()['f1'] precision = self.precision.compute()['precision'] recall = self.recall.compute()['recall'] total_loss = total_loss / len(test_dataset) return accuracy, f1, precision, recall, total_loss def test(self): results={} with torch.no_grad(): for test_set in self.test_dataset_dict.keys(): logging.info(f"Testing {test_set} dataset") accuracy, f1, precision, recall, total_loss = self._test(self.test_dataset_dict[test_set]) results[test_set] = { 'accuracy': accuracy, 'f1': f1, 'precision': precision, 'recall': recall, 'loss': total_loss } logging.info(f"Results for {test_set} dataset: {results[test_set]}") return results