Spaces:
Sleeping
Sleeping
| import math | |
| import torch | |
| import torch.nn as nn | |
| class PositionalEncoding(nn.Module): | |
| r"""Inject some information about the relative or absolute position of the tokens | |
| in the sequence. The positional encodings have the same dimension as | |
| the embeddings, so that the two can be summed. Here, we use sine and cosine | |
| functions of different frequencies. | |
| .. math:: | |
| \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) | |
| \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) | |
| \text{where pos is the word position and i is the embed idx) | |
| Args: | |
| d_model: the embed dim (required). | |
| dropout: the dropout value (default=0.1). | |
| max_len: the max. length of the incoming sequence (default=5000). | |
| Examples: | |
| >>> pos_encoder = PositionalEncoding(d_model) | |
| """ | |
| def __init__(self, d_model, dropout=0.1, max_len=5000): | |
| super().__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0).transpose(0, 1) | |
| self.register_buffer('pe', pe) | |
| def forward(self, x): | |
| r"""Inputs of forward function | |
| Args: | |
| x: the sequence fed to the positional encoder model (required). | |
| Shape: | |
| x: [sequence length, batch size, embed dim] | |
| output: [sequence length, batch size, embed dim] | |
| Examples: | |
| >>> output = pos_encoder(x) | |
| """ | |
| x = x + self.pe[:x.size(0), :] | |
| return self.dropout(x) | |
| def encoder_layer(in_c, out_c, k=3, s=2, p=1): | |
| return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p), | |
| nn.BatchNorm2d(out_c), | |
| nn.ReLU(True)) | |
| def decoder_layer(in_c, out_c, k=3, s=1, p=1, mode='nearest', scale_factor=None, size=None): | |
| align_corners = None if mode == 'nearest' else True | |
| return nn.Sequential(nn.Upsample(size=size, scale_factor=scale_factor, | |
| mode=mode, align_corners=align_corners), | |
| nn.Conv2d(in_c, out_c, k, s, p), | |
| nn.BatchNorm2d(out_c), | |
| nn.ReLU(True)) | |
| class PositionAttention(nn.Module): | |
| def __init__(self, max_length, in_channels=512, num_channels=64, | |
| h=8, w=32, mode='nearest', **kwargs): | |
| super().__init__() | |
| self.max_length = max_length | |
| self.k_encoder = nn.Sequential( | |
| encoder_layer(in_channels, num_channels, s=(1, 2)), | |
| encoder_layer(num_channels, num_channels, s=(2, 2)), | |
| encoder_layer(num_channels, num_channels, s=(2, 2)), | |
| encoder_layer(num_channels, num_channels, s=(2, 2)) | |
| ) | |
| self.k_decoder = nn.Sequential( | |
| decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), | |
| decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), | |
| decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode), | |
| decoder_layer(num_channels, in_channels, size=(h, w), mode=mode) | |
| ) | |
| self.pos_encoder = PositionalEncoding(in_channels, dropout=0., max_len=max_length) | |
| self.project = nn.Linear(in_channels, in_channels) | |
| def forward(self, x): | |
| N, E, H, W = x.size() | |
| k, v = x, x # (N, E, H, W) | |
| # calculate key vector | |
| features = [] | |
| for i in range(0, len(self.k_encoder)): | |
| k = self.k_encoder[i](k) | |
| features.append(k) | |
| for i in range(0, len(self.k_decoder) - 1): | |
| k = self.k_decoder[i](k) | |
| k = k + features[len(self.k_decoder) - 2 - i] | |
| k = self.k_decoder[-1](k) | |
| # calculate query vector | |
| zeros = x.new_zeros((self.max_length, N, E)) # (T, N, E) | |
| q = self.pos_encoder(zeros) # (T, N, E) | |
| q = q.permute(1, 0, 2) # (N, T, E) | |
| q = self.project(q) # (N, T, E) | |
| # calculate attention | |
| attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W)) | |
| attn_scores = attn_scores / (E ** 0.5) | |
| attn_scores = torch.softmax(attn_scores, dim=-1) | |
| v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E) | |
| attn_vecs = torch.bmm(attn_scores, v) # (N, T, E) | |
| return attn_vecs, attn_scores.view(N, -1, H, W) | |