Spaces:
Sleeping
Sleeping
Commit
·
2ca2654
1
Parent(s):
3a557a9
Add parameter for batching
Browse files- julia/sr.jl +26 -4
- pysr/sr.py +8 -1
julia/sr.jl
CHANGED
|
@@ -616,8 +616,11 @@ function iterate(member::PopMember, T::Float32)::PopMember
|
|
| 616 |
prev = member.tree
|
| 617 |
tree = copyNode(prev)
|
| 618 |
#TODO - reconsider this
|
| 619 |
-
|
| 620 |
-
|
|
|
|
|
|
|
|
|
|
| 621 |
|
| 622 |
mutationChoice = rand()
|
| 623 |
weightAdjustmentMutateConstant = min(8, countConstants(tree))/8.0
|
|
@@ -648,7 +651,11 @@ function iterate(member::PopMember, T::Float32)::PopMember
|
|
| 648 |
return PopMember(tree, beforeLoss)
|
| 649 |
end
|
| 650 |
|
| 651 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 652 |
|
| 653 |
if annealing
|
| 654 |
delta = afterLoss - beforeLoss
|
|
@@ -697,6 +704,16 @@ function bestOfSample(pop::Population)::PopMember
|
|
| 697 |
return sample.members[best_idx]
|
| 698 |
end
|
| 699 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 700 |
# Return best 10 examples
|
| 701 |
function bestSubPop(pop::Population; topn::Integer=10)::Population
|
| 702 |
best_idx = sortperm([pop.members[member].score for member=1:pop.n])
|
|
@@ -1000,7 +1017,7 @@ function fullRun(niterations::Integer;
|
|
| 1000 |
@async begin
|
| 1001 |
allPops[i] = @spawnat :any let
|
| 1002 |
tmp_pop = run(cur_pop, ncyclesperiteration, verbosity=verbosity)
|
| 1003 |
-
for j=1:tmp_pop.n
|
| 1004 |
if rand() < 0.1
|
| 1005 |
tmp_pop.members[j].tree = simplifyTree(tmp_pop.members[j].tree)
|
| 1006 |
tmp_pop.members[j].tree = combineOperators(tmp_pop.members[j].tree)
|
|
@@ -1009,6 +1026,11 @@ function fullRun(niterations::Integer;
|
|
| 1009 |
end
|
| 1010 |
end
|
| 1011 |
end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1012 |
tmp_pop
|
| 1013 |
end
|
| 1014 |
put!(channels[i], fetch(allPops[i]))
|
|
|
|
| 616 |
prev = member.tree
|
| 617 |
tree = copyNode(prev)
|
| 618 |
#TODO - reconsider this
|
| 619 |
+
if batching
|
| 620 |
+
beforeLoss = scoreFuncBatch(member.tree)
|
| 621 |
+
else
|
| 622 |
+
beforeLoss = member.score
|
| 623 |
+
end
|
| 624 |
|
| 625 |
mutationChoice = rand()
|
| 626 |
weightAdjustmentMutateConstant = min(8, countConstants(tree))/8.0
|
|
|
|
| 651 |
return PopMember(tree, beforeLoss)
|
| 652 |
end
|
| 653 |
|
| 654 |
+
if batching
|
| 655 |
+
afterLoss = scoreFuncBatch(tree)
|
| 656 |
+
else
|
| 657 |
+
afterLoss = scoreFunc(tree)
|
| 658 |
+
end
|
| 659 |
|
| 660 |
if annealing
|
| 661 |
delta = afterLoss - beforeLoss
|
|
|
|
| 704 |
return sample.members[best_idx]
|
| 705 |
end
|
| 706 |
|
| 707 |
+
function finalizeScores(pop::Population)::Population
|
| 708 |
+
need_recalculate = batching
|
| 709 |
+
if need_recalculate
|
| 710 |
+
@inbounds @simd for member=1:pop.n
|
| 711 |
+
pop.members[member].score = scoreFunc(pop.members[member].tree)
|
| 712 |
+
end
|
| 713 |
+
end
|
| 714 |
+
return pop
|
| 715 |
+
end
|
| 716 |
+
|
| 717 |
# Return best 10 examples
|
| 718 |
function bestSubPop(pop::Population; topn::Integer=10)::Population
|
| 719 |
best_idx = sortperm([pop.members[member].score for member=1:pop.n])
|
|
|
|
| 1017 |
@async begin
|
| 1018 |
allPops[i] = @spawnat :any let
|
| 1019 |
tmp_pop = run(cur_pop, ncyclesperiteration, verbosity=verbosity)
|
| 1020 |
+
@inbounds @simd for j=1:tmp_pop.n
|
| 1021 |
if rand() < 0.1
|
| 1022 |
tmp_pop.members[j].tree = simplifyTree(tmp_pop.members[j].tree)
|
| 1023 |
tmp_pop.members[j].tree = combineOperators(tmp_pop.members[j].tree)
|
|
|
|
| 1026 |
end
|
| 1027 |
end
|
| 1028 |
end
|
| 1029 |
+
if shouldOptimizeConstants
|
| 1030 |
+
#pass #(We already calculate full scores in the optimizer)
|
| 1031 |
+
else
|
| 1032 |
+
tmp_pop = finalizeScores(tmp_pop)
|
| 1033 |
+
end
|
| 1034 |
tmp_pop
|
| 1035 |
end
|
| 1036 |
put!(channels[i], fetch(allPops[i]))
|
pysr/sr.py
CHANGED
|
@@ -76,6 +76,8 @@ def pysr(X=None, y=None, weights=None,
|
|
| 76 |
fast_cycle=False,
|
| 77 |
maxdepth=None,
|
| 78 |
variable_names=[],
|
|
|
|
|
|
|
| 79 |
threads=None, #deprecated
|
| 80 |
julia_optimization=3,
|
| 81 |
):
|
|
@@ -138,6 +140,10 @@ def pysr(X=None, y=None, weights=None,
|
|
| 138 |
15% faster. May be algorithmically less efficient.
|
| 139 |
:param variable_names: list, a list of names for the variables, other
|
| 140 |
than "x0", "x1", etc.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
:param julia_optimization: int, Optimization level (0, 1, 2, 3)
|
| 142 |
:returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
|
| 143 |
(as strings).
|
|
@@ -227,7 +233,8 @@ const nrestarts = {nrestarts:d}
|
|
| 227 |
const perturbationFactor = {perturbationFactor:f}f0
|
| 228 |
const annealing = {"true" if annealing else "false"}
|
| 229 |
const weighted = {"true" if weights is not None else "false"}
|
| 230 |
-
const
|
|
|
|
| 231 |
const useVarMap = {"false" if len(variable_names) == 0 else "true"}
|
| 232 |
const mutationWeights = [
|
| 233 |
{weightMutateConstant:f},
|
|
|
|
| 76 |
fast_cycle=False,
|
| 77 |
maxdepth=None,
|
| 78 |
variable_names=[],
|
| 79 |
+
batching=False,
|
| 80 |
+
batchSize=50,
|
| 81 |
threads=None, #deprecated
|
| 82 |
julia_optimization=3,
|
| 83 |
):
|
|
|
|
| 140 |
15% faster. May be algorithmically less efficient.
|
| 141 |
:param variable_names: list, a list of names for the variables, other
|
| 142 |
than "x0", "x1", etc.
|
| 143 |
+
:param batching: bool, whether to compare population members on small batches
|
| 144 |
+
during evolution. Still uses full dataset for comparing against
|
| 145 |
+
hall of fame.
|
| 146 |
+
:param batchSize: int, the amount of data to use if doing batching.
|
| 147 |
:param julia_optimization: int, Optimization level (0, 1, 2, 3)
|
| 148 |
:returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
|
| 149 |
(as strings).
|
|
|
|
| 233 |
const perturbationFactor = {perturbationFactor:f}f0
|
| 234 |
const annealing = {"true" if annealing else "false"}
|
| 235 |
const weighted = {"true" if weights is not None else "false"}
|
| 236 |
+
const batching = {"true" if batching else "false"}
|
| 237 |
+
const batchSize = {min([batchSize, len(X)]) if batching else len(X):d}
|
| 238 |
const useVarMap = {"false" if len(variable_names) == 0 else "true"}
|
| 239 |
const mutationWeights = [
|
| 240 |
{weightMutateConstant:f},
|