Spaces:
Running
Running
Commit ·
5c0ad55
1
Parent(s): b16d9ef
Allow loading from pickle file
Browse files- pysr/sr.py +24 -3
pysr/sr.py
CHANGED
|
@@ -2061,9 +2061,9 @@ def run_feature_selection(X, y, select_k_features, random_state=None):
|
|
| 2061 |
def load(
|
| 2062 |
equation_file,
|
| 2063 |
*,
|
| 2064 |
-
binary_operators,
|
| 2065 |
-
unary_operators,
|
| 2066 |
-
n_features_in,
|
| 2067 |
feature_names_in=None,
|
| 2068 |
selection_mask=None,
|
| 2069 |
nout=1,
|
|
@@ -2097,12 +2097,33 @@ def load(
|
|
| 2097 |
|
| 2098 |
pysr_kwargs : dict
|
| 2099 |
Any other keyword arguments to initialize the PySRRegressor object.
|
|
|
|
| 2100 |
|
| 2101 |
Returns
|
| 2102 |
-------
|
| 2103 |
model : PySRRegressor
|
| 2104 |
The model with fitted equations.
|
| 2105 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2106 |
|
| 2107 |
# TODO: copy .bkup file if exists.
|
| 2108 |
model = PySRRegressor(
|
|
|
|
| 2061 |
def load(
|
| 2062 |
equation_file,
|
| 2063 |
*,
|
| 2064 |
+
binary_operators=None,
|
| 2065 |
+
unary_operators=None,
|
| 2066 |
+
n_features_in=None,
|
| 2067 |
feature_names_in=None,
|
| 2068 |
selection_mask=None,
|
| 2069 |
nout=1,
|
|
|
|
| 2097 |
|
| 2098 |
pysr_kwargs : dict
|
| 2099 |
Any other keyword arguments to initialize the PySRRegressor object.
|
| 2100 |
+
These will overwrite those stored in the pickle file.
|
| 2101 |
|
| 2102 |
Returns
|
| 2103 |
-------
|
| 2104 |
model : PySRRegressor
|
| 2105 |
The model with fitted equations.
|
| 2106 |
"""
|
| 2107 |
+
# Try to load model from <equation_file>.pkl
|
| 2108 |
+
print(f"Checking if {equation_file}.pkl exists...")
|
| 2109 |
+
if os.path.exists(str(equation_file) + ".pkl"):
|
| 2110 |
+
assert binary_operators is None
|
| 2111 |
+
assert unary_operators is None
|
| 2112 |
+
assert n_features_in is None
|
| 2113 |
+
with open(str(equation_file) + ".pkl", "rb") as f:
|
| 2114 |
+
model = pkl.load(f)
|
| 2115 |
+
model.set_params(**pysr_kwargs)
|
| 2116 |
+
model.refresh()
|
| 2117 |
+
return model
|
| 2118 |
+
|
| 2119 |
+
# Else, we re-create it.
|
| 2120 |
+
print(
|
| 2121 |
+
f"{equation_file}.pkl does not exist, "
|
| 2122 |
+
"so we must create the model from scratch."
|
| 2123 |
+
)
|
| 2124 |
+
assert binary_operators is not None
|
| 2125 |
+
assert unary_operators is not None
|
| 2126 |
+
assert n_features_in is not None
|
| 2127 |
|
| 2128 |
# TODO: copy .bkup file if exists.
|
| 2129 |
model = PySRRegressor(
|