Spaces:
Running
Running
fix: symbolic numbers in torch
Browse files- pysr/export_torch.py +5 -0
- pysr/test/test_torch.py +11 -0
pysr/export_torch.py
CHANGED
|
@@ -116,6 +116,11 @@ def _initialize_torch():
|
|
| 116 |
self._value = int(expr)
|
| 117 |
self._torch_func = lambda: self._value
|
| 118 |
self._args = ()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
elif issubclass(expr.func, sympy.Symbol):
|
| 120 |
self._name = expr.name
|
| 121 |
self._torch_func = lambda value: value
|
|
|
|
| 116 |
self._value = int(expr)
|
| 117 |
self._torch_func = lambda: self._value
|
| 118 |
self._args = ()
|
| 119 |
+
elif issubclass(expr.func, sympy.NumberSymbol):
|
| 120 |
+
# Can get here from exp(1) or exact pi
|
| 121 |
+
self._value = float(expr)
|
| 122 |
+
self._torch_func = lambda: self._value
|
| 123 |
+
self._args = ()
|
| 124 |
elif issubclass(expr.func, sympy.Symbol):
|
| 125 |
self._name = expr.name
|
| 126 |
self._torch_func = lambda value: value
|
pysr/test/test_torch.py
CHANGED
|
@@ -173,6 +173,17 @@ class TestTorch(unittest.TestCase):
|
|
| 173 |
decimal=3,
|
| 174 |
)
|
| 175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
def test_feature_selection_custom_operators(self):
|
| 177 |
rstate = np.random.RandomState(0)
|
| 178 |
X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
|
|
|
|
| 173 |
decimal=3,
|
| 174 |
)
|
| 175 |
|
| 176 |
+
def test_issue_656(self):
|
| 177 |
+
# Should correctly map numeric symbols to floats
|
| 178 |
+
E_plus_x1 = sympy.exp(1) + sympy.symbols("x1")
|
| 179 |
+
m = pysr.export_torch.sympy2torch(E_plus_x1, ["x1"])
|
| 180 |
+
X = np.random.randn(10, 1)
|
| 181 |
+
np.testing.assert_almost_equal(
|
| 182 |
+
m(self.torch.tensor(X)).detach().numpy(),
|
| 183 |
+
np.exp(1) + X[:, 0],
|
| 184 |
+
decimal=3,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
def test_feature_selection_custom_operators(self):
|
| 188 |
rstate = np.random.RandomState(0)
|
| 189 |
X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
|