|
|
"""A functionally equivalent parser of the numpy.einsum input parser.""" |
|
|
|
|
|
import itertools |
|
|
from typing import Any, Dict, Iterator, List, Sequence, Tuple |
|
|
|
|
|
from opt_einsum.typing import ArrayType, TensorShapeType |
|
|
|
|
|
__all__ = [ |
|
|
"is_valid_einsum_char", |
|
|
"has_valid_einsum_chars_only", |
|
|
"get_symbol", |
|
|
"get_shape", |
|
|
"gen_unused_symbols", |
|
|
"convert_to_valid_einsum_chars", |
|
|
"alpha_canonicalize", |
|
|
"find_output_str", |
|
|
"find_output_shape", |
|
|
"possibly_convert_to_numpy", |
|
|
"parse_einsum_input", |
|
|
] |
|
|
|
|
|
_einsum_symbols_base = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" |
|
|
|
|
|
|
|
|
def is_valid_einsum_char(x: str) -> bool: |
|
|
"""Check if the character ``x`` is valid for numpy einsum. |
|
|
|
|
|
**Examples:** |
|
|
|
|
|
```python |
|
|
is_valid_einsum_char("a") |
|
|
#> True |
|
|
|
|
|
is_valid_einsum_char("Ǵ") |
|
|
#> False |
|
|
``` |
|
|
""" |
|
|
return (x in _einsum_symbols_base) or (x in ",->.") |
|
|
|
|
|
|
|
|
def has_valid_einsum_chars_only(einsum_str: str) -> bool: |
|
|
"""Check if ``einsum_str`` contains only valid characters for numpy einsum. |
|
|
|
|
|
**Examples:** |
|
|
|
|
|
```python |
|
|
has_valid_einsum_chars_only("abAZ") |
|
|
#> True |
|
|
|
|
|
has_valid_einsum_chars_only("Över") |
|
|
#> False |
|
|
``` |
|
|
""" |
|
|
return all(map(is_valid_einsum_char, einsum_str)) |
|
|
|
|
|
|
|
|
def get_symbol(i: int) -> str: |
|
|
"""Get the symbol corresponding to int ``i`` - runs through the usual 52 |
|
|
letters before resorting to unicode characters, starting at ``chr(192)`` and skipping surrogates. |
|
|
|
|
|
**Examples:** |
|
|
|
|
|
```python |
|
|
get_symbol(2) |
|
|
#> 'c' |
|
|
|
|
|
get_symbol(200) |
|
|
#> 'Ŕ' |
|
|
|
|
|
get_symbol(20000) |
|
|
#> '京' |
|
|
``` |
|
|
""" |
|
|
if i < 52: |
|
|
return _einsum_symbols_base[i] |
|
|
elif i >= 55296: |
|
|
|
|
|
return chr(i + 2048) |
|
|
else: |
|
|
return chr(i + 140) |
|
|
|
|
|
|
|
|
def gen_unused_symbols(used: str, n: int) -> Iterator[str]: |
|
|
"""Generate ``n`` symbols that are not already in ``used``. |
|
|
|
|
|
**Examples:** |
|
|
```python |
|
|
list(oe.parser.gen_unused_symbols("abd", 2)) |
|
|
#> ['c', 'e'] |
|
|
``` |
|
|
""" |
|
|
i = cnt = 0 |
|
|
while cnt < n: |
|
|
s = get_symbol(i) |
|
|
i += 1 |
|
|
if s in used: |
|
|
continue |
|
|
yield s |
|
|
cnt += 1 |
|
|
|
|
|
|
|
|
def convert_to_valid_einsum_chars(einsum_str: str) -> str: |
|
|
"""Convert the str ``einsum_str`` to contain only the alphabetic characters |
|
|
valid for numpy einsum. If there are too many symbols, let the backend |
|
|
throw an error. |
|
|
|
|
|
Examples: |
|
|
-------- |
|
|
>>> oe.parser.convert_to_valid_einsum_chars("Ĥěļļö") |
|
|
'cbdda' |
|
|
""" |
|
|
symbols = sorted(set(einsum_str) - set(",->")) |
|
|
replacer = {x: get_symbol(i) for i, x in enumerate(symbols)} |
|
|
return "".join(replacer.get(x, x) for x in einsum_str) |
|
|
|
|
|
|
|
|
def alpha_canonicalize(equation: str) -> str: |
|
|
"""Alpha convert an equation in an order-independent canonical way. |
|
|
|
|
|
Examples: |
|
|
-------- |
|
|
>>> oe.parser.alpha_canonicalize("dcba") |
|
|
'abcd' |
|
|
|
|
|
>>> oe.parser.alpha_canonicalize("Ĥěļļö") |
|
|
'abccd' |
|
|
""" |
|
|
rename: Dict[str, str] = {} |
|
|
for name in equation: |
|
|
if name in ".,->": |
|
|
continue |
|
|
if name not in rename: |
|
|
rename[name] = get_symbol(len(rename)) |
|
|
return "".join(rename.get(x, x) for x in equation) |
|
|
|
|
|
|
|
|
def find_output_str(subscripts: str) -> str: |
|
|
"""Find the output string for the inputs ``subscripts`` under canonical einstein summation rules. |
|
|
That is, repeated indices are summed over by default. |
|
|
|
|
|
Examples: |
|
|
-------- |
|
|
>>> oe.parser.find_output_str("ab,bc") |
|
|
'ac' |
|
|
|
|
|
>>> oe.parser.find_output_str("a,b") |
|
|
'ab' |
|
|
|
|
|
>>> oe.parser.find_output_str("a,a,b,b") |
|
|
'' |
|
|
""" |
|
|
tmp_subscripts = subscripts.replace(",", "") |
|
|
return "".join(s for s in sorted(set(tmp_subscripts)) if tmp_subscripts.count(s) == 1) |
|
|
|
|
|
|
|
|
def find_output_shape(inputs: List[str], shapes: List[TensorShapeType], output: str) -> TensorShapeType: |
|
|
"""Find the output shape for given inputs, shapes and output string, taking |
|
|
into account broadcasting. |
|
|
|
|
|
Examples: |
|
|
-------- |
|
|
>>> oe.parser.find_output_shape(["ab", "bc"], [(2, 3), (3, 4)], "ac") |
|
|
(2, 4) |
|
|
|
|
|
# Broadcasting is accounted for |
|
|
>>> oe.parser.find_output_shape(["a", "a"], [(4, ), (1, )], "a") |
|
|
(4,) |
|
|
""" |
|
|
return tuple(max(shape[loc] for shape, loc in zip(shapes, [x.find(c) for x in inputs]) if loc >= 0) for c in output) |
|
|
|
|
|
|
|
|
_BaseTypes = (bool, int, float, complex, str, bytes) |
|
|
|
|
|
|
|
|
def get_shape(x: Any) -> TensorShapeType: |
|
|
"""Get the shape of the array-like object `x`. If `x` is not array-like, raise an error. |
|
|
|
|
|
Array-like objects are those that have a `shape` attribute, are sequences of BaseTypes, or are BaseTypes. |
|
|
BaseTypes are defined as `bool`, `int`, `float`, `complex`, `str`, and `bytes`. |
|
|
""" |
|
|
if hasattr(x, "shape"): |
|
|
return x.shape |
|
|
elif isinstance(x, _BaseTypes): |
|
|
return () |
|
|
elif isinstance(x, Sequence): |
|
|
shape = [] |
|
|
while isinstance(x, Sequence) and not isinstance(x, _BaseTypes): |
|
|
shape.append(len(x)) |
|
|
x = x[0] |
|
|
return tuple(shape) |
|
|
else: |
|
|
raise ValueError(f"Cannot determine the shape of {x}, can only determine the shape of array-like objects.") |
|
|
|
|
|
|
|
|
def possibly_convert_to_numpy(x: Any) -> Any: |
|
|
"""Convert things without a 'shape' to ndarrays, but leave everything else. |
|
|
|
|
|
Examples: |
|
|
-------- |
|
|
>>> oe.parser.possibly_convert_to_numpy(5) |
|
|
array(5) |
|
|
|
|
|
>>> oe.parser.possibly_convert_to_numpy([5, 3]) |
|
|
array([5, 3]) |
|
|
|
|
|
>>> oe.parser.possibly_convert_to_numpy(np.array([5, 3])) |
|
|
array([5, 3]) |
|
|
|
|
|
# Any class with a shape is passed through |
|
|
>>> class Shape: |
|
|
... def __init__(self, shape): |
|
|
... self.shape = shape |
|
|
... |
|
|
|
|
|
>>> myshape = Shape((5, 5)) |
|
|
>>> oe.parser.possibly_convert_to_numpy(myshape) |
|
|
<__main__.Shape object at 0x10f850710> |
|
|
""" |
|
|
if not hasattr(x, "shape"): |
|
|
try: |
|
|
import numpy as np |
|
|
except ModuleNotFoundError: |
|
|
raise ModuleNotFoundError( |
|
|
"numpy is required to convert non-array objects to arrays. This function will be deprecated in the future." |
|
|
) |
|
|
|
|
|
return np.asanyarray(x) |
|
|
else: |
|
|
return x |
|
|
|
|
|
|
|
|
def convert_subscripts(old_sub: List[Any], symbol_map: Dict[Any, Any]) -> str: |
|
|
"""Convert user custom subscripts list to subscript string according to `symbol_map`. |
|
|
|
|
|
Examples: |
|
|
-------- |
|
|
>>> oe.parser.convert_subscripts(['abc', 'def'], {'abc':'a', 'def':'b'}) |
|
|
'ab' |
|
|
>>> oe.parser.convert_subscripts([Ellipsis, object], {object:'a'}) |
|
|
'...a' |
|
|
""" |
|
|
new_sub = "" |
|
|
for s in old_sub: |
|
|
if s is Ellipsis: |
|
|
new_sub += "..." |
|
|
else: |
|
|
|
|
|
new_sub += symbol_map[s] |
|
|
return new_sub |
|
|
|
|
|
|
|
|
def convert_interleaved_input(operands: Sequence[Any]) -> Tuple[str, Tuple[Any, ...]]: |
|
|
"""Convert 'interleaved' input to standard einsum input.""" |
|
|
tmp_operands = list(operands) |
|
|
operand_list = [] |
|
|
subscript_list = [] |
|
|
for _ in range(len(operands) // 2): |
|
|
operand_list.append(tmp_operands.pop(0)) |
|
|
subscript_list.append(tmp_operands.pop(0)) |
|
|
|
|
|
output_list = tmp_operands[-1] if len(tmp_operands) else None |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
symbol_set = set(itertools.chain.from_iterable(subscript_list)) |
|
|
|
|
|
|
|
|
symbol_set.discard(Ellipsis) |
|
|
|
|
|
|
|
|
symbol_map = {symbol: get_symbol(idx) for idx, symbol in enumerate(sorted(symbol_set))} |
|
|
|
|
|
except TypeError: |
|
|
raise TypeError( |
|
|
"For this input type lists must contain either Ellipsis " |
|
|
"or hashable and comparable object (e.g. int, str)." |
|
|
) |
|
|
|
|
|
subscripts = ",".join(convert_subscripts(sub, symbol_map) for sub in subscript_list) |
|
|
if output_list is not None: |
|
|
subscripts += "->" |
|
|
subscripts += convert_subscripts(output_list, symbol_map) |
|
|
|
|
|
return subscripts, tuple(operand_list) |
|
|
|
|
|
|
|
|
def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, List[ArrayType]]: |
|
|
"""A reproduction of einsum c side einsum parsing in python. |
|
|
|
|
|
Parameters: |
|
|
operands: Intakes the same inputs as `contract_path`, but NOT the keyword args. The only |
|
|
supported keyword argument is: |
|
|
shapes: Whether ``parse_einsum_input`` should assume arrays (the default) or |
|
|
array shapes have been supplied. |
|
|
|
|
|
Returns: |
|
|
input_strings: Parsed input strings |
|
|
output_string: Parsed output string |
|
|
operands: The operands to use in the numpy contraction |
|
|
|
|
|
Examples: |
|
|
The operand list is simplified to reduce printing: |
|
|
|
|
|
```python |
|
|
>>> a = np.random.rand(4, 4) |
|
|
>>> b = np.random.rand(4, 4, 4) |
|
|
>>> parse_einsum_input(('...a,...a->...', a, b)) |
|
|
('za,xza', 'xz', [a, b]) |
|
|
|
|
|
>>> parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0])) |
|
|
('za,xza', 'xz', [a, b]) |
|
|
``` |
|
|
""" |
|
|
if len(operands) == 0: |
|
|
raise ValueError("No input operands") |
|
|
|
|
|
if isinstance(operands[0], str): |
|
|
subscripts = operands[0].replace(" ", "") |
|
|
if shapes: |
|
|
if any(hasattr(o, "shape") for o in operands[1:]): |
|
|
raise ValueError( |
|
|
"shapes is set to True but given at least one operand looks like an array" |
|
|
" (at least one operand has a shape attribute). " |
|
|
) |
|
|
operands = operands[1:] |
|
|
else: |
|
|
subscripts, operands = convert_interleaved_input(operands) |
|
|
|
|
|
if shapes: |
|
|
operand_shapes = operands |
|
|
else: |
|
|
operand_shapes = [get_shape(o) for o in operands] |
|
|
|
|
|
|
|
|
if ("-" in subscripts) or (">" in subscripts): |
|
|
invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1) |
|
|
if invalid or (subscripts.count("->") != 1): |
|
|
raise ValueError("Subscripts can only contain one '->'.") |
|
|
|
|
|
|
|
|
if "." in subscripts: |
|
|
used = subscripts.replace(".", "").replace(",", "").replace("->", "") |
|
|
ellipse_inds = "".join(gen_unused_symbols(used, max(len(x) for x in operand_shapes))) |
|
|
longest = 0 |
|
|
|
|
|
|
|
|
if "->" in subscripts: |
|
|
input_tmp, output_sub = subscripts.split("->") |
|
|
split_subscripts = input_tmp.split(",") |
|
|
out_sub = True |
|
|
else: |
|
|
split_subscripts = subscripts.split(",") |
|
|
out_sub = False |
|
|
|
|
|
for num, sub in enumerate(split_subscripts): |
|
|
if "." in sub: |
|
|
if (sub.count(".") != 3) or (sub.count("...") != 1): |
|
|
raise ValueError("Invalid Ellipses.") |
|
|
|
|
|
|
|
|
if operand_shapes[num] == (): |
|
|
ellipse_count = 0 |
|
|
else: |
|
|
ellipse_count = max(len(operand_shapes[num]), 1) - (len(sub) - 3) |
|
|
|
|
|
if ellipse_count > longest: |
|
|
longest = ellipse_count |
|
|
|
|
|
if ellipse_count < 0: |
|
|
raise ValueError("Ellipses lengths do not match.") |
|
|
elif ellipse_count == 0: |
|
|
split_subscripts[num] = sub.replace("...", "") |
|
|
else: |
|
|
split_subscripts[num] = sub.replace("...", ellipse_inds[-ellipse_count:]) |
|
|
|
|
|
subscripts = ",".join(split_subscripts) |
|
|
|
|
|
|
|
|
if longest == 0: |
|
|
out_ellipse = "" |
|
|
else: |
|
|
out_ellipse = ellipse_inds[-longest:] |
|
|
|
|
|
if out_sub: |
|
|
subscripts += "->" + output_sub.replace("...", out_ellipse) |
|
|
else: |
|
|
|
|
|
output_subscript = find_output_str(subscripts) |
|
|
normal_inds = "".join(sorted(set(output_subscript) - set(out_ellipse))) |
|
|
|
|
|
subscripts += "->" + out_ellipse + normal_inds |
|
|
|
|
|
|
|
|
if "->" in subscripts: |
|
|
input_subscripts, output_subscript = subscripts.split("->") |
|
|
else: |
|
|
input_subscripts, output_subscript = subscripts, find_output_str(subscripts) |
|
|
|
|
|
|
|
|
for char in output_subscript: |
|
|
if output_subscript.count(char) != 1: |
|
|
raise ValueError(f"Output character '{char}' appeared more than once in the output.") |
|
|
if char not in input_subscripts: |
|
|
raise ValueError(f"Output character '{char}' did not appear in the input") |
|
|
|
|
|
|
|
|
if len(input_subscripts.split(",")) != len(operands): |
|
|
raise ValueError( |
|
|
f"Number of einsum subscripts, {len(input_subscripts.split(','))}, must be equal to the " |
|
|
f"number of operands, {len(operands)}." |
|
|
) |
|
|
|
|
|
return input_subscripts, output_subscript, operands |
|
|
|