Spaces:
Running
Running
Commit
·
ed19905
1
Parent(s):
bae75db
Start on state saving
Browse files- pysr/sr.py +23 -3
pysr/sr.py
CHANGED
|
@@ -636,9 +636,11 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 636 |
|
| 637 |
# Stored equations:
|
| 638 |
self.equations = None
|
|
|
|
|
|
|
|
|
|
| 639 |
|
| 640 |
self.multioutput = None
|
| 641 |
-
self.raw_julia_output = None
|
| 642 |
self.equation_file = equation_file
|
| 643 |
self.n_features = None
|
| 644 |
self.extra_sympy_mappings = extra_sympy_mappings
|
|
@@ -654,7 +656,6 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 654 |
self.surface_parameters = [
|
| 655 |
"model_selection",
|
| 656 |
"multioutput",
|
| 657 |
-
"raw_julia_output",
|
| 658 |
"equation_file",
|
| 659 |
"n_features",
|
| 660 |
"extra_sympy_mappings",
|
|
@@ -1046,6 +1047,21 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 1046 |
float(weightDoNothing),
|
| 1047 |
]
|
| 1048 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1049 |
options = Main.Options(
|
| 1050 |
binary_operators=Main.eval(str(tuple(binary_operators)).replace("'", "")),
|
| 1051 |
unary_operators=Main.eval(str(tuple(unary_operators)).replace("'", "")),
|
|
@@ -1085,6 +1101,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 1085 |
optimizer_iterations=self.params["optimizer_iterations"],
|
| 1086 |
perturbationFactor=self.params["perturbationFactor"],
|
| 1087 |
annealing=self.params["annealing"],
|
|
|
|
| 1088 |
)
|
| 1089 |
|
| 1090 |
np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[
|
|
@@ -1106,7 +1123,9 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 1106 |
|
| 1107 |
cprocs = 0 if multithreading else procs
|
| 1108 |
|
| 1109 |
-
|
|
|
|
|
|
|
| 1110 |
Main.X,
|
| 1111 |
Main.y,
|
| 1112 |
weights=Main.weights,
|
|
@@ -1119,6 +1138,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 1119 |
options=options,
|
| 1120 |
numprocs=int(cprocs),
|
| 1121 |
multithreading=bool(multithreading),
|
|
|
|
| 1122 |
)
|
| 1123 |
|
| 1124 |
self.variable_names = variable_names
|
|
|
|
| 636 |
|
| 637 |
# Stored equations:
|
| 638 |
self.equations = None
|
| 639 |
+
self.params_hash = None
|
| 640 |
+
self.raw_julia_state = None
|
| 641 |
+
self.raw_julia_hof = None
|
| 642 |
|
| 643 |
self.multioutput = None
|
|
|
|
| 644 |
self.equation_file = equation_file
|
| 645 |
self.n_features = None
|
| 646 |
self.extra_sympy_mappings = extra_sympy_mappings
|
|
|
|
| 656 |
self.surface_parameters = [
|
| 657 |
"model_selection",
|
| 658 |
"multioutput",
|
|
|
|
| 659 |
"equation_file",
|
| 660 |
"n_features",
|
| 661 |
"extra_sympy_mappings",
|
|
|
|
| 1047 |
float(weightDoNothing),
|
| 1048 |
]
|
| 1049 |
|
| 1050 |
+
all_params = {
|
| 1051 |
+
**{k: self.__getattribute__(k) for k in self.surface_parameters}
|
| 1052 |
+
** self.params
|
| 1053 |
+
}
|
| 1054 |
+
if self.params_hash is not None:
|
| 1055 |
+
if hash(all_params) != self.params_hash:
|
| 1056 |
+
warnings.warn(
|
| 1057 |
+
"Warning: PySR options have changed since the last run. "
|
| 1058 |
+
"This is experimental and may not work. "
|
| 1059 |
+
"For example, if the operators change, or even their order,",
|
| 1060 |
+
" the saved equations will be in the wrong format."
|
| 1061 |
+
)
|
| 1062 |
+
|
| 1063 |
+
self.params_hash = hash(all_params)
|
| 1064 |
+
|
| 1065 |
options = Main.Options(
|
| 1066 |
binary_operators=Main.eval(str(tuple(binary_operators)).replace("'", "")),
|
| 1067 |
unary_operators=Main.eval(str(tuple(unary_operators)).replace("'", "")),
|
|
|
|
| 1101 |
optimizer_iterations=self.params["optimizer_iterations"],
|
| 1102 |
perturbationFactor=self.params["perturbationFactor"],
|
| 1103 |
annealing=self.params["annealing"],
|
| 1104 |
+
stateReturn=True, # Required for state saving.
|
| 1105 |
)
|
| 1106 |
|
| 1107 |
np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[
|
|
|
|
| 1123 |
|
| 1124 |
cprocs = 0 if multithreading else procs
|
| 1125 |
|
| 1126 |
+
# Julia return value:
|
| 1127 |
+
# state = (returnPops, hallOfFame)
|
| 1128 |
+
self.raw_julia_state, self.raw_julia_hof = Main.EquationSearch(
|
| 1129 |
Main.X,
|
| 1130 |
Main.y,
|
| 1131 |
weights=Main.weights,
|
|
|
|
| 1138 |
options=options,
|
| 1139 |
numprocs=int(cprocs),
|
| 1140 |
multithreading=bool(multithreading),
|
| 1141 |
+
saved_state=self.raw_julia_state,
|
| 1142 |
)
|
| 1143 |
|
| 1144 |
self.variable_names = variable_names
|