Spaces:
Running
Running
Commit
·
c96b30c
1
Parent(s):
8c55475
Clean up global variables into single dict
Browse files- pysr/sr.py +60 -49
- test/test.py +2 -1
pysr/sr.py
CHANGED
|
@@ -14,12 +14,19 @@ from pathlib import Path
|
|
| 14 |
from datetime import datetime
|
| 15 |
import warnings
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
sympy_mappings = {
|
| 25 |
'div': lambda x, y : x/y,
|
|
@@ -62,16 +69,20 @@ sympy_mappings = {
|
|
| 62 |
|
| 63 |
class CallableEquation(object):
|
| 64 |
"""Simple wrapper for numpy lambda functions built with sympy"""
|
| 65 |
-
def __init__(self, sympy_symbols, eqn):
|
| 66 |
self._sympy = eqn
|
| 67 |
self._sympy_symbols = sympy_symbols
|
|
|
|
| 68 |
self._lambda = lambdify(sympy_symbols, eqn)
|
| 69 |
|
| 70 |
def __repr__(self):
|
| 71 |
return f"PySRFunction(X=>{self._sympy})"
|
| 72 |
|
| 73 |
def __call__(self, X):
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
def pysr(X, y, weights=None,
|
| 77 |
binary_operators=None,
|
|
@@ -284,7 +295,7 @@ def pysr(X, y, weights=None,
|
|
| 284 |
if maxsize > 40:
|
| 285 |
warnings.warn("Note: Using a large maxsize for the equation search will be slow and use significant memory. You should consider turning `useFrequency` to False, and perhaps use `warmupMaxsizeBy`.")
|
| 286 |
|
| 287 |
-
X, variable_names = _handle_feature_selection(
|
| 288 |
X, select_k_features,
|
| 289 |
use_custom_variable_names, variable_names, y
|
| 290 |
)
|
|
@@ -343,6 +354,7 @@ def pysr(X, y, weights=None,
|
|
| 343 |
julia_project=julia_project, loss=loss,
|
| 344 |
output_jax_format=output_jax_format,
|
| 345 |
output_torch_format=output_torch_format,
|
|
|
|
| 346 |
multioutput=multioutput, nout=nout)
|
| 347 |
|
| 348 |
kwargs = {**_set_paths(tempdir), **kwargs}
|
|
@@ -391,21 +403,13 @@ def pysr(X, y, weights=None,
|
|
| 391 |
return equations
|
| 392 |
|
| 393 |
|
|
|
|
|
|
|
| 394 |
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
global global_variable_names
|
| 400 |
-
global global_extra_sympy_mappings
|
| 401 |
-
global global_multioutput
|
| 402 |
-
global global_nout
|
| 403 |
-
global_n_features = X.shape[1]
|
| 404 |
-
global_equation_file = equation_file
|
| 405 |
-
global_variable_names = variable_names
|
| 406 |
-
global_extra_sympy_mappings = extra_sympy_mappings
|
| 407 |
-
global_multioutput = multioutput
|
| 408 |
-
global_nout = nout
|
| 409 |
|
| 410 |
|
| 411 |
def _final_pysr_process(julia_optimization, runfile_filename, timeout, **kwargs):
|
|
@@ -668,7 +672,9 @@ def _handle_feature_selection(X, select_k_features, use_custom_variable_names, v
|
|
| 668 |
|
| 669 |
if use_custom_variable_names:
|
| 670 |
variable_names = [variable_names[selection[i]] for i in range(len(selection))]
|
| 671 |
-
|
|
|
|
|
|
|
| 672 |
|
| 673 |
|
| 674 |
def _set_paths(tempdir):
|
|
@@ -732,33 +738,38 @@ def run_feature_selection(X, y, select_k_features):
|
|
| 732 |
return selector.get_support(indices=True)
|
| 733 |
|
| 734 |
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
| 735 |
-
|
| 736 |
-
|
| 737 |
extra_jax_mappings=None, extra_torch_mappings=None,
|
| 738 |
multioutput=None, nout=None, **kwargs):
|
| 739 |
"""Get the equations from a hall of fame file. If no arguments
|
| 740 |
entered, the ones used previously from a call to PySR will be used."""
|
| 741 |
|
| 742 |
-
global
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
if
|
| 750 |
-
if
|
| 751 |
-
if
|
| 752 |
-
if
|
| 753 |
-
if
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 762 |
|
| 763 |
try:
|
| 764 |
if multioutput:
|
|
@@ -797,18 +808,18 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
| 797 |
sympy_format.append(eqn)
|
| 798 |
|
| 799 |
# Numpy:
|
| 800 |
-
lambda_format.append(CallableEquation(sympy_symbols, eqn))
|
| 801 |
|
| 802 |
# JAX:
|
| 803 |
if output_jax_format:
|
| 804 |
from .export_jax import sympy2jax
|
| 805 |
-
func, params = sympy2jax(eqn, sympy_symbols)
|
| 806 |
jax_format.append({'callable': func, 'parameters': params})
|
| 807 |
|
| 808 |
# Torch:
|
| 809 |
if output_torch_format:
|
| 810 |
from .export_torch import sympy2torch
|
| 811 |
-
module = sympy2torch(eqn, sympy_symbols)
|
| 812 |
torch_format.append(module)
|
| 813 |
|
| 814 |
curMSE = output.loc[i, 'MSE']
|
|
|
|
| 14 |
from datetime import datetime
|
| 15 |
import warnings
|
| 16 |
|
| 17 |
+
global_state = dict(
|
| 18 |
+
equation_file='hall_of_fame.csv',
|
| 19 |
+
n_features=None,
|
| 20 |
+
variable_names=[],
|
| 21 |
+
extra_sympy_mappings={},
|
| 22 |
+
extra_torch_mappings={},
|
| 23 |
+
extra_jax_mappings={},
|
| 24 |
+
output_jax_format=False,
|
| 25 |
+
output_torch_format=False,
|
| 26 |
+
multioutput=False,
|
| 27 |
+
nout=1,
|
| 28 |
+
selection=None
|
| 29 |
+
)
|
| 30 |
|
| 31 |
sympy_mappings = {
|
| 32 |
'div': lambda x, y : x/y,
|
|
|
|
| 69 |
|
| 70 |
class CallableEquation(object):
|
| 71 |
"""Simple wrapper for numpy lambda functions built with sympy"""
|
| 72 |
+
def __init__(self, sympy_symbols, eqn, selection=None):
|
| 73 |
self._sympy = eqn
|
| 74 |
self._sympy_symbols = sympy_symbols
|
| 75 |
+
self._selection = selection
|
| 76 |
self._lambda = lambdify(sympy_symbols, eqn)
|
| 77 |
|
| 78 |
def __repr__(self):
|
| 79 |
return f"PySRFunction(X=>{self._sympy})"
|
| 80 |
|
| 81 |
def __call__(self, X):
|
| 82 |
+
if self._selection is not None:
|
| 83 |
+
return self._lambda(*X[:, self._selection].T)
|
| 84 |
+
else:
|
| 85 |
+
return self._lambda(*X.T)
|
| 86 |
|
| 87 |
def pysr(X, y, weights=None,
|
| 88 |
binary_operators=None,
|
|
|
|
| 295 |
if maxsize > 40:
|
| 296 |
warnings.warn("Note: Using a large maxsize for the equation search will be slow and use significant memory. You should consider turning `useFrequency` to False, and perhaps use `warmupMaxsizeBy`.")
|
| 297 |
|
| 298 |
+
X, variable_names, selection = _handle_feature_selection(
|
| 299 |
X, select_k_features,
|
| 300 |
use_custom_variable_names, variable_names, y
|
| 301 |
)
|
|
|
|
| 354 |
julia_project=julia_project, loss=loss,
|
| 355 |
output_jax_format=output_jax_format,
|
| 356 |
output_torch_format=output_torch_format,
|
| 357 |
+
selection=selection,
|
| 358 |
multioutput=multioutput, nout=nout)
|
| 359 |
|
| 360 |
kwargs = {**_set_paths(tempdir), **kwargs}
|
|
|
|
| 403 |
return equations
|
| 404 |
|
| 405 |
|
| 406 |
+
def _set_globals(X, **kwargs):
|
| 407 |
+
global global_state
|
| 408 |
|
| 409 |
+
global_state['n_features'] = X.shape[1]
|
| 410 |
+
for key, value in kwargs.items():
|
| 411 |
+
if key in global_state:
|
| 412 |
+
global_state[key] = value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
|
| 414 |
|
| 415 |
def _final_pysr_process(julia_optimization, runfile_filename, timeout, **kwargs):
|
|
|
|
| 672 |
|
| 673 |
if use_custom_variable_names:
|
| 674 |
variable_names = [variable_names[selection[i]] for i in range(len(selection))]
|
| 675 |
+
else:
|
| 676 |
+
selection = None
|
| 677 |
+
return X, variable_names, selection
|
| 678 |
|
| 679 |
|
| 680 |
def _set_paths(tempdir):
|
|
|
|
| 738 |
return selector.get_support(indices=True)
|
| 739 |
|
| 740 |
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
| 741 |
+
output_jax_format=None, output_torch_format=None,
|
| 742 |
+
selection=None, extra_sympy_mappings=None,
|
| 743 |
extra_jax_mappings=None, extra_torch_mappings=None,
|
| 744 |
multioutput=None, nout=None, **kwargs):
|
| 745 |
"""Get the equations from a hall of fame file. If no arguments
|
| 746 |
entered, the ones used previously from a call to PySR will be used."""
|
| 747 |
|
| 748 |
+
global global_state
|
| 749 |
+
|
| 750 |
+
if equation_file is None: equation_file = global_state['equation_file']
|
| 751 |
+
if n_features is None: n_features = global_state['n_features']
|
| 752 |
+
if variable_names is None: variable_names = global_state['variable_names']
|
| 753 |
+
if extra_sympy_mappings is None: extra_sympy_mappings = global_state['extra_sympy_mappings']
|
| 754 |
+
if extra_jax_mappings is None: extra_jax_mappings = global_state['extra_jax_mappings']
|
| 755 |
+
if extra_torch_mappings is None: extra_torch_mappings = global_state['extra_torch_mappings']
|
| 756 |
+
if output_torch_format is None: output_torch_format = global_state['output_torch_format']
|
| 757 |
+
if output_jax_format is None: output_jax_format = global_state['output_jax_format']
|
| 758 |
+
if multioutput is None: multioutput = global_state['multioutput']
|
| 759 |
+
if nout is None: nout = global_state['nout']
|
| 760 |
+
|
| 761 |
+
global_state['selection'] = selection
|
| 762 |
+
global_state['equation_file'] = equation_file
|
| 763 |
+
global_state['n_features'] = n_features
|
| 764 |
+
global_state['variable_names'] = variable_names
|
| 765 |
+
global_state['extra_sympy_mappings'] = extra_sympy_mappings
|
| 766 |
+
global_state['extra_jax_mappings'] = extra_jax_mappings
|
| 767 |
+
global_state['extra_torch_mappings'] = extra_torch_mappings
|
| 768 |
+
global_state['output_torch_format'] = output_torch_format
|
| 769 |
+
global_state['output_jax_format'] = output_jax_format
|
| 770 |
+
global_state['multioutput'] = multioutput
|
| 771 |
+
global_state['nout'] = nout
|
| 772 |
+
global_state['selection'] = selection
|
| 773 |
|
| 774 |
try:
|
| 775 |
if multioutput:
|
|
|
|
| 808 |
sympy_format.append(eqn)
|
| 809 |
|
| 810 |
# Numpy:
|
| 811 |
+
lambda_format.append(CallableEquation(sympy_symbols, eqn, selection))
|
| 812 |
|
| 813 |
# JAX:
|
| 814 |
if output_jax_format:
|
| 815 |
from .export_jax import sympy2jax
|
| 816 |
+
func, params = sympy2jax(eqn, sympy_symbols, selection)
|
| 817 |
jax_format.append({'callable': func, 'parameters': params})
|
| 818 |
|
| 819 |
# Torch:
|
| 820 |
if output_torch_format:
|
| 821 |
from .export_torch import sympy2torch
|
| 822 |
+
module = sympy2torch(eqn, sympy_symbols, selection)
|
| 823 |
torch_format.append(module)
|
| 824 |
|
| 825 |
curMSE = output.loc[i, 'MSE']
|
test/test.py
CHANGED
|
@@ -112,11 +112,12 @@ class TestFeatureSelection(unittest.TestCase):
|
|
| 112 |
X = np.random.randn(20000, 5)
|
| 113 |
y = X[:, 2]**2 + X[:, 3]**2
|
| 114 |
var_names = [f'x{i}' for i in range(5)]
|
| 115 |
-
selected_X, selected_var_names = _handle_feature_selection(
|
| 116 |
X, select_k_features=2,
|
| 117 |
use_custom_variable_names=True,
|
| 118 |
variable_names=[f'x{i}' for i in range(5)],
|
| 119 |
y=y)
|
|
|
|
| 120 |
self.assertEqual(set(selected_var_names), set('x2 x3'.split(' ')))
|
| 121 |
np.testing.assert_array_equal(
|
| 122 |
np.sort(selected_X, axis=1),
|
|
|
|
| 112 |
X = np.random.randn(20000, 5)
|
| 113 |
y = X[:, 2]**2 + X[:, 3]**2
|
| 114 |
var_names = [f'x{i}' for i in range(5)]
|
| 115 |
+
selected_X, selected_var_names, selection = _handle_feature_selection(
|
| 116 |
X, select_k_features=2,
|
| 117 |
use_custom_variable_names=True,
|
| 118 |
variable_names=[f'x{i}' for i in range(5)],
|
| 119 |
y=y)
|
| 120 |
+
self.assertTrue((2 in selection) and (3 in selection))
|
| 121 |
self.assertEqual(set(selected_var_names), set('x2 x3'.split(' ')))
|
| 122 |
np.testing.assert_array_equal(
|
| 123 |
np.sort(selected_X, axis=1),
|