cole / src /task /task.py
davebulaval's picture
v1
8fa3acc
import logging
from enum import Enum
from typing import Dict, Union, Tuple, List
from src.metrics.metric_factory import metric_factory
from src.dataset.datasets_data import datasets
class TaskType(Enum):
GENERATIVE = 0
INFERENCE = 1
class Task:
"""
Class representing a task to be executed.
:param: name (str): The name of the task.
:param: metric (str): The name of the metric to use.
:param: ground_truths_column_name (str): The ground truths column name in the dataset.
"""
def __init__(
self,
task_name: str,
metric: str,
task_type: TaskType,
) -> None:
self._metric_name = metric
self._metric_computer = metric_factory(metric_name=self.metric_name)
self.task_name = task_name
self.dataset = datasets[task_name]
self.task_type = task_type
@property
def metric_name(self) -> str:
return self._metric_name
def compute(self, predictions: Union[List, None]) -> Tuple[Dict, str]:
warning = None
if predictions is None:
# Case where we did not find any prediction for the task.
warning = "No predictions found for this task."
return {self.metric_name: 0.0}, warning
sample_size = len(predictions)
if sample_size < len(self.dataset):
# Means we have a sample of the prediction
ground_truths = self.dataset[:sample_size]
warning = (
f"Your prediction size is of '{sample_size}', while the ground truths size is "
f"of '{len(self.dataset)}'. We computed the metric over the first "
f"{sample_size} elements."
)
elif sample_size > len(self.dataset):
error = "There are more predictions than ground truths."
logging.error(error)
raise ValueError(error)
else:
ground_truths = self.dataset.ground_truths
metric_score = self._metric_computer.compute(
predictions=predictions, references=ground_truths
)
return metric_score, warning