Spaces:
Running
Running
Merge pull request #281 from MilesCranmer/complex-numbers
Browse files- docs/examples.md +35 -1
- pysr/__init__.py +1 -0
- pysr/sklearn_monkeypatch.py +13 -0
- pysr/sr.py +28 -1
- pysr/test/test.py +15 -2
- pysr/version.py +2 -2
docs/examples.md
CHANGED
|
@@ -284,7 +284,41 @@ You can get the sympy version of the best equation with:
|
|
| 284 |
model.sympy()
|
| 285 |
```
|
| 286 |
|
| 287 |
-
## 8.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
For the many other features available in PySR, please
|
| 290 |
read the [Options section](options.md).
|
|
|
|
| 284 |
model.sympy()
|
| 285 |
```
|
| 286 |
|
| 287 |
+
## 8. Complex numbers
|
| 288 |
+
|
| 289 |
+
PySR can also search for complex-valued expressions. Simply pass
|
| 290 |
+
data with a complex datatype (e.g., `np.complex128`),
|
| 291 |
+
and PySR will automatically search for complex-valued expressions:
|
| 292 |
+
|
| 293 |
+
```python
|
| 294 |
+
import numpy as np
|
| 295 |
+
|
| 296 |
+
X = np.random.randn(100, 1) + 1j * np.random.randn(100, 1)
|
| 297 |
+
y = (1 + 2j) * np.cos(X[:, 0] * (0.5 - 0.2j))
|
| 298 |
+
|
| 299 |
+
model = PySRRegressor(
|
| 300 |
+
binary_operators=["+", "-", "*"], unary_operators=["cos"], niterations=100,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
model.fit(X, y)
|
| 304 |
+
```
|
| 305 |
+
|
| 306 |
+
You can see that all of the learned constants are now complex numbers.
|
| 307 |
+
We can get the sympy version of the best equation with:
|
| 308 |
+
|
| 309 |
+
```python
|
| 310 |
+
model.sympy()
|
| 311 |
+
```
|
| 312 |
+
|
| 313 |
+
We can also make predictions normally, by passing complex data:
|
| 314 |
+
|
| 315 |
+
```python
|
| 316 |
+
model.predict(X, -1)
|
| 317 |
+
```
|
| 318 |
+
|
| 319 |
+
to make predictions with the most accurate expression.
|
| 320 |
+
|
| 321 |
+
## 9. Additional features
|
| 322 |
|
| 323 |
For the many other features available in PySR, please
|
| 324 |
read the [Options section](options.md).
|
pysr/__init__.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from .version import __version__
|
| 2 |
from .sr import (
|
| 3 |
pysr,
|
|
|
|
| 1 |
+
from . import sklearn_monkeypatch
|
| 2 |
from .version import __version__
|
| 3 |
from .sr import (
|
| 4 |
pysr,
|
pysr/sklearn_monkeypatch.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Here, we monkey patch scikit-learn until this
|
| 2 |
+
# issue is fixed: https://github.com/scikit-learn/scikit-learn/issues/25922
|
| 3 |
+
from sklearn.utils import validation
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _ensure_no_complex_data(*args, **kwargs):
|
| 7 |
+
...
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
validation._ensure_no_complex_data = _ensure_no_complex_data
|
| 12 |
+
except AttributeError:
|
| 13 |
+
...
|
pysr/sr.py
CHANGED
|
@@ -498,6 +498,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 498 |
What precision to use for the data. By default this is `32`
|
| 499 |
(float32), but you can select `64` or `16` as well, giving
|
| 500 |
you 64 or 16 bits of floating point precision, respectively.
|
|
|
|
|
|
|
| 501 |
Default is `32`.
|
| 502 |
random_state : int, Numpy RandomState instance or None
|
| 503 |
Pass an int for reproducible results across multiple function calls.
|
|
@@ -1619,7 +1621,13 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1619 |
)
|
| 1620 |
|
| 1621 |
# Convert data to desired precision
|
| 1622 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1623 |
|
| 1624 |
# This converts the data into a Julia array:
|
| 1625 |
Main.X = np.array(X, dtype=np_dtype).T
|
|
@@ -2007,6 +2015,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 2007 |
|
| 2008 |
def _read_equation_file(self):
|
| 2009 |
"""Read the hall of fame file created by `SymbolicRegression.jl`."""
|
|
|
|
| 2010 |
try:
|
| 2011 |
if self.nout_ > 1:
|
| 2012 |
all_outputs = []
|
|
@@ -2024,6 +2033,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 2024 |
},
|
| 2025 |
inplace=True,
|
| 2026 |
)
|
|
|
|
| 2027 |
|
| 2028 |
all_outputs.append(df)
|
| 2029 |
else:
|
|
@@ -2039,6 +2049,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 2039 |
},
|
| 2040 |
inplace=True,
|
| 2041 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2042 |
except FileNotFoundError:
|
| 2043 |
raise RuntimeError(
|
| 2044 |
"Couldn't find equation file! The equation search likely exited "
|
|
@@ -2329,3 +2343,16 @@ def _csv_filename_to_pkl_filename(csv_filename) -> str:
|
|
| 2329 |
pkl_basename = base + ".pkl"
|
| 2330 |
|
| 2331 |
return os.path.join(dirname, pkl_basename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
What precision to use for the data. By default this is `32`
|
| 499 |
(float32), but you can select `64` or `16` as well, giving
|
| 500 |
you 64 or 16 bits of floating point precision, respectively.
|
| 501 |
+
If you pass complex data, the corresponding complex precision
|
| 502 |
+
will be used (i.e., `64` for complex128, `32` for complex64).
|
| 503 |
Default is `32`.
|
| 504 |
random_state : int, Numpy RandomState instance or None
|
| 505 |
Pass an int for reproducible results across multiple function calls.
|
|
|
|
| 1621 |
)
|
| 1622 |
|
| 1623 |
# Convert data to desired precision
|
| 1624 |
+
test_X = np.array(X)
|
| 1625 |
+
is_complex = np.issubdtype(test_X.dtype, np.complexfloating)
|
| 1626 |
+
is_real = not is_complex
|
| 1627 |
+
if is_real:
|
| 1628 |
+
np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self.precision]
|
| 1629 |
+
else:
|
| 1630 |
+
np_dtype = {32: np.complex64, 64: np.complex128}[self.precision]
|
| 1631 |
|
| 1632 |
# This converts the data into a Julia array:
|
| 1633 |
Main.X = np.array(X, dtype=np_dtype).T
|
|
|
|
| 2015 |
|
| 2016 |
def _read_equation_file(self):
|
| 2017 |
"""Read the hall of fame file created by `SymbolicRegression.jl`."""
|
| 2018 |
+
|
| 2019 |
try:
|
| 2020 |
if self.nout_ > 1:
|
| 2021 |
all_outputs = []
|
|
|
|
| 2033 |
},
|
| 2034 |
inplace=True,
|
| 2035 |
)
|
| 2036 |
+
df["equation"] = df["equation"].apply(_preprocess_julia_floats)
|
| 2037 |
|
| 2038 |
all_outputs.append(df)
|
| 2039 |
else:
|
|
|
|
| 2049 |
},
|
| 2050 |
inplace=True,
|
| 2051 |
)
|
| 2052 |
+
all_outputs[-1]["equation"] = all_outputs[-1]["equation"].apply(
|
| 2053 |
+
_preprocess_julia_floats
|
| 2054 |
+
)
|
| 2055 |
+
|
| 2056 |
except FileNotFoundError:
|
| 2057 |
raise RuntimeError(
|
| 2058 |
"Couldn't find equation file! The equation search likely exited "
|
|
|
|
| 2343 |
pkl_basename = base + ".pkl"
|
| 2344 |
|
| 2345 |
return os.path.join(dirname, pkl_basename)
|
| 2346 |
+
|
| 2347 |
+
|
| 2348 |
+
_regexp_im = re.compile(r"\b(\d+\.\d+)im\b")
|
| 2349 |
+
_regexp_im_sci = re.compile(r"\b(\d+\.\d+)[eEfF]([+-]?\d+)im\b")
|
| 2350 |
+
_regexp_sci = re.compile(r"\b(\d+\.\d+)[eEfF]([+-]?\d+)\b")
|
| 2351 |
+
|
| 2352 |
+
_apply_regexp_im = lambda x: _regexp_im.sub(r"\1j", x)
|
| 2353 |
+
_apply_regexp_im_sci = lambda x: _regexp_im_sci.sub(r"\1e\2j", x)
|
| 2354 |
+
_apply_regexp_sci = lambda x: _regexp_sci.sub(r"\1e\2", x)
|
| 2355 |
+
|
| 2356 |
+
|
| 2357 |
+
def _preprocess_julia_floats(s: str) -> str:
|
| 2358 |
+
return _apply_regexp_sci(_apply_regexp_im_sci(_apply_regexp_im(s)))
|
pysr/test/test.py
CHANGED
|
@@ -181,6 +181,20 @@ class TestPipeline(unittest.TestCase):
|
|
| 181 |
print("Model equations: ", model.sympy()[1])
|
| 182 |
print("True equation: x1^2")
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
def test_empty_operators_single_input_warm_start(self):
|
| 185 |
X = self.rstate.randn(100, 1)
|
| 186 |
y = X[:, 0] + 3.0
|
|
@@ -230,7 +244,6 @@ class TestPipeline(unittest.TestCase):
|
|
| 230 |
regressor.fit(self.X, y)
|
| 231 |
|
| 232 |
def test_noisy(self):
|
| 233 |
-
|
| 234 |
y = self.X[:, [0, 1]] ** 2 + self.rstate.randn(self.X.shape[0], 1) * 0.05
|
| 235 |
model = PySRRegressor(
|
| 236 |
# Test that passing a single operator works:
|
|
@@ -664,7 +677,7 @@ class TestMiscellaneous(unittest.TestCase):
|
|
| 664 |
|
| 665 |
check_generator = check_estimator(model, generate_only=True)
|
| 666 |
exception_messages = []
|
| 667 |
-
for
|
| 668 |
try:
|
| 669 |
with warnings.catch_warnings():
|
| 670 |
warnings.simplefilter("ignore")
|
|
|
|
| 181 |
print("Model equations: ", model.sympy()[1])
|
| 182 |
print("True equation: x1^2")
|
| 183 |
|
| 184 |
+
def test_complex_equations_anonymous_stop(self):
|
| 185 |
+
X = self.rstate.randn(100, 3) + 1j * self.rstate.randn(100, 3)
|
| 186 |
+
y = (2 + 1j) * np.cos(X[:, 0] * (0.5 - 0.3j))
|
| 187 |
+
model = PySRRegressor(
|
| 188 |
+
binary_operators=["+", "-", "*"],
|
| 189 |
+
unary_operators=["cos"],
|
| 190 |
+
**self.default_test_kwargs,
|
| 191 |
+
early_stop_condition="(loss, complexity) -> loss <= 1e-4 && complexity <= 6",
|
| 192 |
+
)
|
| 193 |
+
model.fit(X, y)
|
| 194 |
+
test_y = model.predict(X)
|
| 195 |
+
self.assertTrue(np.issubdtype(test_y.dtype, np.complexfloating))
|
| 196 |
+
self.assertLessEqual(np.average(np.abs(test_y - y) ** 2), 1e-4)
|
| 197 |
+
|
| 198 |
def test_empty_operators_single_input_warm_start(self):
|
| 199 |
X = self.rstate.randn(100, 1)
|
| 200 |
y = X[:, 0] + 3.0
|
|
|
|
| 244 |
regressor.fit(self.X, y)
|
| 245 |
|
| 246 |
def test_noisy(self):
|
|
|
|
| 247 |
y = self.X[:, [0, 1]] ** 2 + self.rstate.randn(self.X.shape[0], 1) * 0.05
|
| 248 |
model = PySRRegressor(
|
| 249 |
# Test that passing a single operator works:
|
|
|
|
| 677 |
|
| 678 |
check_generator = check_estimator(model, generate_only=True)
|
| 679 |
exception_messages = []
|
| 680 |
+
for _, check in check_generator:
|
| 681 |
try:
|
| 682 |
with warnings.catch_warnings():
|
| 683 |
warnings.simplefilter("ignore")
|
pysr/version.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
| 1 |
-
__version__ = "0.
|
| 2 |
-
__symbolic_regression_jl_version__ = "0.
|
|
|
|
| 1 |
+
__version__ = "0.12.0"
|
| 2 |
+
__symbolic_regression_jl_version__ = "0.16.1"
|