| | """ |
| | @Date: 2021/09/01 |
| | @description: |
| | """ |
| | import warnings |
| | import math |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | from torch import nn, einsum |
| | from einops import rearrange |
| |
|
| |
|
| | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): |
| | |
| | |
| | def norm_cdf(x): |
| | |
| | return (1. + math.erf(x / math.sqrt(2.))) / 2. |
| |
|
| | if (mean < a - 2 * std) or (mean > b + 2 * std): |
| | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
| | "The distribution of values may be incorrect.", |
| | stacklevel=2) |
| |
|
| | with torch.no_grad(): |
| | |
| | |
| | |
| | l = norm_cdf((a - mean) / std) |
| | u = norm_cdf((b - mean) / std) |
| |
|
| | |
| | |
| | tensor.uniform_(2 * l - 1, 2 * u - 1) |
| |
|
| | |
| | |
| | tensor.erfinv_() |
| |
|
| | |
| | tensor.mul_(std * math.sqrt(2.)) |
| | tensor.add_(mean) |
| |
|
| | |
| | tensor.clamp_(min=a, max=b) |
| | return tensor |
| |
|
| |
|
| | class PreNorm(nn.Module): |
| | def __init__(self, dim, fn): |
| | super().__init__() |
| | self.norm = nn.LayerNorm(dim) |
| | self.fn = fn |
| |
|
| | def forward(self, x, **kwargs): |
| | return self.fn(self.norm(x), **kwargs) |
| |
|
| |
|
| | |
| | class GELU(nn.Module): |
| | def forward(self, input): |
| | return F.gelu(input) |
| |
|
| |
|
| | class Attend(nn.Module): |
| |
|
| | def __init__(self, dim=None): |
| | super().__init__() |
| | self.dim = dim |
| |
|
| | def forward(self, input): |
| | return F.softmax(input, dim=self.dim, dtype=input.dtype) |
| |
|
| |
|
| | class FeedForward(nn.Module): |
| | def __init__(self, dim, hidden_dim, dropout=0.): |
| | super().__init__() |
| | self.net = nn.Sequential( |
| | nn.Linear(dim, hidden_dim), |
| | GELU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(hidden_dim, dim), |
| | nn.Dropout(dropout) |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.net(x) |
| |
|
| |
|
| | class RelativePosition(nn.Module): |
| | def __init__(self, heads, patch_num=None, rpe=None): |
| | super().__init__() |
| | self.rpe = rpe |
| | self.heads = heads |
| | self.patch_num = patch_num |
| |
|
| | if rpe == 'lr_parameter': |
| | |
| | count = patch_num * 2 - 1 |
| | self.rpe_table = nn.Parameter(torch.Tensor(count, heads)) |
| | nn.init.xavier_uniform_(self.rpe_table) |
| | elif rpe == 'lr_parameter_mirror': |
| | |
| | count = patch_num // 2 + 1 |
| | self.rpe_table = nn.Parameter(torch.Tensor(count, heads)) |
| | nn.init.xavier_uniform_(self.rpe_table) |
| | elif rpe == 'lr_parameter_half': |
| | |
| | count = patch_num |
| | self.rpe_table = nn.Parameter(torch.Tensor(count, heads)) |
| | nn.init.xavier_uniform_(self.rpe_table) |
| | elif rpe == 'fix_angle': |
| | |
| | count = patch_num // 2 + 1 |
| | |
| | rpe_table = (torch.arange(count, 0, -1) / count)[..., None].repeat(1, heads) |
| | self.register_buffer('rpe_table', rpe_table) |
| |
|
| | def get_relative_pos_embed(self): |
| | range_vec = torch.arange(self.patch_num) |
| | distance_mat = range_vec[None, :] - range_vec[:, None] |
| | if self.rpe == 'lr_parameter': |
| | |
| | distance_mat += self.patch_num - 1 |
| | return self.rpe_table[distance_mat].permute(2, 0, 1)[None] |
| | elif self.rpe == 'lr_parameter_mirror' or self.rpe == 'fix_angle': |
| | distance_mat[distance_mat < 0] = -distance_mat[distance_mat < 0] |
| | distance_mat[distance_mat > self.patch_num // 2] = self.patch_num - distance_mat[ |
| | distance_mat > self.patch_num // 2] |
| | return self.rpe_table[distance_mat].permute(2, 0, 1)[None] |
| | elif self.rpe == 'lr_parameter_half': |
| | distance_mat[distance_mat > self.patch_num // 2] = distance_mat[ |
| | distance_mat > self.patch_num // 2] - self.patch_num |
| | distance_mat[distance_mat < -self.patch_num // 2 + 1] = distance_mat[ |
| | distance_mat < -self.patch_num // 2 + 1] + self.patch_num |
| | |
| | distance_mat += self.patch_num//2 - 1 |
| | return self.rpe_table[distance_mat].permute(2, 0, 1)[None] |
| |
|
| | def forward(self, attn): |
| | return attn + self.get_relative_pos_embed() |
| |
|
| |
|
| | class Attention(nn.Module): |
| | def __init__(self, dim, heads=8, dim_head=64, dropout=0., patch_num=None, rpe=None, rpe_pos=1): |
| | """ |
| | :param dim: |
| | :param heads: |
| | :param dim_head: |
| | :param dropout: |
| | :param patch_num: |
| | :param rpe: relative position embedding |
| | """ |
| | super().__init__() |
| |
|
| | self.relative_pos_embed = None if patch_num is None or rpe is None else RelativePosition(heads, patch_num, rpe) |
| | inner_dim = dim_head * heads |
| | project_out = not (heads == 1 and dim_head == dim) |
| |
|
| | self.heads = heads |
| | self.scale = dim_head ** -0.5 |
| | self.rpe_pos = rpe_pos |
| |
|
| | self.attend = Attend(dim=-1) |
| | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) |
| |
|
| | self.to_out = nn.Sequential( |
| | nn.Linear(inner_dim, dim), |
| | nn.Dropout(dropout) |
| | ) if project_out else nn.Identity() |
| |
|
| | def forward(self, x): |
| | b, n, _, h = *x.shape, self.heads |
| | qkv = self.to_qkv(x).chunk(3, dim=-1) |
| | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) |
| |
|
| | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale |
| |
|
| | if self.rpe_pos == 0: |
| | if self.relative_pos_embed is not None: |
| | dots = self.relative_pos_embed(dots) |
| |
|
| | attn = self.attend(dots) |
| |
|
| | if self.rpe_pos == 1: |
| | if self.relative_pos_embed is not None: |
| | attn = self.relative_pos_embed(attn) |
| |
|
| | out = einsum('b h i j, b h j d -> b h i d', attn, v) |
| | out = rearrange(out, 'b h n d -> b n (h d)') |
| | return self.to_out(out) |
| |
|
| |
|
| | class AbsolutePosition(nn.Module): |
| | def __init__(self, dim, dropout=0., patch_num=None, ape=None): |
| | super().__init__() |
| | self.ape = ape |
| |
|
| | if ape == 'lr_parameter': |
| | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, patch_num, dim)) |
| | trunc_normal_(self.absolute_pos_embed, std=.02) |
| |
|
| | elif ape == 'fix_angle': |
| | angle = torch.arange(0, patch_num, dtype=torch.float) / patch_num * (math.pi * 2) |
| | self.absolute_pos_embed = torch.sin(angle)[..., None].repeat(1, dim)[None] |
| |
|
| | def forward(self, x): |
| | return x + self.absolute_pos_embed |
| |
|
| |
|
| | class WinAttention(nn.Module): |
| | def __init__(self, dim, win_size=8, shift=0, heads=8, dim_head=64, dropout=0., rpe=None, rpe_pos=1): |
| | super().__init__() |
| |
|
| | self.win_size = win_size |
| | self.shift = shift |
| | self.attend = Attention(dim, heads=heads, dim_head=dim_head, |
| | dropout=dropout, patch_num=win_size, rpe=None if rpe is None else 'lr_parameter', |
| | rpe_pos=rpe_pos) |
| |
|
| | def forward(self, x): |
| | b = x.shape[0] |
| | if self.shift != 0: |
| | x = torch.roll(x, shifts=self.shift, dims=-2) |
| | x = rearrange(x, 'b (m w) d -> (b m) w d', w=self.win_size) |
| |
|
| | out = self.attend(x) |
| |
|
| | out = rearrange(out, '(b m) w d -> b (m w) d ', b=b) |
| | if self.shift != 0: |
| | out = torch.roll(out, shifts=-self.shift, dims=-2) |
| |
|
| | return out |
| |
|
| |
|
| | class Conv(nn.Module): |
| | def __init__(self, dim, dropout=0.): |
| | super().__init__() |
| | self.dim = dim |
| | self.net = nn.Sequential( |
| | nn.Conv1d(dim, dim, kernel_size=3, stride=1, padding=0), |
| | nn.Dropout(dropout) |
| | ) |
| |
|
| | def forward(self, x): |
| | x = x.transpose(1, 2) |
| | x = torch.cat([x[..., -1:], x, x[..., :1]], dim=-1) |
| | x = self.net(x) |
| | return x.transpose(1, 2) |
| |
|