Spaces:
Running
Running
Commit
·
2ceb526
1
Parent(s):
66dcb6d
Add JAX export functionality
Browse files- pysr/__init__.py +1 -0
- pysr/export.py +158 -0
- pysr/sr.py +2 -2
pysr/__init__.py
CHANGED
|
@@ -1,2 +1,3 @@
|
|
| 1 |
from .sr import pysr, get_hof, best, best_tex, best_callable, best_row
|
| 2 |
from .feynman_problems import Problem, FeynmanProblem
|
|
|
|
|
|
| 1 |
from .sr import pysr, get_hof, best, best_tex, best_callable, best_row
|
| 2 |
from .feynman_problems import Problem, FeynmanProblem
|
| 3 |
+
from .export import sympy2jax
|
pysr/export.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools as ft
|
| 2 |
+
import sympy
|
| 3 |
+
import string
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
import jax
|
| 8 |
+
from jax import numpy as jnp
|
| 9 |
+
from jax.scipy import special as jsp
|
| 10 |
+
|
| 11 |
+
# Special since need to reduce arguments.
|
| 12 |
+
MUL = 0
|
| 13 |
+
ADD = 1
|
| 14 |
+
|
| 15 |
+
_jnp_func_lookup = {
|
| 16 |
+
sympy.Mul: MUL,
|
| 17 |
+
sympy.Add: ADD,
|
| 18 |
+
sympy.div: "jnp.div",
|
| 19 |
+
sympy.Abs: "jnp.abs",
|
| 20 |
+
sympy.sign: "jnp.sign",
|
| 21 |
+
# Note: May raise error for ints.
|
| 22 |
+
sympy.ceiling: "jnp.ceil",
|
| 23 |
+
sympy.floor: "jnp.floor",
|
| 24 |
+
sympy.log: "jnp.log",
|
| 25 |
+
sympy.exp: "jnp.exp",
|
| 26 |
+
sympy.sqrt: "jnp.sqrt",
|
| 27 |
+
sympy.cos: "jnp.cos",
|
| 28 |
+
sympy.acos: "jnp.acos",
|
| 29 |
+
sympy.sin: "jnp.sin",
|
| 30 |
+
sympy.asin: "jnp.asin",
|
| 31 |
+
sympy.tan: "jnp.tan",
|
| 32 |
+
sympy.atan: "jnp.atan",
|
| 33 |
+
sympy.atan2: "jnp.atan2",
|
| 34 |
+
# Note: Also may give NaN for complex results.
|
| 35 |
+
sympy.cosh: "jnp.cosh",
|
| 36 |
+
sympy.acosh: "jnp.acosh",
|
| 37 |
+
sympy.sinh: "jnp.sinh",
|
| 38 |
+
sympy.asinh: "jnp.asinh",
|
| 39 |
+
sympy.tanh: "jnp.tanh",
|
| 40 |
+
sympy.atanh: "jnp.atanh",
|
| 41 |
+
sympy.Pow: "jnp.power",
|
| 42 |
+
sympy.re: "jnp.real",
|
| 43 |
+
sympy.im: "jnp.imag",
|
| 44 |
+
sympy.arg: "jnp.angle",
|
| 45 |
+
# Note: May raise error for ints and complexes
|
| 46 |
+
sympy.erf: "jsp.erf",
|
| 47 |
+
sympy.erfc: "jsp.erfc",
|
| 48 |
+
sympy.LessThan: "jnp.le",
|
| 49 |
+
sympy.GreaterThan: "jnp.ge",
|
| 50 |
+
sympy.And: "jnp.logical_and",
|
| 51 |
+
sympy.Or: "jnp.logical_or",
|
| 52 |
+
sympy.Not: "jnp.logical_not",
|
| 53 |
+
sympy.Max: "jnp.max",
|
| 54 |
+
sympy.Min: "jnp.min",
|
| 55 |
+
sympy.Mod: "jnp.mod",
|
| 56 |
+
sympy.round: 'jnp.round'
|
| 57 |
+
}
|
| 58 |
+
except ImportError:
|
| 59 |
+
...
|
| 60 |
+
|
| 61 |
+
def sympy2jaxtext(expr, parameters, symbols_in):
|
| 62 |
+
if issubclass(expr.func, sympy.Float):
|
| 63 |
+
parameters.append(float(expr))
|
| 64 |
+
return f"parameters[{len(parameters) - 1}]"
|
| 65 |
+
elif issubclass(expr.func, sympy.Integer):
|
| 66 |
+
return "{int(expr)}"
|
| 67 |
+
elif issubclass(expr.func, sympy.Symbol):
|
| 68 |
+
return f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
|
| 69 |
+
else:
|
| 70 |
+
_func = _jnp_func_lookup[expr.func]
|
| 71 |
+
args = [sympy2jaxtext(arg, parameters, symbols_in) for arg in expr.args]
|
| 72 |
+
if _func == MUL:
|
| 73 |
+
return ' * '.join(['(' + arg + ')' for arg in args])
|
| 74 |
+
elif _func == ADD:
|
| 75 |
+
return ' + '.join(['(' + arg + ')' for arg in args])
|
| 76 |
+
else:
|
| 77 |
+
return f'{_func}({", ".join(args)})'
|
| 78 |
+
|
| 79 |
+
def sympy2jax(equation, symbols_in):
|
| 80 |
+
"""Returns a function f and its parameters;
|
| 81 |
+
the function takes an input matrix, and a list of arguments:
|
| 82 |
+
f(X, parameters)
|
| 83 |
+
where the parameters appear in the JAX equation.
|
| 84 |
+
|
| 85 |
+
# Examples:
|
| 86 |
+
|
| 87 |
+
Let's create a function in SymPy:
|
| 88 |
+
```python
|
| 89 |
+
x, y = symbols('x y')
|
| 90 |
+
cosx = 1.0 * sympy.cos(x) + 3.2 * y
|
| 91 |
+
```
|
| 92 |
+
Let's get the JAX version. We pass the equation, and
|
| 93 |
+
the symbols required.
|
| 94 |
+
```python
|
| 95 |
+
f, params = sympy2jax(cosx, [x, y])
|
| 96 |
+
```
|
| 97 |
+
The order you supply the symbols is the same order
|
| 98 |
+
you should supply the features when calling
|
| 99 |
+
the function `f` (shape `[nrows, nfeatures]`).
|
| 100 |
+
In this case, features=2 for x and y.
|
| 101 |
+
The `params` in this case will be
|
| 102 |
+
`jnp.array([1.0, 3.2])`. You pass these parameters
|
| 103 |
+
when calling the function, which will let you change them
|
| 104 |
+
and take gradients.
|
| 105 |
+
|
| 106 |
+
Let's generate some JAX data to pass:
|
| 107 |
+
```python
|
| 108 |
+
key = random.PRNGKey(0)
|
| 109 |
+
X = random.normal(key, (10, 2))
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
We can call the function with:
|
| 113 |
+
```python
|
| 114 |
+
f(X, params)
|
| 115 |
+
|
| 116 |
+
#> DeviceArray([-2.6080756 , 0.72633684, -6.7557726 , -0.2963162 ,
|
| 117 |
+
# 6.6014843 , 5.032483 , -0.810931 , 4.2520013 ,
|
| 118 |
+
# 3.5427954 , -2.7479894 ], dtype=float32)
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
We can take gradients with respect
|
| 122 |
+
to the parameters for each row with JAX
|
| 123 |
+
gradient parameters now:
|
| 124 |
+
```python
|
| 125 |
+
jac_f = jax.jacobian(f, argnums=1)
|
| 126 |
+
jac_f(X, params)
|
| 127 |
+
|
| 128 |
+
#> DeviceArray([[ 0.49364874, -0.9692889 ],
|
| 129 |
+
# [ 0.8283714 , -0.0318858 ],
|
| 130 |
+
# [-0.7447336 , -1.8784496 ],
|
| 131 |
+
# [ 0.70755106, -0.3137085 ],
|
| 132 |
+
# [ 0.944834 , 1.767703 ],
|
| 133 |
+
# [ 0.51673377, 1.4111717 ],
|
| 134 |
+
# [ 0.87347716, -0.52637756],
|
| 135 |
+
# [ 0.8760679 , 1.0549792 ],
|
| 136 |
+
# [ 0.9961824 , 0.79581654],
|
| 137 |
+
# [-0.88465923, -0.5822907 ]], dtype=float32)
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
We can also JIT-compile our function:
|
| 141 |
+
```python
|
| 142 |
+
compiled_f = jax.jit(f)
|
| 143 |
+
compiled_f(X, params)
|
| 144 |
+
|
| 145 |
+
#> DeviceArray([-2.6080756 , 0.72633684, -6.7557726 , -0.2963162 ,
|
| 146 |
+
# 6.6014843 , 5.032483 , -0.810931 , 4.2520013 ,
|
| 147 |
+
# 3.5427954 , -2.7479894 ], dtype=float32)
|
| 148 |
+
```
|
| 149 |
+
"""
|
| 150 |
+
parameters = []
|
| 151 |
+
functional_form_text = sympy2jaxtext(equation, parameters, symbols_in)
|
| 152 |
+
hash_string = 'A' + str(hash([equation, symbols_in]))
|
| 153 |
+
text = f"def {hash_string}(X, parameters):\n"
|
| 154 |
+
text += " return "
|
| 155 |
+
text += functional_form_text
|
| 156 |
+
ldict = {}
|
| 157 |
+
exec(text, globals(), ldict)
|
| 158 |
+
return ldict['f'], jnp.array(parameters)
|
pysr/sr.py
CHANGED
|
@@ -47,8 +47,8 @@ sympy_mappings = {
|
|
| 47 |
'erf': lambda x : sympy.erf(x),
|
| 48 |
'erfc': lambda x : sympy.erfc(x),
|
| 49 |
'logm': lambda x : sympy.log(abs(x)),
|
| 50 |
-
'logm10':lambda x : sympy.log(abs(x),
|
| 51 |
-
'logm2': lambda x : sympy.log(abs(x),
|
| 52 |
'log1p': lambda x : sympy.log(x + 1),
|
| 53 |
'floor': lambda x : sympy.floor(x),
|
| 54 |
'ceil': lambda x : sympy.ceil(x),
|
|
|
|
| 47 |
'erf': lambda x : sympy.erf(x),
|
| 48 |
'erfc': lambda x : sympy.erfc(x),
|
| 49 |
'logm': lambda x : sympy.log(abs(x)),
|
| 50 |
+
'logm10':lambda x : sympy.log(abs(x), 10),
|
| 51 |
+
'logm2': lambda x : sympy.log(abs(x), 2),
|
| 52 |
'log1p': lambda x : sympy.log(x + 1),
|
| 53 |
'floor': lambda x : sympy.floor(x),
|
| 54 |
'ceil': lambda x : sympy.ceil(x),
|