| | """Reusable metrics functions for evaluating models |
| | """ |
| | import multiprocessing as mp |
| | from typing import List |
| |
|
| | import torch |
| | from torch.utils.data import DataLoader |
| | from transformers import default_data_collator |
| | from tqdm import tqdm |
| |
|
| | def get_predictions( |
| | model: torch.nn.Module, |
| | dataset: torch.utils.data.Dataset, |
| | ) -> List: |
| | """Compute model predictions for `dataset`. |
| | |
| | Args: |
| | model (torch.nn.Module): Model to evaluate |
| | dataset (torch.utils.data.Dataset): Dataset to get predictions for |
| | return_labels (bool, optional): Whether to return the labels (predictions are always returned). |
| | Defaults to True. |
| | |
| | Returns: |
| | Tuple[torch.Tensor, torch.Tensor]: 'true_labels', 'pred_labels' |
| | """ |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | model.to(device) |
| | model.eval() |
| | loader = DataLoader( |
| | dataset, |
| | batch_size=64, |
| | collate_fn=default_data_collator, |
| | drop_last=False, |
| | num_workers=mp.cpu_count(), |
| | ) |
| | pred_labels = [] |
| | for batch in tqdm(loader): |
| | inputs = {k: batch[k].to(device) for k in ["attention_mask", "input_ids"]} |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | del inputs |
| | logits = outputs[0] |
| | logits = logits.cpu().tolist() |
| | for i in range(len(logits)): |
| | pred_labels.append([round(e, 4) for e in logits[i]]) |
| | |
| | return pred_labels |