Spaces:
Running
Running
Commit
·
84e4a47
1
Parent(s):
e7ede78
Allow user to pass extra torch operators to pysr
Browse files- pysr/sr.py +12 -5
pysr/sr.py
CHANGED
|
@@ -102,6 +102,8 @@ def pysr(X, y, weights=None,
|
|
| 102 |
perturbationFactor=1.0,
|
| 103 |
timeout=None,
|
| 104 |
extra_sympy_mappings=None,
|
|
|
|
|
|
|
| 105 |
equation_file=None,
|
| 106 |
verbosity=1e9,
|
| 107 |
progress=True,
|
|
@@ -336,6 +338,8 @@ def pysr(X, y, weights=None,
|
|
| 336 |
weightSimplify=weightSimplify,
|
| 337 |
constraints=constraints,
|
| 338 |
extra_sympy_mappings=extra_sympy_mappings,
|
|
|
|
|
|
|
| 339 |
julia_project=julia_project, loss=loss,
|
| 340 |
output_jax_format=output_jax_format,
|
| 341 |
output_torch_format=output_torch_format,
|
|
@@ -730,6 +734,7 @@ def run_feature_selection(X, y, select_k_features):
|
|
| 730 |
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
| 731 |
extra_sympy_mappings=None, output_jax_format=False,
|
| 732 |
output_torch_format=False,
|
|
|
|
| 733 |
multioutput=None, nout=None, **kwargs):
|
| 734 |
"""Get the equations from a hall of fame file. If no arguments
|
| 735 |
entered, the ones used previously from a call to PySR will be used."""
|
|
@@ -790,20 +795,22 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
| 790 |
for i in range(len(output)):
|
| 791 |
eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
|
| 792 |
sympy_format.append(eqn)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 793 |
if output_jax_format:
|
| 794 |
from .export_jax import sympy2jax
|
| 795 |
func, params = sympy2jax(eqn, sympy_symbols)
|
| 796 |
jax_format.append({'callable': func, 'parameters': params})
|
| 797 |
-
<<<<<<< HEAD
|
| 798 |
|
| 799 |
-
|
| 800 |
-
=======
|
| 801 |
if output_torch_format:
|
| 802 |
from .export_torch import sympy2torch
|
| 803 |
module = sympy2torch(eqn, sympy_symbols)
|
| 804 |
torch_format.append(module)
|
| 805 |
-
|
| 806 |
-
>>>>>>> 6ba697f (Add torch format output; dont import jax/torch by default)
|
| 807 |
curMSE = output.loc[i, 'MSE']
|
| 808 |
curComplexity = output.loc[i, 'Complexity']
|
| 809 |
|
|
|
|
| 102 |
perturbationFactor=1.0,
|
| 103 |
timeout=None,
|
| 104 |
extra_sympy_mappings=None,
|
| 105 |
+
extra_torch_mappings=None,
|
| 106 |
+
extra_jax_mappings=None,
|
| 107 |
equation_file=None,
|
| 108 |
verbosity=1e9,
|
| 109 |
progress=True,
|
|
|
|
| 338 |
weightSimplify=weightSimplify,
|
| 339 |
constraints=constraints,
|
| 340 |
extra_sympy_mappings=extra_sympy_mappings,
|
| 341 |
+
extra_jax_mappings=extra_jax_mappings,
|
| 342 |
+
extra_torch_mappings=extra_torch_mappings,
|
| 343 |
julia_project=julia_project, loss=loss,
|
| 344 |
output_jax_format=output_jax_format,
|
| 345 |
output_torch_format=output_torch_format,
|
|
|
|
| 734 |
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
| 735 |
extra_sympy_mappings=None, output_jax_format=False,
|
| 736 |
output_torch_format=False,
|
| 737 |
+
extra_jax_mappings=None, extra_torch_mappings=None,
|
| 738 |
multioutput=None, nout=None, **kwargs):
|
| 739 |
"""Get the equations from a hall of fame file. If no arguments
|
| 740 |
entered, the ones used previously from a call to PySR will be used."""
|
|
|
|
| 795 |
for i in range(len(output)):
|
| 796 |
eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
|
| 797 |
sympy_format.append(eqn)
|
| 798 |
+
|
| 799 |
+
# Numpy:
|
| 800 |
+
lambda_format.append(CallableEquation(sympy_symbols, eqn))
|
| 801 |
+
|
| 802 |
+
# JAX:
|
| 803 |
if output_jax_format:
|
| 804 |
from .export_jax import sympy2jax
|
| 805 |
func, params = sympy2jax(eqn, sympy_symbols)
|
| 806 |
jax_format.append({'callable': func, 'parameters': params})
|
|
|
|
| 807 |
|
| 808 |
+
# Torch:
|
|
|
|
| 809 |
if output_torch_format:
|
| 810 |
from .export_torch import sympy2torch
|
| 811 |
module = sympy2torch(eqn, sympy_symbols)
|
| 812 |
torch_format.append(module)
|
| 813 |
+
|
|
|
|
| 814 |
curMSE = output.loc[i, 'MSE']
|
| 815 |
curComplexity = output.loc[i, 'Complexity']
|
| 816 |
|