Spaces:
Running
Running
Save raw bytes so can warm-restart in new python session
Browse files- pysr/julia_helpers.py +10 -2
- pysr/sr.py +34 -22
pysr/julia_helpers.py
CHANGED
|
@@ -22,8 +22,7 @@ import juliapkg
|
|
| 22 |
from juliacall import Main as jl
|
| 23 |
from juliacall import convert as jl_convert
|
| 24 |
|
| 25 |
-
jl.seval("using
|
| 26 |
-
PythonCall = jl.PythonCall
|
| 27 |
|
| 28 |
juliainfo = None
|
| 29 |
julia_initialized = False
|
|
@@ -63,3 +62,12 @@ def jl_array(x):
|
|
| 63 |
if x is None:
|
| 64 |
return None
|
| 65 |
return jl_convert(jl.Array, x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
from juliacall import Main as jl
|
| 23 |
from juliacall import convert as jl_convert
|
| 24 |
|
| 25 |
+
jl.seval("using Serialization: Serialization")
|
|
|
|
| 26 |
|
| 27 |
juliainfo = None
|
| 28 |
julia_initialized = False
|
|
|
|
| 62 |
if x is None:
|
| 63 |
return None
|
| 64 |
return jl_convert(jl.Array, x)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def jl_deserialize_s(s):
|
| 68 |
+
if s is None:
|
| 69 |
+
return s
|
| 70 |
+
buf = jl.IOBuffer()
|
| 71 |
+
jl.write(buf, jl_array(s))
|
| 72 |
+
jl.seekstart(buf)
|
| 73 |
+
return jl.Serialization.deserialize(buf)
|
pysr/sr.py
CHANGED
|
@@ -34,12 +34,11 @@ from .export_sympy import assert_valid_sympy_symbol, create_sympy_symbols, pysr2
|
|
| 34 |
from .export_torch import sympy2torch
|
| 35 |
from .feature_selection import run_feature_selection
|
| 36 |
from .julia_helpers import (
|
| 37 |
-
PythonCall,
|
| 38 |
_escape_filename,
|
| 39 |
_load_cluster_manager,
|
| 40 |
jl,
|
| 41 |
jl_array,
|
| 42 |
-
|
| 43 |
)
|
| 44 |
from .utils import (
|
| 45 |
_csv_filename_to_pkl_filename,
|
|
@@ -614,8 +613,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 614 |
Path to the temporary equations directory.
|
| 615 |
equation_file_ : str
|
| 616 |
Output equation file name produced by the julia backend.
|
| 617 |
-
|
| 618 |
-
The state for the julia SymbolicRegression.jl backend
|
| 619 |
equation_file_contents_ : list[pandas.DataFrame]
|
| 620 |
Contents of the equation file output by the Julia backend.
|
| 621 |
show_pickle_warnings_ : bool
|
|
@@ -1048,22 +1047,13 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1048 |
serialization.
|
| 1049 |
|
| 1050 |
Thus, for `PySRRegressor` to support pickle serialization, the
|
| 1051 |
-
`
|
| 1052 |
prevent the `warm_start` of any model that is loaded via `pickle.loads()`,
|
| 1053 |
but does allow all other attributes of a fitted `PySRRegressor` estimator
|
| 1054 |
to be serialized. Note: Jax and Torch format equations are also removed
|
| 1055 |
from the pickled instance.
|
| 1056 |
"""
|
| 1057 |
state = self.__dict__
|
| 1058 |
-
show_pickle_warning = not (
|
| 1059 |
-
"show_pickle_warnings_" in state and not state["show_pickle_warnings_"]
|
| 1060 |
-
)
|
| 1061 |
-
if "raw_julia_state_" in state and show_pickle_warning:
|
| 1062 |
-
warnings.warn(
|
| 1063 |
-
"raw_julia_state_ cannot be pickled and will be removed from the "
|
| 1064 |
-
"serialized instance. This will prevent a `warm_start` fit of any "
|
| 1065 |
-
"model that is deserialized via `pickle.load()`."
|
| 1066 |
-
)
|
| 1067 |
state_keys_containing_lambdas = ["extra_sympy_mappings", "extra_torch_mappings"]
|
| 1068 |
for state_key in state_keys_containing_lambdas:
|
| 1069 |
if state[state_key] is not None and show_pickle_warning:
|
|
@@ -1072,7 +1062,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1072 |
"serialized instance. When loading the model, please redefine "
|
| 1073 |
f"`{state_key}` at runtime."
|
| 1074 |
)
|
| 1075 |
-
state_keys_to_clear =
|
| 1076 |
pickled_state = {
|
| 1077 |
key: (None if key in state_keys_to_clear else value)
|
| 1078 |
for key, value in state.items()
|
|
@@ -1122,6 +1112,20 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1122 |
)
|
| 1123 |
return self.equations_
|
| 1124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1125 |
def get_best(self, index=None):
|
| 1126 |
"""
|
| 1127 |
Get best equation using `model_selection`.
|
|
@@ -1724,7 +1728,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1724 |
# Python's garbage collection is unaware of them.
|
| 1725 |
jl._equation_search_args = (jl_X, jl_y)
|
| 1726 |
jl._equation_search_kwargs = namedtuple(
|
| 1727 |
-
"
|
| 1728 |
(
|
| 1729 |
"weights",
|
| 1730 |
"niterations",
|
|
@@ -1754,18 +1758,26 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1754 |
options=options,
|
| 1755 |
numprocs=cprocs,
|
| 1756 |
parallelism=parallelism,
|
| 1757 |
-
saved_state=self.
|
| 1758 |
return_state=True,
|
| 1759 |
addprocs_function=cluster_manager,
|
| 1760 |
heap_size_hint_in_bytes=self.heap_size_hint_in_bytes,
|
| 1761 |
progress=progress and self.verbosity > 0 and len(y.shape) == 1,
|
| 1762 |
verbosity=int(self.verbosity),
|
| 1763 |
)
|
| 1764 |
-
|
| 1765 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1766 |
)
|
| 1767 |
jl._equation_search_args = None
|
| 1768 |
jl._equation_search_kwargs = None
|
|
|
|
| 1769 |
|
| 1770 |
# Set attributes
|
| 1771 |
self.equations_ = self.get_hof()
|
|
@@ -1829,10 +1841,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1829 |
Fitted estimator.
|
| 1830 |
"""
|
| 1831 |
# Init attributes that are not specified in BaseEstimator
|
| 1832 |
-
if self.warm_start and hasattr(self, "
|
| 1833 |
pass
|
| 1834 |
else:
|
| 1835 |
-
if hasattr(self, "
|
| 1836 |
warnings.warn(
|
| 1837 |
"The discovered expressions are being reset. "
|
| 1838 |
"Please set `warm_start=True` if you wish to continue "
|
|
@@ -1842,7 +1854,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1842 |
self.equations_ = None
|
| 1843 |
self.nout_ = 1
|
| 1844 |
self.selection_mask_ = None
|
| 1845 |
-
self.
|
| 1846 |
self.X_units_ = None
|
| 1847 |
self.y_units_ = None
|
| 1848 |
|
|
|
|
| 34 |
from .export_torch import sympy2torch
|
| 35 |
from .feature_selection import run_feature_selection
|
| 36 |
from .julia_helpers import (
|
|
|
|
| 37 |
_escape_filename,
|
| 38 |
_load_cluster_manager,
|
| 39 |
jl,
|
| 40 |
jl_array,
|
| 41 |
+
jl_deserialize_s,
|
| 42 |
)
|
| 43 |
from .utils import (
|
| 44 |
_csv_filename_to_pkl_filename,
|
|
|
|
| 613 |
Path to the temporary equations directory.
|
| 614 |
equation_file_ : str
|
| 615 |
Output equation file name produced by the julia backend.
|
| 616 |
+
raw_julia_state_stream_ : ndarray
|
| 617 |
+
The serialized state for the julia SymbolicRegression.jl backend (after fitting).
|
| 618 |
equation_file_contents_ : list[pandas.DataFrame]
|
| 619 |
Contents of the equation file output by the Julia backend.
|
| 620 |
show_pickle_warnings_ : bool
|
|
|
|
| 1047 |
serialization.
|
| 1048 |
|
| 1049 |
Thus, for `PySRRegressor` to support pickle serialization, the
|
| 1050 |
+
`raw_julia_state_stream_` attribute must be hidden from pickle. This will
|
| 1051 |
prevent the `warm_start` of any model that is loaded via `pickle.loads()`,
|
| 1052 |
but does allow all other attributes of a fitted `PySRRegressor` estimator
|
| 1053 |
to be serialized. Note: Jax and Torch format equations are also removed
|
| 1054 |
from the pickled instance.
|
| 1055 |
"""
|
| 1056 |
state = self.__dict__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1057 |
state_keys_containing_lambdas = ["extra_sympy_mappings", "extra_torch_mappings"]
|
| 1058 |
for state_key in state_keys_containing_lambdas:
|
| 1059 |
if state[state_key] is not None and show_pickle_warning:
|
|
|
|
| 1062 |
"serialized instance. When loading the model, please redefine "
|
| 1063 |
f"`{state_key}` at runtime."
|
| 1064 |
)
|
| 1065 |
+
state_keys_to_clear = state_keys_containing_lambdas
|
| 1066 |
pickled_state = {
|
| 1067 |
key: (None if key in state_keys_to_clear else value)
|
| 1068 |
for key, value in state.items()
|
|
|
|
| 1112 |
)
|
| 1113 |
return self.equations_
|
| 1114 |
|
| 1115 |
+
@property
|
| 1116 |
+
def julia_state(self):
|
| 1117 |
+
return jl_deserialize_s(self.raw_julia_state_stream_)
|
| 1118 |
+
|
| 1119 |
+
@property
|
| 1120 |
+
def raw_julia_state_(self):
|
| 1121 |
+
warnings.warn(
|
| 1122 |
+
"PySRRegressor.raw_julia_state_ is now deprecated. "
|
| 1123 |
+
"Please use PySRRegressor.julia_state instead, or `raw_julia_state_stream_` "
|
| 1124 |
+
"for the raw stream of bytes.",
|
| 1125 |
+
FutureWarning,
|
| 1126 |
+
)
|
| 1127 |
+
return self.julia_state
|
| 1128 |
+
|
| 1129 |
def get_best(self, index=None):
|
| 1130 |
"""
|
| 1131 |
Get best equation using `model_selection`.
|
|
|
|
| 1728 |
# Python's garbage collection is unaware of them.
|
| 1729 |
jl._equation_search_args = (jl_X, jl_y)
|
| 1730 |
jl._equation_search_kwargs = namedtuple(
|
| 1731 |
+
"equation_search_kwargs",
|
| 1732 |
(
|
| 1733 |
"weights",
|
| 1734 |
"niterations",
|
|
|
|
| 1758 |
options=options,
|
| 1759 |
numprocs=cprocs,
|
| 1760 |
parallelism=parallelism,
|
| 1761 |
+
saved_state=self.julia_state,
|
| 1762 |
return_state=True,
|
| 1763 |
addprocs_function=cluster_manager,
|
| 1764 |
heap_size_hint_in_bytes=self.heap_size_hint_in_bytes,
|
| 1765 |
progress=progress and self.verbosity > 0 and len(y.shape) == 1,
|
| 1766 |
verbosity=int(self.verbosity),
|
| 1767 |
)
|
| 1768 |
+
output_stream = jl.seval(
|
| 1769 |
+
"""
|
| 1770 |
+
let args = deepcopy(_equation_search_args), kwargs=deepcopy(_equation_search_kwargs)
|
| 1771 |
+
out = SymbolicRegression.equation_search(args...; kwargs...)
|
| 1772 |
+
buf = IOBuffer()
|
| 1773 |
+
Serialization.serialize(buf, out)
|
| 1774 |
+
take!(buf)
|
| 1775 |
+
end
|
| 1776 |
+
"""
|
| 1777 |
)
|
| 1778 |
jl._equation_search_args = None
|
| 1779 |
jl._equation_search_kwargs = None
|
| 1780 |
+
self.raw_julia_state_stream_ = np.array(output_stream).copy()
|
| 1781 |
|
| 1782 |
# Set attributes
|
| 1783 |
self.equations_ = self.get_hof()
|
|
|
|
| 1841 |
Fitted estimator.
|
| 1842 |
"""
|
| 1843 |
# Init attributes that are not specified in BaseEstimator
|
| 1844 |
+
if self.warm_start and hasattr(self, "raw_julia_state_stream_"):
|
| 1845 |
pass
|
| 1846 |
else:
|
| 1847 |
+
if hasattr(self, "raw_julia_state_stream_"):
|
| 1848 |
warnings.warn(
|
| 1849 |
"The discovered expressions are being reset. "
|
| 1850 |
"Please set `warm_start=True` if you wish to continue "
|
|
|
|
| 1854 |
self.equations_ = None
|
| 1855 |
self.nout_ = 1
|
| 1856 |
self.selection_mask_ = None
|
| 1857 |
+
self.raw_julia_state_stream_ = None
|
| 1858 |
self.X_units_ = None
|
| 1859 |
self.y_units_ = None
|
| 1860 |
|