Spaces:
Running
Running
Commit
·
7a5a9a0
1
Parent(s):
f07f6e6
Clean up mutation_weights setting
Browse files- pysr/sr.py +14 -11
pysr/sr.py
CHANGED
|
@@ -1314,16 +1314,19 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1314 |
custom_loss = Main.eval(self.loss)
|
| 1315 |
early_stop_condition = Main.eval(str(self.early_stop_condition))
|
| 1316 |
|
| 1317 |
-
|
| 1318 |
-
|
| 1319 |
-
|
| 1320 |
-
|
| 1321 |
-
|
| 1322 |
-
|
| 1323 |
-
|
| 1324 |
-
|
| 1325 |
-
|
| 1326 |
-
|
|
|
|
|
|
|
|
|
|
| 1327 |
|
| 1328 |
# Call to Julia backend.
|
| 1329 |
# See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/OptionsStruct.jl
|
|
@@ -1342,7 +1345,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1342 |
npopulations=int(self.populations),
|
| 1343 |
batching=self.batching,
|
| 1344 |
batchSize=int(min([batch_size, len(X)]) if self.batching else len(X)),
|
| 1345 |
-
mutationWeights=
|
| 1346 |
probPickFirst=self.tournament_selection_p,
|
| 1347 |
ns=self.tournament_selection_n,
|
| 1348 |
# These have the same name:
|
|
|
|
| 1314 |
custom_loss = Main.eval(self.loss)
|
| 1315 |
early_stop_condition = Main.eval(str(self.early_stop_condition))
|
| 1316 |
|
| 1317 |
+
mutation_weights = np.array(
|
| 1318 |
+
[
|
| 1319 |
+
self.weight_mutate_constant,
|
| 1320 |
+
self.weight_mutate_operator,
|
| 1321 |
+
self.weight_add_node,
|
| 1322 |
+
self.weight_insert_node,
|
| 1323 |
+
self.weight_delete_node,
|
| 1324 |
+
self.weight_simplify,
|
| 1325 |
+
self.weight_randomize,
|
| 1326 |
+
self.weight_do_nothing,
|
| 1327 |
+
],
|
| 1328 |
+
dtype=float,
|
| 1329 |
+
)
|
| 1330 |
|
| 1331 |
# Call to Julia backend.
|
| 1332 |
# See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/OptionsStruct.jl
|
|
|
|
| 1345 |
npopulations=int(self.populations),
|
| 1346 |
batching=self.batching,
|
| 1347 |
batchSize=int(min([batch_size, len(X)]) if self.batching else len(X)),
|
| 1348 |
+
mutationWeights=mutation_weights,
|
| 1349 |
probPickFirst=self.tournament_selection_p,
|
| 1350 |
ns=self.tournament_selection_n,
|
| 1351 |
# These have the same name:
|