MViT-TR / Model /attention.py
serdaryildiz's picture
Upload 31 files
9b6af3b verified
import math
import torch
import torch.nn as nn
class PositionalEncoding(nn.Module):
r"""Inject some information about the relative or absolute position of the tokens
in the sequence. The positional encodings have the same dimension as
the embeddings, so that the two can be summed. Here, we use sine and cosine
functions of different frequencies.
.. math::
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
\text{where pos is the word position and i is the embed idx)
Args:
d_model: the embed dim (required).
dropout: the dropout value (default=0.1).
max_len: the max. length of the incoming sequence (default=5000).
Examples:
>>> pos_encoder = PositionalEncoding(d_model)
"""
def __init__(self, d_model, dropout=0.1, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
r"""Inputs of forward function
Args:
x: the sequence fed to the positional encoder model (required).
Shape:
x: [sequence length, batch size, embed dim]
output: [sequence length, batch size, embed dim]
Examples:
>>> output = pos_encoder(x)
"""
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
def encoder_layer(in_c, out_c, k=3, s=2, p=1):
return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p),
nn.BatchNorm2d(out_c),
nn.ReLU(True))
def decoder_layer(in_c, out_c, k=3, s=1, p=1, mode='nearest', scale_factor=None, size=None):
align_corners = None if mode == 'nearest' else True
return nn.Sequential(nn.Upsample(size=size, scale_factor=scale_factor,
mode=mode, align_corners=align_corners),
nn.Conv2d(in_c, out_c, k, s, p),
nn.BatchNorm2d(out_c),
nn.ReLU(True))
class PositionAttention(nn.Module):
def __init__(self, max_length, in_channels=512, num_channels=64,
h=8, w=32, mode='nearest', **kwargs):
super().__init__()
self.max_length = max_length
self.k_encoder = nn.Sequential(
encoder_layer(in_channels, num_channels, s=(1, 2)),
encoder_layer(num_channels, num_channels, s=(2, 2)),
encoder_layer(num_channels, num_channels, s=(2, 2)),
encoder_layer(num_channels, num_channels, s=(2, 2))
)
self.k_decoder = nn.Sequential(
decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
decoder_layer(num_channels, in_channels, size=(h, w), mode=mode)
)
self.pos_encoder = PositionalEncoding(in_channels, dropout=0., max_len=max_length)
self.project = nn.Linear(in_channels, in_channels)
def forward(self, x):
N, E, H, W = x.size()
k, v = x, x # (N, E, H, W)
# calculate key vector
features = []
for i in range(0, len(self.k_encoder)):
k = self.k_encoder[i](k)
features.append(k)
for i in range(0, len(self.k_decoder) - 1):
k = self.k_decoder[i](k)
k = k + features[len(self.k_decoder) - 2 - i]
k = self.k_decoder[-1](k)
# calculate query vector
zeros = x.new_zeros((self.max_length, N, E)) # (T, N, E)
q = self.pos_encoder(zeros) # (T, N, E)
q = q.permute(1, 0, 2) # (N, T, E)
q = self.project(q) # (N, T, E)
# calculate attention
attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W))
attn_scores = attn_scores / (E ** 0.5)
attn_scores = torch.softmax(attn_scores, dim=-1)
v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E)
attn_vecs = torch.bmm(attn_scores, v) # (N, T, E)
return attn_vecs, attn_scores.view(N, -1, H, W)