Spaces:
Running
Running
Commit
·
3da0df5
1
Parent(s):
03d5a42
Fix pickling for multi-output
Browse files- pysr/sr.py +18 -6
pysr/sr.py
CHANGED
|
@@ -884,12 +884,24 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 884 |
if "equations_" in pickled_state:
|
| 885 |
pickled_state["output_torch_format"] = False
|
| 886 |
pickled_state["output_jax_format"] = False
|
| 887 |
-
|
| 888 |
-
["
|
| 889 |
-
|
| 890 |
-
|
| 891 |
-
pickled_state["equations_"]
|
| 892 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 893 |
return pickled_state
|
| 894 |
|
| 895 |
@property
|
|
|
|
| 884 |
if "equations_" in pickled_state:
|
| 885 |
pickled_state["output_torch_format"] = False
|
| 886 |
pickled_state["output_jax_format"] = False
|
| 887 |
+
if self.nout_ == 1:
|
| 888 |
+
pickled_columns = ~pickled_state["equations_"].columns.isin(
|
| 889 |
+
["jax_format", "torch_format"]
|
| 890 |
+
)
|
| 891 |
+
pickled_state["equations_"] = (
|
| 892 |
+
pickled_state["equations_"].loc[:, pickled_columns].copy()
|
| 893 |
+
)
|
| 894 |
+
else:
|
| 895 |
+
pickled_columns = [
|
| 896 |
+
~dataframe.columns.isin(["jax_format", "torch_format"])
|
| 897 |
+
for dataframe in pickled_state["equations_"]
|
| 898 |
+
]
|
| 899 |
+
pickled_state["equations_"] = [
|
| 900 |
+
dataframe.loc[:, signle_pickled_columns]
|
| 901 |
+
for dataframe, signle_pickled_columns in zip(
|
| 902 |
+
pickled_state["equations_"], pickled_columns
|
| 903 |
+
)
|
| 904 |
+
]
|
| 905 |
return pickled_state
|
| 906 |
|
| 907 |
@property
|