Spaces:
Running
Running
Commit
·
17c9b1a
1
Parent(s):
beaf20b
Fix sympy2jax for rational numbers
Browse files- pysr/export_jax.py +4 -2
pysr/export_jax.py
CHANGED
|
@@ -58,9 +58,11 @@ def sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None):
|
|
| 58 |
if issubclass(expr.func, sympy.Float):
|
| 59 |
parameters.append(float(expr))
|
| 60 |
return f"parameters[{len(parameters) - 1}]"
|
| 61 |
-
|
|
|
|
|
|
|
| 62 |
return f"{int(expr)}"
|
| 63 |
-
|
| 64 |
return (
|
| 65 |
f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
|
| 66 |
)
|
|
|
|
| 58 |
if issubclass(expr.func, sympy.Float):
|
| 59 |
parameters.append(float(expr))
|
| 60 |
return f"parameters[{len(parameters) - 1}]"
|
| 61 |
+
elif issubclass(expr.func, sympy.Rational):
|
| 62 |
+
return f"{float(expr)}"
|
| 63 |
+
elif issubclass(expr.func, sympy.Integer):
|
| 64 |
return f"{int(expr)}"
|
| 65 |
+
elif issubclass(expr.func, sympy.Symbol):
|
| 66 |
return (
|
| 67 |
f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
|
| 68 |
)
|