import torch def fake_decorator(*a, **k): if len(a) == 1 and not k: return a[0] else: return fake_decorator def make_list(x, n=None, **kwargs): """Ensure that the input is a list (of a given size) Parameters ---------- x : list or tuple or scalar Input object n : int, optional Required length default : scalar, optional Value to right-pad with. Use last value of the input by default. Returns ------- x : list """ if not isinstance(x, (list, tuple)): x = [x] x = list(x) if n and len(x) < n: default = kwargs.get('default', x[-1]) x = x + [default] * max(0, n - len(x)) return x def expanded_shape(*shapes, side='left'): """Expand input shapes according to broadcasting rules Parameters ---------- *shapes : sequence[int] Input shapes side : {'left', 'right'}, default='left' Side to add singleton dimensions. Returns ------- shape : tuple[int] Output shape Raises ------ ValueError If shapes are not compatible for broadcast. """ def error(s0, s1): raise ValueError('Incompatible shapes for broadcasting: {} and {}.' .format(s0, s1)) # 1. nb dimensions nb_dim = 0 for shape in shapes: nb_dim = max(nb_dim, len(shape)) # 2. enumerate shape = [1] * nb_dim for i, shape1 in enumerate(shapes): pad_size = nb_dim - len(shape1) ones = [1] * pad_size if side == 'left': shape1 = [*ones, *shape1] else: shape1 = [*shape1, *ones] shape = [max(s0, s1) if s0 == 1 or s1 == 1 or s0 == s1 else error(s0, s1) for s0, s1 in zip(shape, shape1)] return tuple(shape) def matvec(mat, vec, out=None): """Matrix-vector product (supports broadcasting) Parameters ---------- mat : (..., M, N) tensor Input matrix. vec : (..., N) tensor Input vector. out : (..., M) tensor, optional Placeholder for the output tensor. Returns ------- mv : (..., M) tensor Matrix vector product of the inputs """ vec = vec[..., None] if out is not None: out = out[..., None] mv = torch.matmul(mat, vec, out=out) mv = mv[..., 0] if out is not None: out = out[..., 0] return mv def _compare_versions(version1, mode, version2): for v1, v2 in zip(version1, version2): if mode in ('gt', '>'): if v1 > v2: return True elif v1 < v2: return False elif mode in ('ge', '>='): if v1 > v2: return True elif v1 < v2: return False elif mode in ('lt', '<'): if v1 < v2: return True elif v1 > v2: return False elif mode in ('le', '<='): if v1 < v2: return True elif v1 > v2: return False if mode in ('gt', 'lt', '>', '<'): return False else: return True def torch_version(mode, version): """Check torch version Parameters ---------- mode : {'<', '<=', '>', '>='} version : tuple[int] Returns ------- True if "torch.version version" """ current_version, *cuda_variant = torch.__version__.split('+') major, minor, patch, *_ = current_version.split('.') # strip alpha tags for x in 'abcdefghijklmnopqrstuvwxy': if x in patch: patch = patch[:patch.index(x)] current_version = (int(major), int(minor), int(patch)) version = make_list(version) return _compare_versions(current_version, mode, version) if torch_version('>=', (1, 10)): meshgrid_ij = lambda *x: torch.meshgrid(*x, indexing='ij') meshgrid_xy = lambda *x: torch.meshgrid(*x, indexing='xy') else: meshgrid_ij = lambda *x: torch.meshgrid(*x) def meshgrid_xy(*x): grid = list(torch.meshgrid(*x)) if len(grid) > 1: grid[0] = grid[0].transpose(0, 1) grid[1] = grid[1].transpose(0, 1) return grid meshgrid = meshgrid_ij