Spaces:
Running
Running
Commit
·
d4d95e5
1
Parent(s):
7d19ebb
Add test for custom torch operator
Browse files- test/test_torch.py +35 -1
test/test_torch.py
CHANGED
|
@@ -36,7 +36,7 @@ class TestTorch(unittest.TestCase):
|
|
| 36 |
|
| 37 |
equations = get_hof(
|
| 38 |
"equation_file.csv",
|
| 39 |
-
n_features=2,
|
| 40 |
variables_names="x1 x2 x3".split(" "),
|
| 41 |
extra_sympy_mappings={},
|
| 42 |
output_torch_format=True,
|
|
@@ -68,3 +68,37 @@ class TestTorch(unittest.TestCase):
|
|
| 68 |
np.testing.assert_array_almost_equal(
|
| 69 |
true_out.detach(), torch_out.detach(), decimal=4
|
| 70 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
equations = get_hof(
|
| 38 |
"equation_file.csv",
|
| 39 |
+
n_features=2, # TODO: Why is this 2 and not 3?
|
| 40 |
variables_names="x1 x2 x3".split(" "),
|
| 41 |
extra_sympy_mappings={},
|
| 42 |
output_torch_format=True,
|
|
|
|
| 68 |
np.testing.assert_array_almost_equal(
|
| 69 |
true_out.detach(), torch_out.detach(), decimal=4
|
| 70 |
)
|
| 71 |
+
|
| 72 |
+
def test_custom_operator(self):
|
| 73 |
+
X = np.random.randn(100, 3)
|
| 74 |
+
|
| 75 |
+
equations = pd.DataFrame(
|
| 76 |
+
{
|
| 77 |
+
"Equation": ["1.0", "mycustomoperator(x0)"],
|
| 78 |
+
"MSE": [1.0, 0.1],
|
| 79 |
+
"Complexity": [1, 2],
|
| 80 |
+
}
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
equations["Complexity MSE Equation".split(" ")].to_csv(
|
| 84 |
+
"equation_file_custom_operator.csv.bkup", sep="|"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
equations = get_hof(
|
| 88 |
+
"equation_file_custom_operator.csv",
|
| 89 |
+
n_features=3,
|
| 90 |
+
variables_names="x1 x2 x3".split(" "),
|
| 91 |
+
extra_sympy_mappings={"mycustomoperator": sympy.sin},
|
| 92 |
+
extra_torch_mappings={"mycustomoperator": torch.sin},
|
| 93 |
+
output_torch_format=True,
|
| 94 |
+
multioutput=False,
|
| 95 |
+
nout=1,
|
| 96 |
+
selection=[0, 1, 2],
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
tformat = equations.iloc[-1].torch_format
|
| 100 |
+
np.testing.assert_almost_equal(
|
| 101 |
+
tformat(torch.tensor(X)).detach().numpy(),
|
| 102 |
+
np.sin(X[:, 0]), # Selection 1st feature
|
| 103 |
+
decimal=4,
|
| 104 |
+
)
|