Spaces:
Sleeping
Sleeping
Commit
·
898f500
1
Parent(s):
3772652
Add mechanism for extracting JAX functions
Browse files- pysr/sr.py +18 -4
- setup.py +1 -1
pysr/sr.py
CHANGED
|
@@ -12,7 +12,7 @@ import shutil
|
|
| 12 |
from pathlib import Path
|
| 13 |
from datetime import datetime
|
| 14 |
import warnings
|
| 15 |
-
|
| 16 |
|
| 17 |
global_equation_file = 'hall_of_fame.csv'
|
| 18 |
global_n_features = None
|
|
@@ -106,6 +106,7 @@ def pysr(X=None, y=None, weights=None,
|
|
| 106 |
user_input=True,
|
| 107 |
update=True,
|
| 108 |
temp_equation_file=False,
|
|
|
|
| 109 |
warmupMaxsize=None, #Deprecated
|
| 110 |
):
|
| 111 |
"""Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
|
|
@@ -216,6 +217,8 @@ def pysr(X=None, y=None, weights=None,
|
|
| 216 |
:param temp_equation_file: Whether to put the hall of fame file in
|
| 217 |
the temp directory. Deletion is then controlled with the
|
| 218 |
delete_tempfiles argument.
|
|
|
|
|
|
|
| 219 |
:returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
|
| 220 |
(as strings).
|
| 221 |
|
|
@@ -281,7 +284,8 @@ def pysr(X=None, y=None, weights=None,
|
|
| 281 |
weightSimplify=weightSimplify,
|
| 282 |
constraints=constraints,
|
| 283 |
extra_sympy_mappings=extra_sympy_mappings,
|
| 284 |
-
julia_project=julia_project, loss=loss
|
|
|
|
| 285 |
|
| 286 |
kwargs = {**_set_paths(tempdir), **kwargs}
|
| 287 |
|
|
@@ -633,7 +637,8 @@ def run_feature_selection(X, y, select_k_features):
|
|
| 633 |
max_features=select_k_features, prefit=True)
|
| 634 |
return selector.get_support(indices=True)
|
| 635 |
|
| 636 |
-
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
|
|
| 637 |
"""Get the equations from a hall of fame file. If no arguments
|
| 638 |
entered, the ones used previously from a call to PySR will be used."""
|
| 639 |
|
|
@@ -663,6 +668,8 @@ def get_hof(equation_file=None, n_features=None, variable_names=None, extra_symp
|
|
| 663 |
lastComplexity = 0
|
| 664 |
sympy_format = []
|
| 665 |
lambda_format = []
|
|
|
|
|
|
|
| 666 |
use_custom_variable_names = (len(variable_names) != 0)
|
| 667 |
local_sympy_mappings = {
|
| 668 |
**extra_sympy_mappings,
|
|
@@ -677,6 +684,9 @@ def get_hof(equation_file=None, n_features=None, variable_names=None, extra_symp
|
|
| 677 |
for i in range(len(output)):
|
| 678 |
eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
|
| 679 |
sympy_format.append(eqn)
|
|
|
|
|
|
|
|
|
|
| 680 |
lambda_format.append(lambdify(sympy_symbols, eqn))
|
| 681 |
curMSE = output.loc[i, 'MSE']
|
| 682 |
curComplexity = output.loc[i, 'Complexity']
|
|
@@ -693,8 +703,12 @@ def get_hof(equation_file=None, n_features=None, variable_names=None, extra_symp
|
|
| 693 |
output['score'] = np.array(scores)
|
| 694 |
output['sympy_format'] = sympy_format
|
| 695 |
output['lambda_format'] = lambda_format
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
|
| 697 |
-
return output[
|
| 698 |
|
| 699 |
def best_row(equations=None):
|
| 700 |
"""Return the best row of a hall of fame file using the score column.
|
|
|
|
| 12 |
from pathlib import Path
|
| 13 |
from datetime import datetime
|
| 14 |
import warnings
|
| 15 |
+
from .export import sympy2jax
|
| 16 |
|
| 17 |
global_equation_file = 'hall_of_fame.csv'
|
| 18 |
global_n_features = None
|
|
|
|
| 106 |
user_input=True,
|
| 107 |
update=True,
|
| 108 |
temp_equation_file=False,
|
| 109 |
+
output_jax_format=False,
|
| 110 |
warmupMaxsize=None, #Deprecated
|
| 111 |
):
|
| 112 |
"""Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
|
|
|
|
| 217 |
:param temp_equation_file: Whether to put the hall of fame file in
|
| 218 |
the temp directory. Deletion is then controlled with the
|
| 219 |
delete_tempfiles argument.
|
| 220 |
+
:param output_jax_format: Whether to create a 'jax_format' column in the output,
|
| 221 |
+
containing jax-callable functions and the default parameters in a jax array.
|
| 222 |
:returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
|
| 223 |
(as strings).
|
| 224 |
|
|
|
|
| 284 |
weightSimplify=weightSimplify,
|
| 285 |
constraints=constraints,
|
| 286 |
extra_sympy_mappings=extra_sympy_mappings,
|
| 287 |
+
julia_project=julia_project, loss=loss,
|
| 288 |
+
output_jax_format=output_jax_format)
|
| 289 |
|
| 290 |
kwargs = {**_set_paths(tempdir), **kwargs}
|
| 291 |
|
|
|
|
| 637 |
max_features=select_k_features, prefit=True)
|
| 638 |
return selector.get_support(indices=True)
|
| 639 |
|
| 640 |
+
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
| 641 |
+
extra_sympy_mappings=None, output_jax_format=False, **kwargs):
|
| 642 |
"""Get the equations from a hall of fame file. If no arguments
|
| 643 |
entered, the ones used previously from a call to PySR will be used."""
|
| 644 |
|
|
|
|
| 668 |
lastComplexity = 0
|
| 669 |
sympy_format = []
|
| 670 |
lambda_format = []
|
| 671 |
+
if output_jax_format:
|
| 672 |
+
jax_format = []
|
| 673 |
use_custom_variable_names = (len(variable_names) != 0)
|
| 674 |
local_sympy_mappings = {
|
| 675 |
**extra_sympy_mappings,
|
|
|
|
| 684 |
for i in range(len(output)):
|
| 685 |
eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
|
| 686 |
sympy_format.append(eqn)
|
| 687 |
+
if output_jax_format:
|
| 688 |
+
func, params = sympy2jax(eqn, sympy_symbols)
|
| 689 |
+
jax_format.append({'callable': func, 'parameters': parameters})
|
| 690 |
lambda_format.append(lambdify(sympy_symbols, eqn))
|
| 691 |
curMSE = output.loc[i, 'MSE']
|
| 692 |
curComplexity = output.loc[i, 'Complexity']
|
|
|
|
| 703 |
output['score'] = np.array(scores)
|
| 704 |
output['sympy_format'] = sympy_format
|
| 705 |
output['lambda_format'] = lambda_format
|
| 706 |
+
output_cols = ['Complexity', 'MSE', 'score', 'Equation', 'sympy_format', 'lambda_format']
|
| 707 |
+
if output_jax_format:
|
| 708 |
+
output_cols += 'jax_format'
|
| 709 |
+
output['jax_format'] = jax_format
|
| 710 |
|
| 711 |
+
return output[output_cols]
|
| 712 |
|
| 713 |
def best_row(equations=None):
|
| 714 |
"""Return the best row of a hall of fame file using the score column.
|
setup.py
CHANGED
|
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
|
|
| 5 |
|
| 6 |
setuptools.setup(
|
| 7 |
name="pysr", # Replace with your own username
|
| 8 |
-
version="0.5.
|
| 9 |
author="Miles Cranmer",
|
| 10 |
author_email="miles.cranmer@gmail.com",
|
| 11 |
description="Simple and efficient symbolic regression",
|
|
|
|
| 5 |
|
| 6 |
setuptools.setup(
|
| 7 |
name="pysr", # Replace with your own username
|
| 8 |
+
version="0.5.13",
|
| 9 |
author="Miles Cranmer",
|
| 10 |
author_email="miles.cranmer@gmail.com",
|
| 11 |
description="Simple and efficient symbolic regression",
|