Spaces:
Running
Running
| import logging | |
| from enum import Enum | |
| from typing import Dict, Union, Tuple, List | |
| from src.metrics.metric_factory import metric_factory | |
| from src.dataset.datasets_data import datasets | |
| class TaskType(Enum): | |
| GENERATIVE = 0 | |
| INFERENCE = 1 | |
| class Task: | |
| """ | |
| Class representing a task to be executed. | |
| :param: name (str): The name of the task. | |
| :param: metric (str): The name of the metric to use. | |
| :param: ground_truths_column_name (str): The ground truths column name in the dataset. | |
| """ | |
| def __init__( | |
| self, | |
| task_name: str, | |
| metric: str, | |
| task_type: TaskType, | |
| ) -> None: | |
| self._metric_name = metric | |
| self._metric_computer = metric_factory(metric_name=self.metric_name) | |
| self.task_name = task_name | |
| self.dataset = datasets[task_name] | |
| self.task_type = task_type | |
| def metric_name(self) -> str: | |
| return self._metric_name | |
| def compute(self, predictions: Union[List, None]) -> Tuple[Dict, str]: | |
| warning = None | |
| if predictions is None: | |
| # Case where we did not find any prediction for the task. | |
| warning = "No predictions found for this task." | |
| return {self.metric_name: 0.0}, warning | |
| sample_size = len(predictions) | |
| if sample_size < len(self.dataset): | |
| # Means we have a sample of the prediction | |
| ground_truths = self.dataset[:sample_size] | |
| warning = ( | |
| f"Your prediction size is of '{sample_size}', while the ground truths size is " | |
| f"of '{len(self.dataset)}'. We computed the metric over the first " | |
| f"{sample_size} elements." | |
| ) | |
| elif sample_size > len(self.dataset): | |
| error = "There are more predictions than ground truths." | |
| logging.error(error) | |
| raise ValueError(error) | |
| else: | |
| ground_truths = self.dataset.ground_truths | |
| metric_score = self._metric_computer.compute( | |
| predictions=predictions, references=ground_truths | |
| ) | |
| return metric_score, warning | |