| """Adapted from https://github.com/arogozhnikov/einops/blob/36c7bb16e57d6e57f8f3050f9e07abdf3f00469f/einops/parsing.py.
|
|
|
| MIT License
|
|
|
| Copyright (c) 2018 Alex Rogozhnikov
|
|
|
| Permission is hereby granted, free of charge, to any person obtaining a copy
|
| of this software and associated documentation files (the "Software"), to deal
|
| in the Software without restriction, including without limitation the rights
|
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| copies of the Software, and to permit persons to whom the Software is
|
| furnished to do so, subject to the following conditions:
|
|
|
| The above copyright notice and this permission notice shall be included in all
|
| copies or substantial portions of the Software.
|
|
|
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| SOFTWARE.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import keyword
|
| import warnings
|
| from typing import Collection, List, Mapping, Optional, Set, Tuple, Union
|
|
|
|
|
| _ellipsis: str = "\u2026"
|
|
|
|
|
| class AnonymousAxis:
|
| """Used by `ParsedExpression` to represent an axis with a size (> 1), but no associated identifier.
|
|
|
| Note: Different instances of this class are not equal to each other, even if they have the same value.
|
| """
|
|
|
| def __init__(self, value: str) -> None:
|
| self.value = int(value)
|
| if self.value < 1:
|
| raise ValueError(
|
| f"Anonymous axis should have positive length, not {self.value}"
|
| )
|
|
|
| def __repr__(self) -> str:
|
| return f"{self.value}-axis"
|
|
|
|
|
| class ParsedExpression:
|
| """Structure containing information about one side of an `einops`-style pattern (e.g. 'b c (h w)')."""
|
|
|
| def __init__(
|
| self,
|
| expression: str,
|
| *,
|
| allow_underscore: bool = False,
|
| allow_duplicates: bool = False,
|
| ) -> None:
|
| """Parse the expression and store relevant metadata.
|
|
|
| Args:
|
| expression (str): the `einops`-pattern to parse
|
| allow_underscore (bool): whether to allow axis identifier names to begin with an underscore
|
| allow_duplicates (bool): whether to allow an identifier to appear more than once in the expression
|
| """
|
| self.has_ellipsis: bool = False
|
| self.has_ellipsis_parenthesized: Optional[bool] = None
|
| self.identifiers: Set[Union[str, AnonymousAxis]] = set()
|
|
|
| self.has_non_unitary_anonymous_axes: bool = False
|
|
|
| self.composition: List[Union[List[Union[str, AnonymousAxis]], str]] = []
|
| if "." in expression:
|
| if "..." not in expression:
|
| raise ValueError(
|
| "Expression may contain dots only inside ellipsis (...)"
|
| )
|
| if str.count(expression, "...") != 1 or str.count(expression, ".") != 3:
|
| raise ValueError(
|
| "Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor "
|
| )
|
| expression = expression.replace("...", _ellipsis)
|
| self.has_ellipsis = True
|
|
|
| bracket_group: Optional[List[Union[str, AnonymousAxis]]] = None
|
|
|
| def add_axis_name(x: str) -> None:
|
| if x in self.identifiers:
|
| if not (allow_underscore and x == "_") and not allow_duplicates:
|
| raise ValueError(
|
| f"Indexing expression contains duplicate dimension '{x}'"
|
| )
|
| if x == _ellipsis:
|
| self.identifiers.add(_ellipsis)
|
| if bracket_group is None:
|
| self.composition.append(_ellipsis)
|
| self.has_ellipsis_parenthesized = False
|
| else:
|
| bracket_group.append(_ellipsis)
|
| self.has_ellipsis_parenthesized = True
|
| else:
|
| is_number = str.isdecimal(x)
|
| if is_number and int(x) == 1:
|
|
|
| if bracket_group is None:
|
| self.composition.append([])
|
| else:
|
| pass
|
| return
|
| is_axis_name, reason = self.check_axis_name_return_reason(
|
| x, allow_underscore=allow_underscore
|
| )
|
| if not (is_number or is_axis_name):
|
| raise ValueError(f"Invalid axis identifier: {x}\n{reason}")
|
| axis_name: Union[str, AnonymousAxis] = (
|
| AnonymousAxis(x) if is_number else x
|
| )
|
| self.identifiers.add(axis_name)
|
| if is_number:
|
| self.has_non_unitary_anonymous_axes = True
|
| if bracket_group is None:
|
| self.composition.append([axis_name])
|
| else:
|
| bracket_group.append(axis_name)
|
|
|
| current_identifier = None
|
| for char in expression:
|
| if char in "() ":
|
| if current_identifier is not None:
|
| add_axis_name(current_identifier)
|
| current_identifier = None
|
| if char == "(":
|
| if bracket_group is not None:
|
| raise ValueError(
|
| "Axis composition is one-level (brackets inside brackets not allowed)"
|
| )
|
| bracket_group = []
|
| elif char == ")":
|
| if bracket_group is None:
|
| raise ValueError("Brackets are not balanced")
|
| self.composition.append(bracket_group)
|
| bracket_group = None
|
| elif str.isalnum(char) or char in ["_", _ellipsis]:
|
| if current_identifier is None:
|
| current_identifier = char
|
| else:
|
| current_identifier += char
|
| else:
|
| raise ValueError(f"Unknown character '{char}'")
|
|
|
| if bracket_group is not None:
|
| raise ValueError(f"Imbalanced parentheses in expression: '{expression}'")
|
| if current_identifier is not None:
|
| add_axis_name(current_identifier)
|
|
|
| @staticmethod
|
| def check_axis_name_return_reason(
|
| name: str, allow_underscore: bool = False
|
| ) -> Tuple[bool, str]:
|
| """Check if the given axis name is valid, and a message explaining why if not.
|
|
|
| Valid axes names are python identifiers except keywords, and should not start or end with an underscore.
|
|
|
| Args:
|
| name (str): the axis name to check
|
| allow_underscore (bool): whether axis names are allowed to start with an underscore
|
|
|
| Returns:
|
| Tuple[bool, str]: whether the axis name is valid, a message explaining why if not
|
| """
|
| if not str.isidentifier(name):
|
| return False, "not a valid python identifier"
|
| elif name[0] == "_" or name[-1] == "_":
|
| if name == "_" and allow_underscore:
|
| return True, ""
|
| return False, "axis name should should not start or end with underscore"
|
| else:
|
| if keyword.iskeyword(name):
|
| warnings.warn(
|
| f"It is discouraged to use axes names that are keywords: {name}",
|
| RuntimeWarning,
|
| )
|
| if name in ["axis"]:
|
| warnings.warn(
|
| "It is discouraged to use 'axis' as an axis name and will raise an error in future",
|
| FutureWarning,
|
| )
|
| return True, ""
|
|
|
| @staticmethod
|
| def check_axis_name(name: str) -> bool:
|
| """Check if the name is a valid axis name.
|
|
|
| Args:
|
| name (str): the axis name to check
|
|
|
| Returns:
|
| bool: whether the axis name is valid
|
| """
|
| is_valid, _ = ParsedExpression.check_axis_name_return_reason(name)
|
| return is_valid
|
|
|
|
|
| def parse_pattern(
|
| pattern: str, axes_lengths: Mapping[str, int]
|
| ) -> Tuple[ParsedExpression, ParsedExpression]:
|
| """Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object.
|
|
|
| Args:
|
| pattern (str): the `einops`-style rearrangement pattern
|
| axes_lengths (Mapping[str, int]): any additional length specifications for dimensions
|
|
|
| Returns:
|
| Tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions
|
| """
|
|
|
|
|
| try:
|
| left_str, right_str = pattern.split("->")
|
| except ValueError:
|
| raise ValueError("Pattern must contain a single '->' separator") from None
|
|
|
| if _ellipsis in axes_lengths:
|
| raise ValueError(f"'{_ellipsis}' is not an allowed axis identifier")
|
|
|
| left = ParsedExpression(left_str)
|
| right = ParsedExpression(right_str)
|
|
|
| if not left.has_ellipsis and right.has_ellipsis:
|
| raise ValueError(
|
| f"Ellipsis found in right side, but not left side of a pattern {pattern}"
|
| )
|
| if left.has_ellipsis and left.has_ellipsis_parenthesized:
|
| raise ValueError(
|
| f"Ellipsis is parenthesis in the left side is not allowed: {pattern}"
|
| )
|
|
|
| return left, right
|
|
|
|
|
| def validate_rearrange_expressions(
|
| left: ParsedExpression, right: ParsedExpression, axes_lengths: Mapping[str, int]
|
| ) -> None:
|
| """Perform expression validations that are specific to the `rearrange` operation.
|
|
|
| Args:
|
| left (ParsedExpression): left-hand side expression
|
| right (ParsedExpression): right-hand side expression
|
| axes_lengths (Mapping[str, int]): any additional length specifications for dimensions
|
| """
|
| for length in axes_lengths.values():
|
| if (length_type := type(length)) is not int:
|
| raise TypeError(
|
| f"rearrange axis lengths must be integers, got: {length_type}"
|
| )
|
|
|
| if left.has_non_unitary_anonymous_axes or right.has_non_unitary_anonymous_axes:
|
| raise ValueError("rearrange only supports unnamed axes of size 1")
|
|
|
| difference = set.symmetric_difference(left.identifiers, right.identifiers)
|
| if len(difference) > 0:
|
| raise ValueError(
|
| f"Identifiers only on one side of rearrange expression (should be on both): {difference}"
|
| )
|
|
|
| unmatched_axes = axes_lengths.keys() - left.identifiers
|
| if len(unmatched_axes) > 0:
|
| raise ValueError(
|
| f"Identifiers not found in rearrange expression: {unmatched_axes}"
|
| )
|
|
|
|
|
| def comma_separate(collection: Collection[Union[str, Collection[str]]]) -> str:
|
| """Convert a collection of strings representing first class dims into a comma-separated string.
|
|
|
| Args:
|
| collection (Collection[Union[str, Collection[str]]]): the collection of strings to convert
|
|
|
| Returns:
|
| str: the comma-separated string
|
|
|
| Examples:
|
| >>> comma_separate(("d0",))
|
| 'd0'
|
|
|
| >>> comma_separate(("d0", "d1", "d2", "d3"))
|
| 'd0, d1, d2, d3'
|
|
|
| >>> comma_separate([("d1", "d4")])
|
| '(d1, d4)'
|
|
|
| >>> comma_separate([("d0",), (), ("d1",), ("d2",), ("d3", "d4")])
|
| '(d0,), (), (d1,), (d2,), (d3, d4)'
|
| """
|
| return ", ".join(
|
| item
|
| if isinstance(item, str)
|
| else f"({comma_separate(item)}{',' if len(item) == 1 else ''})"
|
| for item in collection
|
| )
|
|
|