| | """Module with functions operating on IndexedBase, Indexed and Idx objects |
| | |
| | - Check shape conformance |
| | - Determine indices in resulting expression |
| | |
| | etc. |
| | |
| | Methods in this module could be implemented by calling methods on Expr |
| | objects instead. When things stabilize this could be a useful |
| | refactoring. |
| | """ |
| |
|
| | from functools import reduce |
| |
|
| | from sympy.core.function import Function |
| | from sympy.functions import exp, Piecewise |
| | from sympy.tensor.indexed import Idx, Indexed |
| | from sympy.utilities import sift |
| |
|
| | from collections import OrderedDict |
| |
|
| | class IndexConformanceException(Exception): |
| | pass |
| |
|
| | def _unique_and_repeated(inds): |
| | """ |
| | Returns the unique and repeated indices. Also note, from the examples given below |
| | that the order of indices is maintained as given in the input. |
| | |
| | Examples |
| | ======== |
| | |
| | >>> from sympy.tensor.index_methods import _unique_and_repeated |
| | >>> _unique_and_repeated([2, 3, 1, 3, 0, 4, 0]) |
| | ([2, 1, 4], [3, 0]) |
| | """ |
| | uniq = OrderedDict() |
| | for i in inds: |
| | if i in uniq: |
| | uniq[i] = 0 |
| | else: |
| | uniq[i] = 1 |
| | return sift(uniq, lambda x: uniq[x], binary=True) |
| |
|
| | def _remove_repeated(inds): |
| | """ |
| | Removes repeated objects from sequences |
| | |
| | Returns a set of the unique objects and a tuple of all that have been |
| | removed. |
| | |
| | Examples |
| | ======== |
| | |
| | >>> from sympy.tensor.index_methods import _remove_repeated |
| | >>> l1 = [1, 2, 3, 2] |
| | >>> _remove_repeated(l1) |
| | ({1, 3}, (2,)) |
| | |
| | """ |
| | u, r = _unique_and_repeated(inds) |
| | return set(u), tuple(r) |
| |
|
| |
|
| | def _get_indices_Mul(expr, return_dummies=False): |
| | """Determine the outer indices of a Mul object. |
| | |
| | Examples |
| | ======== |
| | |
| | >>> from sympy.tensor.index_methods import _get_indices_Mul |
| | >>> from sympy.tensor.indexed import IndexedBase, Idx |
| | >>> i, j, k = map(Idx, ['i', 'j', 'k']) |
| | >>> x = IndexedBase('x') |
| | >>> y = IndexedBase('y') |
| | >>> _get_indices_Mul(x[i, k]*y[j, k]) |
| | ({i, j}, {}) |
| | >>> _get_indices_Mul(x[i, k]*y[j, k], return_dummies=True) |
| | ({i, j}, {}, (k,)) |
| | |
| | """ |
| |
|
| | inds = list(map(get_indices, expr.args)) |
| | inds, syms = list(zip(*inds)) |
| |
|
| | inds = list(map(list, inds)) |
| | inds = list(reduce(lambda x, y: x + y, inds)) |
| | inds, dummies = _remove_repeated(inds) |
| |
|
| | symmetry = {} |
| | for s in syms: |
| | for pair in s: |
| | if pair in symmetry: |
| | symmetry[pair] *= s[pair] |
| | else: |
| | symmetry[pair] = s[pair] |
| |
|
| | if return_dummies: |
| | return inds, symmetry, dummies |
| | else: |
| | return inds, symmetry |
| |
|
| |
|
| | def _get_indices_Pow(expr): |
| | """Determine outer indices of a power or an exponential. |
| | |
| | A power is considered a universal function, so that the indices of a Pow is |
| | just the collection of indices present in the expression. This may be |
| | viewed as a bit inconsistent in the special case: |
| | |
| | x[i]**2 = x[i]*x[i] (1) |
| | |
| | The above expression could have been interpreted as the contraction of x[i] |
| | with itself, but we choose instead to interpret it as a function |
| | |
| | lambda y: y**2 |
| | |
| | applied to each element of x (a universal function in numpy terms). In |
| | order to allow an interpretation of (1) as a contraction, we need |
| | contravariant and covariant Idx subclasses. (FIXME: this is not yet |
| | implemented) |
| | |
| | Expressions in the base or exponent are subject to contraction as usual, |
| | but an index that is present in the exponent, will not be considered |
| | contractable with its own base. Note however, that indices in the same |
| | exponent can be contracted with each other. |
| | |
| | Examples |
| | ======== |
| | |
| | >>> from sympy.tensor.index_methods import _get_indices_Pow |
| | >>> from sympy import Pow, exp, IndexedBase, Idx |
| | >>> A = IndexedBase('A') |
| | >>> x = IndexedBase('x') |
| | >>> i, j, k = map(Idx, ['i', 'j', 'k']) |
| | >>> _get_indices_Pow(exp(A[i, j]*x[j])) |
| | ({i}, {}) |
| | >>> _get_indices_Pow(Pow(x[i], x[i])) |
| | ({i}, {}) |
| | >>> _get_indices_Pow(Pow(A[i, j]*x[j], x[i])) |
| | ({i}, {}) |
| | |
| | """ |
| | base, exp = expr.as_base_exp() |
| | binds, bsyms = get_indices(base) |
| | einds, esyms = get_indices(exp) |
| |
|
| | inds = binds | einds |
| |
|
| | |
| | symmetries = {} |
| |
|
| | return inds, symmetries |
| |
|
| |
|
| | def _get_indices_Add(expr): |
| | """Determine outer indices of an Add object. |
| | |
| | In a sum, each term must have the same set of outer indices. A valid |
| | expression could be |
| | |
| | x(i)*y(j) - x(j)*y(i) |
| | |
| | But we do not allow expressions like: |
| | |
| | x(i)*y(j) - z(j)*z(j) |
| | |
| | FIXME: Add support for Numpy broadcasting |
| | |
| | Examples |
| | ======== |
| | |
| | >>> from sympy.tensor.index_methods import _get_indices_Add |
| | >>> from sympy.tensor.indexed import IndexedBase, Idx |
| | >>> i, j, k = map(Idx, ['i', 'j', 'k']) |
| | >>> x = IndexedBase('x') |
| | >>> y = IndexedBase('y') |
| | >>> _get_indices_Add(x[i] + x[k]*y[i, k]) |
| | ({i}, {}) |
| | |
| | """ |
| |
|
| | inds = list(map(get_indices, expr.args)) |
| | inds, syms = list(zip(*inds)) |
| |
|
| | |
| | non_scalars = [x for x in inds if x != set()] |
| | if not non_scalars: |
| | return set(), {} |
| |
|
| | if not all(x == non_scalars[0] for x in non_scalars[1:]): |
| | raise IndexConformanceException("Indices are not consistent: %s" % expr) |
| | if not reduce(lambda x, y: x != y or y, syms): |
| | symmetries = syms[0] |
| | else: |
| | |
| | symmetries = {} |
| |
|
| | return non_scalars[0], symmetries |
| |
|
| |
|
| | def get_indices(expr): |
| | """Determine the outer indices of expression ``expr`` |
| | |
| | By *outer* we mean indices that are not summation indices. Returns a set |
| | and a dict. The set contains outer indices and the dict contains |
| | information about index symmetries. |
| | |
| | Examples |
| | ======== |
| | |
| | >>> from sympy.tensor.index_methods import get_indices |
| | >>> from sympy import symbols |
| | >>> from sympy.tensor import IndexedBase |
| | >>> x, y, A = map(IndexedBase, ['x', 'y', 'A']) |
| | >>> i, j, a, z = symbols('i j a z', integer=True) |
| | |
| | The indices of the total expression is determined, Repeated indices imply a |
| | summation, for instance the trace of a matrix A: |
| | |
| | >>> get_indices(A[i, i]) |
| | (set(), {}) |
| | |
| | In the case of many terms, the terms are required to have identical |
| | outer indices. Else an IndexConformanceException is raised. |
| | |
| | >>> get_indices(x[i] + A[i, j]*y[j]) |
| | ({i}, {}) |
| | |
| | :Exceptions: |
| | |
| | An IndexConformanceException means that the terms ar not compatible, e.g. |
| | |
| | >>> get_indices(x[i] + y[j]) #doctest: +SKIP |
| | (...) |
| | IndexConformanceException: Indices are not consistent: x(i) + y(j) |
| | |
| | .. warning:: |
| | The concept of *outer* indices applies recursively, starting on the deepest |
| | level. This implies that dummies inside parenthesis are assumed to be |
| | summed first, so that the following expression is handled gracefully: |
| | |
| | >>> get_indices((x[i] + A[i, j]*y[j])*x[j]) |
| | ({i, j}, {}) |
| | |
| | This is correct and may appear convenient, but you need to be careful |
| | with this as SymPy will happily .expand() the product, if requested. The |
| | resulting expression would mix the outer ``j`` with the dummies inside |
| | the parenthesis, which makes it a different expression. To be on the |
| | safe side, it is best to avoid such ambiguities by using unique indices |
| | for all contractions that should be held separate. |
| | |
| | """ |
| | |
| |
|
| | |
| | if isinstance(expr, Indexed): |
| | c = expr.indices |
| | inds, dummies = _remove_repeated(c) |
| | return inds, {} |
| | elif expr is None: |
| | return set(), {} |
| | elif isinstance(expr, Idx): |
| | return {expr}, {} |
| | elif expr.is_Atom: |
| | return set(), {} |
| |
|
| |
|
| | |
| | else: |
| | if expr.is_Mul: |
| | return _get_indices_Mul(expr) |
| | elif expr.is_Add: |
| | return _get_indices_Add(expr) |
| | elif expr.is_Pow or isinstance(expr, exp): |
| | return _get_indices_Pow(expr) |
| |
|
| | elif isinstance(expr, Piecewise): |
| | |
| | return set(), {} |
| | elif isinstance(expr, Function): |
| | |
| | |
| | |
| | ind0 = set() |
| | for arg in expr.args: |
| | ind, sym = get_indices(arg) |
| | ind0 |= ind |
| | return ind0, sym |
| |
|
| | |
| | elif not expr.has(Indexed): |
| | return set(), {} |
| | raise NotImplementedError( |
| | "FIXME: No specialized handling of type %s" % type(expr)) |
| |
|
| |
|
| | def get_contraction_structure(expr): |
| | """Determine dummy indices of ``expr`` and describe its structure |
| | |
| | By *dummy* we mean indices that are summation indices. |
| | |
| | The structure of the expression is determined and described as follows: |
| | |
| | 1) A conforming summation of Indexed objects is described with a dict where |
| | the keys are summation indices and the corresponding values are sets |
| | containing all terms for which the summation applies. All Add objects |
| | in the SymPy expression tree are described like this. |
| | |
| | 2) For all nodes in the SymPy expression tree that are *not* of type Add, the |
| | following applies: |
| | |
| | If a node discovers contractions in one of its arguments, the node |
| | itself will be stored as a key in the dict. For that key, the |
| | corresponding value is a list of dicts, each of which is the result of a |
| | recursive call to get_contraction_structure(). The list contains only |
| | dicts for the non-trivial deeper contractions, omitting dicts with None |
| | as the one and only key. |
| | |
| | .. Note:: The presence of expressions among the dictionary keys indicates |
| | multiple levels of index contractions. A nested dict displays nested |
| | contractions and may itself contain dicts from a deeper level. In |
| | practical calculations the summation in the deepest nested level must be |
| | calculated first so that the outer expression can access the resulting |
| | indexed object. |
| | |
| | Examples |
| | ======== |
| | |
| | >>> from sympy.tensor.index_methods import get_contraction_structure |
| | >>> from sympy import default_sort_key |
| | >>> from sympy.tensor import IndexedBase, Idx |
| | >>> x, y, A = map(IndexedBase, ['x', 'y', 'A']) |
| | >>> i, j, k, l = map(Idx, ['i', 'j', 'k', 'l']) |
| | >>> get_contraction_structure(x[i]*y[i] + A[j, j]) |
| | {(i,): {x[i]*y[i]}, (j,): {A[j, j]}} |
| | >>> get_contraction_structure(x[i]*y[j]) |
| | {None: {x[i]*y[j]}} |
| | |
| | A multiplication of contracted factors results in nested dicts representing |
| | the internal contractions. |
| | |
| | >>> d = get_contraction_structure(x[i, i]*y[j, j]) |
| | >>> sorted(d.keys(), key=default_sort_key) |
| | [None, x[i, i]*y[j, j]] |
| | |
| | In this case, the product has no contractions: |
| | |
| | >>> d[None] |
| | {x[i, i]*y[j, j]} |
| | |
| | Factors are contracted "first": |
| | |
| | >>> sorted(d[x[i, i]*y[j, j]], key=default_sort_key) |
| | [{(i,): {x[i, i]}}, {(j,): {y[j, j]}}] |
| | |
| | A parenthesized Add object is also returned as a nested dictionary. The |
| | term containing the parenthesis is a Mul with a contraction among the |
| | arguments, so it will be found as a key in the result. It stores the |
| | dictionary resulting from a recursive call on the Add expression. |
| | |
| | >>> d = get_contraction_structure(x[i]*(y[i] + A[i, j]*x[j])) |
| | >>> sorted(d.keys(), key=default_sort_key) |
| | [(A[i, j]*x[j] + y[i])*x[i], (i,)] |
| | >>> d[(i,)] |
| | {(A[i, j]*x[j] + y[i])*x[i]} |
| | >>> d[x[i]*(A[i, j]*x[j] + y[i])] |
| | [{None: {y[i]}, (j,): {A[i, j]*x[j]}}] |
| | |
| | Powers with contractions in either base or exponent will also be found as |
| | keys in the dictionary, mapping to a list of results from recursive calls: |
| | |
| | >>> d = get_contraction_structure(A[j, j]**A[i, i]) |
| | >>> d[None] |
| | {A[j, j]**A[i, i]} |
| | >>> nested_contractions = d[A[j, j]**A[i, i]] |
| | >>> nested_contractions[0] |
| | {(j,): {A[j, j]}} |
| | >>> nested_contractions[1] |
| | {(i,): {A[i, i]}} |
| | |
| | The description of the contraction structure may appear complicated when |
| | represented with a string in the above examples, but it is easy to iterate |
| | over: |
| | |
| | >>> from sympy import Expr |
| | >>> for key in d: |
| | ... if isinstance(key, Expr): |
| | ... continue |
| | ... for term in d[key]: |
| | ... if term in d: |
| | ... # treat deepest contraction first |
| | ... pass |
| | ... # treat outermost contactions here |
| | |
| | """ |
| |
|
| | |
| |
|
| | if isinstance(expr, Indexed): |
| | junk, key = _remove_repeated(expr.indices) |
| | return {key or None: {expr}} |
| | elif expr.is_Atom: |
| | return {None: {expr}} |
| | elif expr.is_Mul: |
| | junk, junk, key = _get_indices_Mul(expr, return_dummies=True) |
| | result = {key or None: {expr}} |
| | |
| | nested = [] |
| | for fac in expr.args: |
| | facd = get_contraction_structure(fac) |
| | if not (None in facd and len(facd) == 1): |
| | nested.append(facd) |
| | if nested: |
| | result[expr] = nested |
| | return result |
| | elif expr.is_Pow or isinstance(expr, exp): |
| | |
| | |
| | b, e = expr.as_base_exp() |
| | dbase = get_contraction_structure(b) |
| | dexp = get_contraction_structure(e) |
| |
|
| | dicts = [] |
| | for d in dbase, dexp: |
| | if not (None in d and len(d) == 1): |
| | dicts.append(d) |
| | result = {None: {expr}} |
| | if dicts: |
| | result[expr] = dicts |
| | return result |
| | elif expr.is_Add: |
| | |
| | |
| | |
| | |
| | result = {} |
| | for term in expr.args: |
| | |
| | d = get_contraction_structure(term) |
| | for key in d: |
| | if key in result: |
| | result[key] |= d[key] |
| | else: |
| | result[key] = d[key] |
| | return result |
| |
|
| | elif isinstance(expr, Piecewise): |
| | |
| | return {None: expr} |
| | elif isinstance(expr, Function): |
| | |
| | |
| | |
| | deeplist = [] |
| | for arg in expr.args: |
| | deep = get_contraction_structure(arg) |
| | if not (None in deep and len(deep) == 1): |
| | deeplist.append(deep) |
| | d = {None: {expr}} |
| | if deeplist: |
| | d[expr] = deeplist |
| | return d |
| |
|
| | |
| | elif not expr.has(Indexed): |
| | return {None: {expr}} |
| | raise NotImplementedError( |
| | "FIXME: No specialized handling of type %s" % type(expr)) |
| |
|