Upload metrics.py with huggingface_hub
Browse files- metrics.py +76 -13
metrics.py
CHANGED
|
@@ -16,7 +16,7 @@ from scipy.stats import bootstrap
|
|
| 16 |
from scipy.stats._warnings_errors import DegenerateDataWarning
|
| 17 |
|
| 18 |
from .artifact import Artifact
|
| 19 |
-
from .dataclass import InternalField, OptionalField
|
| 20 |
from .logging_utils import get_logger
|
| 21 |
from .metric_utils import InstanceInput, MetricRequest, MetricResponse
|
| 22 |
from .operator import (
|
|
@@ -58,6 +58,16 @@ def nan_mean(x):
|
|
| 58 |
return np.nanmean(x)
|
| 59 |
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
class UpdateStream(StreamInstanceOperator):
|
| 62 |
update: dict
|
| 63 |
|
|
@@ -69,11 +79,7 @@ class UpdateStream(StreamInstanceOperator):
|
|
| 69 |
|
| 70 |
|
| 71 |
class Metric(Artifact):
|
| 72 |
-
|
| 73 |
-
@abstractmethod
|
| 74 |
-
def main_score(self):
|
| 75 |
-
pass
|
| 76 |
-
|
| 77 |
# Override 'prediction_type' with the expected type of predictions
|
| 78 |
# and references. Example: "List[str]", "List[Dict]"", "string".
|
| 79 |
# If left with default None, a warning will be displayed.
|
|
@@ -229,6 +235,18 @@ class MetricWithConfidenceInterval(Metric):
|
|
| 229 |
[instance["score"]["instance"][score_name] for instance in instances]
|
| 230 |
)
|
| 231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
@staticmethod
|
| 233 |
def _all_instance_scores_equal(instances, score_name):
|
| 234 |
instance_scores = [
|
|
@@ -625,13 +643,10 @@ class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
| 625 |
# if subgroup_column is not None, a column by the specified name will be required in task_data
|
| 626 |
subgroup_column = None
|
| 627 |
implemented_reductions: List[str] = field(
|
| 628 |
-
default_factory=lambda: ["mean", "group_mean"]
|
| 629 |
)
|
| 630 |
|
| 631 |
-
|
| 632 |
-
@abstractmethod
|
| 633 |
-
def reduction_map(self) -> dict:
|
| 634 |
-
pass
|
| 635 |
|
| 636 |
def _validate_group_mean_reduction(self, instances: List[dict]):
|
| 637 |
"""Ensure that group_mean reduction_map is properly formatted.
|
|
@@ -739,12 +754,19 @@ class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
| 739 |
|
| 740 |
field_name_full_prefix = ""
|
| 741 |
# used for passing to the bootstrapping, depends on whether the groups are fixed or not
|
| 742 |
-
aggregation_function =
|
| 743 |
if reduction_type == "mean":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 744 |
reduction_fields = list(set(reduction_params))
|
| 745 |
# no group reduction, so resample instances individually
|
| 746 |
scores_to_resample = instances
|
| 747 |
elif reduction_type == "group_mean":
|
|
|
|
| 748 |
self._validate_group_mean_reduction(instances=instances)
|
| 749 |
reduction_fields = (
|
| 750 |
[self.main_score]
|
|
@@ -941,6 +963,12 @@ class Accuracy(InstanceMetric):
|
|
| 941 |
return result
|
| 942 |
|
| 943 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 944 |
class UnsortedListExactMatch(InstanceMetric):
|
| 945 |
reduction_map = {"mean": ["unsorted_list_exact_match"]}
|
| 946 |
main_score = "unsorted_list_exact_match"
|
|
@@ -988,7 +1016,15 @@ class MetricPipeline(MultiStreamOperator, Metric):
|
|
| 988 |
self.metric.disable_confidence_interval_calculation()
|
| 989 |
|
| 990 |
def verify(self):
|
| 991 |
-
assert
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 992 |
|
| 993 |
def prepare(self):
|
| 994 |
super().prepare()
|
|
@@ -3266,3 +3302,30 @@ class BinaryMaxAccuracy(GlobalMetric):
|
|
| 3266 |
best_thr = thr
|
| 3267 |
|
| 3268 |
return {self.main_score: best_acc, "best_thr_max_acc": best_thr}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
from scipy.stats._warnings_errors import DegenerateDataWarning
|
| 17 |
|
| 18 |
from .artifact import Artifact
|
| 19 |
+
from .dataclass import AbstractField, InternalField, OptionalField
|
| 20 |
from .logging_utils import get_logger
|
| 21 |
from .metric_utils import InstanceInput, MetricRequest, MetricResponse
|
| 22 |
from .operator import (
|
|
|
|
| 58 |
return np.nanmean(x)
|
| 59 |
|
| 60 |
|
| 61 |
+
def nan_max(x):
|
| 62 |
+
with warnings.catch_warnings():
|
| 63 |
+
# final mean should be mean of scores, ignoring NaN, hence nanmax
|
| 64 |
+
# but if the group function values is NaN for ALL values, nanmean throws a
|
| 65 |
+
# RuntimeWarning that it is calculating the mean of an empty slice (with no non-Nans)
|
| 66 |
+
# this is the desired behavior, but we want to avoid the warning here
|
| 67 |
+
warnings.simplefilter("ignore", category=RuntimeWarning)
|
| 68 |
+
return np.nanmax(x)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
class UpdateStream(StreamInstanceOperator):
|
| 72 |
update: dict
|
| 73 |
|
|
|
|
| 79 |
|
| 80 |
|
| 81 |
class Metric(Artifact):
|
| 82 |
+
main_score: str = AbstractField()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
# Override 'prediction_type' with the expected type of predictions
|
| 84 |
# and references. Example: "List[str]", "List[Dict]"", "string".
|
| 85 |
# If left with default None, a warning will be displayed.
|
|
|
|
| 235 |
[instance["score"]["instance"][score_name] for instance in instances]
|
| 236 |
)
|
| 237 |
|
| 238 |
+
@staticmethod
|
| 239 |
+
def max_item_scores(instances: List[dict], score_name: str):
|
| 240 |
+
"""Calculate max of a set of instance scores (given by score_name), omitting NaN values.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
instances: list of dicts of each instance's instance scores.
|
| 244 |
+
score_name: score field names to compute the mean for.
|
| 245 |
+
"""
|
| 246 |
+
return nan_max(
|
| 247 |
+
[instance["score"]["instance"][score_name] for instance in instances]
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
@staticmethod
|
| 251 |
def _all_instance_scores_equal(instances, score_name):
|
| 252 |
instance_scores = [
|
|
|
|
| 643 |
# if subgroup_column is not None, a column by the specified name will be required in task_data
|
| 644 |
subgroup_column = None
|
| 645 |
implemented_reductions: List[str] = field(
|
| 646 |
+
default_factory=lambda: ["mean", "group_mean", "max"]
|
| 647 |
)
|
| 648 |
|
| 649 |
+
reduction_map: Dict[str, List[str]] = AbstractField()
|
|
|
|
|
|
|
|
|
|
| 650 |
|
| 651 |
def _validate_group_mean_reduction(self, instances: List[dict]):
|
| 652 |
"""Ensure that group_mean reduction_map is properly formatted.
|
|
|
|
| 754 |
|
| 755 |
field_name_full_prefix = ""
|
| 756 |
# used for passing to the bootstrapping, depends on whether the groups are fixed or not
|
| 757 |
+
aggregation_function = None
|
| 758 |
if reduction_type == "mean":
|
| 759 |
+
aggregation_function = self.average_item_scores
|
| 760 |
+
reduction_fields = list(set(reduction_params))
|
| 761 |
+
# no group reduction, so resample instances individually
|
| 762 |
+
scores_to_resample = instances
|
| 763 |
+
elif reduction_type == "max":
|
| 764 |
+
aggregation_function = self.max_item_scores
|
| 765 |
reduction_fields = list(set(reduction_params))
|
| 766 |
# no group reduction, so resample instances individually
|
| 767 |
scores_to_resample = instances
|
| 768 |
elif reduction_type == "group_mean":
|
| 769 |
+
aggregation_function = self.average_item_scores
|
| 770 |
self._validate_group_mean_reduction(instances=instances)
|
| 771 |
reduction_fields = (
|
| 772 |
[self.main_score]
|
|
|
|
| 963 |
return result
|
| 964 |
|
| 965 |
|
| 966 |
+
class MaxAccuracy(Accuracy):
|
| 967 |
+
"""Calculate the maximal accuracy over all instances as the global score."""
|
| 968 |
+
|
| 969 |
+
reduction_map = {"max": ["accuracy"]}
|
| 970 |
+
|
| 971 |
+
|
| 972 |
class UnsortedListExactMatch(InstanceMetric):
|
| 973 |
reduction_map = {"mean": ["unsorted_list_exact_match"]}
|
| 974 |
main_score = "unsorted_list_exact_match"
|
|
|
|
| 1016 |
self.metric.disable_confidence_interval_calculation()
|
| 1017 |
|
| 1018 |
def verify(self):
|
| 1019 |
+
assert (
|
| 1020 |
+
self.metric is not None
|
| 1021 |
+
), f"'metric' is not set in {self.get_metric_name()}"
|
| 1022 |
+
assert (
|
| 1023 |
+
self.main_score is not None
|
| 1024 |
+
), f"'main_score' is not set in {self.get_metric_name()}"
|
| 1025 |
+
assert isinstance(
|
| 1026 |
+
self.metric, Metric
|
| 1027 |
+
), f"'metric' is not set to a Metric class in {self.get_metric_name()} (type{self.metric})"
|
| 1028 |
|
| 1029 |
def prepare(self):
|
| 1030 |
super().prepare()
|
|
|
|
| 3302 |
best_thr = thr
|
| 3303 |
|
| 3304 |
return {self.main_score: best_acc, "best_thr_max_acc": best_thr}
|
| 3305 |
+
|
| 3306 |
+
|
| 3307 |
+
KO_ERROR_MESSAGE = """
|
| 3308 |
+
|
| 3309 |
+
Additional dependencies required. To install them, run:
|
| 3310 |
+
`pip install "sacrebleu[ko]"`.
|
| 3311 |
+
|
| 3312 |
+
For MacOS: If error on 'mecab-config' show up during installation ], one should run:
|
| 3313 |
+
|
| 3314 |
+
`brew install mecab`
|
| 3315 |
+
`pip install "sacrebleu[ko]"`
|
| 3316 |
+
|
| 3317 |
+
"""
|
| 3318 |
+
|
| 3319 |
+
|
| 3320 |
+
class NormalizedSacrebleu(HuggingfaceMetric):
|
| 3321 |
+
hf_metric_name = "sacrebleu"
|
| 3322 |
+
hf_main_score = "score"
|
| 3323 |
+
prediction_type = "str"
|
| 3324 |
+
main_score = "sacrebleu"
|
| 3325 |
+
scale = 100.0
|
| 3326 |
+
scaled_fields = ["sacrebleu", "precisions"]
|
| 3327 |
+
hf_additional_input_fields_pass_one_value = ["tokenize"]
|
| 3328 |
+
_requirements_list = {
|
| 3329 |
+
"mecab_ko": KO_ERROR_MESSAGE,
|
| 3330 |
+
"mecab_ko_dic": KO_ERROR_MESSAGE,
|
| 3331 |
+
}
|