| import torch | |
| from torch import einsum | |
| from einops import rearrange#, einsum | |
| import torch._dynamo | |
| torch._dynamo.config.suppress_errors = True | |
| _exists = lambda val: val is not None | |
| _default = lambda val, d: val if _exists(val) else d | |
| _einsum = lambda *args, **kwargs: einsum(*args, **kwargs).contiguous() | |
| _rearrange = lambda *args, **kwargs: rearrange(*args, **kwargs).contiguous() | |
| from torch.nn.modules.utils import _ntuple, _single, _pair, _triple, _quadruple | |
| to_2tuple = _pair | |
| to_3tuple = _triple | |
| to_4tuple = _quadruple | |
| def calc_rel_pos(n): | |
| pos = torch.meshgrid(torch.arange(n), torch.arange(n)) | |
| pos = _rearrange(torch.stack(pos), 'n i j -> (i j) n') | |
| rel_pos = pos[None, :] - pos[:, None] | |
| rel_pos += n - 1 | |
| return rel_pos | |