import math import numpy as np import torch import torch.nn as nn from typing import Tuple from torch import Tensor from torch.nn import functional as F class Adaptive2DPositionalEncoding(nn.Module): """Implement Adaptive 2D positional encoder for SATRN, see `SATRN `_ Modified from https://github.com/Media-Smart/vedastr Licensed under the Apache License, Version 2.0 (the "License"); Args: d_hid (int): Dimensions of hidden layer. n_height (int): Max height of the 2D feature output. n_width (int): Max width of the 2D feature output. dropout (int): Size of hidden layers of the model. """ def __init__(self, d_hid=512, n_height=100, n_width=100, dropout=0.1, ): super().__init__() h_position_encoder = self._get_sinusoid_encoding_table(n_height, d_hid) h_position_encoder = h_position_encoder.transpose(0, 1) h_position_encoder = h_position_encoder.view(1, d_hid, n_height, 1) w_position_encoder = self._get_sinusoid_encoding_table(n_width, d_hid) w_position_encoder = w_position_encoder.transpose(0, 1) w_position_encoder = w_position_encoder.view(1, d_hid, 1, n_width) self.register_buffer('h_position_encoder', h_position_encoder) self.register_buffer('w_position_encoder', w_position_encoder) self.h_scale = self.scale_factor_generate(d_hid) self.w_scale = self.scale_factor_generate(d_hid) self.pool = nn.AdaptiveAvgPool2d(1) self.dropout = nn.Dropout(p=dropout) def _get_sinusoid_encoding_table(self, n_position, d_hid): """Sinusoid position encoding table.""" denominator = torch.Tensor([ 1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid) ]) denominator = denominator.view(1, -1) pos_tensor = torch.arange(n_position).unsqueeze(-1).float() sinusoid_table = pos_tensor * denominator sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2]) sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2]) return sinusoid_table def scale_factor_generate(self, d_hid): scale_factor = nn.Sequential( nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.ReLU(inplace=True), nn.Conv2d(d_hid, d_hid, kernel_size=1), nn.Sigmoid()) return scale_factor def init_weight(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('ReLU')) def forward(self, x): b, c, h, w = x.size() avg_pool = self.pool(x) h_pos_encoding = \ self.h_scale(avg_pool) * self.h_position_encoder[:, :, :h, :] w_pos_encoding = \ self.w_scale(avg_pool) * self.w_position_encoder[:, :, :, :w] out = x + h_pos_encoding + w_pos_encoding out = self.dropout(out) return out class PositionalEncoding2D(nn.Module): """2-D positional encodings for the feature maps produced by the encoder. Following https://arxiv.org/abs/2103.06450 by Sumeet Singh. Reference: https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2021-labs/blob/main/lab9/text_recognizer/models/transformer_util.py """ def __init__(self, d_model: int, max_h: int = 2000, max_w: int = 2000) -> None: super().__init__() self.d_model = d_model assert d_model % 2 == 0, f"Embedding depth {d_model} is not even" pe = self.make_pe(d_model, max_h, max_w) # (d_model, max_h, max_w) self.register_buffer("pe", pe) @staticmethod def make_pe(d_model: int, max_h: int, max_w: int) -> Tensor: """Compute positional encoding.""" pe_h = PositionalEncoding1D.make_pe(d_model=d_model // 2, max_len=max_h) # (max_h, 1 d_model // 2) pe_h = pe_h.permute(2, 0, 1).expand(-1, -1, max_w) # (d_model // 2, max_h, max_w) pe_w = PositionalEncoding1D.make_pe(d_model=d_model // 2, max_len=max_w) # (max_w, 1, d_model // 2) pe_w = pe_w.permute(2, 1, 0).expand(-1, max_h, -1) # (d_model // 2, max_h, max_w) pe = torch.cat([pe_h, pe_w], dim=0) # (d_model, max_h, max_w) return pe def forward(self, x: Tensor) -> Tensor: """Forward pass. Args: x: (B, d_model, H, W) Returns: (B, d_model, H, W) """ assert x.shape[1] == self.pe.shape[0] # type: ignore x = x + self.pe[:, : x.size(2), : x.size(3)] # type: ignore return x class PositionalEncoding1D(nn.Module): """Classic Attention-is-all-you-need positional encoding.""" def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1000) -> None: super().__init__() self.dropout = nn.Dropout(p=dropout) pe = self.make_pe(d_model, max_len) # (max_len, 1, d_model) self.register_buffer("pe", pe) @staticmethod def make_pe(d_model: int, max_len: int) -> Tensor: """Compute positional encoding.""" pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(1) return pe def forward(self, x: Tensor) -> Tensor: """Forward pass. Args: x: (S, B, d_model) Returns: (S, B, d_model) """ assert x.shape[2] == self.pe.shape[2] # type: ignore x = x + self.pe[: x.size(0)] # type: ignore return self.dropout(x) Size_ = Tuple[int, int] class PosConv(nn.Module): # PEG from https://arxiv.org/abs/2102.10882 def __init__(self, in_chans, embed_dim=768, stride=1): super(PosConv, self).__init__() self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), ) self.stride = stride def forward(self, x, size: Size_): B, N, C = x.shape cls_token, feat_token = x[:, 0], x[:, 1:] cnn_feat_token = feat_token.transpose(1, 2).view(B, C, *size) x = self.proj(cnn_feat_token) if self.stride == 1: x += cnn_feat_token x = x.flatten(2).transpose(1, 2) x = torch.cat((cls_token.unsqueeze(1), x), dim=1) return x def no_weight_decay(self): return ['proj.%d.weight' % i for i in range(4)] class PosConv1D(nn.Module): # PEG from https://arxiv.org/abs/2102.10882 def __init__(self, in_chans, embed_dim=768, stride=1): super(PosConv1D, self).__init__() self.proj = nn.Sequential(nn.Conv1d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), ) self.stride = stride def forward(self, x, size: int): B, N, C = x.shape cls_token, feat_token = x[:, 0], x[:, 1:] cnn_feat_token = feat_token.transpose(1, 2).view(B, C, size) x = self.proj(cnn_feat_token) if self.stride == 1: x += cnn_feat_token x = x.transpose(1, 2) x = torch.cat((cls_token.unsqueeze(1), x), dim=1) return x def no_weight_decay(self): return ['proj.%d.weight' % i for i in range(4)] def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=(), old_grid_shape=()): # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 print('Resized position embedding: %s to %s'%(posemb.shape, posemb_new.shape)) ntok_new = posemb_new.shape[1] if num_tokens: posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] ntok_new -= num_tokens else: posemb_tok, posemb_grid = posemb[:, :0], posemb[0] if not len(gs_new): # backwards compatibility gs_new = [int(math.sqrt(ntok_new))] * 2 assert len(gs_new) >= 2 print('Position embedding grid-size from %s to %s'%(old_grid_shape, gs_new)) posemb_grid = posemb_grid.reshape(1, old_grid_shape[0], old_grid_shape[1], -1).permute(0, 3, 1, 2) posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) posemb = torch.cat([posemb_tok, posemb_grid], dim=1) return posemb