|
|
import bisect |
|
|
from collections import defaultdict |
|
|
|
|
|
from sympy.combinatorics import Permutation |
|
|
from sympy.core.containers import Tuple |
|
|
from sympy.core.numbers import Integer |
|
|
|
|
|
|
|
|
def _get_mapping_from_subranks(subranks): |
|
|
mapping = {} |
|
|
counter = 0 |
|
|
for i, rank in enumerate(subranks): |
|
|
for j in range(rank): |
|
|
mapping[counter] = (i, j) |
|
|
counter += 1 |
|
|
return mapping |
|
|
|
|
|
|
|
|
def _get_contraction_links(args, subranks, *contraction_indices): |
|
|
mapping = _get_mapping_from_subranks(subranks) |
|
|
contraction_tuples = [[mapping[j] for j in i] for i in contraction_indices] |
|
|
dlinks = defaultdict(dict) |
|
|
for links in contraction_tuples: |
|
|
if len(links) == 2: |
|
|
(arg1, pos1), (arg2, pos2) = links |
|
|
dlinks[arg1][pos1] = (arg2, pos2) |
|
|
dlinks[arg2][pos2] = (arg1, pos1) |
|
|
continue |
|
|
|
|
|
return args, dict(dlinks) |
|
|
|
|
|
|
|
|
def _sort_contraction_indices(pairing_indices): |
|
|
pairing_indices = [Tuple(*sorted(i)) for i in pairing_indices] |
|
|
pairing_indices.sort(key=lambda x: min(x)) |
|
|
return pairing_indices |
|
|
|
|
|
|
|
|
def _get_diagonal_indices(flattened_indices): |
|
|
axes_contraction = defaultdict(list) |
|
|
for i, ind in enumerate(flattened_indices): |
|
|
if isinstance(ind, (int, Integer)): |
|
|
|
|
|
continue |
|
|
axes_contraction[ind].append(i) |
|
|
axes_contraction = {k: v for k, v in axes_contraction.items() if len(v) > 1} |
|
|
|
|
|
ret_indices = [i for i in flattened_indices if i not in axes_contraction] |
|
|
diag_indices = list(axes_contraction) |
|
|
diag_indices.sort(key=lambda x: flattened_indices.index(x)) |
|
|
diagonal_indices = [tuple(axes_contraction[i]) for i in diag_indices] |
|
|
ret_indices += diag_indices |
|
|
ret_indices = tuple(ret_indices) |
|
|
return diagonal_indices, ret_indices |
|
|
|
|
|
|
|
|
def _get_argindex(subindices, ind): |
|
|
for i, sind in enumerate(subindices): |
|
|
if ind == sind: |
|
|
return i |
|
|
if isinstance(sind, (set, frozenset)) and ind in sind: |
|
|
return i |
|
|
raise IndexError("%s not found in %s" % (ind, subindices)) |
|
|
|
|
|
|
|
|
def _apply_recursively_over_nested_lists(func, arr): |
|
|
if isinstance(arr, (tuple, list, Tuple)): |
|
|
return tuple(_apply_recursively_over_nested_lists(func, i) for i in arr) |
|
|
elif isinstance(arr, Tuple): |
|
|
return Tuple.fromiter(_apply_recursively_over_nested_lists(func, i) for i in arr) |
|
|
else: |
|
|
return func(arr) |
|
|
|
|
|
|
|
|
def _build_push_indices_up_func_transformation(flattened_contraction_indices): |
|
|
shifts = {0: 0} |
|
|
i = 0 |
|
|
cumulative = 0 |
|
|
while i < len(flattened_contraction_indices): |
|
|
j = 1 |
|
|
while i+j < len(flattened_contraction_indices): |
|
|
if flattened_contraction_indices[i] + j != flattened_contraction_indices[i+j]: |
|
|
break |
|
|
j += 1 |
|
|
cumulative += j |
|
|
shifts[flattened_contraction_indices[i]] = cumulative |
|
|
i += j |
|
|
shift_keys = sorted(shifts.keys()) |
|
|
|
|
|
def func(idx): |
|
|
return shifts[shift_keys[bisect.bisect_right(shift_keys, idx)-1]] |
|
|
|
|
|
def transform(j): |
|
|
if j in flattened_contraction_indices: |
|
|
return None |
|
|
else: |
|
|
return j - func(j) |
|
|
|
|
|
return transform |
|
|
|
|
|
|
|
|
def _build_push_indices_down_func_transformation(flattened_contraction_indices): |
|
|
N = flattened_contraction_indices[-1]+2 |
|
|
|
|
|
shifts = [i for i in range(N) if i not in flattened_contraction_indices] |
|
|
|
|
|
def transform(j): |
|
|
if j < len(shifts): |
|
|
return shifts[j] |
|
|
else: |
|
|
return j + shifts[-1] - len(shifts) + 1 |
|
|
|
|
|
return transform |
|
|
|
|
|
|
|
|
def _apply_permutation_to_list(perm: Permutation, target_list: list): |
|
|
""" |
|
|
Permute a list according to the given permutation. |
|
|
""" |
|
|
new_list = [None for i in range(perm.size)] |
|
|
for i, e in enumerate(target_list): |
|
|
new_list[perm(i)] = e |
|
|
return new_list |
|
|
|