| from __future__ import annotations |
|
|
| from sympy.core import Basic, S |
| from sympy.core.function import Lambda |
| from sympy.core.numbers import equal_valued |
| from sympy.printing.codeprinter import CodePrinter |
| from sympy.printing.precedence import precedence |
| from functools import reduce |
|
|
| known_functions = { |
| 'Abs': 'abs', |
| 'sin': 'sin', |
| 'cos': 'cos', |
| 'tan': 'tan', |
| 'acos': 'acos', |
| 'asin': 'asin', |
| 'atan': 'atan', |
| 'atan2': 'atan', |
| 'ceiling': 'ceil', |
| 'floor': 'floor', |
| 'sign': 'sign', |
| 'exp': 'exp', |
| 'log': 'log', |
| 'add': 'add', |
| 'sub': 'sub', |
| 'mul': 'mul', |
| 'pow': 'pow' |
| } |
|
|
| class GLSLPrinter(CodePrinter): |
| """ |
| Rudimentary, generic GLSL printing tools. |
| |
| Additional settings: |
| 'use_operators': Boolean (should the printer use operators for +,-,*, or functions?) |
| """ |
| _not_supported: set[Basic] = set() |
| printmethod = "_glsl" |
| language = "GLSL" |
|
|
| _default_settings = dict(CodePrinter._default_settings, **{ |
| 'use_operators': True, |
| 'zero': 0, |
| 'mat_nested': False, |
| 'mat_separator': ',\n', |
| 'mat_transpose': False, |
| 'array_type': 'float', |
| 'glsl_types': True, |
|
|
| 'precision': 9, |
| 'user_functions': {}, |
| 'contract': True, |
| }) |
|
|
| def __init__(self, settings={}): |
| CodePrinter.__init__(self, settings) |
| self.known_functions = dict(known_functions) |
| userfuncs = settings.get('user_functions', {}) |
| self.known_functions.update(userfuncs) |
|
|
| def _rate_index_position(self, p): |
| return p*5 |
|
|
| def _get_statement(self, codestring): |
| return "%s;" % codestring |
|
|
| def _get_comment(self, text): |
| return "// {}".format(text) |
|
|
| def _declare_number_const(self, name, value): |
| return "float {} = {};".format(name, value) |
|
|
| def _format_code(self, lines): |
| return self.indent_code(lines) |
|
|
| def indent_code(self, code): |
| """Accepts a string of code or a list of code lines""" |
|
|
| if isinstance(code, str): |
| code_lines = self.indent_code(code.splitlines(True)) |
| return ''.join(code_lines) |
|
|
| tab = " " |
| inc_token = ('{', '(', '{\n', '(\n') |
| dec_token = ('}', ')') |
|
|
| code = [line.lstrip(' \t') for line in code] |
|
|
| increase = [int(any(map(line.endswith, inc_token))) for line in code] |
| decrease = [int(any(map(line.startswith, dec_token))) for line in code] |
|
|
| pretty = [] |
| level = 0 |
| for n, line in enumerate(code): |
| if line in ('', '\n'): |
| pretty.append(line) |
| continue |
| level -= decrease[n] |
| pretty.append("%s%s" % (tab*level, line)) |
| level += increase[n] |
| return pretty |
|
|
| def _print_MatrixBase(self, mat): |
| mat_separator = self._settings['mat_separator'] |
| mat_transpose = self._settings['mat_transpose'] |
| column_vector = (mat.rows == 1) if mat_transpose else (mat.cols == 1) |
| A = mat.transpose() if mat_transpose != column_vector else mat |
|
|
| glsl_types = self._settings['glsl_types'] |
| array_type = self._settings['array_type'] |
| array_size = A.cols*A.rows |
| array_constructor = "{}[{}]".format(array_type, array_size) |
|
|
| if A.cols == 1: |
| return self._print(A[0]) |
| if A.rows <= 4 and A.cols <= 4 and glsl_types: |
| if A.rows == 1: |
| return "vec{}{}".format( |
| A.cols, A.table(self,rowstart='(',rowend=')') |
| ) |
| elif A.rows == A.cols: |
| return "mat{}({})".format( |
| A.rows, A.table(self,rowsep=', ', |
| rowstart='',rowend='') |
| ) |
| else: |
| return "mat{}x{}({})".format( |
| A.cols, A.rows, |
| A.table(self,rowsep=', ', |
| rowstart='',rowend='') |
| ) |
| elif S.One in A.shape: |
| return "{}({})".format( |
| array_constructor, |
| A.table(self,rowsep=mat_separator,rowstart='',rowend='') |
| ) |
| elif not self._settings['mat_nested']: |
| return "{}(\n{}\n) /* a {}x{} matrix */".format( |
| array_constructor, |
| A.table(self,rowsep=mat_separator,rowstart='',rowend=''), |
| A.rows, A.cols |
| ) |
| elif self._settings['mat_nested']: |
| return "{}[{}][{}](\n{}\n)".format( |
| array_type, A.rows, A.cols, |
| A.table(self,rowsep=mat_separator,rowstart='float[](',rowend=')') |
| ) |
|
|
| def _print_SparseRepMatrix(self, mat): |
| |
| return self._print_not_supported(mat) |
|
|
| def _traverse_matrix_indices(self, mat): |
| mat_transpose = self._settings['mat_transpose'] |
| if mat_transpose: |
| rows,cols = mat.shape |
| else: |
| cols,rows = mat.shape |
| return ((i, j) for i in range(cols) for j in range(rows)) |
|
|
| def _print_MatrixElement(self, expr): |
| |
| nest = self._settings['mat_nested'] |
| glsl_types = self._settings['glsl_types'] |
| mat_transpose = self._settings['mat_transpose'] |
| if mat_transpose: |
| cols,rows = expr.parent.shape |
| i,j = expr.j,expr.i |
| else: |
| rows,cols = expr.parent.shape |
| i,j = expr.i,expr.j |
| pnt = self._print(expr.parent) |
| if glsl_types and ((rows <= 4 and cols <=4) or nest): |
| return "{}[{}][{}]".format(pnt, i, j) |
| else: |
| return "{}[{}]".format(pnt, i + j*rows) |
|
|
| def _print_list(self, expr): |
| l = ', '.join(self._print(item) for item in expr) |
| glsl_types = self._settings['glsl_types'] |
| array_type = self._settings['array_type'] |
| array_size = len(expr) |
| array_constructor = '{}[{}]'.format(array_type, array_size) |
|
|
| if array_size <= 4 and glsl_types: |
| return 'vec{}({})'.format(array_size, l) |
| else: |
| return '{}({})'.format(array_constructor, l) |
|
|
| _print_tuple = _print_list |
| _print_Tuple = _print_list |
|
|
| def _get_loop_opening_ending(self, indices): |
| open_lines = [] |
| close_lines = [] |
| loopstart = "for (int %(varble)s=%(start)s; %(varble)s<%(end)s; %(varble)s++){" |
| for i in indices: |
| |
| open_lines.append(loopstart % { |
| 'varble': self._print(i.label), |
| 'start': self._print(i.lower), |
| 'end': self._print(i.upper + 1)}) |
| close_lines.append("}") |
| return open_lines, close_lines |
|
|
| def _print_Function_with_args(self, func, func_args): |
| if func in self.known_functions: |
| cond_func = self.known_functions[func] |
| func = None |
| if isinstance(cond_func, str): |
| func = cond_func |
| else: |
| for cond, func in cond_func: |
| if cond(func_args): |
| break |
| if func is not None: |
| try: |
| return func(*[self.parenthesize(item, 0) for item in func_args]) |
| except TypeError: |
| return '{}({})'.format(func, self.stringify(func_args, ", ")) |
| elif isinstance(func, Lambda): |
| |
| return self._print(func(*func_args)) |
| else: |
| return self._print_not_supported(func) |
|
|
| def _print_Piecewise(self, expr): |
| from sympy.codegen.ast import Assignment |
| if expr.args[-1].cond != True: |
| |
| |
| raise ValueError("All Piecewise expressions must contain an " |
| "(expr, True) statement to be used as a default " |
| "condition. Without one, the generated " |
| "expression may not evaluate to anything under " |
| "some condition.") |
| lines = [] |
| if expr.has(Assignment): |
| for i, (e, c) in enumerate(expr.args): |
| if i == 0: |
| lines.append("if (%s) {" % self._print(c)) |
| elif i == len(expr.args) - 1 and c == True: |
| lines.append("else {") |
| else: |
| lines.append("else if (%s) {" % self._print(c)) |
| code0 = self._print(e) |
| lines.append(code0) |
| lines.append("}") |
| return "\n".join(lines) |
| else: |
| |
| |
| |
| |
| ecpairs = ["((%s) ? (\n%s\n)\n" % (self._print(c), |
| self._print(e)) |
| for e, c in expr.args[:-1]] |
| last_line = ": (\n%s\n)" % self._print(expr.args[-1].expr) |
| return ": ".join(ecpairs) + last_line + " ".join([")"*len(ecpairs)]) |
|
|
| def _print_Indexed(self, expr): |
| |
| dims = expr.shape |
| elem = S.Zero |
| offset = S.One |
| for i in reversed(range(expr.rank)): |
| elem += expr.indices[i]*offset |
| offset *= dims[i] |
| return "{}[{}]".format( |
| self._print(expr.base.label), |
| self._print(elem) |
| ) |
|
|
| def _print_Pow(self, expr): |
| PREC = precedence(expr) |
| if equal_valued(expr.exp, -1): |
| return '1.0/%s' % (self.parenthesize(expr.base, PREC)) |
| elif equal_valued(expr.exp, 0.5): |
| return 'sqrt(%s)' % self._print(expr.base) |
| else: |
| try: |
| e = self._print(float(expr.exp)) |
| except TypeError: |
| e = self._print(expr.exp) |
| return self._print_Function_with_args('pow', ( |
| self._print(expr.base), |
| e |
| )) |
|
|
| def _print_int(self, expr): |
| return str(float(expr)) |
|
|
| def _print_Rational(self, expr): |
| return "{}.0/{}.0".format(expr.p, expr.q) |
|
|
| def _print_Relational(self, expr): |
| lhs_code = self._print(expr.lhs) |
| rhs_code = self._print(expr.rhs) |
| op = expr.rel_op |
| return "{} {} {}".format(lhs_code, op, rhs_code) |
|
|
| def _print_Add(self, expr, order=None): |
| if self._settings['use_operators']: |
| return CodePrinter._print_Add(self, expr, order=order) |
|
|
| terms = expr.as_ordered_terms() |
|
|
| def partition(p,l): |
| return reduce(lambda x, y: (x[0]+[y], x[1]) if p(y) else (x[0], x[1]+[y]), l, ([], [])) |
| def add(a,b): |
| return self._print_Function_with_args('add', (a, b)) |
| |
| neg, pos = partition(lambda arg: arg.could_extract_minus_sign(), terms) |
| if pos: |
| s = pos = reduce(lambda a,b: add(a,b), (self._print(t) for t in pos)) |
| else: |
| s = pos = self._print(self._settings['zero']) |
|
|
| if neg: |
| |
| neg = reduce(lambda a,b: add(a,b), (self._print(-n) for n in neg)) |
| |
| s = self._print_Function_with_args('sub', (pos,neg)) |
| |
| return s |
|
|
| def _print_Mul(self, expr, **kwargs): |
| if self._settings['use_operators']: |
| return CodePrinter._print_Mul(self, expr, **kwargs) |
| terms = expr.as_ordered_factors() |
| def mul(a,b): |
| |
| return self._print_Function_with_args('mul', (a,b)) |
|
|
| s = reduce(lambda a,b: mul(a,b), (self._print(t) for t in terms)) |
| return s |
|
|
| def glsl_code(expr,assign_to=None,**settings): |
| """Converts an expr to a string of GLSL code |
| |
| Parameters |
| ========== |
| |
| expr : Expr |
| A SymPy expression to be converted. |
| assign_to : optional |
| When given, the argument is used for naming the variable or variables |
| to which the expression is assigned. Can be a string, ``Symbol``, |
| ``MatrixSymbol`` or ``Indexed`` type object. In cases where ``expr`` |
| would be printed as an array, a list of string or ``Symbol`` objects |
| can also be passed. |
| |
| This is helpful in case of line-wrapping, or for expressions that |
| generate multi-line statements. It can also be used to spread an array-like |
| expression into multiple assignments. |
| use_operators: bool, optional |
| If set to False, then *,/,+,- operators will be replaced with functions |
| mul, add, and sub, which must be implemented by the user, e.g. for |
| implementing non-standard rings or emulated quad/octal precision. |
| [default=True] |
| glsl_types: bool, optional |
| Set this argument to ``False`` in order to avoid using the ``vec`` and ``mat`` |
| types. The printer will instead use arrays (or nested arrays). |
| [default=True] |
| mat_nested: bool, optional |
| GLSL version 4.3 and above support nested arrays (arrays of arrays). Set this to ``True`` |
| to render matrices as nested arrays. |
| [default=False] |
| mat_separator: str, optional |
| By default, matrices are rendered with newlines using this separator, |
| making them easier to read, but less compact. By removing the newline |
| this option can be used to make them more vertically compact. |
| [default=',\n'] |
| mat_transpose: bool, optional |
| GLSL's matrix multiplication implementation assumes column-major indexing. |
| By default, this printer ignores that convention. Setting this option to |
| ``True`` transposes all matrix output. |
| [default=False] |
| array_type: str, optional |
| The GLSL array constructor type. |
| [default='float'] |
| precision : integer, optional |
| The precision for numbers such as pi [default=15]. |
| user_functions : dict, optional |
| A dictionary where keys are ``FunctionClass`` instances and values are |
| their string representations. Alternatively, the dictionary value can |
| be a list of tuples i.e. [(argument_test, js_function_string)]. See |
| below for examples. |
| human : bool, optional |
| If True, the result is a single string that may contain some constant |
| declarations for the number symbols. If False, the same information is |
| returned in a tuple of (symbols_to_declare, not_supported_functions, |
| code_text). [default=True]. |
| contract: bool, optional |
| If True, ``Indexed`` instances are assumed to obey tensor contraction |
| rules and the corresponding nested loops over indices are generated. |
| Setting contract=False will not generate loops, instead the user is |
| responsible to provide values for the indices in the code. |
| [default=True]. |
| |
| Examples |
| ======== |
| |
| >>> from sympy import glsl_code, symbols, Rational, sin, ceiling, Abs |
| >>> x, tau = symbols("x, tau") |
| >>> glsl_code((2*tau)**Rational(7, 2)) |
| '8*sqrt(2)*pow(tau, 3.5)' |
| >>> glsl_code(sin(x), assign_to="float y") |
| 'float y = sin(x);' |
| |
| Various GLSL types are supported: |
| >>> from sympy import Matrix, glsl_code |
| >>> glsl_code(Matrix([1,2,3])) |
| 'vec3(1, 2, 3)' |
| |
| >>> glsl_code(Matrix([[1, 2],[3, 4]])) |
| 'mat2(1, 2, 3, 4)' |
| |
| Pass ``mat_transpose = True`` to switch to column-major indexing: |
| >>> glsl_code(Matrix([[1, 2],[3, 4]]), mat_transpose = True) |
| 'mat2(1, 3, 2, 4)' |
| |
| By default, larger matrices get collapsed into float arrays: |
| >>> print(glsl_code( Matrix([[1,2,3,4,5],[6,7,8,9,10]]) )) |
| float[10]( |
| 1, 2, 3, 4, 5, |
| 6, 7, 8, 9, 10 |
| ) /* a 2x5 matrix */ |
| |
| The type of array constructor used to print GLSL arrays can be controlled |
| via the ``array_type`` parameter: |
| >>> glsl_code(Matrix([1,2,3,4,5]), array_type='int') |
| 'int[5](1, 2, 3, 4, 5)' |
| |
| Passing a list of strings or ``symbols`` to the ``assign_to`` parameter will yield |
| a multi-line assignment for each item in an array-like expression: |
| >>> x_struct_members = symbols('x.a x.b x.c x.d') |
| >>> print(glsl_code(Matrix([1,2,3,4]), assign_to=x_struct_members)) |
| x.a = 1; |
| x.b = 2; |
| x.c = 3; |
| x.d = 4; |
| |
| This could be useful in cases where it's desirable to modify members of a |
| GLSL ``Struct``. It could also be used to spread items from an array-like |
| expression into various miscellaneous assignments: |
| >>> misc_assignments = ('x[0]', 'x[1]', 'float y', 'float z') |
| >>> print(glsl_code(Matrix([1,2,3,4]), assign_to=misc_assignments)) |
| x[0] = 1; |
| x[1] = 2; |
| float y = 3; |
| float z = 4; |
| |
| Passing ``mat_nested = True`` instead prints out nested float arrays, which are |
| supported in GLSL 4.3 and above. |
| >>> mat = Matrix([ |
| ... [ 0, 1, 2], |
| ... [ 3, 4, 5], |
| ... [ 6, 7, 8], |
| ... [ 9, 10, 11], |
| ... [12, 13, 14]]) |
| >>> print(glsl_code( mat, mat_nested = True )) |
| float[5][3]( |
| float[]( 0, 1, 2), |
| float[]( 3, 4, 5), |
| float[]( 6, 7, 8), |
| float[]( 9, 10, 11), |
| float[](12, 13, 14) |
| ) |
| |
| |
| |
| Custom printing can be defined for certain types by passing a dictionary of |
| "type" : "function" to the ``user_functions`` kwarg. Alternatively, the |
| dictionary value can be a list of tuples i.e. [(argument_test, |
| js_function_string)]. |
| |
| >>> custom_functions = { |
| ... "ceiling": "CEIL", |
| ... "Abs": [(lambda x: not x.is_integer, "fabs"), |
| ... (lambda x: x.is_integer, "ABS")] |
| ... } |
| >>> glsl_code(Abs(x) + ceiling(x), user_functions=custom_functions) |
| 'fabs(x) + CEIL(x)' |
| |
| If further control is needed, addition, subtraction, multiplication and |
| division operators can be replaced with ``add``, ``sub``, and ``mul`` |
| functions. This is done by passing ``use_operators = False``: |
| |
| >>> x,y,z = symbols('x,y,z') |
| >>> glsl_code(x*(y+z), use_operators = False) |
| 'mul(x, add(y, z))' |
| >>> glsl_code(x*(y+z*(x-y)**z), use_operators = False) |
| 'mul(x, add(y, mul(z, pow(sub(x, y), z))))' |
| |
| ``Piecewise`` expressions are converted into conditionals. If an |
| ``assign_to`` variable is provided an if statement is created, otherwise |
| the ternary operator is used. Note that if the ``Piecewise`` lacks a |
| default term, represented by ``(expr, True)`` then an error will be thrown. |
| This is to prevent generating an expression that may not evaluate to |
| anything. |
| |
| >>> from sympy import Piecewise |
| >>> expr = Piecewise((x + 1, x > 0), (x, True)) |
| >>> print(glsl_code(expr, tau)) |
| if (x > 0) { |
| tau = x + 1; |
| } |
| else { |
| tau = x; |
| } |
| |
| Support for loops is provided through ``Indexed`` types. With |
| ``contract=True`` these expressions will be turned into loops, whereas |
| ``contract=False`` will just print the assignment expression that should be |
| looped over: |
| |
| >>> from sympy import Eq, IndexedBase, Idx |
| >>> len_y = 5 |
| >>> y = IndexedBase('y', shape=(len_y,)) |
| >>> t = IndexedBase('t', shape=(len_y,)) |
| >>> Dy = IndexedBase('Dy', shape=(len_y-1,)) |
| >>> i = Idx('i', len_y-1) |
| >>> e=Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i])) |
| >>> glsl_code(e.rhs, assign_to=e.lhs, contract=False) |
| 'Dy[i] = (y[i + 1] - y[i])/(t[i + 1] - t[i]);' |
| |
| >>> from sympy import Matrix, MatrixSymbol |
| >>> mat = Matrix([x**2, Piecewise((x + 1, x > 0), (x, True)), sin(x)]) |
| >>> A = MatrixSymbol('A', 3, 1) |
| >>> print(glsl_code(mat, A)) |
| A[0][0] = pow(x, 2.0); |
| if (x > 0) { |
| A[1][0] = x + 1; |
| } |
| else { |
| A[1][0] = x; |
| } |
| A[2][0] = sin(x); |
| """ |
| return GLSLPrinter(settings).doprint(expr,assign_to) |
|
|
| def print_glsl(expr, **settings): |
| """Prints the GLSL representation of the given expression. |
| |
| See GLSLPrinter init function for settings. |
| """ |
| print(glsl_code(expr, **settings)) |
|
|