File size: 750 Bytes
166ab04 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | import torch
from torch import einsum
from einops import rearrange#, einsum
import torch._dynamo
torch._dynamo.config.suppress_errors = True
_exists = lambda val: val is not None
_default = lambda val, d: val if _exists(val) else d
_einsum = lambda *args, **kwargs: einsum(*args, **kwargs).contiguous()
_rearrange = lambda *args, **kwargs: rearrange(*args, **kwargs).contiguous()
from torch.nn.modules.utils import _ntuple, _single, _pair, _triple, _quadruple
to_2tuple = _pair
to_3tuple = _triple
to_4tuple = _quadruple
def calc_rel_pos(n):
pos = torch.meshgrid(torch.arange(n), torch.arange(n))
pos = _rearrange(torch.stack(pos), 'n i j -> (i j) n')
rel_pos = pos[None, :] - pos[:, None]
rel_pos += n - 1
return rel_pos
|