Spaces:
Sleeping
Sleeping
Commit
·
1f4e612
1
Parent(s):
eccca5d
Reset to last working copy
Browse files- julia/sr.jl +65 -59
- pysr/sr.py +12 -20
julia/sr.jl
CHANGED
|
@@ -96,9 +96,9 @@ end
|
|
| 96 |
|
| 97 |
# Copy an equation (faster than deepcopy)
|
| 98 |
function copyNode(tree::Node)::Node
|
| 99 |
-
if tree.degree
|
| 100 |
return Node(tree.val)
|
| 101 |
-
elseif tree.degree
|
| 102 |
return Node(tree.op, copyNode(tree.l))
|
| 103 |
else
|
| 104 |
return Node(tree.op, copyNode(tree.l), copyNode(tree.r))
|
|
@@ -107,9 +107,9 @@ end
|
|
| 107 |
|
| 108 |
# Count the operators, constants, variables in an equation
|
| 109 |
function countNodes(tree::Node)::Integer
|
| 110 |
-
if tree.degree
|
| 111 |
return 1
|
| 112 |
-
elseif tree.degree
|
| 113 |
return 1 + countNodes(tree.l)
|
| 114 |
else
|
| 115 |
return 1 + countNodes(tree.l) + countNodes(tree.r)
|
|
@@ -118,9 +118,9 @@ end
|
|
| 118 |
|
| 119 |
# Count the max depth of a tree
|
| 120 |
function countDepth(tree::Node)::Integer
|
| 121 |
-
if tree.degree
|
| 122 |
return 1
|
| 123 |
-
elseif tree.degree
|
| 124 |
return 1 + countDepth(tree.l)
|
| 125 |
else
|
| 126 |
return 1 + max(countDepth(tree.l), countDepth(tree.r))
|
|
@@ -129,7 +129,7 @@ end
|
|
| 129 |
|
| 130 |
# Convert an equation to a string
|
| 131 |
function stringTree(tree::Node)::String
|
| 132 |
-
if tree.degree
|
| 133 |
if tree.constant
|
| 134 |
return string(tree.val)
|
| 135 |
else
|
|
@@ -139,7 +139,7 @@ function stringTree(tree::Node)::String
|
|
| 139 |
return "x$(tree.val - 1)"
|
| 140 |
end
|
| 141 |
end
|
| 142 |
-
elseif tree.degree
|
| 143 |
return "$(unaops[tree.op])($(stringTree(tree.l)))"
|
| 144 |
else
|
| 145 |
return "$(binops[tree.op])($(stringTree(tree.l)), $(stringTree(tree.r)))"
|
|
@@ -153,7 +153,7 @@ end
|
|
| 153 |
|
| 154 |
# Return a random node from the tree
|
| 155 |
function randomNode(tree::Node)::Node
|
| 156 |
-
if tree.degree
|
| 157 |
return tree
|
| 158 |
end
|
| 159 |
a = countNodes(tree)
|
|
@@ -162,14 +162,14 @@ function randomNode(tree::Node)::Node
|
|
| 162 |
if tree.degree >= 1
|
| 163 |
b = countNodes(tree.l)
|
| 164 |
end
|
| 165 |
-
if tree.degree
|
| 166 |
c = countNodes(tree.r)
|
| 167 |
end
|
| 168 |
|
| 169 |
i = rand(1:1+b+c)
|
| 170 |
if i <= b
|
| 171 |
return randomNode(tree.l)
|
| 172 |
-
elseif i
|
| 173 |
return tree
|
| 174 |
end
|
| 175 |
|
|
@@ -178,9 +178,9 @@ end
|
|
| 178 |
|
| 179 |
# Count the number of unary operators in the equation
|
| 180 |
function countUnaryOperators(tree::Node)::Integer
|
| 181 |
-
if tree.degree
|
| 182 |
return 0
|
| 183 |
-
elseif tree.degree
|
| 184 |
return 1 + countUnaryOperators(tree.l)
|
| 185 |
else
|
| 186 |
return 0 + countUnaryOperators(tree.l) + countUnaryOperators(tree.r)
|
|
@@ -189,9 +189,9 @@ end
|
|
| 189 |
|
| 190 |
# Count the number of binary operators in the equation
|
| 191 |
function countBinaryOperators(tree::Node)::Integer
|
| 192 |
-
if tree.degree
|
| 193 |
return 0
|
| 194 |
-
elseif tree.degree
|
| 195 |
return 0 + countBinaryOperators(tree.l)
|
| 196 |
else
|
| 197 |
return 1 + countBinaryOperators(tree.l) + countBinaryOperators(tree.r)
|
|
@@ -206,14 +206,14 @@ end
|
|
| 206 |
# Randomly convert an operator into another one (binary->binary;
|
| 207 |
# unary->unary)
|
| 208 |
function mutateOperator(tree::Node)::Node
|
| 209 |
-
if countOperators(tree)
|
| 210 |
return tree
|
| 211 |
end
|
| 212 |
node = randomNode(tree)
|
| 213 |
-
while node.degree
|
| 214 |
node = randomNode(tree)
|
| 215 |
end
|
| 216 |
-
if node.degree
|
| 217 |
node.op = rand(1:length(unaops))
|
| 218 |
else
|
| 219 |
node.op = rand(1:length(binops))
|
|
@@ -223,9 +223,9 @@ end
|
|
| 223 |
|
| 224 |
# Count the number of constants in an equation
|
| 225 |
function countConstants(tree::Node)::Integer
|
| 226 |
-
if tree.degree
|
| 227 |
return convert(Integer, tree.constant)
|
| 228 |
-
elseif tree.degree
|
| 229 |
return 0 + countConstants(tree.l)
|
| 230 |
else
|
| 231 |
return 0 + countConstants(tree.l) + countConstants(tree.r)
|
|
@@ -238,11 +238,11 @@ function mutateConstant(
|
|
| 238 |
probNegate::Float32=0.01f0)::Node
|
| 239 |
# T is between 0 and 1.
|
| 240 |
|
| 241 |
-
if countConstants(tree)
|
| 242 |
return tree
|
| 243 |
end
|
| 244 |
node = randomNode(tree)
|
| 245 |
-
while node.degree
|
| 246 |
node = randomNode(tree)
|
| 247 |
end
|
| 248 |
|
|
@@ -273,19 +273,21 @@ end
|
|
| 273 |
# Evaluate an equation over an array of datapoints
|
| 274 |
function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32, 1}, Nothing}
|
| 275 |
clen = size(cX)[1]
|
| 276 |
-
if tree.degree
|
| 277 |
if tree.constant
|
| 278 |
return fill(tree.val, clen)
|
| 279 |
else
|
| 280 |
return copy(cX[:, tree.val])
|
| 281 |
end
|
| 282 |
-
elseif tree.degree
|
| 283 |
cumulator = evalTreeArray(tree.l, cX)
|
| 284 |
if cumulator === nothing
|
| 285 |
return nothing
|
| 286 |
end
|
| 287 |
op_idx = tree.op
|
| 288 |
-
|
|
|
|
|
|
|
| 289 |
@inbounds for i=1:clen
|
| 290 |
if isinf(cumulator[i]) || isnan(cumulator[i])
|
| 291 |
return nothing
|
|
@@ -301,8 +303,12 @@ function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32,
|
|
| 301 |
if array2 === nothing
|
| 302 |
return nothing
|
| 303 |
end
|
|
|
|
| 304 |
op_idx = tree.op
|
| 305 |
-
|
|
|
|
|
|
|
|
|
|
| 306 |
@inbounds for i=1:clen
|
| 307 |
if isinf(cumulator[i]) || isnan(cumulator[i])
|
| 308 |
return nothing
|
|
@@ -350,7 +356,7 @@ end
|
|
| 350 |
# Add a random unary/binary operation to the end of a tree
|
| 351 |
function appendRandomOp(tree::Node)::Node
|
| 352 |
node = randomNode(tree)
|
| 353 |
-
while node.degree
|
| 354 |
node = randomNode(tree)
|
| 355 |
end
|
| 356 |
|
|
@@ -458,7 +464,7 @@ end
|
|
| 458 |
|
| 459 |
# Return a random node from the tree with parent
|
| 460 |
function randomNodeAndParent(tree::Node, parent::Union{Node, Nothing})::Tuple{Node, Union{Node, Nothing}}
|
| 461 |
-
if tree.degree
|
| 462 |
return tree, parent
|
| 463 |
end
|
| 464 |
a = countNodes(tree)
|
|
@@ -467,14 +473,14 @@ function randomNodeAndParent(tree::Node, parent::Union{Node, Nothing})::Tuple{No
|
|
| 467 |
if tree.degree >= 1
|
| 468 |
b = countNodes(tree.l)
|
| 469 |
end
|
| 470 |
-
if tree.degree
|
| 471 |
c = countNodes(tree.r)
|
| 472 |
end
|
| 473 |
|
| 474 |
i = rand(1:1+b+c)
|
| 475 |
if i <= b
|
| 476 |
return randomNodeAndParent(tree.l, tree)
|
| 477 |
-
elseif i
|
| 478 |
return tree, parent
|
| 479 |
end
|
| 480 |
|
|
@@ -487,7 +493,7 @@ function deleteRandomOp(tree::Node)::Node
|
|
| 487 |
node, parent = randomNodeAndParent(tree, nothing)
|
| 488 |
isroot = (parent === nothing)
|
| 489 |
|
| 490 |
-
if node.degree
|
| 491 |
# Replace with new constant
|
| 492 |
newnode = randomConstantNode()
|
| 493 |
node.l = newnode.l
|
|
@@ -496,7 +502,7 @@ function deleteRandomOp(tree::Node)::Node
|
|
| 496 |
node.degree = newnode.degree
|
| 497 |
node.val = newnode.val
|
| 498 |
node.constant = newnode.constant
|
| 499 |
-
elseif node.degree
|
| 500 |
# Join one of the children with the parent
|
| 501 |
if isroot
|
| 502 |
return node.l
|
|
@@ -536,17 +542,17 @@ function combineOperators(tree::Node)::Node
|
|
| 536 |
# ((const - var) - const) => (const - var)
|
| 537 |
# (want to add anything commutative!)
|
| 538 |
# TODO - need to combine plus/sub if they are both there.
|
| 539 |
-
if tree.degree
|
| 540 |
return tree
|
| 541 |
-
elseif tree.degree
|
| 542 |
tree.l = combineOperators(tree.l)
|
| 543 |
-
elseif tree.degree
|
| 544 |
tree.l = combineOperators(tree.l)
|
| 545 |
tree.r = combineOperators(tree.r)
|
| 546 |
end
|
| 547 |
|
| 548 |
-
top_level_constant = tree.degree
|
| 549 |
-
if tree.degree
|
| 550 |
op = tree.op
|
| 551 |
# Put the constant in r
|
| 552 |
if tree.l.constant
|
|
@@ -557,7 +563,7 @@ function combineOperators(tree::Node)::Node
|
|
| 557 |
topconstant = tree.r.val
|
| 558 |
# Simplify down first
|
| 559 |
below = tree.l
|
| 560 |
-
if below.degree
|
| 561 |
if below.l.constant
|
| 562 |
tree = below
|
| 563 |
tree.l.val = binops[op](tree.l.val, topconstant)
|
|
@@ -568,11 +574,11 @@ function combineOperators(tree::Node)::Node
|
|
| 568 |
end
|
| 569 |
end
|
| 570 |
|
| 571 |
-
if tree.degree
|
| 572 |
# Currently just simplifies subtraction. (can't assume both plus and sub are operators)
|
| 573 |
# Not commutative, so use different op.
|
| 574 |
if tree.l.constant
|
| 575 |
-
if tree.r.degree
|
| 576 |
if tree.r.l.constant
|
| 577 |
#(const - (const - var)) => (var - const)
|
| 578 |
l = tree.l
|
|
@@ -591,7 +597,7 @@ function combineOperators(tree::Node)::Node
|
|
| 591 |
end
|
| 592 |
end
|
| 593 |
else #tree.r.constant is true
|
| 594 |
-
if tree.l.degree
|
| 595 |
if tree.l.l.constant
|
| 596 |
#((const - var) - const) => (const - var)
|
| 597 |
l = tree.l
|
|
@@ -616,17 +622,17 @@ end
|
|
| 616 |
|
| 617 |
# Simplify tree
|
| 618 |
function simplifyTree(tree::Node)::Node
|
| 619 |
-
if tree.degree
|
| 620 |
tree.l = simplifyTree(tree.l)
|
| 621 |
-
if tree.l.degree
|
| 622 |
return Node(unaops[tree.op](tree.l.val))
|
| 623 |
end
|
| 624 |
-
elseif tree.degree
|
| 625 |
tree.l = simplifyTree(tree.l)
|
| 626 |
tree.r = simplifyTree(tree.r)
|
| 627 |
constantsBelow = (
|
| 628 |
-
tree.l.degree
|
| 629 |
-
tree.r.degree
|
| 630 |
)
|
| 631 |
if constantsBelow
|
| 632 |
return Node(binops[tree.op](tree.l.val, tree.r.val))
|
|
@@ -648,9 +654,9 @@ end
|
|
| 648 |
|
| 649 |
# Check if any power operator is to the power of a complex expression
|
| 650 |
function deepPow(tree::Node)::Integer
|
| 651 |
-
if tree.degree
|
| 652 |
return 0
|
| 653 |
-
elseif tree.degree
|
| 654 |
return 0 + deepPow(tree.l)
|
| 655 |
else
|
| 656 |
if binops[tree.op] === pow
|
|
@@ -857,7 +863,7 @@ function run(
|
|
| 857 |
pop = regEvolCycle(pop, 1.0f0, curmaxsize)
|
| 858 |
end
|
| 859 |
|
| 860 |
-
if verbosity > 0 && (iT % verbosity
|
| 861 |
bestPops = bestSubPop(pop)
|
| 862 |
bestCurScoreIdx = argmin([bestPops.members[member].score for member=1:bestPops.n])
|
| 863 |
bestCurScore = bestPops.members[bestCurScoreIdx].score
|
|
@@ -870,13 +876,13 @@ end
|
|
| 870 |
|
| 871 |
# Get all the constants from a tree
|
| 872 |
function getConstants(tree::Node)::Array{Float32, 1}
|
| 873 |
-
if tree.degree
|
| 874 |
if tree.constant
|
| 875 |
return [tree.val]
|
| 876 |
else
|
| 877 |
return Float32[]
|
| 878 |
end
|
| 879 |
-
elseif tree.degree
|
| 880 |
return getConstants(tree.l)
|
| 881 |
else
|
| 882 |
both = [getConstants(tree.l), getConstants(tree.r)]
|
|
@@ -886,11 +892,11 @@ end
|
|
| 886 |
|
| 887 |
# Set all the constants inside a tree
|
| 888 |
function setConstants(tree::Node, constants::Array{Float32, 1})
|
| 889 |
-
if tree.degree
|
| 890 |
if tree.constant
|
| 891 |
tree.val = constants[1]
|
| 892 |
end
|
| 893 |
-
elseif tree.degree
|
| 894 |
setConstants(tree.l, constants)
|
| 895 |
else
|
| 896 |
numberLeft = countConstants(tree.l)
|
|
@@ -909,12 +915,12 @@ end
|
|
| 909 |
# Use Nelder-Mead to optimize the constants in an equation
|
| 910 |
function optimizeConstants(member::PopMember)::PopMember
|
| 911 |
nconst = countConstants(member.tree)
|
| 912 |
-
if nconst
|
| 913 |
return member
|
| 914 |
end
|
| 915 |
x0 = getConstants(member.tree)
|
| 916 |
f(x::Array{Float32,1})::Float32 = optFunc(x, member.tree)
|
| 917 |
-
if size(x0)[1]
|
| 918 |
algorithm = Optim.Newton
|
| 919 |
else
|
| 920 |
algorithm = Optim.NelderMead
|
|
@@ -998,7 +1004,7 @@ function fullRun(niterations::Integer;
|
|
| 998 |
bestSubPops = [Population(1) for j=1:npopulations]
|
| 999 |
hallOfFame = HallOfFame()
|
| 1000 |
curmaxsize = 3
|
| 1001 |
-
if warmupMaxsize
|
| 1002 |
curmaxsize = maxsize
|
| 1003 |
end
|
| 1004 |
|
|
@@ -1067,7 +1073,7 @@ function fullRun(niterations::Integer;
|
|
| 1067 |
numberSmallerAndBetter += 1
|
| 1068 |
end
|
| 1069 |
end
|
| 1070 |
-
betterThanAllSmaller = (numberSmallerAndBetter
|
| 1071 |
if betterThanAllSmaller
|
| 1072 |
println(io, "$size|$(curMSE)|$(stringTree(member.tree))")
|
| 1073 |
push!(dominating, member)
|
|
@@ -1117,7 +1123,7 @@ function fullRun(niterations::Integer;
|
|
| 1117 |
|
| 1118 |
cycles_complete -= 1
|
| 1119 |
cycles_elapsed = npopulations * niterations - cycles_complete
|
| 1120 |
-
if warmupMaxsize
|
| 1121 |
curmaxsize += 1
|
| 1122 |
if curmaxsize > maxsize
|
| 1123 |
curmaxsize = maxsize
|
|
@@ -1167,7 +1173,7 @@ function fullRun(niterations::Integer;
|
|
| 1167 |
numberSmallerAndBetter += 1
|
| 1168 |
end
|
| 1169 |
end
|
| 1170 |
-
betterThanAllSmaller = (numberSmallerAndBetter
|
| 1171 |
if betterThanAllSmaller
|
| 1172 |
delta_c = size - lastComplexity
|
| 1173 |
delta_l_mse = log(curMSE/lastMSE)
|
|
|
|
| 96 |
|
| 97 |
# Copy an equation (faster than deepcopy)
|
| 98 |
function copyNode(tree::Node)::Node
|
| 99 |
+
if tree.degree == 0
|
| 100 |
return Node(tree.val)
|
| 101 |
+
elseif tree.degree == 1
|
| 102 |
return Node(tree.op, copyNode(tree.l))
|
| 103 |
else
|
| 104 |
return Node(tree.op, copyNode(tree.l), copyNode(tree.r))
|
|
|
|
| 107 |
|
| 108 |
# Count the operators, constants, variables in an equation
|
| 109 |
function countNodes(tree::Node)::Integer
|
| 110 |
+
if tree.degree == 0
|
| 111 |
return 1
|
| 112 |
+
elseif tree.degree == 1
|
| 113 |
return 1 + countNodes(tree.l)
|
| 114 |
else
|
| 115 |
return 1 + countNodes(tree.l) + countNodes(tree.r)
|
|
|
|
| 118 |
|
| 119 |
# Count the max depth of a tree
|
| 120 |
function countDepth(tree::Node)::Integer
|
| 121 |
+
if tree.degree == 0
|
| 122 |
return 1
|
| 123 |
+
elseif tree.degree == 1
|
| 124 |
return 1 + countDepth(tree.l)
|
| 125 |
else
|
| 126 |
return 1 + max(countDepth(tree.l), countDepth(tree.r))
|
|
|
|
| 129 |
|
| 130 |
# Convert an equation to a string
|
| 131 |
function stringTree(tree::Node)::String
|
| 132 |
+
if tree.degree == 0
|
| 133 |
if tree.constant
|
| 134 |
return string(tree.val)
|
| 135 |
else
|
|
|
|
| 139 |
return "x$(tree.val - 1)"
|
| 140 |
end
|
| 141 |
end
|
| 142 |
+
elseif tree.degree == 1
|
| 143 |
return "$(unaops[tree.op])($(stringTree(tree.l)))"
|
| 144 |
else
|
| 145 |
return "$(binops[tree.op])($(stringTree(tree.l)), $(stringTree(tree.r)))"
|
|
|
|
| 153 |
|
| 154 |
# Return a random node from the tree
|
| 155 |
function randomNode(tree::Node)::Node
|
| 156 |
+
if tree.degree == 0
|
| 157 |
return tree
|
| 158 |
end
|
| 159 |
a = countNodes(tree)
|
|
|
|
| 162 |
if tree.degree >= 1
|
| 163 |
b = countNodes(tree.l)
|
| 164 |
end
|
| 165 |
+
if tree.degree == 2
|
| 166 |
c = countNodes(tree.r)
|
| 167 |
end
|
| 168 |
|
| 169 |
i = rand(1:1+b+c)
|
| 170 |
if i <= b
|
| 171 |
return randomNode(tree.l)
|
| 172 |
+
elseif i == b + 1
|
| 173 |
return tree
|
| 174 |
end
|
| 175 |
|
|
|
|
| 178 |
|
| 179 |
# Count the number of unary operators in the equation
|
| 180 |
function countUnaryOperators(tree::Node)::Integer
|
| 181 |
+
if tree.degree == 0
|
| 182 |
return 0
|
| 183 |
+
elseif tree.degree == 1
|
| 184 |
return 1 + countUnaryOperators(tree.l)
|
| 185 |
else
|
| 186 |
return 0 + countUnaryOperators(tree.l) + countUnaryOperators(tree.r)
|
|
|
|
| 189 |
|
| 190 |
# Count the number of binary operators in the equation
|
| 191 |
function countBinaryOperators(tree::Node)::Integer
|
| 192 |
+
if tree.degree == 0
|
| 193 |
return 0
|
| 194 |
+
elseif tree.degree == 1
|
| 195 |
return 0 + countBinaryOperators(tree.l)
|
| 196 |
else
|
| 197 |
return 1 + countBinaryOperators(tree.l) + countBinaryOperators(tree.r)
|
|
|
|
| 206 |
# Randomly convert an operator into another one (binary->binary;
|
| 207 |
# unary->unary)
|
| 208 |
function mutateOperator(tree::Node)::Node
|
| 209 |
+
if countOperators(tree) == 0
|
| 210 |
return tree
|
| 211 |
end
|
| 212 |
node = randomNode(tree)
|
| 213 |
+
while node.degree == 0
|
| 214 |
node = randomNode(tree)
|
| 215 |
end
|
| 216 |
+
if node.degree == 1
|
| 217 |
node.op = rand(1:length(unaops))
|
| 218 |
else
|
| 219 |
node.op = rand(1:length(binops))
|
|
|
|
| 223 |
|
| 224 |
# Count the number of constants in an equation
|
| 225 |
function countConstants(tree::Node)::Integer
|
| 226 |
+
if tree.degree == 0
|
| 227 |
return convert(Integer, tree.constant)
|
| 228 |
+
elseif tree.degree == 1
|
| 229 |
return 0 + countConstants(tree.l)
|
| 230 |
else
|
| 231 |
return 0 + countConstants(tree.l) + countConstants(tree.r)
|
|
|
|
| 238 |
probNegate::Float32=0.01f0)::Node
|
| 239 |
# T is between 0 and 1.
|
| 240 |
|
| 241 |
+
if countConstants(tree) == 0
|
| 242 |
return tree
|
| 243 |
end
|
| 244 |
node = randomNode(tree)
|
| 245 |
+
while node.degree != 0 || node.constant == false
|
| 246 |
node = randomNode(tree)
|
| 247 |
end
|
| 248 |
|
|
|
|
| 273 |
# Evaluate an equation over an array of datapoints
|
| 274 |
function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32, 1}, Nothing}
|
| 275 |
clen = size(cX)[1]
|
| 276 |
+
if tree.degree == 0
|
| 277 |
if tree.constant
|
| 278 |
return fill(tree.val, clen)
|
| 279 |
else
|
| 280 |
return copy(cX[:, tree.val])
|
| 281 |
end
|
| 282 |
+
elseif tree.degree == 1
|
| 283 |
cumulator = evalTreeArray(tree.l, cX)
|
| 284 |
if cumulator === nothing
|
| 285 |
return nothing
|
| 286 |
end
|
| 287 |
op_idx = tree.op
|
| 288 |
+
@inbounds @simd for i=1:clen
|
| 289 |
+
cumulator[i] = UNAOP(op_idx, cumulator[i])
|
| 290 |
+
end
|
| 291 |
@inbounds for i=1:clen
|
| 292 |
if isinf(cumulator[i]) || isnan(cumulator[i])
|
| 293 |
return nothing
|
|
|
|
| 303 |
if array2 === nothing
|
| 304 |
return nothing
|
| 305 |
end
|
| 306 |
+
|
| 307 |
op_idx = tree.op
|
| 308 |
+
|
| 309 |
+
@inbounds @simd for i=1:clen
|
| 310 |
+
cumulator[i] = BINOP(op_idx, cumulator[i], array2[i])
|
| 311 |
+
end
|
| 312 |
@inbounds for i=1:clen
|
| 313 |
if isinf(cumulator[i]) || isnan(cumulator[i])
|
| 314 |
return nothing
|
|
|
|
| 356 |
# Add a random unary/binary operation to the end of a tree
|
| 357 |
function appendRandomOp(tree::Node)::Node
|
| 358 |
node = randomNode(tree)
|
| 359 |
+
while node.degree != 0
|
| 360 |
node = randomNode(tree)
|
| 361 |
end
|
| 362 |
|
|
|
|
| 464 |
|
| 465 |
# Return a random node from the tree with parent
|
| 466 |
function randomNodeAndParent(tree::Node, parent::Union{Node, Nothing})::Tuple{Node, Union{Node, Nothing}}
|
| 467 |
+
if tree.degree == 0
|
| 468 |
return tree, parent
|
| 469 |
end
|
| 470 |
a = countNodes(tree)
|
|
|
|
| 473 |
if tree.degree >= 1
|
| 474 |
b = countNodes(tree.l)
|
| 475 |
end
|
| 476 |
+
if tree.degree == 2
|
| 477 |
c = countNodes(tree.r)
|
| 478 |
end
|
| 479 |
|
| 480 |
i = rand(1:1+b+c)
|
| 481 |
if i <= b
|
| 482 |
return randomNodeAndParent(tree.l, tree)
|
| 483 |
+
elseif i == b + 1
|
| 484 |
return tree, parent
|
| 485 |
end
|
| 486 |
|
|
|
|
| 493 |
node, parent = randomNodeAndParent(tree, nothing)
|
| 494 |
isroot = (parent === nothing)
|
| 495 |
|
| 496 |
+
if node.degree == 0
|
| 497 |
# Replace with new constant
|
| 498 |
newnode = randomConstantNode()
|
| 499 |
node.l = newnode.l
|
|
|
|
| 502 |
node.degree = newnode.degree
|
| 503 |
node.val = newnode.val
|
| 504 |
node.constant = newnode.constant
|
| 505 |
+
elseif node.degree == 1
|
| 506 |
# Join one of the children with the parent
|
| 507 |
if isroot
|
| 508 |
return node.l
|
|
|
|
| 542 |
# ((const - var) - const) => (const - var)
|
| 543 |
# (want to add anything commutative!)
|
| 544 |
# TODO - need to combine plus/sub if they are both there.
|
| 545 |
+
if tree.degree == 0
|
| 546 |
return tree
|
| 547 |
+
elseif tree.degree == 1
|
| 548 |
tree.l = combineOperators(tree.l)
|
| 549 |
+
elseif tree.degree == 2
|
| 550 |
tree.l = combineOperators(tree.l)
|
| 551 |
tree.r = combineOperators(tree.r)
|
| 552 |
end
|
| 553 |
|
| 554 |
+
top_level_constant = tree.degree == 2 && (tree.l.constant || tree.r.constant)
|
| 555 |
+
if tree.degree == 2 && (binops[tree.op] === mult || binops[tree.op] === plus) && top_level_constant
|
| 556 |
op = tree.op
|
| 557 |
# Put the constant in r
|
| 558 |
if tree.l.constant
|
|
|
|
| 563 |
topconstant = tree.r.val
|
| 564 |
# Simplify down first
|
| 565 |
below = tree.l
|
| 566 |
+
if below.degree == 2 && below.op == op
|
| 567 |
if below.l.constant
|
| 568 |
tree = below
|
| 569 |
tree.l.val = binops[op](tree.l.val, topconstant)
|
|
|
|
| 574 |
end
|
| 575 |
end
|
| 576 |
|
| 577 |
+
if tree.degree == 2 && binops[tree.op] === sub && top_level_constant
|
| 578 |
# Currently just simplifies subtraction. (can't assume both plus and sub are operators)
|
| 579 |
# Not commutative, so use different op.
|
| 580 |
if tree.l.constant
|
| 581 |
+
if tree.r.degree == 2 && binops[tree.r.op] === sub
|
| 582 |
if tree.r.l.constant
|
| 583 |
#(const - (const - var)) => (var - const)
|
| 584 |
l = tree.l
|
|
|
|
| 597 |
end
|
| 598 |
end
|
| 599 |
else #tree.r.constant is true
|
| 600 |
+
if tree.l.degree == 2 && binops[tree.l.op] === sub
|
| 601 |
if tree.l.l.constant
|
| 602 |
#((const - var) - const) => (const - var)
|
| 603 |
l = tree.l
|
|
|
|
| 622 |
|
| 623 |
# Simplify tree
|
| 624 |
function simplifyTree(tree::Node)::Node
|
| 625 |
+
if tree.degree == 1
|
| 626 |
tree.l = simplifyTree(tree.l)
|
| 627 |
+
if tree.l.degree == 0 && tree.l.constant
|
| 628 |
return Node(unaops[tree.op](tree.l.val))
|
| 629 |
end
|
| 630 |
+
elseif tree.degree == 2
|
| 631 |
tree.l = simplifyTree(tree.l)
|
| 632 |
tree.r = simplifyTree(tree.r)
|
| 633 |
constantsBelow = (
|
| 634 |
+
tree.l.degree == 0 && tree.l.constant &&
|
| 635 |
+
tree.r.degree == 0 && tree.r.constant
|
| 636 |
)
|
| 637 |
if constantsBelow
|
| 638 |
return Node(binops[tree.op](tree.l.val, tree.r.val))
|
|
|
|
| 654 |
|
| 655 |
# Check if any power operator is to the power of a complex expression
|
| 656 |
function deepPow(tree::Node)::Integer
|
| 657 |
+
if tree.degree == 0
|
| 658 |
return 0
|
| 659 |
+
elseif tree.degree == 1
|
| 660 |
return 0 + deepPow(tree.l)
|
| 661 |
else
|
| 662 |
if binops[tree.op] === pow
|
|
|
|
| 863 |
pop = regEvolCycle(pop, 1.0f0, curmaxsize)
|
| 864 |
end
|
| 865 |
|
| 866 |
+
if verbosity > 0 && (iT % verbosity == 0)
|
| 867 |
bestPops = bestSubPop(pop)
|
| 868 |
bestCurScoreIdx = argmin([bestPops.members[member].score for member=1:bestPops.n])
|
| 869 |
bestCurScore = bestPops.members[bestCurScoreIdx].score
|
|
|
|
| 876 |
|
| 877 |
# Get all the constants from a tree
|
| 878 |
function getConstants(tree::Node)::Array{Float32, 1}
|
| 879 |
+
if tree.degree == 0
|
| 880 |
if tree.constant
|
| 881 |
return [tree.val]
|
| 882 |
else
|
| 883 |
return Float32[]
|
| 884 |
end
|
| 885 |
+
elseif tree.degree == 1
|
| 886 |
return getConstants(tree.l)
|
| 887 |
else
|
| 888 |
both = [getConstants(tree.l), getConstants(tree.r)]
|
|
|
|
| 892 |
|
| 893 |
# Set all the constants inside a tree
|
| 894 |
function setConstants(tree::Node, constants::Array{Float32, 1})
|
| 895 |
+
if tree.degree == 0
|
| 896 |
if tree.constant
|
| 897 |
tree.val = constants[1]
|
| 898 |
end
|
| 899 |
+
elseif tree.degree == 1
|
| 900 |
setConstants(tree.l, constants)
|
| 901 |
else
|
| 902 |
numberLeft = countConstants(tree.l)
|
|
|
|
| 915 |
# Use Nelder-Mead to optimize the constants in an equation
|
| 916 |
function optimizeConstants(member::PopMember)::PopMember
|
| 917 |
nconst = countConstants(member.tree)
|
| 918 |
+
if nconst == 0
|
| 919 |
return member
|
| 920 |
end
|
| 921 |
x0 = getConstants(member.tree)
|
| 922 |
f(x::Array{Float32,1})::Float32 = optFunc(x, member.tree)
|
| 923 |
+
if size(x0)[1] == 1
|
| 924 |
algorithm = Optim.Newton
|
| 925 |
else
|
| 926 |
algorithm = Optim.NelderMead
|
|
|
|
| 1004 |
bestSubPops = [Population(1) for j=1:npopulations]
|
| 1005 |
hallOfFame = HallOfFame()
|
| 1006 |
curmaxsize = 3
|
| 1007 |
+
if warmupMaxsize == 0
|
| 1008 |
curmaxsize = maxsize
|
| 1009 |
end
|
| 1010 |
|
|
|
|
| 1073 |
numberSmallerAndBetter += 1
|
| 1074 |
end
|
| 1075 |
end
|
| 1076 |
+
betterThanAllSmaller = (numberSmallerAndBetter == 0)
|
| 1077 |
if betterThanAllSmaller
|
| 1078 |
println(io, "$size|$(curMSE)|$(stringTree(member.tree))")
|
| 1079 |
push!(dominating, member)
|
|
|
|
| 1123 |
|
| 1124 |
cycles_complete -= 1
|
| 1125 |
cycles_elapsed = npopulations * niterations - cycles_complete
|
| 1126 |
+
if warmupMaxsize != 0 && cycles_elapsed % warmupMaxsize == 0
|
| 1127 |
curmaxsize += 1
|
| 1128 |
if curmaxsize > maxsize
|
| 1129 |
curmaxsize = maxsize
|
|
|
|
| 1173 |
numberSmallerAndBetter += 1
|
| 1174 |
end
|
| 1175 |
end
|
| 1176 |
+
betterThanAllSmaller = (numberSmallerAndBetter == 0)
|
| 1177 |
if betterThanAllSmaller
|
| 1178 |
delta_c = size - lastComplexity
|
| 1179 |
delta_l_mse = log(curMSE/lastMSE)
|
pysr/sr.py
CHANGED
|
@@ -286,35 +286,27 @@ const limitPowComplexity = {"true" if limitPowComplexity else "false"}
|
|
| 286 |
|
| 287 |
op_runner = ""
|
| 288 |
if len(binary_operators) > 0:
|
| 289 |
-
op_runner += """
|
| 290 |
-
@inline function BINOP
|
| 291 |
-
if i
|
| 292 |
-
|
| 293 |
-
x[j] = """f"{binary_operators[0]}""""(x[j], y[j])
|
| 294 |
-
end"""
|
| 295 |
for i in range(1, len(binary_operators)):
|
| 296 |
op_runner += f"""
|
| 297 |
-
elseif i
|
| 298 |
-
|
| 299 |
-
x[j] = {binary_operators[i]}(x[j], y[j])
|
| 300 |
-
end"""
|
| 301 |
op_runner += """
|
| 302 |
end
|
| 303 |
end"""
|
| 304 |
|
| 305 |
if len(unary_operators) > 0:
|
| 306 |
-
op_runner += """
|
| 307 |
-
@inline function UNAOP
|
| 308 |
-
if i
|
| 309 |
-
|
| 310 |
-
x[j] = """f"{unary_operators[0]}(x[j])""""
|
| 311 |
-
end"""
|
| 312 |
for i in range(1, len(unary_operators)):
|
| 313 |
op_runner += f"""
|
| 314 |
-
elseif i
|
| 315 |
-
|
| 316 |
-
x[j] = {unary_operators[i]}(x[j])
|
| 317 |
-
end"""
|
| 318 |
op_runner += """
|
| 319 |
end
|
| 320 |
end"""
|
|
|
|
| 286 |
|
| 287 |
op_runner = ""
|
| 288 |
if len(binary_operators) > 0:
|
| 289 |
+
op_runner += f"""
|
| 290 |
+
@inline function BINOP(i::Int, x::Float32, y::Float32)::Float32
|
| 291 |
+
if i == 1
|
| 292 |
+
return @fastmath {binary_operators[0]}(x, y)"""
|
|
|
|
|
|
|
| 293 |
for i in range(1, len(binary_operators)):
|
| 294 |
op_runner += f"""
|
| 295 |
+
elseif i == {i+1}
|
| 296 |
+
return @fastmath {binary_operators[i]}(x, y)"""
|
|
|
|
|
|
|
| 297 |
op_runner += """
|
| 298 |
end
|
| 299 |
end"""
|
| 300 |
|
| 301 |
if len(unary_operators) > 0:
|
| 302 |
+
op_runner += f"""
|
| 303 |
+
@inline function UNAOP(i::Int, x::Float32)::Float32
|
| 304 |
+
if i == 1
|
| 305 |
+
return @fastmath {unary_operators[0]}(x)"""
|
|
|
|
|
|
|
| 306 |
for i in range(1, len(unary_operators)):
|
| 307 |
op_runner += f"""
|
| 308 |
+
elseif i == {i+1}
|
| 309 |
+
return @fastmath {unary_operators[i]}(x)"""
|
|
|
|
|
|
|
| 310 |
op_runner += """
|
| 311 |
end
|
| 312 |
end"""
|