moebius / model_lib /nets /layers /utils.py
Mike0021's picture
Implement Moebius Gradio Space
166ab04 verified
Raw
History Blame Contribute Delete
750 Bytes
import torch
from torch import einsum
from einops import rearrange#, einsum
import torch._dynamo
torch._dynamo.config.suppress_errors = True
_exists = lambda val: val is not None
_default = lambda val, d: val if _exists(val) else d
_einsum = lambda *args, **kwargs: einsum(*args, **kwargs).contiguous()
_rearrange = lambda *args, **kwargs: rearrange(*args, **kwargs).contiguous()
from torch.nn.modules.utils import _ntuple, _single, _pair, _triple, _quadruple
to_2tuple = _pair
to_3tuple = _triple
to_4tuple = _quadruple
def calc_rel_pos(n):
pos = torch.meshgrid(torch.arange(n), torch.arange(n))
pos = _rearrange(torch.stack(pos), 'n i j -> (i j) n')
rel_pos = pos[None, :] - pos[:, None]
rel_pos += n - 1
return rel_pos