Spaces:
Sleeping
Sleeping
Nathan
commited on
update call to `mean_squared_error` to comply with sklearn v1.6
Browse filesThe `squared` argument of the scimitar-learn method `mean_squared_error` is depreciated since v1.4 and has been removed in v1.6, making this evaluation module incompatible with the latest version.
This PR makes it call the `root_mean_squared_error` when `squared=True`, solving this error.
For reference: https://scikit-learn.org/1.5/modules/generated/sklearn.metrics.mean_squared_error.html (docs of v1.5)
cc
@lvwerra
mse.py
CHANGED
|
@@ -14,7 +14,7 @@
|
|
| 14 |
"""MSE - Mean Squared Error Metric"""
|
| 15 |
|
| 16 |
import datasets
|
| 17 |
-
from sklearn.metrics import mean_squared_error
|
| 18 |
|
| 19 |
import evaluate
|
| 20 |
|
|
@@ -112,8 +112,13 @@ class Mse(evaluate.Metric):
|
|
| 112 |
|
| 113 |
def _compute(self, predictions, references, sample_weight=None, multioutput="uniform_average", squared=True):
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
return {"mse": mse}
|
|
|
|
| 14 |
"""MSE - Mean Squared Error Metric"""
|
| 15 |
|
| 16 |
import datasets
|
| 17 |
+
from sklearn.metrics import mean_squared_error, root_mean_squared_error
|
| 18 |
|
| 19 |
import evaluate
|
| 20 |
|
|
|
|
| 112 |
|
| 113 |
def _compute(self, predictions, references, sample_weight=None, multioutput="uniform_average", squared=True):
|
| 114 |
|
| 115 |
+
if squared:
|
| 116 |
+
mse = mean_squared_error(
|
| 117 |
+
references, predictions, sample_weight=sample_weight, multioutput=multioutput
|
| 118 |
+
)
|
| 119 |
+
else:
|
| 120 |
+
mse = root_mean_squared_error(
|
| 121 |
+
references, predictions, sample_weight=sample_weight, multioutput=multioutput
|
| 122 |
+
)
|
| 123 |
|
| 124 |
return {"mse": mse}
|