Spaces:
Running
Running
Commit
·
44dcbea
1
Parent(s):
4c9fe98
Allow functional versions of early stop condition
Browse files- pysr/sr.py +8 -3
pysr/sr.py
CHANGED
|
@@ -312,8 +312,11 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 312 |
annealing : bool, default=True
|
| 313 |
Whether to use annealing. You should (and it is default).
|
| 314 |
|
| 315 |
-
early_stop_condition : float, default=None
|
| 316 |
-
Stop the search early if this loss is reached.
|
|
|
|
|
|
|
|
|
|
| 317 |
|
| 318 |
ncyclesperiteration : int, default=550
|
| 319 |
Number of total mutations to run, per 10 samples of the
|
|
@@ -971,6 +974,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 971 |
|
| 972 |
# 'Mutable' parameter validation
|
| 973 |
buffer_available = "buffer" in sys.stdout.__dir__()
|
|
|
|
| 974 |
modifiable_params = {
|
| 975 |
"binary_operators": "+ * - /".split(" "),
|
| 976 |
"unary_operators": [],
|
|
@@ -1308,6 +1312,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1308 |
complexity_of_operators = Main.eval(complexity_of_operators_str)
|
| 1309 |
|
| 1310 |
custom_loss = Main.eval(self.loss)
|
|
|
|
| 1311 |
|
| 1312 |
mutationWeights = [
|
| 1313 |
float(self.weight_mutate_constant),
|
|
@@ -1369,7 +1374,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1369 |
crossoverProbability=self.crossover_probability,
|
| 1370 |
skip_mutation_failures=self.skip_mutation_failures,
|
| 1371 |
max_evals=self.max_evals,
|
| 1372 |
-
earlyStopCondition=
|
| 1373 |
seed=seed,
|
| 1374 |
deterministic=self.deterministic,
|
| 1375 |
)
|
|
|
|
| 312 |
annealing : bool, default=True
|
| 313 |
Whether to use annealing. You should (and it is default).
|
| 314 |
|
| 315 |
+
early_stop_condition : { float | str }, default=None
|
| 316 |
+
Stop the search early if this loss is reached. You may also
|
| 317 |
+
pass a string containing a Julia function which
|
| 318 |
+
takes a loss and complexity as input, for example:
|
| 319 |
+
`"f(loss, complexity) = (loss < 0.1) && (complexity < 10)"`.
|
| 320 |
|
| 321 |
ncyclesperiteration : int, default=550
|
| 322 |
Number of total mutations to run, per 10 samples of the
|
|
|
|
| 974 |
|
| 975 |
# 'Mutable' parameter validation
|
| 976 |
buffer_available = "buffer" in sys.stdout.__dir__()
|
| 977 |
+
# Params and their default values, if None is given:
|
| 978 |
modifiable_params = {
|
| 979 |
"binary_operators": "+ * - /".split(" "),
|
| 980 |
"unary_operators": [],
|
|
|
|
| 1312 |
complexity_of_operators = Main.eval(complexity_of_operators_str)
|
| 1313 |
|
| 1314 |
custom_loss = Main.eval(self.loss)
|
| 1315 |
+
early_stop_condition = Main.eval(self.early_stop_condition)
|
| 1316 |
|
| 1317 |
mutationWeights = [
|
| 1318 |
float(self.weight_mutate_constant),
|
|
|
|
| 1374 |
crossoverProbability=self.crossover_probability,
|
| 1375 |
skip_mutation_failures=self.skip_mutation_failures,
|
| 1376 |
max_evals=self.max_evals,
|
| 1377 |
+
earlyStopCondition=early_stop_condition,
|
| 1378 |
seed=seed,
|
| 1379 |
deterministic=self.deterministic,
|
| 1380 |
)
|