Spaces:
Sleeping
Sleeping
Merge pull request #47 from MilesCranmer/coverage
Browse files- .github/workflows/CI.yml +16 -2
- test/test.py +35 -31
- test/test_jax.py +10 -8
.github/workflows/CI.yml
CHANGED
|
@@ -59,14 +59,28 @@ jobs:
|
|
| 59 |
python -m pip install --upgrade pip
|
| 60 |
pip install -r requirements.txt
|
| 61 |
python setup.py install
|
|
|
|
|
|
|
| 62 |
- name: "Install JAX"
|
| 63 |
if: matrix.os != 'windows-latest'
|
| 64 |
run: pip install jax jaxlib # (optional import)
|
| 65 |
shell: bash
|
| 66 |
- name: "Run tests"
|
| 67 |
-
run:
|
|
|
|
|
|
|
|
|
|
| 68 |
shell: bash
|
| 69 |
- name: "Run JAX tests"
|
| 70 |
if: matrix.os != 'windows-latest'
|
| 71 |
-
run:
|
|
|
|
|
|
|
|
|
|
| 72 |
shell: bash
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
python -m pip install --upgrade pip
|
| 60 |
pip install -r requirements.txt
|
| 61 |
python setup.py install
|
| 62 |
+
- name: "Install Coverage tool"
|
| 63 |
+
run: pip install coverage coveralls
|
| 64 |
- name: "Install JAX"
|
| 65 |
if: matrix.os != 'windows-latest'
|
| 66 |
run: pip install jax jaxlib # (optional import)
|
| 67 |
shell: bash
|
| 68 |
- name: "Run tests"
|
| 69 |
+
run: |
|
| 70 |
+
cd test
|
| 71 |
+
coverage run --source=pysr --omit=pysr.feynman_problems -m unittest test
|
| 72 |
+
cd ..
|
| 73 |
shell: bash
|
| 74 |
- name: "Run JAX tests"
|
| 75 |
if: matrix.os != 'windows-latest'
|
| 76 |
+
run: |
|
| 77 |
+
cd test
|
| 78 |
+
coverage run --append --source=pysr --omit=pysr.feynman_problems -m unittest test_jax
|
| 79 |
+
cd ..
|
| 80 |
shell: bash
|
| 81 |
+
- name: Coveralls
|
| 82 |
+
env:
|
| 83 |
+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
| 84 |
+
run: |
|
| 85 |
+
cd test
|
| 86 |
+
coveralls --service=github
|
test/test.py
CHANGED
|
@@ -1,38 +1,42 @@
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
from pysr import pysr
|
| 3 |
import sympy
|
| 4 |
-
X = np.random.randn(100, 5)
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
|
| 15 |
-
y = X[:, 0]
|
| 16 |
-
equations = pysr(X, y,
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
print(equations)
|
| 27 |
-
assert equations[0].iloc[-1]['MSE'] < 1e-4
|
| 28 |
-
assert equations[1].iloc[-1]['MSE'] < 1e-4
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
print("Test 3 - empty operator list, and single dimension input")
|
| 33 |
-
equations = pysr(X, y,
|
| 34 |
-
unary_operators=[], binary_operators=["plus"],
|
| 35 |
-
**default_test_kwargs)
|
| 36 |
-
|
| 37 |
-
print(equations)
|
| 38 |
-
assert equations.iloc[-1]['MSE'] < 1e-4
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
import numpy as np
|
| 3 |
from pysr import pysr
|
| 4 |
import sympy
|
|
|
|
| 5 |
|
| 6 |
+
class TestPipeline(unittest.TestCase):
|
| 7 |
+
def setUp(self):
|
| 8 |
+
self.default_test_kwargs = dict(
|
| 9 |
+
niterations=10,
|
| 10 |
+
populations=4,
|
| 11 |
+
user_input=False,
|
| 12 |
+
annealing=True,
|
| 13 |
+
useFrequency=False,
|
| 14 |
+
)
|
| 15 |
+
np.random.seed(0)
|
| 16 |
+
self.X = np.random.randn(100, 5)
|
| 17 |
+
|
| 18 |
+
def test_linear_relation(self):
|
| 19 |
+
y = self.X[:, 0]
|
| 20 |
+
equations = pysr(self.X, y, **self.default_test_kwargs)
|
| 21 |
+
print(equations)
|
| 22 |
+
self.assertLessEqual(equations.iloc[-1]['MSE'], 1e-4)
|
| 23 |
|
| 24 |
+
def test_multioutput_custom_operator(self):
|
| 25 |
+
y = self.X[:, [0, 1]]**2
|
| 26 |
+
equations = pysr(self.X, y,
|
| 27 |
+
unary_operators=["sq(x) = x^2"], binary_operators=["plus"],
|
| 28 |
+
extra_sympy_mappings={'square': lambda x: x**2},
|
| 29 |
+
**self.default_test_kwargs)
|
| 30 |
+
print(equations)
|
| 31 |
+
self.assertLessEqual(equations[0].iloc[-1]['MSE'], 1e-4)
|
| 32 |
+
self.assertLessEqual(equations[1].iloc[-1]['MSE'], 1e-4)
|
| 33 |
|
| 34 |
+
def test_empty_operators_single_input(self):
|
| 35 |
+
X = np.random.randn(100, 1)
|
| 36 |
+
y = X[:, 0] + 3.0
|
| 37 |
+
equations = pysr(X, y,
|
| 38 |
+
unary_operators=[], binary_operators=["plus"],
|
| 39 |
+
**self.default_test_kwargs)
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
+
print(equations)
|
| 42 |
+
self.assertLessEqual(equations.iloc[-1]['MSE'], 1e-4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test/test_jax.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
from pysr import pysr, sympy2jax
|
| 3 |
from jax import numpy as jnp
|
|
@@ -5,11 +6,12 @@ from jax import random
|
|
| 5 |
from jax import grad
|
| 6 |
import sympy
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
import numpy as np
|
| 3 |
from pysr import pysr, sympy2jax
|
| 4 |
from jax import numpy as jnp
|
|
|
|
| 6 |
from jax import grad
|
| 7 |
import sympy
|
| 8 |
|
| 9 |
+
class TestJAX(unittest.TestCase):
|
| 10 |
+
def test_sympy2jax(self):
|
| 11 |
+
x, y, z = sympy.symbols('x y z')
|
| 12 |
+
cosx = 1.0 * sympy.cos(x) + y
|
| 13 |
+
key = random.PRNGKey(0)
|
| 14 |
+
X = random.normal(key, (1000, 2))
|
| 15 |
+
true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
|
| 16 |
+
f, params = sympy2jax(cosx, [x, y, z])
|
| 17 |
+
self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())
|