first
Browse files
aucpr.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
| 15 |
|
| 16 |
import evaluate
|
| 17 |
import datasets
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
# TODO: Add BibTeX citation
|
|
@@ -70,9 +71,13 @@ class AUCPR(evaluate.Metric):
|
|
| 70 |
citation=_CITATION,
|
| 71 |
inputs_description=_KWARGS_DESCRIPTION,
|
| 72 |
# This defines the format of each prediction and reference
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
features=datasets.Features({
|
| 74 |
-
|
| 75 |
-
|
| 76 |
}),
|
| 77 |
# Homepage of the module for documentation
|
| 78 |
homepage="http://module.homepage",
|
|
@@ -86,10 +91,11 @@ class AUCPR(evaluate.Metric):
|
|
| 86 |
# TODO: Download external resources if needed
|
| 87 |
pass
|
| 88 |
|
| 89 |
-
def _compute(self,
|
| 90 |
"""Returns the scores"""
|
| 91 |
# TODO: Compute the different scores of the module
|
| 92 |
-
|
|
|
|
| 93 |
return {
|
| 94 |
-
"
|
| 95 |
}
|
|
|
|
| 15 |
|
| 16 |
import evaluate
|
| 17 |
import datasets
|
| 18 |
+
from sklearn.metrics import precision_recall_curve, auc
|
| 19 |
|
| 20 |
|
| 21 |
# TODO: Add BibTeX citation
|
|
|
|
| 71 |
citation=_CITATION,
|
| 72 |
inputs_description=_KWARGS_DESCRIPTION,
|
| 73 |
# This defines the format of each prediction and reference
|
| 74 |
+
# features=datasets.Features({
|
| 75 |
+
# 'predictions': datasets.Value('int64'),
|
| 76 |
+
# 'references': datasets.Value('int64'),
|
| 77 |
+
# }),
|
| 78 |
features=datasets.Features({
|
| 79 |
+
"prediction_scores": datasets.Value("float"),
|
| 80 |
+
"references": datasets.Value("int32"),
|
| 81 |
}),
|
| 82 |
# Homepage of the module for documentation
|
| 83 |
homepage="http://module.homepage",
|
|
|
|
| 91 |
# TODO: Download external resources if needed
|
| 92 |
pass
|
| 93 |
|
| 94 |
+
def _compute(self, references, prediction_scores):
|
| 95 |
"""Returns the scores"""
|
| 96 |
# TODO: Compute the different scores of the module
|
| 97 |
+
precision, recall, _ = precision_recall_curve(references, prediction_scores)
|
| 98 |
+
aucpr = auc(recall, precision)
|
| 99 |
return {
|
| 100 |
+
"aucpr": aucpr,
|
| 101 |
}
|