Spaces:
Running
Running
Commit
·
dde0ef7
1
Parent(s):
85371bb
Remove extra_sympy_mappings from pickle file
Browse files- pysr/sr.py +17 -2
pysr/sr.py
CHANGED
|
@@ -562,6 +562,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 562 |
equation_file_contents_ : list[pandas.DataFrame]
|
| 563 |
Contents of the equation file output by the Julia backend.
|
| 564 |
|
|
|
|
|
|
|
|
|
|
| 565 |
Notes
|
| 566 |
-----
|
| 567 |
Most default parameters have been tuned over several example equations,
|
|
@@ -873,14 +876,26 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 873 |
from the pickled instance.
|
| 874 |
"""
|
| 875 |
state = self.__dict__
|
| 876 |
-
|
|
|
|
|
|
|
|
|
|
| 877 |
warnings.warn(
|
| 878 |
"raw_julia_state_ cannot be pickled and will be removed from the "
|
| 879 |
"serialized instance. This will prevent a `warm_start` fit of any "
|
| 880 |
"model that is deserialized via `pickle.load()`."
|
| 881 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 882 |
pickled_state = {
|
| 883 |
-
key: None if key
|
| 884 |
for key, value in state.items()
|
| 885 |
}
|
| 886 |
if ("equations_" in pickled_state) and (
|
|
|
|
| 562 |
equation_file_contents_ : list[pandas.DataFrame]
|
| 563 |
Contents of the equation file output by the Julia backend.
|
| 564 |
|
| 565 |
+
show_pickle_warnings_ : bool
|
| 566 |
+
Whether to show warnings about what attributes can be pickled.
|
| 567 |
+
|
| 568 |
Notes
|
| 569 |
-----
|
| 570 |
Most default parameters have been tuned over several example equations,
|
|
|
|
| 876 |
from the pickled instance.
|
| 877 |
"""
|
| 878 |
state = self.__dict__
|
| 879 |
+
show_pickle_warning = not (
|
| 880 |
+
"show_pickle_warnings_" in state and not state["show_pickle_warnings_"]
|
| 881 |
+
)
|
| 882 |
+
if "raw_julia_state_" in state and show_pickle_warning:
|
| 883 |
warnings.warn(
|
| 884 |
"raw_julia_state_ cannot be pickled and will be removed from the "
|
| 885 |
"serialized instance. This will prevent a `warm_start` fit of any "
|
| 886 |
"model that is deserialized via `pickle.load()`."
|
| 887 |
)
|
| 888 |
+
state_keys_containing_lambdas = ["extra_sympy_mappings", "extra_torch_mappings"]
|
| 889 |
+
for state_key in state_keys_containing_lambdas:
|
| 890 |
+
if state[state_key] is not None and show_pickle_warning:
|
| 891 |
+
warnings.warn(
|
| 892 |
+
f"`{state_key}` cannot be pickled and will be removed from the "
|
| 893 |
+
"serialized instance. When loading the model, please redefine "
|
| 894 |
+
f"`{state_key}` at runtime."
|
| 895 |
+
)
|
| 896 |
+
state_keys_to_clear = ["raw_julia_state_"] + state_keys_containing_lambdas
|
| 897 |
pickled_state = {
|
| 898 |
+
key: (None if key in state_keys_to_clear else value)
|
| 899 |
for key, value in state.items()
|
| 900 |
}
|
| 901 |
if ("equations_" in pickled_state) and (
|