| | ''' |
| | Use llvmlite to create executable functions from SymPy expressions |
| | |
| | This module requires llvmlite (https://github.com/numba/llvmlite). |
| | ''' |
| |
|
| | import ctypes |
| |
|
| | from sympy.external import import_module |
| | from sympy.printing.printer import Printer |
| | from sympy.core.singleton import S |
| | from sympy.tensor.indexed import IndexedBase |
| | from sympy.utilities.decorator import doctest_depends_on |
| |
|
| | llvmlite = import_module('llvmlite') |
| | if llvmlite: |
| | ll = import_module('llvmlite.ir').ir |
| | llvm = import_module('llvmlite.binding').binding |
| | llvm.initialize() |
| | llvm.initialize_native_target() |
| | llvm.initialize_native_asmprinter() |
| |
|
| |
|
| | __doctest_requires__ = {('llvm_callable'): ['llvmlite']} |
| |
|
| |
|
| | class LLVMJitPrinter(Printer): |
| | '''Convert expressions to LLVM IR''' |
| | def __init__(self, module, builder, fn, *args, **kwargs): |
| | self.func_arg_map = kwargs.pop("func_arg_map", {}) |
| | if not llvmlite: |
| | raise ImportError("llvmlite is required for LLVMJITPrinter") |
| | super().__init__(*args, **kwargs) |
| | self.fp_type = ll.DoubleType() |
| | self.module = module |
| | self.builder = builder |
| | self.fn = fn |
| | self.ext_fn = {} |
| | self.tmp_var = {} |
| |
|
| | def _add_tmp_var(self, name, value): |
| | self.tmp_var[name] = value |
| |
|
| | def _print_Number(self, n): |
| | return ll.Constant(self.fp_type, float(n)) |
| |
|
| | def _print_Integer(self, expr): |
| | return ll.Constant(self.fp_type, float(expr.p)) |
| |
|
| | def _print_Symbol(self, s): |
| | val = self.tmp_var.get(s) |
| | if not val: |
| | |
| | val = self.func_arg_map.get(s) |
| | if not val: |
| | raise LookupError("Symbol not found: %s" % s) |
| | return val |
| |
|
| | def _print_Pow(self, expr): |
| | base0 = self._print(expr.base) |
| | if expr.exp == S.NegativeOne: |
| | return self.builder.fdiv(ll.Constant(self.fp_type, 1.0), base0) |
| | if expr.exp == S.Half: |
| | fn = self.ext_fn.get("sqrt") |
| | if not fn: |
| | fn_type = ll.FunctionType(self.fp_type, [self.fp_type]) |
| | fn = ll.Function(self.module, fn_type, "sqrt") |
| | self.ext_fn["sqrt"] = fn |
| | return self.builder.call(fn, [base0], "sqrt") |
| | if expr.exp == 2: |
| | return self.builder.fmul(base0, base0) |
| |
|
| | exp0 = self._print(expr.exp) |
| | fn = self.ext_fn.get("pow") |
| | if not fn: |
| | fn_type = ll.FunctionType(self.fp_type, [self.fp_type, self.fp_type]) |
| | fn = ll.Function(self.module, fn_type, "pow") |
| | self.ext_fn["pow"] = fn |
| | return self.builder.call(fn, [base0, exp0], "pow") |
| |
|
| | def _print_Mul(self, expr): |
| | nodes = [self._print(a) for a in expr.args] |
| | e = nodes[0] |
| | for node in nodes[1:]: |
| | e = self.builder.fmul(e, node) |
| | return e |
| |
|
| | def _print_Add(self, expr): |
| | nodes = [self._print(a) for a in expr.args] |
| | e = nodes[0] |
| | for node in nodes[1:]: |
| | e = self.builder.fadd(e, node) |
| | return e |
| |
|
| | |
| | |
| | def _print_Function(self, expr): |
| | name = expr.func.__name__ |
| | e0 = self._print(expr.args[0]) |
| | fn = self.ext_fn.get(name) |
| | if not fn: |
| | fn_type = ll.FunctionType(self.fp_type, [self.fp_type]) |
| | fn = ll.Function(self.module, fn_type, name) |
| | self.ext_fn[name] = fn |
| | return self.builder.call(fn, [e0], name) |
| |
|
| | def emptyPrinter(self, expr): |
| | raise TypeError("Unsupported type for LLVM JIT conversion: %s" |
| | % type(expr)) |
| |
|
| |
|
| | |
| | |
| | class LLVMJitCallbackPrinter(LLVMJitPrinter): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| |
|
| | def _print_Indexed(self, expr): |
| | array, idx = self.func_arg_map[expr.base] |
| | offset = int(expr.indices[0].evalf()) |
| | array_ptr = self.builder.gep(array, [ll.Constant(ll.IntType(32), offset)]) |
| | fp_array_ptr = self.builder.bitcast(array_ptr, ll.PointerType(self.fp_type)) |
| | value = self.builder.load(fp_array_ptr) |
| | return value |
| |
|
| | def _print_Symbol(self, s): |
| | val = self.tmp_var.get(s) |
| | if val: |
| | return val |
| |
|
| | array, idx = self.func_arg_map.get(s, [None, 0]) |
| | if not array: |
| | raise LookupError("Symbol not found: %s" % s) |
| | array_ptr = self.builder.gep(array, [ll.Constant(ll.IntType(32), idx)]) |
| | fp_array_ptr = self.builder.bitcast(array_ptr, |
| | ll.PointerType(self.fp_type)) |
| | value = self.builder.load(fp_array_ptr) |
| | return value |
| |
|
| |
|
| | |
| | |
| | exe_engines = [] |
| |
|
| | |
| | link_names = set() |
| | current_link_suffix = 0 |
| |
|
| |
|
| | class LLVMJitCode: |
| | def __init__(self, signature): |
| | self.signature = signature |
| | self.fp_type = ll.DoubleType() |
| | self.module = ll.Module('mod1') |
| | self.fn = None |
| | self.llvm_arg_types = [] |
| | self.llvm_ret_type = self.fp_type |
| | self.param_dict = {} |
| | self.link_name = '' |
| |
|
| | def _from_ctype(self, ctype): |
| | if ctype == ctypes.c_int: |
| | return ll.IntType(32) |
| | if ctype == ctypes.c_double: |
| | return self.fp_type |
| | if ctype == ctypes.POINTER(ctypes.c_double): |
| | return ll.PointerType(self.fp_type) |
| | if ctype == ctypes.c_void_p: |
| | return ll.PointerType(ll.IntType(32)) |
| | if ctype == ctypes.py_object: |
| | return ll.PointerType(ll.IntType(32)) |
| |
|
| | print("Unhandled ctype = %s" % str(ctype)) |
| |
|
| | def _create_args(self, func_args): |
| | """Create types for function arguments""" |
| | self.llvm_ret_type = self._from_ctype(self.signature.ret_type) |
| | self.llvm_arg_types = \ |
| | [self._from_ctype(a) for a in self.signature.arg_ctypes] |
| |
|
| | def _create_function_base(self): |
| | """Create function with name and type signature""" |
| | global current_link_suffix |
| | default_link_name = 'jit_func' |
| | current_link_suffix += 1 |
| | self.link_name = default_link_name + str(current_link_suffix) |
| | link_names.add(self.link_name) |
| |
|
| | fn_type = ll.FunctionType(self.llvm_ret_type, self.llvm_arg_types) |
| | self.fn = ll.Function(self.module, fn_type, name=self.link_name) |
| |
|
| | def _create_param_dict(self, func_args): |
| | """Mapping of symbolic values to function arguments""" |
| | for i, a in enumerate(func_args): |
| | self.fn.args[i].name = str(a) |
| | self.param_dict[a] = self.fn.args[i] |
| |
|
| | def _create_function(self, expr): |
| | """Create function body and return LLVM IR""" |
| | bb_entry = self.fn.append_basic_block('entry') |
| | builder = ll.IRBuilder(bb_entry) |
| |
|
| | lj = LLVMJitPrinter(self.module, builder, self.fn, |
| | func_arg_map=self.param_dict) |
| |
|
| | ret = self._convert_expr(lj, expr) |
| | lj.builder.ret(self._wrap_return(lj, ret)) |
| |
|
| | strmod = str(self.module) |
| | return strmod |
| |
|
| | def _wrap_return(self, lj, vals): |
| | |
| | |
| |
|
| | |
| | if self.signature.ret_type == ctypes.c_double: |
| | return vals[0] |
| |
|
| | |
| | void_ptr = ll.PointerType(ll.IntType(32)) |
| |
|
| | |
| | wrap_type = ll.FunctionType(void_ptr, [self.fp_type]) |
| | wrap_fn = ll.Function(lj.module, wrap_type, "PyFloat_FromDouble") |
| |
|
| | wrapped_vals = [lj.builder.call(wrap_fn, [v]) for v in vals] |
| | if len(vals) == 1: |
| | final_val = wrapped_vals[0] |
| | else: |
| | |
| |
|
| | |
| | tuple_arg_types = [ll.IntType(32)] |
| |
|
| | tuple_arg_types.extend([void_ptr]*len(vals)) |
| | tuple_type = ll.FunctionType(void_ptr, tuple_arg_types) |
| | tuple_fn = ll.Function(lj.module, tuple_type, "PyTuple_Pack") |
| |
|
| | tuple_args = [ll.Constant(ll.IntType(32), len(wrapped_vals))] |
| | tuple_args.extend(wrapped_vals) |
| |
|
| | final_val = lj.builder.call(tuple_fn, tuple_args) |
| |
|
| | return final_val |
| |
|
| | def _convert_expr(self, lj, expr): |
| | try: |
| | |
| | if len(expr) == 2: |
| | tmp_exprs = expr[0] |
| | final_exprs = expr[1] |
| | if len(final_exprs) != 1 and self.signature.ret_type == ctypes.c_double: |
| | raise NotImplementedError("Return of multiple expressions not supported for this callback") |
| | for name, e in tmp_exprs: |
| | val = lj._print(e) |
| | lj._add_tmp_var(name, val) |
| | except TypeError: |
| | final_exprs = [expr] |
| |
|
| | vals = [lj._print(e) for e in final_exprs] |
| |
|
| | return vals |
| |
|
| | def _compile_function(self, strmod): |
| | llmod = llvm.parse_assembly(strmod) |
| |
|
| | pmb = llvm.create_pass_manager_builder() |
| | pmb.opt_level = 2 |
| | pass_manager = llvm.create_module_pass_manager() |
| | pmb.populate(pass_manager) |
| |
|
| | pass_manager.run(llmod) |
| |
|
| | target_machine = \ |
| | llvm.Target.from_default_triple().create_target_machine() |
| | exe_eng = llvm.create_mcjit_compiler(llmod, target_machine) |
| | exe_eng.finalize_object() |
| | exe_engines.append(exe_eng) |
| |
|
| | if False: |
| | print("Assembly") |
| | print(target_machine.emit_assembly(llmod)) |
| |
|
| | fptr = exe_eng.get_function_address(self.link_name) |
| |
|
| | return fptr |
| |
|
| |
|
| | class LLVMJitCodeCallback(LLVMJitCode): |
| | def __init__(self, signature): |
| | super().__init__(signature) |
| |
|
| | def _create_param_dict(self, func_args): |
| | for i, a in enumerate(func_args): |
| | if isinstance(a, IndexedBase): |
| | self.param_dict[a] = (self.fn.args[i], i) |
| | self.fn.args[i].name = str(a) |
| | else: |
| | self.param_dict[a] = (self.fn.args[self.signature.input_arg], |
| | i) |
| |
|
| | def _create_function(self, expr): |
| | """Create function body and return LLVM IR""" |
| | bb_entry = self.fn.append_basic_block('entry') |
| | builder = ll.IRBuilder(bb_entry) |
| |
|
| | lj = LLVMJitCallbackPrinter(self.module, builder, self.fn, |
| | func_arg_map=self.param_dict) |
| |
|
| | ret = self._convert_expr(lj, expr) |
| |
|
| | if self.signature.ret_arg: |
| | output_fp_ptr = builder.bitcast(self.fn.args[self.signature.ret_arg], |
| | ll.PointerType(self.fp_type)) |
| | for i, val in enumerate(ret): |
| | index = ll.Constant(ll.IntType(32), i) |
| | output_array_ptr = builder.gep(output_fp_ptr, [index]) |
| | builder.store(val, output_array_ptr) |
| | builder.ret(ll.Constant(ll.IntType(32), 0)) |
| | else: |
| | lj.builder.ret(self._wrap_return(lj, ret)) |
| |
|
| | strmod = str(self.module) |
| | return strmod |
| |
|
| |
|
| | class CodeSignature: |
| | def __init__(self, ret_type): |
| | self.ret_type = ret_type |
| | self.arg_ctypes = [] |
| |
|
| | |
| | self.input_arg = 0 |
| |
|
| | |
| | |
| | self.ret_arg = None |
| |
|
| |
|
| | def _llvm_jit_code(args, expr, signature, callback_type): |
| | """Create a native code function from a SymPy expression""" |
| | if callback_type is None: |
| | jit = LLVMJitCode(signature) |
| | else: |
| | jit = LLVMJitCodeCallback(signature) |
| |
|
| | jit._create_args(args) |
| | jit._create_function_base() |
| | jit._create_param_dict(args) |
| | strmod = jit._create_function(expr) |
| | if False: |
| | print("LLVM IR") |
| | print(strmod) |
| | fptr = jit._compile_function(strmod) |
| | return fptr |
| |
|
| |
|
| | @doctest_depends_on(modules=('llvmlite', 'scipy')) |
| | def llvm_callable(args, expr, callback_type=None): |
| | '''Compile function from a SymPy expression |
| | |
| | Expressions are evaluated using double precision arithmetic. |
| | Some single argument math functions (exp, sin, cos, etc.) are supported |
| | in expressions. |
| | |
| | Parameters |
| | ========== |
| | |
| | args : List of Symbol |
| | Arguments to the generated function. Usually the free symbols in |
| | the expression. Currently each one is assumed to convert to |
| | a double precision scalar. |
| | expr : Expr, or (Replacements, Expr) as returned from 'cse' |
| | Expression to compile. |
| | callback_type : string |
| | Create function with signature appropriate to use as a callback. |
| | Currently supported: |
| | 'scipy.integrate' |
| | 'scipy.integrate.test' |
| | 'cubature' |
| | |
| | Returns |
| | ======= |
| | |
| | Compiled function that can evaluate the expression. |
| | |
| | Examples |
| | ======== |
| | |
| | >>> import sympy.printing.llvmjitcode as jit |
| | >>> from sympy.abc import a |
| | >>> e = a*a + a + 1 |
| | >>> e1 = jit.llvm_callable([a], e) |
| | >>> e.subs(a, 1.1) # Evaluate via substitution |
| | 3.31000000000000 |
| | >>> e1(1.1) # Evaluate using JIT-compiled code |
| | 3.3100000000000005 |
| | |
| | |
| | Callbacks for integration functions can be JIT compiled. |
| | |
| | >>> import sympy.printing.llvmjitcode as jit |
| | >>> from sympy.abc import a |
| | >>> from sympy import integrate |
| | >>> from scipy.integrate import quad |
| | >>> e = a*a |
| | >>> e1 = jit.llvm_callable([a], e, callback_type='scipy.integrate') |
| | >>> integrate(e, (a, 0.0, 2.0)) |
| | 2.66666666666667 |
| | >>> quad(e1, 0.0, 2.0)[0] |
| | 2.66666666666667 |
| | |
| | The 'cubature' callback is for the Python wrapper around the |
| | cubature package ( https://github.com/saullocastro/cubature ) |
| | and ( http://ab-initio.mit.edu/wiki/index.php/Cubature ) |
| | |
| | There are two signatures for the SciPy integration callbacks. |
| | The first ('scipy.integrate') is the function to be passed to the |
| | integration routine, and will pass the signature checks. |
| | The second ('scipy.integrate.test') is only useful for directly calling |
| | the function using ctypes variables. It will not pass the signature checks |
| | for scipy.integrate. |
| | |
| | The return value from the cse module can also be compiled. This |
| | can improve the performance of the compiled function. If multiple |
| | expressions are given to cse, the compiled function returns a tuple. |
| | The 'cubature' callback handles multiple expressions (set `fdim` |
| | to match in the integration call.) |
| | |
| | >>> import sympy.printing.llvmjitcode as jit |
| | >>> from sympy import cse |
| | >>> from sympy.abc import x,y |
| | >>> e1 = x*x + y*y |
| | >>> e2 = 4*(x*x + y*y) + 8.0 |
| | >>> after_cse = cse([e1,e2]) |
| | >>> after_cse |
| | ([(x0, x**2), (x1, y**2)], [x0 + x1, 4*x0 + 4*x1 + 8.0]) |
| | >>> j1 = jit.llvm_callable([x,y], after_cse) |
| | >>> j1(1.0, 2.0) |
| | (5.0, 28.0) |
| | ''' |
| |
|
| | if not llvmlite: |
| | raise ImportError("llvmlite is required for llvmjitcode") |
| |
|
| | signature = CodeSignature(ctypes.py_object) |
| |
|
| | arg_ctypes = [] |
| | if callback_type is None: |
| | for _ in args: |
| | arg_ctype = ctypes.c_double |
| | arg_ctypes.append(arg_ctype) |
| | elif callback_type in ('scipy.integrate', 'scipy.integrate.test'): |
| | signature.ret_type = ctypes.c_double |
| | arg_ctypes = [ctypes.c_int, ctypes.POINTER(ctypes.c_double)] |
| | arg_ctypes_formal = [ctypes.c_int, ctypes.c_double] |
| | signature.input_arg = 1 |
| | elif callback_type == 'cubature': |
| | arg_ctypes = [ctypes.c_int, |
| | ctypes.POINTER(ctypes.c_double), |
| | ctypes.c_void_p, |
| | ctypes.c_int, |
| | ctypes.POINTER(ctypes.c_double) |
| | ] |
| | signature.ret_type = ctypes.c_int |
| | signature.input_arg = 1 |
| | signature.ret_arg = 4 |
| | else: |
| | raise ValueError("Unknown callback type: %s" % callback_type) |
| |
|
| | signature.arg_ctypes = arg_ctypes |
| |
|
| | fptr = _llvm_jit_code(args, expr, signature, callback_type) |
| |
|
| | if callback_type and callback_type == 'scipy.integrate': |
| | arg_ctypes = arg_ctypes_formal |
| |
|
| | |
| | |
| | |
| | |
| | |
| | if signature.ret_type == ctypes.py_object: |
| | FUNCTYPE = ctypes.PYFUNCTYPE |
| | else: |
| | FUNCTYPE = ctypes.CFUNCTYPE |
| |
|
| | cfunc = FUNCTYPE(signature.ret_type, *arg_ctypes)(fptr) |
| | return cfunc |
| |
|