Spaces:
Runtime error
Runtime error
Refactor SemF1 aggregation logic and fix typo in comment
Browse files
semf1.py
CHANGED
|
@@ -422,8 +422,8 @@ class SemF1(evaluate.Metric):
|
|
| 422 |
recall_scores = [np.clip(r_scores, 0.0, 1.0).item() for (r_scores, _) in recall_scores]
|
| 423 |
|
| 424 |
results.append(Scores(precision, recall_scores))
|
| 425 |
-
|
| 426 |
-
#
|
| 427 |
if aggregate:
|
| 428 |
mean_prec = np.mean(
|
| 429 |
[score.precision for score in results]
|
|
@@ -432,12 +432,9 @@ class SemF1(evaluate.Metric):
|
|
| 432 |
[np.array(score.recall) for score in results]
|
| 433 |
))
|
| 434 |
aggregated_score = Scores(
|
| 435 |
-
float(mean_prec),
|
| 436 |
[float(mean_recall)]
|
| 437 |
)
|
| 438 |
-
aggregated_score.f1 = float(np.mean(
|
| 439 |
-
[score.f1 for score in results]
|
| 440 |
-
))
|
| 441 |
results = aggregated_score
|
| 442 |
|
| 443 |
-
return results
|
|
|
|
| 422 |
recall_scores = [np.clip(r_scores, 0.0, 1.0).item() for (r_scores, _) in recall_scores]
|
| 423 |
|
| 424 |
results.append(Scores(precision, recall_scores))
|
| 425 |
+
|
| 426 |
+
# run aggregation procedure
|
| 427 |
if aggregate:
|
| 428 |
mean_prec = np.mean(
|
| 429 |
[score.precision for score in results]
|
|
|
|
| 432 |
[np.array(score.recall) for score in results]
|
| 433 |
))
|
| 434 |
aggregated_score = Scores(
|
| 435 |
+
float(mean_prec),
|
| 436 |
[float(mean_recall)]
|
| 437 |
)
|
|
|
|
|
|
|
|
|
|
| 438 |
results = aggregated_score
|
| 439 |
|
| 440 |
+
return results
|
tests.py
CHANGED
|
@@ -708,5 +708,6 @@ class TestValidateInputFormat(unittest.TestCase):
|
|
| 708 |
def run_tests():
|
| 709 |
unittest.main(verbosity=2)
|
| 710 |
|
|
|
|
| 711 |
if __name__ == '__main__':
|
| 712 |
run_tests()
|
|
|
|
| 708 |
def run_tests():
|
| 709 |
unittest.main(verbosity=2)
|
| 710 |
|
| 711 |
+
|
| 712 |
if __name__ == '__main__':
|
| 713 |
run_tests()
|