Spaces:
Running
Running
Commit
·
43bc86a
1
Parent(s):
406ae3e
Process params in __init__ instead of fit
Browse files- pysr/sr.py +6 -13
pysr/sr.py
CHANGED
|
@@ -754,6 +754,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
| 754 |
f"{k} is not a valid keyword argument for PySRRegressor"
|
| 755 |
)
|
| 756 |
|
|
|
|
|
|
|
| 757 |
def __repr__(self):
|
| 758 |
"""
|
| 759 |
Prints all current equations fitted by the model.
|
|
@@ -858,20 +860,12 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
| 858 |
f"{self.model_selection} is not a valid model selection strategy."
|
| 859 |
)
|
| 860 |
|
| 861 |
-
def
|
| 862 |
"""
|
| 863 |
Perform validation on the parameters defined in init for the
|
| 864 |
-
dataset specified in :term`fit
|
| 865 |
-
|
| 866 |
-
|
| 867 |
-
----------
|
| 868 |
-
n_samples : int
|
| 869 |
-
Number of samples in the dataset to be fitted.
|
| 870 |
-
|
| 871 |
-
Returns
|
| 872 |
-
-------
|
| 873 |
-
self : object
|
| 874 |
-
Reference to `self` with validated parameters.
|
| 875 |
|
| 876 |
Raises
|
| 877 |
------
|
|
@@ -1406,7 +1400,6 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
| 1406 |
self.raw_julia_state_ = None
|
| 1407 |
|
| 1408 |
# Parameter input validation (for parameters defined in __init__)
|
| 1409 |
-
self._validate_params(n_samples=X.shape[0])
|
| 1410 |
X, y, Xresampled, variable_names = self._validate_fit_params(
|
| 1411 |
X, y, Xresampled, variable_names
|
| 1412 |
)
|
|
|
|
| 754 |
f"{k} is not a valid keyword argument for PySRRegressor"
|
| 755 |
)
|
| 756 |
|
| 757 |
+
self._process_params()
|
| 758 |
+
|
| 759 |
def __repr__(self):
|
| 760 |
"""
|
| 761 |
Prints all current equations fitted by the model.
|
|
|
|
| 860 |
f"{self.model_selection} is not a valid model selection strategy."
|
| 861 |
)
|
| 862 |
|
| 863 |
+
def _process_params(self):
|
| 864 |
"""
|
| 865 |
Perform validation on the parameters defined in init for the
|
| 866 |
+
dataset specified in :term`fit`, and update them if necessary.
|
| 867 |
+
For example, this will change :param`binary_operators`
|
| 868 |
+
into `["+", "-", "*", "/"]` if `binary_operators` is `None`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 869 |
|
| 870 |
Raises
|
| 871 |
------
|
|
|
|
| 1400 |
self.raw_julia_state_ = None
|
| 1401 |
|
| 1402 |
# Parameter input validation (for parameters defined in __init__)
|
|
|
|
| 1403 |
X, y, Xresampled, variable_names = self._validate_fit_params(
|
| 1404 |
X, y, Xresampled, variable_names
|
| 1405 |
)
|