Spaces:
Sleeping
Sleeping
Commit
·
6fa78c9
1
Parent(s):
1d23dc2
Use egal instead of equal for many ops
Browse files- julia/sr.jl +59 -59
- pysr/sr.py +18 -10
julia/sr.jl
CHANGED
|
@@ -96,9 +96,9 @@ end
|
|
| 96 |
|
| 97 |
# Copy an equation (faster than deepcopy)
|
| 98 |
function copyNode(tree::Node)::Node
|
| 99 |
-
if tree.degree
|
| 100 |
return Node(tree.val)
|
| 101 |
-
elseif tree.degree
|
| 102 |
return Node(tree.op, copyNode(tree.l))
|
| 103 |
else
|
| 104 |
return Node(tree.op, copyNode(tree.l), copyNode(tree.r))
|
|
@@ -107,9 +107,9 @@ end
|
|
| 107 |
|
| 108 |
# Count the operators, constants, variables in an equation
|
| 109 |
function countNodes(tree::Node)::Integer
|
| 110 |
-
if tree.degree
|
| 111 |
return 1
|
| 112 |
-
elseif tree.degree
|
| 113 |
return 1 + countNodes(tree.l)
|
| 114 |
else
|
| 115 |
return 1 + countNodes(tree.l) + countNodes(tree.r)
|
|
@@ -118,9 +118,9 @@ end
|
|
| 118 |
|
| 119 |
# Count the max depth of a tree
|
| 120 |
function countDepth(tree::Node)::Integer
|
| 121 |
-
if tree.degree
|
| 122 |
return 1
|
| 123 |
-
elseif tree.degree
|
| 124 |
return 1 + countDepth(tree.l)
|
| 125 |
else
|
| 126 |
return 1 + max(countDepth(tree.l), countDepth(tree.r))
|
|
@@ -129,7 +129,7 @@ end
|
|
| 129 |
|
| 130 |
# Convert an equation to a string
|
| 131 |
function stringTree(tree::Node)::String
|
| 132 |
-
if tree.degree
|
| 133 |
if tree.constant
|
| 134 |
return string(tree.val)
|
| 135 |
else
|
|
@@ -139,7 +139,7 @@ function stringTree(tree::Node)::String
|
|
| 139 |
return "x$(tree.val - 1)"
|
| 140 |
end
|
| 141 |
end
|
| 142 |
-
elseif tree.degree
|
| 143 |
return "$(unaops[tree.op])($(stringTree(tree.l)))"
|
| 144 |
else
|
| 145 |
return "$(binops[tree.op])($(stringTree(tree.l)), $(stringTree(tree.r)))"
|
|
@@ -153,7 +153,7 @@ end
|
|
| 153 |
|
| 154 |
# Return a random node from the tree
|
| 155 |
function randomNode(tree::Node)::Node
|
| 156 |
-
if tree.degree
|
| 157 |
return tree
|
| 158 |
end
|
| 159 |
a = countNodes(tree)
|
|
@@ -162,14 +162,14 @@ function randomNode(tree::Node)::Node
|
|
| 162 |
if tree.degree >= 1
|
| 163 |
b = countNodes(tree.l)
|
| 164 |
end
|
| 165 |
-
if tree.degree
|
| 166 |
c = countNodes(tree.r)
|
| 167 |
end
|
| 168 |
|
| 169 |
i = rand(1:1+b+c)
|
| 170 |
if i <= b
|
| 171 |
return randomNode(tree.l)
|
| 172 |
-
elseif i
|
| 173 |
return tree
|
| 174 |
end
|
| 175 |
|
|
@@ -178,9 +178,9 @@ end
|
|
| 178 |
|
| 179 |
# Count the number of unary operators in the equation
|
| 180 |
function countUnaryOperators(tree::Node)::Integer
|
| 181 |
-
if tree.degree
|
| 182 |
return 0
|
| 183 |
-
elseif tree.degree
|
| 184 |
return 1 + countUnaryOperators(tree.l)
|
| 185 |
else
|
| 186 |
return 0 + countUnaryOperators(tree.l) + countUnaryOperators(tree.r)
|
|
@@ -189,9 +189,9 @@ end
|
|
| 189 |
|
| 190 |
# Count the number of binary operators in the equation
|
| 191 |
function countBinaryOperators(tree::Node)::Integer
|
| 192 |
-
if tree.degree
|
| 193 |
return 0
|
| 194 |
-
elseif tree.degree
|
| 195 |
return 0 + countBinaryOperators(tree.l)
|
| 196 |
else
|
| 197 |
return 1 + countBinaryOperators(tree.l) + countBinaryOperators(tree.r)
|
|
@@ -206,14 +206,14 @@ end
|
|
| 206 |
# Randomly convert an operator into another one (binary->binary;
|
| 207 |
# unary->unary)
|
| 208 |
function mutateOperator(tree::Node)::Node
|
| 209 |
-
if countOperators(tree)
|
| 210 |
return tree
|
| 211 |
end
|
| 212 |
node = randomNode(tree)
|
| 213 |
-
while node.degree
|
| 214 |
node = randomNode(tree)
|
| 215 |
end
|
| 216 |
-
if node.degree
|
| 217 |
node.op = rand(1:length(unaops))
|
| 218 |
else
|
| 219 |
node.op = rand(1:length(binops))
|
|
@@ -223,9 +223,9 @@ end
|
|
| 223 |
|
| 224 |
# Count the number of constants in an equation
|
| 225 |
function countConstants(tree::Node)::Integer
|
| 226 |
-
if tree.degree
|
| 227 |
return convert(Integer, tree.constant)
|
| 228 |
-
elseif tree.degree
|
| 229 |
return 0 + countConstants(tree.l)
|
| 230 |
else
|
| 231 |
return 0 + countConstants(tree.l) + countConstants(tree.r)
|
|
@@ -238,11 +238,11 @@ function mutateConstant(
|
|
| 238 |
probNegate::Float32=0.01f0)::Node
|
| 239 |
# T is between 0 and 1.
|
| 240 |
|
| 241 |
-
if countConstants(tree)
|
| 242 |
return tree
|
| 243 |
end
|
| 244 |
node = randomNode(tree)
|
| 245 |
-
while node.degree
|
| 246 |
node = randomNode(tree)
|
| 247 |
end
|
| 248 |
|
|
@@ -273,19 +273,19 @@ end
|
|
| 273 |
# Evaluate an equation over an array of datapoints
|
| 274 |
function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32, 1}, Nothing}
|
| 275 |
clen = size(cX)[1]
|
| 276 |
-
if tree.degree
|
| 277 |
if tree.constant
|
| 278 |
return fill(tree.val, clen)
|
| 279 |
else
|
| 280 |
return copy(cX[:, tree.val])
|
| 281 |
end
|
| 282 |
-
elseif tree.degree
|
| 283 |
cumulator = evalTreeArray(tree.l, cX)
|
| 284 |
if cumulator === nothing
|
| 285 |
return nothing
|
| 286 |
end
|
| 287 |
op_idx = tree.op
|
| 288 |
-
UNAOP!(op_idx,
|
| 289 |
@inbounds for i=1:clen
|
| 290 |
if isinf(cumulator[i]) || isnan(cumulator[i])
|
| 291 |
return nothing
|
|
@@ -302,7 +302,7 @@ function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32,
|
|
| 302 |
return nothing
|
| 303 |
end
|
| 304 |
op_idx = tree.op
|
| 305 |
-
BINOP!(
|
| 306 |
@inbounds for i=1:clen
|
| 307 |
if isinf(cumulator[i]) || isnan(cumulator[i])
|
| 308 |
return nothing
|
|
@@ -350,7 +350,7 @@ end
|
|
| 350 |
# Add a random unary/binary operation to the end of a tree
|
| 351 |
function appendRandomOp(tree::Node)::Node
|
| 352 |
node = randomNode(tree)
|
| 353 |
-
while node.degree
|
| 354 |
node = randomNode(tree)
|
| 355 |
end
|
| 356 |
|
|
@@ -458,7 +458,7 @@ end
|
|
| 458 |
|
| 459 |
# Return a random node from the tree with parent
|
| 460 |
function randomNodeAndParent(tree::Node, parent::Union{Node, Nothing})::Tuple{Node, Union{Node, Nothing}}
|
| 461 |
-
if tree.degree
|
| 462 |
return tree, parent
|
| 463 |
end
|
| 464 |
a = countNodes(tree)
|
|
@@ -467,14 +467,14 @@ function randomNodeAndParent(tree::Node, parent::Union{Node, Nothing})::Tuple{No
|
|
| 467 |
if tree.degree >= 1
|
| 468 |
b = countNodes(tree.l)
|
| 469 |
end
|
| 470 |
-
if tree.degree
|
| 471 |
c = countNodes(tree.r)
|
| 472 |
end
|
| 473 |
|
| 474 |
i = rand(1:1+b+c)
|
| 475 |
if i <= b
|
| 476 |
return randomNodeAndParent(tree.l, tree)
|
| 477 |
-
elseif i
|
| 478 |
return tree, parent
|
| 479 |
end
|
| 480 |
|
|
@@ -487,7 +487,7 @@ function deleteRandomOp(tree::Node)::Node
|
|
| 487 |
node, parent = randomNodeAndParent(tree, nothing)
|
| 488 |
isroot = (parent === nothing)
|
| 489 |
|
| 490 |
-
if node.degree
|
| 491 |
# Replace with new constant
|
| 492 |
newnode = randomConstantNode()
|
| 493 |
node.l = newnode.l
|
|
@@ -496,7 +496,7 @@ function deleteRandomOp(tree::Node)::Node
|
|
| 496 |
node.degree = newnode.degree
|
| 497 |
node.val = newnode.val
|
| 498 |
node.constant = newnode.constant
|
| 499 |
-
elseif node.degree
|
| 500 |
# Join one of the children with the parent
|
| 501 |
if isroot
|
| 502 |
return node.l
|
|
@@ -536,17 +536,17 @@ function combineOperators(tree::Node)::Node
|
|
| 536 |
# ((const - var) - const) => (const - var)
|
| 537 |
# (want to add anything commutative!)
|
| 538 |
# TODO - need to combine plus/sub if they are both there.
|
| 539 |
-
if tree.degree
|
| 540 |
return tree
|
| 541 |
-
elseif tree.degree
|
| 542 |
tree.l = combineOperators(tree.l)
|
| 543 |
-
elseif tree.degree
|
| 544 |
tree.l = combineOperators(tree.l)
|
| 545 |
tree.r = combineOperators(tree.r)
|
| 546 |
end
|
| 547 |
|
| 548 |
-
top_level_constant = tree.degree
|
| 549 |
-
if tree.degree
|
| 550 |
op = tree.op
|
| 551 |
# Put the constant in r
|
| 552 |
if tree.l.constant
|
|
@@ -557,7 +557,7 @@ function combineOperators(tree::Node)::Node
|
|
| 557 |
topconstant = tree.r.val
|
| 558 |
# Simplify down first
|
| 559 |
below = tree.l
|
| 560 |
-
if below.degree
|
| 561 |
if below.l.constant
|
| 562 |
tree = below
|
| 563 |
tree.l.val = binops[op](tree.l.val, topconstant)
|
|
@@ -568,11 +568,11 @@ function combineOperators(tree::Node)::Node
|
|
| 568 |
end
|
| 569 |
end
|
| 570 |
|
| 571 |
-
if tree.degree
|
| 572 |
# Currently just simplifies subtraction. (can't assume both plus and sub are operators)
|
| 573 |
# Not commutative, so use different op.
|
| 574 |
if tree.l.constant
|
| 575 |
-
if tree.r.degree
|
| 576 |
if tree.r.l.constant
|
| 577 |
#(const - (const - var)) => (var - const)
|
| 578 |
l = tree.l
|
|
@@ -591,7 +591,7 @@ function combineOperators(tree::Node)::Node
|
|
| 591 |
end
|
| 592 |
end
|
| 593 |
else #tree.r.constant is true
|
| 594 |
-
if tree.l.degree
|
| 595 |
if tree.l.l.constant
|
| 596 |
#((const - var) - const) => (const - var)
|
| 597 |
l = tree.l
|
|
@@ -616,17 +616,17 @@ end
|
|
| 616 |
|
| 617 |
# Simplify tree
|
| 618 |
function simplifyTree(tree::Node)::Node
|
| 619 |
-
if tree.degree
|
| 620 |
tree.l = simplifyTree(tree.l)
|
| 621 |
-
if tree.l.degree
|
| 622 |
return Node(unaops[tree.op](tree.l.val))
|
| 623 |
end
|
| 624 |
-
elseif tree.degree
|
| 625 |
tree.l = simplifyTree(tree.l)
|
| 626 |
tree.r = simplifyTree(tree.r)
|
| 627 |
constantsBelow = (
|
| 628 |
-
tree.l.degree
|
| 629 |
-
tree.r.degree
|
| 630 |
)
|
| 631 |
if constantsBelow
|
| 632 |
return Node(binops[tree.op](tree.l.val, tree.r.val))
|
|
@@ -648,9 +648,9 @@ end
|
|
| 648 |
|
| 649 |
# Check if any power operator is to the power of a complex expression
|
| 650 |
function deepPow(tree::Node)::Integer
|
| 651 |
-
if tree.degree
|
| 652 |
return 0
|
| 653 |
-
elseif tree.degree
|
| 654 |
return 0 + deepPow(tree.l)
|
| 655 |
else
|
| 656 |
if binops[tree.op] === pow
|
|
@@ -857,7 +857,7 @@ function run(
|
|
| 857 |
pop = regEvolCycle(pop, 1.0f0, curmaxsize)
|
| 858 |
end
|
| 859 |
|
| 860 |
-
if verbosity > 0 && (iT % verbosity
|
| 861 |
bestPops = bestSubPop(pop)
|
| 862 |
bestCurScoreIdx = argmin([bestPops.members[member].score for member=1:bestPops.n])
|
| 863 |
bestCurScore = bestPops.members[bestCurScoreIdx].score
|
|
@@ -870,13 +870,13 @@ end
|
|
| 870 |
|
| 871 |
# Get all the constants from a tree
|
| 872 |
function getConstants(tree::Node)::Array{Float32, 1}
|
| 873 |
-
if tree.degree
|
| 874 |
if tree.constant
|
| 875 |
return [tree.val]
|
| 876 |
else
|
| 877 |
return Float32[]
|
| 878 |
end
|
| 879 |
-
elseif tree.degree
|
| 880 |
return getConstants(tree.l)
|
| 881 |
else
|
| 882 |
both = [getConstants(tree.l), getConstants(tree.r)]
|
|
@@ -886,11 +886,11 @@ end
|
|
| 886 |
|
| 887 |
# Set all the constants inside a tree
|
| 888 |
function setConstants(tree::Node, constants::Array{Float32, 1})
|
| 889 |
-
if tree.degree
|
| 890 |
if tree.constant
|
| 891 |
tree.val = constants[1]
|
| 892 |
end
|
| 893 |
-
elseif tree.degree
|
| 894 |
setConstants(tree.l, constants)
|
| 895 |
else
|
| 896 |
numberLeft = countConstants(tree.l)
|
|
@@ -909,12 +909,12 @@ end
|
|
| 909 |
# Use Nelder-Mead to optimize the constants in an equation
|
| 910 |
function optimizeConstants(member::PopMember)::PopMember
|
| 911 |
nconst = countConstants(member.tree)
|
| 912 |
-
if nconst
|
| 913 |
return member
|
| 914 |
end
|
| 915 |
x0 = getConstants(member.tree)
|
| 916 |
f(x::Array{Float32,1})::Float32 = optFunc(x, member.tree)
|
| 917 |
-
if size(x0)[1]
|
| 918 |
algorithm = Optim.Newton
|
| 919 |
else
|
| 920 |
algorithm = Optim.NelderMead
|
|
@@ -998,7 +998,7 @@ function fullRun(niterations::Integer;
|
|
| 998 |
bestSubPops = [Population(1) for j=1:npopulations]
|
| 999 |
hallOfFame = HallOfFame()
|
| 1000 |
curmaxsize = 3
|
| 1001 |
-
if warmupMaxsize
|
| 1002 |
curmaxsize = maxsize
|
| 1003 |
end
|
| 1004 |
|
|
@@ -1067,7 +1067,7 @@ function fullRun(niterations::Integer;
|
|
| 1067 |
numberSmallerAndBetter += 1
|
| 1068 |
end
|
| 1069 |
end
|
| 1070 |
-
betterThanAllSmaller = (numberSmallerAndBetter
|
| 1071 |
if betterThanAllSmaller
|
| 1072 |
println(io, "$size|$(curMSE)|$(stringTree(member.tree))")
|
| 1073 |
push!(dominating, member)
|
|
@@ -1117,7 +1117,7 @@ function fullRun(niterations::Integer;
|
|
| 1117 |
|
| 1118 |
cycles_complete -= 1
|
| 1119 |
cycles_elapsed = npopulations * niterations - cycles_complete
|
| 1120 |
-
if warmupMaxsize
|
| 1121 |
curmaxsize += 1
|
| 1122 |
if curmaxsize > maxsize
|
| 1123 |
curmaxsize = maxsize
|
|
@@ -1167,7 +1167,7 @@ function fullRun(niterations::Integer;
|
|
| 1167 |
numberSmallerAndBetter += 1
|
| 1168 |
end
|
| 1169 |
end
|
| 1170 |
-
betterThanAllSmaller = (numberSmallerAndBetter
|
| 1171 |
if betterThanAllSmaller
|
| 1172 |
delta_c = size - lastComplexity
|
| 1173 |
delta_l_mse = log(curMSE/lastMSE)
|
|
|
|
| 96 |
|
| 97 |
# Copy an equation (faster than deepcopy)
|
| 98 |
function copyNode(tree::Node)::Node
|
| 99 |
+
if tree.degree === 0
|
| 100 |
return Node(tree.val)
|
| 101 |
+
elseif tree.degree === 1
|
| 102 |
return Node(tree.op, copyNode(tree.l))
|
| 103 |
else
|
| 104 |
return Node(tree.op, copyNode(tree.l), copyNode(tree.r))
|
|
|
|
| 107 |
|
| 108 |
# Count the operators, constants, variables in an equation
|
| 109 |
function countNodes(tree::Node)::Integer
|
| 110 |
+
if tree.degree === 0
|
| 111 |
return 1
|
| 112 |
+
elseif tree.degree === 1
|
| 113 |
return 1 + countNodes(tree.l)
|
| 114 |
else
|
| 115 |
return 1 + countNodes(tree.l) + countNodes(tree.r)
|
|
|
|
| 118 |
|
| 119 |
# Count the max depth of a tree
|
| 120 |
function countDepth(tree::Node)::Integer
|
| 121 |
+
if tree.degree === 0
|
| 122 |
return 1
|
| 123 |
+
elseif tree.degree === 1
|
| 124 |
return 1 + countDepth(tree.l)
|
| 125 |
else
|
| 126 |
return 1 + max(countDepth(tree.l), countDepth(tree.r))
|
|
|
|
| 129 |
|
| 130 |
# Convert an equation to a string
|
| 131 |
function stringTree(tree::Node)::String
|
| 132 |
+
if tree.degree === 0
|
| 133 |
if tree.constant
|
| 134 |
return string(tree.val)
|
| 135 |
else
|
|
|
|
| 139 |
return "x$(tree.val - 1)"
|
| 140 |
end
|
| 141 |
end
|
| 142 |
+
elseif tree.degree === 1
|
| 143 |
return "$(unaops[tree.op])($(stringTree(tree.l)))"
|
| 144 |
else
|
| 145 |
return "$(binops[tree.op])($(stringTree(tree.l)), $(stringTree(tree.r)))"
|
|
|
|
| 153 |
|
| 154 |
# Return a random node from the tree
|
| 155 |
function randomNode(tree::Node)::Node
|
| 156 |
+
if tree.degree === 0
|
| 157 |
return tree
|
| 158 |
end
|
| 159 |
a = countNodes(tree)
|
|
|
|
| 162 |
if tree.degree >= 1
|
| 163 |
b = countNodes(tree.l)
|
| 164 |
end
|
| 165 |
+
if tree.degree === 2
|
| 166 |
c = countNodes(tree.r)
|
| 167 |
end
|
| 168 |
|
| 169 |
i = rand(1:1+b+c)
|
| 170 |
if i <= b
|
| 171 |
return randomNode(tree.l)
|
| 172 |
+
elseif i === b + 1
|
| 173 |
return tree
|
| 174 |
end
|
| 175 |
|
|
|
|
| 178 |
|
| 179 |
# Count the number of unary operators in the equation
|
| 180 |
function countUnaryOperators(tree::Node)::Integer
|
| 181 |
+
if tree.degree === 0
|
| 182 |
return 0
|
| 183 |
+
elseif tree.degree === 1
|
| 184 |
return 1 + countUnaryOperators(tree.l)
|
| 185 |
else
|
| 186 |
return 0 + countUnaryOperators(tree.l) + countUnaryOperators(tree.r)
|
|
|
|
| 189 |
|
| 190 |
# Count the number of binary operators in the equation
|
| 191 |
function countBinaryOperators(tree::Node)::Integer
|
| 192 |
+
if tree.degree === 0
|
| 193 |
return 0
|
| 194 |
+
elseif tree.degree === 1
|
| 195 |
return 0 + countBinaryOperators(tree.l)
|
| 196 |
else
|
| 197 |
return 1 + countBinaryOperators(tree.l) + countBinaryOperators(tree.r)
|
|
|
|
| 206 |
# Randomly convert an operator into another one (binary->binary;
|
| 207 |
# unary->unary)
|
| 208 |
function mutateOperator(tree::Node)::Node
|
| 209 |
+
if countOperators(tree) === 0
|
| 210 |
return tree
|
| 211 |
end
|
| 212 |
node = randomNode(tree)
|
| 213 |
+
while node.degree === 0
|
| 214 |
node = randomNode(tree)
|
| 215 |
end
|
| 216 |
+
if node.degree === 1
|
| 217 |
node.op = rand(1:length(unaops))
|
| 218 |
else
|
| 219 |
node.op = rand(1:length(binops))
|
|
|
|
| 223 |
|
| 224 |
# Count the number of constants in an equation
|
| 225 |
function countConstants(tree::Node)::Integer
|
| 226 |
+
if tree.degree === 0
|
| 227 |
return convert(Integer, tree.constant)
|
| 228 |
+
elseif tree.degree === 1
|
| 229 |
return 0 + countConstants(tree.l)
|
| 230 |
else
|
| 231 |
return 0 + countConstants(tree.l) + countConstants(tree.r)
|
|
|
|
| 238 |
probNegate::Float32=0.01f0)::Node
|
| 239 |
# T is between 0 and 1.
|
| 240 |
|
| 241 |
+
if countConstants(tree) === 0
|
| 242 |
return tree
|
| 243 |
end
|
| 244 |
node = randomNode(tree)
|
| 245 |
+
while node.degree !== 0 || node.constant === false
|
| 246 |
node = randomNode(tree)
|
| 247 |
end
|
| 248 |
|
|
|
|
| 273 |
# Evaluate an equation over an array of datapoints
|
| 274 |
function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32, 1}, Nothing}
|
| 275 |
clen = size(cX)[1]
|
| 276 |
+
if tree.degree === 0
|
| 277 |
if tree.constant
|
| 278 |
return fill(tree.val, clen)
|
| 279 |
else
|
| 280 |
return copy(cX[:, tree.val])
|
| 281 |
end
|
| 282 |
+
elseif tree.degree === 1
|
| 283 |
cumulator = evalTreeArray(tree.l, cX)
|
| 284 |
if cumulator === nothing
|
| 285 |
return nothing
|
| 286 |
end
|
| 287 |
op_idx = tree.op
|
| 288 |
+
UNAOP!(cumulator, op_idx, clen)
|
| 289 |
@inbounds for i=1:clen
|
| 290 |
if isinf(cumulator[i]) || isnan(cumulator[i])
|
| 291 |
return nothing
|
|
|
|
| 302 |
return nothing
|
| 303 |
end
|
| 304 |
op_idx = tree.op
|
| 305 |
+
BINOP!(cumulator, array2, op_idx, clen)
|
| 306 |
@inbounds for i=1:clen
|
| 307 |
if isinf(cumulator[i]) || isnan(cumulator[i])
|
| 308 |
return nothing
|
|
|
|
| 350 |
# Add a random unary/binary operation to the end of a tree
|
| 351 |
function appendRandomOp(tree::Node)::Node
|
| 352 |
node = randomNode(tree)
|
| 353 |
+
while node.degree !== 0
|
| 354 |
node = randomNode(tree)
|
| 355 |
end
|
| 356 |
|
|
|
|
| 458 |
|
| 459 |
# Return a random node from the tree with parent
|
| 460 |
function randomNodeAndParent(tree::Node, parent::Union{Node, Nothing})::Tuple{Node, Union{Node, Nothing}}
|
| 461 |
+
if tree.degree === 0
|
| 462 |
return tree, parent
|
| 463 |
end
|
| 464 |
a = countNodes(tree)
|
|
|
|
| 467 |
if tree.degree >= 1
|
| 468 |
b = countNodes(tree.l)
|
| 469 |
end
|
| 470 |
+
if tree.degree === 2
|
| 471 |
c = countNodes(tree.r)
|
| 472 |
end
|
| 473 |
|
| 474 |
i = rand(1:1+b+c)
|
| 475 |
if i <= b
|
| 476 |
return randomNodeAndParent(tree.l, tree)
|
| 477 |
+
elseif i === b + 1
|
| 478 |
return tree, parent
|
| 479 |
end
|
| 480 |
|
|
|
|
| 487 |
node, parent = randomNodeAndParent(tree, nothing)
|
| 488 |
isroot = (parent === nothing)
|
| 489 |
|
| 490 |
+
if node.degree === 0
|
| 491 |
# Replace with new constant
|
| 492 |
newnode = randomConstantNode()
|
| 493 |
node.l = newnode.l
|
|
|
|
| 496 |
node.degree = newnode.degree
|
| 497 |
node.val = newnode.val
|
| 498 |
node.constant = newnode.constant
|
| 499 |
+
elseif node.degree === 1
|
| 500 |
# Join one of the children with the parent
|
| 501 |
if isroot
|
| 502 |
return node.l
|
|
|
|
| 536 |
# ((const - var) - const) => (const - var)
|
| 537 |
# (want to add anything commutative!)
|
| 538 |
# TODO - need to combine plus/sub if they are both there.
|
| 539 |
+
if tree.degree === 0
|
| 540 |
return tree
|
| 541 |
+
elseif tree.degree === 1
|
| 542 |
tree.l = combineOperators(tree.l)
|
| 543 |
+
elseif tree.degree === 2
|
| 544 |
tree.l = combineOperators(tree.l)
|
| 545 |
tree.r = combineOperators(tree.r)
|
| 546 |
end
|
| 547 |
|
| 548 |
+
top_level_constant = tree.degree === 2 && (tree.l.constant || tree.r.constant)
|
| 549 |
+
if tree.degree === 2 && (binops[tree.op] === mult || binops[tree.op] === plus) && top_level_constant
|
| 550 |
op = tree.op
|
| 551 |
# Put the constant in r
|
| 552 |
if tree.l.constant
|
|
|
|
| 557 |
topconstant = tree.r.val
|
| 558 |
# Simplify down first
|
| 559 |
below = tree.l
|
| 560 |
+
if below.degree === 2 && below.op === op
|
| 561 |
if below.l.constant
|
| 562 |
tree = below
|
| 563 |
tree.l.val = binops[op](tree.l.val, topconstant)
|
|
|
|
| 568 |
end
|
| 569 |
end
|
| 570 |
|
| 571 |
+
if tree.degree === 2 && binops[tree.op] === sub && top_level_constant
|
| 572 |
# Currently just simplifies subtraction. (can't assume both plus and sub are operators)
|
| 573 |
# Not commutative, so use different op.
|
| 574 |
if tree.l.constant
|
| 575 |
+
if tree.r.degree === 2 && binops[tree.r.op] === sub
|
| 576 |
if tree.r.l.constant
|
| 577 |
#(const - (const - var)) => (var - const)
|
| 578 |
l = tree.l
|
|
|
|
| 591 |
end
|
| 592 |
end
|
| 593 |
else #tree.r.constant is true
|
| 594 |
+
if tree.l.degree === 2 && binops[tree.l.op] === sub
|
| 595 |
if tree.l.l.constant
|
| 596 |
#((const - var) - const) => (const - var)
|
| 597 |
l = tree.l
|
|
|
|
| 616 |
|
| 617 |
# Simplify tree
|
| 618 |
function simplifyTree(tree::Node)::Node
|
| 619 |
+
if tree.degree === 1
|
| 620 |
tree.l = simplifyTree(tree.l)
|
| 621 |
+
if tree.l.degree === 0 && tree.l.constant
|
| 622 |
return Node(unaops[tree.op](tree.l.val))
|
| 623 |
end
|
| 624 |
+
elseif tree.degree === 2
|
| 625 |
tree.l = simplifyTree(tree.l)
|
| 626 |
tree.r = simplifyTree(tree.r)
|
| 627 |
constantsBelow = (
|
| 628 |
+
tree.l.degree === 0 && tree.l.constant &&
|
| 629 |
+
tree.r.degree === 0 && tree.r.constant
|
| 630 |
)
|
| 631 |
if constantsBelow
|
| 632 |
return Node(binops[tree.op](tree.l.val, tree.r.val))
|
|
|
|
| 648 |
|
| 649 |
# Check if any power operator is to the power of a complex expression
|
| 650 |
function deepPow(tree::Node)::Integer
|
| 651 |
+
if tree.degree === 0
|
| 652 |
return 0
|
| 653 |
+
elseif tree.degree === 1
|
| 654 |
return 0 + deepPow(tree.l)
|
| 655 |
else
|
| 656 |
if binops[tree.op] === pow
|
|
|
|
| 857 |
pop = regEvolCycle(pop, 1.0f0, curmaxsize)
|
| 858 |
end
|
| 859 |
|
| 860 |
+
if verbosity > 0 && (iT % verbosity === 0)
|
| 861 |
bestPops = bestSubPop(pop)
|
| 862 |
bestCurScoreIdx = argmin([bestPops.members[member].score for member=1:bestPops.n])
|
| 863 |
bestCurScore = bestPops.members[bestCurScoreIdx].score
|
|
|
|
| 870 |
|
| 871 |
# Get all the constants from a tree
|
| 872 |
function getConstants(tree::Node)::Array{Float32, 1}
|
| 873 |
+
if tree.degree === 0
|
| 874 |
if tree.constant
|
| 875 |
return [tree.val]
|
| 876 |
else
|
| 877 |
return Float32[]
|
| 878 |
end
|
| 879 |
+
elseif tree.degree === 1
|
| 880 |
return getConstants(tree.l)
|
| 881 |
else
|
| 882 |
both = [getConstants(tree.l), getConstants(tree.r)]
|
|
|
|
| 886 |
|
| 887 |
# Set all the constants inside a tree
|
| 888 |
function setConstants(tree::Node, constants::Array{Float32, 1})
|
| 889 |
+
if tree.degree === 0
|
| 890 |
if tree.constant
|
| 891 |
tree.val = constants[1]
|
| 892 |
end
|
| 893 |
+
elseif tree.degree === 1
|
| 894 |
setConstants(tree.l, constants)
|
| 895 |
else
|
| 896 |
numberLeft = countConstants(tree.l)
|
|
|
|
| 909 |
# Use Nelder-Mead to optimize the constants in an equation
|
| 910 |
function optimizeConstants(member::PopMember)::PopMember
|
| 911 |
nconst = countConstants(member.tree)
|
| 912 |
+
if nconst === 0
|
| 913 |
return member
|
| 914 |
end
|
| 915 |
x0 = getConstants(member.tree)
|
| 916 |
f(x::Array{Float32,1})::Float32 = optFunc(x, member.tree)
|
| 917 |
+
if size(x0)[1] === 1
|
| 918 |
algorithm = Optim.Newton
|
| 919 |
else
|
| 920 |
algorithm = Optim.NelderMead
|
|
|
|
| 998 |
bestSubPops = [Population(1) for j=1:npopulations]
|
| 999 |
hallOfFame = HallOfFame()
|
| 1000 |
curmaxsize = 3
|
| 1001 |
+
if warmupMaxsize === 0
|
| 1002 |
curmaxsize = maxsize
|
| 1003 |
end
|
| 1004 |
|
|
|
|
| 1067 |
numberSmallerAndBetter += 1
|
| 1068 |
end
|
| 1069 |
end
|
| 1070 |
+
betterThanAllSmaller = (numberSmallerAndBetter === 0)
|
| 1071 |
if betterThanAllSmaller
|
| 1072 |
println(io, "$size|$(curMSE)|$(stringTree(member.tree))")
|
| 1073 |
push!(dominating, member)
|
|
|
|
| 1117 |
|
| 1118 |
cycles_complete -= 1
|
| 1119 |
cycles_elapsed = npopulations * niterations - cycles_complete
|
| 1120 |
+
if warmupMaxsize !== 0 && cycles_elapsed % warmupMaxsize === 0
|
| 1121 |
curmaxsize += 1
|
| 1122 |
if curmaxsize > maxsize
|
| 1123 |
curmaxsize = maxsize
|
|
|
|
| 1167 |
numberSmallerAndBetter += 1
|
| 1168 |
end
|
| 1169 |
end
|
| 1170 |
+
betterThanAllSmaller = (numberSmallerAndBetter === 0)
|
| 1171 |
if betterThanAllSmaller
|
| 1172 |
delta_c = size - lastComplexity
|
| 1173 |
delta_l_mse = log(curMSE/lastMSE)
|
pysr/sr.py
CHANGED
|
@@ -287,26 +287,34 @@ const limitPowComplexity = {"true" if limitPowComplexity else "false"}
|
|
| 287 |
op_runner = ""
|
| 288 |
if len(binary_operators) > 0:
|
| 289 |
op_runner += """
|
| 290 |
-
function BINOP!(
|
| 291 |
-
if i
|
| 292 |
-
|
|
|
|
|
|
|
| 293 |
for i in range(1, len(binary_operators)):
|
| 294 |
op_runner += f"""
|
| 295 |
-
elseif i
|
| 296 |
-
|
|
|
|
|
|
|
| 297 |
op_runner += """
|
| 298 |
end
|
| 299 |
end"""
|
| 300 |
|
| 301 |
if len(unary_operators) > 0:
|
| 302 |
op_runner += """
|
| 303 |
-
function UNAOP!(
|
| 304 |
-
if i
|
| 305 |
-
|
|
|
|
|
|
|
| 306 |
for i in range(1, len(unary_operators)):
|
| 307 |
op_runner += """
|
| 308 |
-
elseif i
|
| 309 |
-
|
|
|
|
|
|
|
| 310 |
op_runner += """
|
| 311 |
end
|
| 312 |
end"""
|
|
|
|
| 287 |
op_runner = ""
|
| 288 |
if len(binary_operators) > 0:
|
| 289 |
op_runner += """
|
| 290 |
+
@inline function BINOP!(x::Array{Float32, 1}, y::Array{Float32, 1}, i::Int, clen::Int)
|
| 291 |
+
if i === 1
|
| 292 |
+
@inbounds @simd for j=1:clen
|
| 293 |
+
x[j] = """f"{binary_operators[0]}""""(x[j], y[j])
|
| 294 |
+
end"""
|
| 295 |
for i in range(1, len(binary_operators)):
|
| 296 |
op_runner += f"""
|
| 297 |
+
elseif i === {i+1}
|
| 298 |
+
@inbounds @simd for j=1:clen
|
| 299 |
+
x[j] = {binary_operators[i]}(x[j], y[j])
|
| 300 |
+
end"""
|
| 301 |
op_runner += """
|
| 302 |
end
|
| 303 |
end"""
|
| 304 |
|
| 305 |
if len(unary_operators) > 0:
|
| 306 |
op_runner += """
|
| 307 |
+
@inline function UNAOP!(x::Array{Float32, 1}, i::Int, clen::Int)
|
| 308 |
+
if i === 1
|
| 309 |
+
@inbounds @simd for j=1:clen
|
| 310 |
+
x[j] = """f"{unary_operators[0]}(x[j])""""
|
| 311 |
+
end"""
|
| 312 |
for i in range(1, len(unary_operators)):
|
| 313 |
op_runner += """
|
| 314 |
+
elseif i === {i+1}
|
| 315 |
+
@inbounds @simd for j=1:clen
|
| 316 |
+
x[j] = {unary_operators[i]}(x[j])
|
| 317 |
+
end"""
|
| 318 |
op_runner += """
|
| 319 |
end
|
| 320 |
end"""
|