| | """ |
| | Mathematica code printer |
| | """ |
| |
|
| | from __future__ import annotations |
| | from typing import Any |
| |
|
| | from sympy.core import Basic, Expr, Float |
| | from sympy.core.sorting import default_sort_key |
| |
|
| | from sympy.printing.codeprinter import CodePrinter |
| | from sympy.printing.precedence import precedence |
| |
|
| | |
| | known_functions = { |
| | "exp": [(lambda x: True, "Exp")], |
| | "log": [(lambda x: True, "Log")], |
| | "sin": [(lambda x: True, "Sin")], |
| | "cos": [(lambda x: True, "Cos")], |
| | "tan": [(lambda x: True, "Tan")], |
| | "cot": [(lambda x: True, "Cot")], |
| | "sec": [(lambda x: True, "Sec")], |
| | "csc": [(lambda x: True, "Csc")], |
| | "asin": [(lambda x: True, "ArcSin")], |
| | "acos": [(lambda x: True, "ArcCos")], |
| | "atan": [(lambda x: True, "ArcTan")], |
| | "acot": [(lambda x: True, "ArcCot")], |
| | "asec": [(lambda x: True, "ArcSec")], |
| | "acsc": [(lambda x: True, "ArcCsc")], |
| | "sinh": [(lambda x: True, "Sinh")], |
| | "cosh": [(lambda x: True, "Cosh")], |
| | "tanh": [(lambda x: True, "Tanh")], |
| | "coth": [(lambda x: True, "Coth")], |
| | "sech": [(lambda x: True, "Sech")], |
| | "csch": [(lambda x: True, "Csch")], |
| | "asinh": [(lambda x: True, "ArcSinh")], |
| | "acosh": [(lambda x: True, "ArcCosh")], |
| | "atanh": [(lambda x: True, "ArcTanh")], |
| | "acoth": [(lambda x: True, "ArcCoth")], |
| | "asech": [(lambda x: True, "ArcSech")], |
| | "acsch": [(lambda x: True, "ArcCsch")], |
| | "sinc": [(lambda x: True, "Sinc")], |
| | "conjugate": [(lambda x: True, "Conjugate")], |
| | "Max": [(lambda *x: True, "Max")], |
| | "Min": [(lambda *x: True, "Min")], |
| | "erf": [(lambda x: True, "Erf")], |
| | "erf2": [(lambda *x: True, "Erf")], |
| | "erfc": [(lambda x: True, "Erfc")], |
| | "erfi": [(lambda x: True, "Erfi")], |
| | "erfinv": [(lambda x: True, "InverseErf")], |
| | "erfcinv": [(lambda x: True, "InverseErfc")], |
| | "erf2inv": [(lambda *x: True, "InverseErf")], |
| | "expint": [(lambda *x: True, "ExpIntegralE")], |
| | "Ei": [(lambda x: True, "ExpIntegralEi")], |
| | "fresnelc": [(lambda x: True, "FresnelC")], |
| | "fresnels": [(lambda x: True, "FresnelS")], |
| | "gamma": [(lambda x: True, "Gamma")], |
| | "uppergamma": [(lambda *x: True, "Gamma")], |
| | "polygamma": [(lambda *x: True, "PolyGamma")], |
| | "loggamma": [(lambda x: True, "LogGamma")], |
| | "beta": [(lambda *x: True, "Beta")], |
| | "Ci": [(lambda x: True, "CosIntegral")], |
| | "Si": [(lambda x: True, "SinIntegral")], |
| | "Chi": [(lambda x: True, "CoshIntegral")], |
| | "Shi": [(lambda x: True, "SinhIntegral")], |
| | "li": [(lambda x: True, "LogIntegral")], |
| | "factorial": [(lambda x: True, "Factorial")], |
| | "factorial2": [(lambda x: True, "Factorial2")], |
| | "subfactorial": [(lambda x: True, "Subfactorial")], |
| | "catalan": [(lambda x: True, "CatalanNumber")], |
| | "harmonic": [(lambda *x: True, "HarmonicNumber")], |
| | "lucas": [(lambda x: True, "LucasL")], |
| | "RisingFactorial": [(lambda *x: True, "Pochhammer")], |
| | "FallingFactorial": [(lambda *x: True, "FactorialPower")], |
| | "laguerre": [(lambda *x: True, "LaguerreL")], |
| | "assoc_laguerre": [(lambda *x: True, "LaguerreL")], |
| | "hermite": [(lambda *x: True, "HermiteH")], |
| | "jacobi": [(lambda *x: True, "JacobiP")], |
| | "gegenbauer": [(lambda *x: True, "GegenbauerC")], |
| | "chebyshevt": [(lambda *x: True, "ChebyshevT")], |
| | "chebyshevu": [(lambda *x: True, "ChebyshevU")], |
| | "legendre": [(lambda *x: True, "LegendreP")], |
| | "assoc_legendre": [(lambda *x: True, "LegendreP")], |
| | "mathieuc": [(lambda *x: True, "MathieuC")], |
| | "mathieus": [(lambda *x: True, "MathieuS")], |
| | "mathieucprime": [(lambda *x: True, "MathieuCPrime")], |
| | "mathieusprime": [(lambda *x: True, "MathieuSPrime")], |
| | "stieltjes": [(lambda x: True, "StieltjesGamma")], |
| | "elliptic_e": [(lambda *x: True, "EllipticE")], |
| | "elliptic_f": [(lambda *x: True, "EllipticE")], |
| | "elliptic_k": [(lambda x: True, "EllipticK")], |
| | "elliptic_pi": [(lambda *x: True, "EllipticPi")], |
| | "zeta": [(lambda *x: True, "Zeta")], |
| | "dirichlet_eta": [(lambda x: True, "DirichletEta")], |
| | "riemann_xi": [(lambda x: True, "RiemannXi")], |
| | "besseli": [(lambda *x: True, "BesselI")], |
| | "besselj": [(lambda *x: True, "BesselJ")], |
| | "besselk": [(lambda *x: True, "BesselK")], |
| | "bessely": [(lambda *x: True, "BesselY")], |
| | "hankel1": [(lambda *x: True, "HankelH1")], |
| | "hankel2": [(lambda *x: True, "HankelH2")], |
| | "airyai": [(lambda x: True, "AiryAi")], |
| | "airybi": [(lambda x: True, "AiryBi")], |
| | "airyaiprime": [(lambda x: True, "AiryAiPrime")], |
| | "airybiprime": [(lambda x: True, "AiryBiPrime")], |
| | "polylog": [(lambda *x: True, "PolyLog")], |
| | "lerchphi": [(lambda *x: True, "LerchPhi")], |
| | "gcd": [(lambda *x: True, "GCD")], |
| | "lcm": [(lambda *x: True, "LCM")], |
| | "jn": [(lambda *x: True, "SphericalBesselJ")], |
| | "yn": [(lambda *x: True, "SphericalBesselY")], |
| | "hyper": [(lambda *x: True, "HypergeometricPFQ")], |
| | "meijerg": [(lambda *x: True, "MeijerG")], |
| | "appellf1": [(lambda *x: True, "AppellF1")], |
| | "DiracDelta": [(lambda x: True, "DiracDelta")], |
| | "Heaviside": [(lambda x: True, "HeavisideTheta")], |
| | "KroneckerDelta": [(lambda *x: True, "KroneckerDelta")], |
| | "sqrt": [(lambda x: True, "Sqrt")], |
| | } |
| |
|
| |
|
| | class MCodePrinter(CodePrinter): |
| | """A printer to convert Python expressions to |
| | strings of the Wolfram's Mathematica code |
| | """ |
| | printmethod = "_mcode" |
| | language = "Wolfram Language" |
| |
|
| | _default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{ |
| | 'precision': 15, |
| | 'user_functions': {}, |
| | }) |
| |
|
| | _number_symbols: set[tuple[Expr, Float]] = set() |
| | _not_supported: set[Basic] = set() |
| |
|
| | def __init__(self, settings={}): |
| | """Register function mappings supplied by user""" |
| | CodePrinter.__init__(self, settings) |
| | self.known_functions = dict(known_functions) |
| | userfuncs = settings.get('user_functions', {}).copy() |
| | for k, v in userfuncs.items(): |
| | if not isinstance(v, list): |
| | userfuncs[k] = [(lambda *x: True, v)] |
| | self.known_functions.update(userfuncs) |
| |
|
| | def _format_code(self, lines): |
| | return lines |
| |
|
| | def _print_Pow(self, expr): |
| | PREC = precedence(expr) |
| | return '%s^%s' % (self.parenthesize(expr.base, PREC), |
| | self.parenthesize(expr.exp, PREC)) |
| |
|
| | def _print_Mul(self, expr): |
| | PREC = precedence(expr) |
| | c, nc = expr.args_cnc() |
| | res = super()._print_Mul(expr.func(*c)) |
| | if nc: |
| | res += '*' |
| | res += '**'.join(self.parenthesize(a, PREC) for a in nc) |
| | return res |
| |
|
| | def _print_Relational(self, expr): |
| | lhs_code = self._print(expr.lhs) |
| | rhs_code = self._print(expr.rhs) |
| | op = expr.rel_op |
| | return "{} {} {}".format(lhs_code, op, rhs_code) |
| |
|
| | |
| | def _print_Zero(self, expr): |
| | return '0' |
| |
|
| | def _print_One(self, expr): |
| | return '1' |
| |
|
| | def _print_NegativeOne(self, expr): |
| | return '-1' |
| |
|
| | def _print_Half(self, expr): |
| | return '1/2' |
| |
|
| | def _print_ImaginaryUnit(self, expr): |
| | return 'I' |
| |
|
| |
|
| | |
| | def _print_Infinity(self, expr): |
| | return 'Infinity' |
| |
|
| | def _print_NegativeInfinity(self, expr): |
| | return '-Infinity' |
| |
|
| | def _print_ComplexInfinity(self, expr): |
| | return 'ComplexInfinity' |
| |
|
| | def _print_NaN(self, expr): |
| | return 'Indeterminate' |
| |
|
| |
|
| | |
| | def _print_Exp1(self, expr): |
| | return 'E' |
| |
|
| | def _print_Pi(self, expr): |
| | return 'Pi' |
| |
|
| | def _print_GoldenRatio(self, expr): |
| | return 'GoldenRatio' |
| |
|
| | def _print_TribonacciConstant(self, expr): |
| | expanded = expr.expand(func=True) |
| | PREC = precedence(expr) |
| | return self.parenthesize(expanded, PREC) |
| |
|
| | def _print_EulerGamma(self, expr): |
| | return 'EulerGamma' |
| |
|
| | def _print_Catalan(self, expr): |
| | return 'Catalan' |
| |
|
| |
|
| | def _print_list(self, expr): |
| | return '{' + ', '.join(self.doprint(a) for a in expr) + '}' |
| | _print_tuple = _print_list |
| | _print_Tuple = _print_list |
| |
|
| | def _print_ImmutableDenseMatrix(self, expr): |
| | return self.doprint(expr.tolist()) |
| |
|
| | def _print_ImmutableSparseMatrix(self, expr): |
| |
|
| | def print_rule(pos, val): |
| | return '{} -> {}'.format( |
| | self.doprint((pos[0]+1, pos[1]+1)), self.doprint(val)) |
| |
|
| | def print_data(): |
| | items = sorted(expr.todok().items(), key=default_sort_key) |
| | return '{' + \ |
| | ', '.join(print_rule(k, v) for k, v in items) + \ |
| | '}' |
| |
|
| | def print_dims(): |
| | return self.doprint(expr.shape) |
| |
|
| | return 'SparseArray[{}, {}]'.format(print_data(), print_dims()) |
| |
|
| | def _print_ImmutableDenseNDimArray(self, expr): |
| | return self.doprint(expr.tolist()) |
| |
|
| | def _print_ImmutableSparseNDimArray(self, expr): |
| | def print_string_list(string_list): |
| | return '{' + ', '.join(a for a in string_list) + '}' |
| |
|
| | def to_mathematica_index(*args): |
| | """Helper function to change Python style indexing to |
| | Pathematica indexing. |
| | |
| | Python indexing (0, 1 ... n-1) |
| | -> Mathematica indexing (1, 2 ... n) |
| | """ |
| | return tuple(i + 1 for i in args) |
| |
|
| | def print_rule(pos, val): |
| | """Helper function to print a rule of Mathematica""" |
| | return '{} -> {}'.format(self.doprint(pos), self.doprint(val)) |
| |
|
| | def print_data(): |
| | """Helper function to print data part of Mathematica |
| | sparse array. |
| | |
| | It uses the fourth notation ``SparseArray[data,{d1,d2,...}]`` |
| | from |
| | https://reference.wolfram.com/language/ref/SparseArray.html |
| | |
| | ``data`` must be formatted with rule. |
| | """ |
| | return print_string_list( |
| | [print_rule( |
| | to_mathematica_index(*(expr._get_tuple_index(key))), |
| | value) |
| | for key, value in sorted(expr._sparse_array.items())] |
| | ) |
| |
|
| | def print_dims(): |
| | """Helper function to print dimensions part of Mathematica |
| | sparse array. |
| | |
| | It uses the fourth notation ``SparseArray[data,{d1,d2,...}]`` |
| | from |
| | https://reference.wolfram.com/language/ref/SparseArray.html |
| | """ |
| | return self.doprint(expr.shape) |
| |
|
| | return 'SparseArray[{}, {}]'.format(print_data(), print_dims()) |
| |
|
| | def _print_Function(self, expr): |
| | if expr.func.__name__ in self.known_functions: |
| | cond_mfunc = self.known_functions[expr.func.__name__] |
| | for cond, mfunc in cond_mfunc: |
| | if cond(*expr.args): |
| | return "%s[%s]" % (mfunc, self.stringify(expr.args, ", ")) |
| | elif expr.func.__name__ in self._rewriteable_functions: |
| | |
| | target_f, required_fs = self._rewriteable_functions[expr.func.__name__] |
| | if self._can_print(target_f) and all(self._can_print(f) for f in required_fs): |
| | return self._print(expr.rewrite(target_f)) |
| | return expr.func.__name__ + "[%s]" % self.stringify(expr.args, ", ") |
| |
|
| | _print_MinMaxBase = _print_Function |
| |
|
| | def _print_LambertW(self, expr): |
| | if len(expr.args) == 1: |
| | return "ProductLog[{}]".format(self._print(expr.args[0])) |
| | return "ProductLog[{}, {}]".format( |
| | self._print(expr.args[1]), self._print(expr.args[0])) |
| |
|
| | def _print_atan2(self, expr): |
| | return "ArcTan[{}, {}]".format( |
| | self._print(expr.args[1]), self._print(expr.args[0])) |
| |
|
| | def _print_Integral(self, expr): |
| | if len(expr.variables) == 1 and not expr.limits[0][1:]: |
| | args = [expr.args[0], expr.variables[0]] |
| | else: |
| | args = expr.args |
| | return "Hold[Integrate[" + ', '.join(self.doprint(a) for a in args) + "]]" |
| |
|
| | def _print_Sum(self, expr): |
| | return "Hold[Sum[" + ', '.join(self.doprint(a) for a in expr.args) + "]]" |
| |
|
| | def _print_Derivative(self, expr): |
| | dexpr = expr.expr |
| | dvars = [i[0] if i[1] == 1 else i for i in expr.variable_count] |
| | return "Hold[D[" + ', '.join(self.doprint(a) for a in [dexpr] + dvars) + "]]" |
| |
|
| |
|
| | def _get_comment(self, text): |
| | return "(* {} *)".format(text) |
| |
|
| |
|
| | def mathematica_code(expr, **settings): |
| | r"""Converts an expr to a string of the Wolfram Mathematica code |
| | |
| | Examples |
| | ======== |
| | |
| | >>> from sympy import mathematica_code as mcode, symbols, sin |
| | >>> x = symbols('x') |
| | >>> mcode(sin(x).series(x).removeO()) |
| | '(1/120)*x^5 - 1/6*x^3 + x' |
| | """ |
| | return MCodePrinter(settings).doprint(expr) |
| |
|