| | from sympy.core.basic import Basic |
| | from sympy.core.expr import Expr |
| | from sympy.core.symbol import Symbol |
| | from sympy.core.numbers import Integer, Rational, Float |
| | from sympy.printing.repr import srepr |
| |
|
| | __all__ = ['dotprint'] |
| |
|
| | default_styles = ( |
| | (Basic, {'color': 'blue', 'shape': 'ellipse'}), |
| | (Expr, {'color': 'black'}) |
| | ) |
| |
|
| | slotClasses = (Symbol, Integer, Rational, Float) |
| | def purestr(x, with_args=False): |
| | """A string that follows ```obj = type(obj)(*obj.args)``` exactly. |
| | |
| | Parameters |
| | ========== |
| | |
| | with_args : boolean, optional |
| | If ``True``, there will be a second argument for the return |
| | value, which is a tuple containing ``purestr`` applied to each |
| | of the subnodes. |
| | |
| | If ``False``, there will not be a second argument for the |
| | return. |
| | |
| | Default is ``False`` |
| | |
| | Examples |
| | ======== |
| | |
| | >>> from sympy import Float, Symbol, MatrixSymbol |
| | >>> from sympy import Integer # noqa: F401 |
| | >>> from sympy.core.symbol import Str # noqa: F401 |
| | >>> from sympy.printing.dot import purestr |
| | |
| | Applying ``purestr`` for basic symbolic object: |
| | >>> code = purestr(Symbol('x')) |
| | >>> code |
| | "Symbol('x')" |
| | >>> eval(code) == Symbol('x') |
| | True |
| | |
| | For basic numeric object: |
| | >>> purestr(Float(2)) |
| | "Float('2.0', precision=53)" |
| | |
| | For matrix symbol: |
| | >>> code = purestr(MatrixSymbol('x', 2, 2)) |
| | >>> code |
| | "MatrixSymbol(Str('x'), Integer(2), Integer(2))" |
| | >>> eval(code) == MatrixSymbol('x', 2, 2) |
| | True |
| | |
| | With ``with_args=True``: |
| | >>> purestr(Float(2), with_args=True) |
| | ("Float('2.0', precision=53)", ()) |
| | >>> purestr(MatrixSymbol('x', 2, 2), with_args=True) |
| | ("MatrixSymbol(Str('x'), Integer(2), Integer(2))", |
| | ("Str('x')", 'Integer(2)', 'Integer(2)')) |
| | """ |
| | sargs = () |
| | if not isinstance(x, Basic): |
| | rv = str(x) |
| | elif not x.args: |
| | rv = srepr(x) |
| | else: |
| | args = x.args |
| | sargs = tuple(map(purestr, args)) |
| | rv = "%s(%s)"%(type(x).__name__, ', '.join(sargs)) |
| | if with_args: |
| | rv = rv, sargs |
| | return rv |
| |
|
| |
|
| | def styleof(expr, styles=default_styles): |
| | """ Merge style dictionaries in order |
| | |
| | Examples |
| | ======== |
| | |
| | >>> from sympy import Symbol, Basic, Expr, S |
| | >>> from sympy.printing.dot import styleof |
| | >>> styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}), |
| | ... (Expr, {'color': 'black'})] |
| | |
| | >>> styleof(Basic(S(1)), styles) |
| | {'color': 'blue', 'shape': 'ellipse'} |
| | |
| | >>> x = Symbol('x') |
| | >>> styleof(x + 1, styles) # this is an Expr |
| | {'color': 'black', 'shape': 'ellipse'} |
| | """ |
| | style = {} |
| | for typ, sty in styles: |
| | if isinstance(expr, typ): |
| | style.update(sty) |
| | return style |
| |
|
| |
|
| | def attrprint(d, delimiter=', '): |
| | """ Print a dictionary of attributes |
| | |
| | Examples |
| | ======== |
| | |
| | >>> from sympy.printing.dot import attrprint |
| | >>> print(attrprint({'color': 'blue', 'shape': 'ellipse'})) |
| | "color"="blue", "shape"="ellipse" |
| | """ |
| | return delimiter.join('"%s"="%s"'%item for item in sorted(d.items())) |
| |
|
| |
|
| | def dotnode(expr, styles=default_styles, labelfunc=str, pos=(), repeat=True): |
| | """ String defining a node |
| | |
| | Examples |
| | ======== |
| | |
| | >>> from sympy.printing.dot import dotnode |
| | >>> from sympy.abc import x |
| | >>> print(dotnode(x)) |
| | "Symbol('x')_()" ["color"="black", "label"="x", "shape"="ellipse"]; |
| | """ |
| | style = styleof(expr, styles) |
| |
|
| | if isinstance(expr, Basic) and not expr.is_Atom: |
| | label = str(expr.__class__.__name__) |
| | else: |
| | label = labelfunc(expr) |
| | style['label'] = label |
| | expr_str = purestr(expr) |
| | if repeat: |
| | expr_str += '_%s' % str(pos) |
| | return '"%s" [%s];' % (expr_str, attrprint(style)) |
| |
|
| |
|
| | def dotedges(expr, atom=lambda x: not isinstance(x, Basic), pos=(), repeat=True): |
| | """ List of strings for all expr->expr.arg pairs |
| | |
| | See the docstring of dotprint for explanations of the options. |
| | |
| | Examples |
| | ======== |
| | |
| | >>> from sympy.printing.dot import dotedges |
| | >>> from sympy.abc import x |
| | >>> for e in dotedges(x+2): |
| | ... print(e) |
| | "Add(Integer(2), Symbol('x'))_()" -> "Integer(2)_(0,)"; |
| | "Add(Integer(2), Symbol('x'))_()" -> "Symbol('x')_(1,)"; |
| | """ |
| | if atom(expr): |
| | return [] |
| | else: |
| | expr_str, arg_strs = purestr(expr, with_args=True) |
| | if repeat: |
| | expr_str += '_%s' % str(pos) |
| | arg_strs = ['%s_%s' % (a, str(pos + (i,))) |
| | for i, a in enumerate(arg_strs)] |
| | return ['"%s" -> "%s";' % (expr_str, a) for a in arg_strs] |
| |
|
| | template = \ |
| | """digraph{ |
| | |
| | # Graph style |
| | %(graphstyle)s |
| | |
| | ######### |
| | # Nodes # |
| | ######### |
| | |
| | %(nodes)s |
| | |
| | ######### |
| | # Edges # |
| | ######### |
| | |
| | %(edges)s |
| | }""" |
| |
|
| | _graphstyle = {'rankdir': 'TD', 'ordering': 'out'} |
| |
|
| | def dotprint(expr, |
| | styles=default_styles, atom=lambda x: not isinstance(x, Basic), |
| | maxdepth=None, repeat=True, labelfunc=str, **kwargs): |
| | """DOT description of a SymPy expression tree |
| | |
| | Parameters |
| | ========== |
| | |
| | styles : list of lists composed of (Class, mapping), optional |
| | Styles for different classes. |
| | |
| | The default is |
| | |
| | .. code-block:: python |
| | |
| | ( |
| | (Basic, {'color': 'blue', 'shape': 'ellipse'}), |
| | (Expr, {'color': 'black'}) |
| | ) |
| | |
| | atom : function, optional |
| | Function used to determine if an arg is an atom. |
| | |
| | A good choice is ``lambda x: not x.args``. |
| | |
| | The default is ``lambda x: not isinstance(x, Basic)``. |
| | |
| | maxdepth : integer, optional |
| | The maximum depth. |
| | |
| | The default is ``None``, meaning no limit. |
| | |
| | repeat : boolean, optional |
| | Whether to use different nodes for common subexpressions. |
| | |
| | The default is ``True``. |
| | |
| | For example, for ``x + x*y`` with ``repeat=True``, it will have |
| | two nodes for ``x``; with ``repeat=False``, it will have one |
| | node. |
| | |
| | .. warning:: |
| | Even if a node appears twice in the same object like ``x`` in |
| | ``Pow(x, x)``, it will still only appear once. |
| | Hence, with ``repeat=False``, the number of arrows out of an |
| | object might not equal the number of args it has. |
| | |
| | labelfunc : function, optional |
| | A function to create a label for a given leaf node. |
| | |
| | The default is ``str``. |
| | |
| | Another good option is ``srepr``. |
| | |
| | For example with ``str``, the leaf nodes of ``x + 1`` are labeled, |
| | ``x`` and ``1``. With ``srepr``, they are labeled ``Symbol('x')`` |
| | and ``Integer(1)``. |
| | |
| | **kwargs : optional |
| | Additional keyword arguments are included as styles for the graph. |
| | |
| | Examples |
| | ======== |
| | |
| | >>> from sympy import dotprint |
| | >>> from sympy.abc import x |
| | >>> print(dotprint(x+2)) # doctest: +NORMALIZE_WHITESPACE |
| | digraph{ |
| | <BLANKLINE> |
| | # Graph style |
| | "ordering"="out" |
| | "rankdir"="TD" |
| | <BLANKLINE> |
| | ######### |
| | # Nodes # |
| | ######### |
| | <BLANKLINE> |
| | "Add(Integer(2), Symbol('x'))_()" ["color"="black", "label"="Add", "shape"="ellipse"]; |
| | "Integer(2)_(0,)" ["color"="black", "label"="2", "shape"="ellipse"]; |
| | "Symbol('x')_(1,)" ["color"="black", "label"="x", "shape"="ellipse"]; |
| | <BLANKLINE> |
| | ######### |
| | # Edges # |
| | ######### |
| | <BLANKLINE> |
| | "Add(Integer(2), Symbol('x'))_()" -> "Integer(2)_(0,)"; |
| | "Add(Integer(2), Symbol('x'))_()" -> "Symbol('x')_(1,)"; |
| | } |
| | |
| | """ |
| | |
| | |
| | |
| | graphstyle = _graphstyle.copy() |
| | graphstyle.update(kwargs) |
| |
|
| | nodes = [] |
| | edges = [] |
| | def traverse(e, depth, pos=()): |
| | nodes.append(dotnode(e, styles, labelfunc=labelfunc, pos=pos, repeat=repeat)) |
| | if maxdepth and depth >= maxdepth: |
| | return |
| | edges.extend(dotedges(e, atom=atom, pos=pos, repeat=repeat)) |
| | [traverse(arg, depth+1, pos + (i,)) for i, arg in enumerate(e.args) if not atom(arg)] |
| | traverse(expr, 0) |
| |
|
| | return template%{'graphstyle': attrprint(graphstyle, delimiter='\n'), |
| | 'nodes': '\n'.join(nodes), |
| | 'edges': '\n'.join(edges)} |
| |
|