import math import numpy as np import torch import torch.nn.functional as F from torch import nn from .weight_init import trunc_normal_ class MaskedConv1D(nn.Module): """ Masked 1D convolution. Interface remains the same as Conv1d. Only support a sub set of 1d convs """ def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros' ): super().__init__() # element must be aligned assert (kernel_size % 2 == 1) and (kernel_size // 2 == padding) # stride self.stride = stride self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode) # zero out the bias term if it exists if bias: torch.nn.init.constant_(self.conv.bias, 0.) def forward(self, x, mask): # x: batch size, feature channel, sequence length, # mask: batch size, 1, sequence length (bool) B, C, T = x.size() # input length must be divisible by stride assert T % self.stride == 0 # conv out_conv = self.conv(x) # compute the mask if self.stride > 1: # downsample the mask using nearest neighbor out_mask = F.interpolate( mask.to(x.dtype), size=T//self.stride, mode='nearest' ) else: # masking out the features out_mask = mask.to(x.dtype) # masking the output, stop grad to mask out_conv = out_conv * out_mask.detach() out_mask = out_mask.bool() return out_conv, out_mask class LayerNorm(nn.Module): """ LayerNorm that supports inputs of size B, C, T """ def __init__( self, num_channels, eps = 1e-5, affine = True, device = None, dtype = None, ): super().__init__() factory_kwargs = {'device': device, 'dtype': dtype} self.num_channels = num_channels self.eps = eps self.affine = affine if self.affine: self.weight = nn.Parameter( torch.ones([1, num_channels, 1], **factory_kwargs)) self.bias = nn.Parameter( torch.zeros([1, num_channels, 1], **factory_kwargs)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) def forward(self, x): assert x.dim() == 3 assert x.shape[1] == self.num_channels # normalization along C channels mu = torch.mean(x, dim=1, keepdim=True) res_x = x - mu sigma = torch.mean(res_x**2, dim=1, keepdim=True) out = res_x / torch.sqrt(sigma + self.eps) # apply weight and bias if self.affine: out *= self.weight out += self.bias return out # helper functions for Transformer blocks def get_sinusoid_encoding(n_position, d_hid): ''' Sinusoid position encoding table ''' def get_position_angle_vec(position): return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 # return a tensor of size 1 C T return torch.FloatTensor(sinusoid_table).unsqueeze(0).transpose(1, 2) # attention / transformers class MaskedMHA(nn.Module): """ Multi Head Attention with mask Modified from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py """ def __init__( self, n_embd, # dimension of the input embedding n_head, # number of heads in multi-head self-attention attn_pdrop=0.0, # dropout rate for the attention map proj_pdrop=0.0 # dropout rate for projection op ): super().__init__() assert n_embd % n_head == 0 self.n_embd = n_embd self.n_head = n_head self.n_channels = n_embd // n_head self.scale = 1.0 / math.sqrt(self.n_channels) # key, query, value projections for all heads # it is OK to ignore masking, as the mask will be attached on the attention self.key = nn.Conv1d(self.n_embd, self.n_embd, 1) self.query = nn.Conv1d(self.n_embd, self.n_embd, 1) self.value = nn.Conv1d(self.n_embd, self.n_embd, 1) # regularization self.attn_drop = nn.Dropout(attn_pdrop) self.proj_drop = nn.Dropout(proj_pdrop) # output projection self.proj = nn.Conv1d(self.n_embd, self.n_embd, 1) def forward(self, x, mask): # x: batch size, feature channel, sequence length, # mask: batch size, 1, sequence length (bool) B, C, T = x.size() # calculate query, key, values for all heads in batch # (B, nh * hs, T) k = self.key(x) q = self.query(x) v = self.value(x) # move head forward to be the batch dim # (B, nh * hs, T) -> (B, nh, T, hs) k = k.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) q = q.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) v = v.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) # self-attention: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) att = (q * self.scale) @ k.transpose(-2, -1) # prevent q from attending to invalid tokens att = att.masked_fill(torch.logical_not(mask[:, :, None, :]), float('-inf')) # softmax attn att = F.softmax(att, dim=-1) att = self.attn_drop(att) # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) out = att @ (v * mask[:, :, :, None].to(v.dtype)) # re-assemble all head outputs side by side out = out.transpose(2, 3).contiguous().view(B, C, -1) # output projection + skip connection out = self.proj_drop(self.proj(out)) * mask.to(out.dtype) return out, mask class MaskedMHCA(nn.Module): """ Multi Head Conv Attention with mask Add a depthwise convolution within a standard MHA The extra conv op can be used to (1) encode relative position information (relacing position encoding); (2) downsample the features if needed; (3) match the feature channels Note: With current implementation, the downsampled feature will be aligned to every s+1 time step, where s is the downsampling stride. This allows us to easily interpolate the corresponding positional embeddings. Modified from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py """ def __init__( self, n_embd, # dimension of the output features n_head, # number of heads in multi-head self-attention n_qx_stride=1, # dowsampling stride for query and input n_kv_stride=1, # downsampling stride for key and value attn_pdrop=0.0, # dropout rate for the attention map proj_pdrop=0.0, # dropout rate for projection op ): super().__init__() assert n_embd % n_head == 0 self.n_embd = n_embd self.n_head = n_head self.n_channels = n_embd // n_head self.scale = 1.0 / math.sqrt(self.n_channels) # conv/pooling operations assert (n_qx_stride == 1) or (n_qx_stride % 2 == 0) assert (n_kv_stride == 1) or (n_kv_stride % 2 == 0) self.n_qx_stride = n_qx_stride self.n_kv_stride = n_kv_stride # query conv (depthwise) kernel_size = self.n_qx_stride + 1 if self.n_qx_stride > 1 else 3 stride, padding = self.n_kv_stride, kernel_size // 2 # 1d depthwise conv self.query_conv = MaskedConv1D( self.n_embd, self.n_embd, kernel_size, stride=stride, padding=padding, groups=self.n_embd, bias=False ) # layernorm self.query_norm = LayerNorm(self.n_embd) # key, value conv (depthwise) kernel_size = self.n_kv_stride + 1 if self.n_kv_stride > 1 else 3 stride, padding = self.n_kv_stride, kernel_size // 2 # 1d depthwise conv self.key_conv = MaskedConv1D( self.n_embd, self.n_embd, kernel_size, stride=stride, padding=padding, groups=self.n_embd, bias=False ) self.key_norm = LayerNorm(self.n_embd) self.value_conv = MaskedConv1D( self.n_embd, self.n_embd, kernel_size, stride=stride, padding=padding, groups=self.n_embd, bias=False ) # layernorm self.value_norm = LayerNorm(self.n_embd) # key, query, value projections for all heads # it is OK to ignore masking, as the mask will be attached on the attention self.key = nn.Conv1d(self.n_embd, self.n_embd, 1) self.query = nn.Conv1d(self.n_embd, self.n_embd, 1) self.value = nn.Conv1d(self.n_embd, self.n_embd, 1) # regularization self.attn_drop = nn.Dropout(attn_pdrop) self.proj_drop = nn.Dropout(proj_pdrop) # output projection self.proj = nn.Conv1d(self.n_embd, self.n_embd, 1) def forward(self, x, mask): # x: batch size, feature channel, sequence length, # mask: batch size, 1, sequence length (bool) B, C, T = x.size() # query conv -> (B, nh * hs, T') q, qx_mask = self.query_conv(x, mask) q = self.query_norm(q) # key, value conv -> (B, nh * hs, T'') k, kv_mask = self.key_conv(x, mask) k = self.key_norm(k) v, _ = self.value_conv(x, mask) v = self.value_norm(v) # projections q = self.query(q) k = self.key(k) v = self.value(v) # move head forward to be the batch dim # (B, nh * hs, T'/T'') -> (B, nh, T'/T'', hs) k = k.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) q = q.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) v = v.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) # self-attention: (B, nh, T', hs) x (B, nh, hs, T'') -> (B, nh, T', T'') att = (q * self.scale) @ k.transpose(-2, -1) # prevent q from attending to invalid tokens att = att.masked_fill(torch.logical_not(kv_mask[:, :, None, :]), float('-inf')) # softmax attn att = F.softmax(att, dim=-1) att = self.attn_drop(att) # (B, nh, T', T'') x (B, nh, T'', hs) -> (B, nh, T', hs) out = att @ (v * kv_mask[:, :, :, None].to(v.dtype)) # re-assemble all head outputs side by side out = out.transpose(2, 3).contiguous().view(B, C, -1) # output projection + skip connection out = self.proj_drop(self.proj(out)) * qx_mask.to(out.dtype) return out, qx_mask class LocalMaskedMHCA(nn.Module): """ Local Multi Head Conv Attention with mask Add a depthwise convolution within a standard MHA The extra conv op can be used to (1) encode relative position information (relacing position encoding); (2) downsample the features if needed; (3) match the feature channels Note: With current implementation, the downsampled feature will be aligned to every s+1 time step, where s is the downsampling stride. This allows us to easily interpolate the corresponding positional embeddings. The implementation is fairly tricky, code reference from https://github.com/huggingface/transformers/blob/master/src/transformers/models/longformer/modeling_longformer.py """ def __init__( self, n_embd, # dimension of the output features n_head, # number of heads in multi-head self-attention window_size, # size of the local attention window n_qx_stride=1, # dowsampling stride for query and input n_kv_stride=1, # downsampling stride for key and value attn_pdrop=0.0, # dropout rate for the attention map proj_pdrop=0.0, # dropout rate for projection op use_rel_pe=False # use relative position encoding ): super().__init__() assert n_embd % n_head == 0 self.n_embd = n_embd self.n_head = n_head self.n_channels = n_embd // n_head self.scale = 1.0 / math.sqrt(self.n_channels) self.window_size = window_size self.window_overlap = window_size // 2 # must use an odd window size assert self.window_size > 1 and self.n_head >= 1 self.use_rel_pe = use_rel_pe # conv/pooling operations assert (n_qx_stride == 1) or (n_qx_stride % 2 == 0) assert (n_kv_stride == 1) or (n_kv_stride % 2 == 0) self.n_qx_stride = n_qx_stride self.n_kv_stride = n_kv_stride # query conv (depthwise) kernel_size = self.n_qx_stride + 1 if self.n_qx_stride > 1 else 3 stride, padding = self.n_kv_stride, kernel_size // 2 # 1d depthwise conv self.query_conv = MaskedConv1D( self.n_embd, self.n_embd, kernel_size, stride=stride, padding=padding, groups=self.n_embd, bias=False ) # layernorm self.query_norm = LayerNorm(self.n_embd) # key, value conv (depthwise) kernel_size = self.n_kv_stride + 1 if self.n_kv_stride > 1 else 3 stride, padding = self.n_kv_stride, kernel_size // 2 # 1d depthwise conv self.key_conv = MaskedConv1D( self.n_embd, self.n_embd, kernel_size, stride=stride, padding=padding, groups=self.n_embd, bias=False ) self.key_norm = LayerNorm(self.n_embd) self.value_conv = MaskedConv1D( self.n_embd, self.n_embd, kernel_size, stride=stride, padding=padding, groups=self.n_embd, bias=False ) # layernorm self.value_norm = LayerNorm(self.n_embd) # key, query, value projections for all heads # it is OK to ignore masking, as the mask will be attached on the attention self.key = nn.Conv1d(self.n_embd, self.n_embd, 1) self.query = nn.Conv1d(self.n_embd, self.n_embd, 1) self.value = nn.Conv1d(self.n_embd, self.n_embd, 1) # regularization self.attn_drop = nn.Dropout(attn_pdrop) self.proj_drop = nn.Dropout(proj_pdrop) # output projection self.proj = nn.Conv1d(self.n_embd, self.n_embd, 1) # relative position encoding if self.use_rel_pe: self.rel_pe = nn.Parameter( torch.zeros(1, 1, self.n_head, self.window_size)) trunc_normal_(self.rel_pe, std=(2.0 / self.n_embd)**0.5) @staticmethod def _chunk(x, window_overlap): """convert into overlapping chunks. Chunk size = 2w, overlap size = w""" # x: B x nh, T, hs # non-overlapping chunks of size = 2w -> B x nh, T//2w, 2w, hs x = x.view( x.size(0), x.size(1) // (window_overlap * 2), window_overlap * 2, x.size(2), ) # use `as_strided` to make the chunks overlap with an overlap size = window_overlap chunk_size = list(x.size()) chunk_size[1] = chunk_size[1] * 2 - 1 chunk_stride = list(x.stride()) chunk_stride[1] = chunk_stride[1] // 2 # B x nh, #chunks = T//w - 1, 2w, hs return x.as_strided(size=chunk_size, stride=chunk_stride) @staticmethod def _pad_and_transpose_last_two_dims(x, padding): """pads rows and then flips rows and columns""" # padding value is not important because it will be overwritten x = nn.functional.pad(x, padding) x = x.view(*x.size()[:-2], x.size(-1), x.size(-2)) return x @staticmethod def _mask_invalid_locations(input_tensor, affected_seq_len): beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0]) beginning_mask = beginning_mask_2d[None, :, None, :] ending_mask = beginning_mask.flip(dims=(1, 3)) beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] beginning_mask = beginning_mask.expand(beginning_input.size()) # `== 1` converts to bool or uint8 beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] ending_mask = ending_mask.expand(ending_input.size()) # `== 1` converts to bool or uint8 ending_input.masked_fill_(ending_mask == 1, -float("inf")) @staticmethod def _pad_and_diagonalize(x): """ shift every row 1 step right, converting columns into diagonals. Example:: chunked_hidden_states: [ 0.4983, 2.6918, -0.0071, 1.0492, -1.8348, 0.7672, 0.2986, 0.0285, -0.7584, 0.4206, -0.0405, 0.1599, 2.0514, -1.1600, 0.5372, 0.2629 ] window_overlap = num_rows = 4 (pad & diagonalize) => [ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000 0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000 0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ] """ total_num_heads, num_chunks, window_overlap, hidden_dim = x.size() # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). x = nn.functional.pad( x, (0, window_overlap + 1) ) # total_num_heads x num_chunks x window_overlap*window_overlap+window_overlap x = x.view(total_num_heads, num_chunks, -1) # total_num_heads x num_chunks x window_overlap*window_overlap x = x[:, :, :-window_overlap] x = x.view( total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim ) x = x[:, :, :, :-1] return x def _sliding_chunks_query_key_matmul( self, query, key, num_heads, window_overlap ): """ Matrix multiplication of query and key tensors using with a sliding window attention pattern. This implementation splits the input into overlapping chunks of size 2w with an overlap of size w (window_overlap) """ # query / key: B*nh, T, hs bnh, seq_len, head_dim = query.size() batch_size = bnh // num_heads assert seq_len % (window_overlap * 2) == 0 assert query.size() == key.size() chunks_count = seq_len // window_overlap - 1 # B * num_heads, head_dim, #chunks=(T//w - 1), 2w chunk_query = self._chunk(query, window_overlap) chunk_key = self._chunk(key, window_overlap) # matrix multiplication # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim # bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap diagonal_chunked_attention_scores = torch.einsum( "bcxd,bcyd->bcxy", (chunk_query, chunk_key)) # convert diagonals into columns # B * num_heads, #chunks, 2w, 2w+1 diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims( diagonal_chunked_attention_scores, padding=(0, 0, 0, 1) ) # allocate space for the overall attention matrix where the chunks are combined. The last dimension # has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to # window_overlap previous words). The following column is attention score from each word to itself, then # followed by window_overlap columns for the upper triangle. diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty( (batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1) ) # copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions # - copying the main diagonal and the upper triangle diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[ :, :, :window_overlap, : window_overlap + 1 ] diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[ :, -1, window_overlap:, : window_overlap + 1 ] # - copying the lower triangle diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[ :, :, -(window_overlap + 1) : -1, window_overlap + 1 : ] diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[ :, 0, : window_overlap - 1, 1 - window_overlap : ] # separate batch_size and num_heads dimensions again diagonal_attention_scores = diagonal_attention_scores.view( batch_size, num_heads, seq_len, 2 * window_overlap + 1 ).transpose(2, 1) self._mask_invalid_locations(diagonal_attention_scores, window_overlap) return diagonal_attention_scores def _sliding_chunks_matmul_attn_probs_value( self, attn_probs, value, num_heads, window_overlap ): """ Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the same shape as `attn_probs` """ bnh, seq_len, head_dim = value.size() batch_size = bnh // num_heads assert seq_len % (window_overlap * 2) == 0 assert attn_probs.size(3) == 2 * window_overlap + 1 chunks_count = seq_len // window_overlap - 1 # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap chunked_attn_probs = attn_probs.transpose(1, 2).reshape( batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1 ) # pad seq_len with w at the beginning of the sequence and another window overlap at the end padded_value = nn.functional.pad(value, (0, 0, window_overlap, window_overlap), value=-1) # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap chunked_value_size = (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim) chunked_value_stride = padded_value.stride() chunked_value_stride = ( chunked_value_stride[0], window_overlap * chunked_value_stride[1], chunked_value_stride[1], chunked_value_stride[2], ) chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride) chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs) context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value)) return context.view(batch_size, num_heads, seq_len, head_dim) def forward(self, x, mask): # x: batch size, feature channel, sequence length, # mask: batch size, 1, sequence length (bool) B, C, T = x.size() # step 1: depth convolutions # query conv -> (B, nh * hs, T') q, qx_mask = self.query_conv(x, mask) q = self.query_norm(q) # key, value conv -> (B, nh * hs, T'') k, kv_mask = self.key_conv(x, mask) k = self.key_norm(k) v, _ = self.value_conv(x, mask) v = self.value_norm(v) # step 2: query, key, value transforms & reshape # projections q = self.query(q) k = self.key(k) v = self.value(v) # (B, nh * hs, T) -> (B, nh, T, hs) q = q.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) k = k.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) v = v.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) # view as (B * nh, T, hs) q = q.view(B * self.n_head, -1, self.n_channels).contiguous() k = k.view(B * self.n_head, -1, self.n_channels).contiguous() v = v.view(B * self.n_head, -1, self.n_channels).contiguous() # step 3: compute local self-attention with rel pe and masking q *= self.scale # chunked query key attention -> B, T, nh, 2w+1 = window_size att = self._sliding_chunks_query_key_matmul( q, k, self.n_head, self.window_overlap) # rel pe if self.use_rel_pe: att += self.rel_pe # kv_mask -> B, T'', 1 inverse_kv_mask = torch.logical_not( kv_mask[:, :, :, None].view(B, -1, 1)) # 0 for valid slot, -inf for masked ones float_inverse_kv_mask = inverse_kv_mask.type_as(q).masked_fill( inverse_kv_mask, -1e4) # compute the diagonal mask (for each local window) diagonal_mask = self._sliding_chunks_query_key_matmul( float_inverse_kv_mask.new_ones(size=float_inverse_kv_mask.size()), float_inverse_kv_mask, 1, self.window_overlap ) att += diagonal_mask # ignore input masking for now att = nn.functional.softmax(att, dim=-1) # softmax sometimes inserts NaN if all positions are masked, replace them with 0 att = att.masked_fill( torch.logical_not(kv_mask.squeeze(1)[:, :, None, None]), 0.0) att = self.attn_drop(att) # step 4: compute attention value product + output projection # chunked attn value product -> B, nh, T, hs out = self._sliding_chunks_matmul_attn_probs_value( att, v, self.n_head, self.window_overlap) # transpose to B, nh, hs, T -> B, nh*hs, T out = out.transpose(2, 3).contiguous().view(B, C, -1) # output projection + skip connection out = self.proj_drop(self.proj(out)) * qx_mask.to(out.dtype) return out, qx_mask class TransformerBlock(nn.Module): """ A simple (post layer norm) Transformer block Modified from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py """ def __init__( self, n_embd, # dimension of the input features n_head, # number of attention heads n_ds_strides=(1, 1), # downsampling strides for q & x, k & v n_out=None, # output dimension, if None, set to input dim n_hidden=None, # dimension of the hidden layer in MLP act_layer=nn.GELU, # nonlinear activation used in MLP, default GELU attn_pdrop=0.0, # dropout rate for the attention map proj_pdrop=0.0, # dropout rate for the projection / MLP path_pdrop=0.0, # drop path rate mha_win_size=-1, # > 0 to use window mha use_rel_pe=False # if to add rel position encoding to attention ): super().__init__() assert len(n_ds_strides) == 2 # layer norm for order (B C T) self.ln1 = LayerNorm(n_embd) self.ln2 = LayerNorm(n_embd) # specify the attention module if mha_win_size > 1: self.attn = LocalMaskedMHCA( n_embd, n_head, window_size=mha_win_size, n_qx_stride=n_ds_strides[0], n_kv_stride=n_ds_strides[1], attn_pdrop=attn_pdrop, proj_pdrop=proj_pdrop, use_rel_pe=use_rel_pe # only valid for local attention ) else: self.attn = MaskedMHCA( n_embd, n_head, n_qx_stride=n_ds_strides[0], n_kv_stride=n_ds_strides[1], attn_pdrop=attn_pdrop, proj_pdrop=proj_pdrop ) # input if n_ds_strides[0] > 1: kernel_size, stride, padding = \ n_ds_strides[0] + 1, n_ds_strides[0], (n_ds_strides[0] + 1)//2 self.pool_skip = nn.MaxPool1d( kernel_size, stride=stride, padding=padding) else: self.pool_skip = nn.Identity() # two layer mlp if n_hidden is None: n_hidden = 4 * n_embd # default if n_out is None: n_out = n_embd # ok to use conv1d here with stride=1 self.mlp = nn.Sequential( nn.Conv1d(n_embd, n_hidden, 1), act_layer(), nn.Dropout(proj_pdrop, inplace=True), nn.Conv1d(n_hidden, n_out, 1), nn.Dropout(proj_pdrop, inplace=True), ) # drop path if path_pdrop > 0.0: self.drop_path_attn = AffineDropPath(n_embd, drop_prob = path_pdrop) self.drop_path_mlp = AffineDropPath(n_out, drop_prob = path_pdrop) else: self.drop_path_attn = nn.Identity() self.drop_path_mlp = nn.Identity() def forward(self, x, mask, pos_embd=None): # pre-LN transformer: https://arxiv.org/pdf/2002.04745.pdf out, out_mask = self.attn(self.ln1(x), mask) out_mask_float = out_mask.to(out.dtype) out = self.pool_skip(x) * out_mask_float + self.drop_path_attn(out) # FFN out = out + self.drop_path_mlp(self.mlp(self.ln2(out)) * out_mask_float) # optionally add pos_embd to the output if pos_embd is not None: out += pos_embd * out_mask_float return out, out_mask class ConvBlock(nn.Module): """ A simple conv block similar to the basic block used in ResNet """ def __init__( self, n_embd, # dimension of the input features kernel_size=3, # conv kernel size n_ds_stride=1, # downsampling stride for the current layer expansion_factor=2, # expansion factor of feat dims n_out=None, # output dimension, if None, set to input dim act_layer=nn.ReLU, # nonlinear activation used after conv, default ReLU ): super().__init__() # must use odd sized kernel assert (kernel_size % 2 == 1) and (kernel_size > 1) padding = kernel_size // 2 if n_out is None: n_out = n_embd # 1x3 (strided) -> 1x3 (basic block in resnet) width = n_embd * expansion_factor self.conv1 = MaskedConv1D( n_embd, width, kernel_size, n_ds_stride, padding=padding) self.conv2 = MaskedConv1D( width, n_out, kernel_size, 1, padding=padding) # attach downsampling conv op if n_ds_stride > 1: # 1x1 strided conv (same as resnet) self.downsample = MaskedConv1D(n_embd, n_out, 1, n_ds_stride) else: self.downsample = None self.act = act_layer() def forward(self, x, mask, pos_embd=None): identity = x out, out_mask = self.conv1(x, mask) out = self.act(out) out, out_mask = self.conv2(out, out_mask) # downsampling if self.downsample is not None: identity, _ = self.downsample(x, mask) # residual connection out += identity out = self.act(out) return out, out_mask # drop path: from https://github.com/facebookresearch/SlowFast/blob/master/slowfast/models/common.py class Scale(nn.Module): """ Multiply the output regression range by a learnable constant value """ def __init__(self, init_value=1.0): """ init_value : initial value for the scalar """ super().__init__() self.scale = nn.Parameter( torch.tensor(init_value, dtype=torch.float32), requires_grad=True ) def forward(self, x): """ input -> scale * input """ return x * self.scale # The follow code is modified from # https://github.com/facebookresearch/SlowFast/blob/master/slowfast/models/common.py def drop_path(x, drop_prob=0.0, training=False): """ Stochastic Depth per sample. """ if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * ( x.ndim - 1 ) # work with diff dim tensors, not just 2D ConvNets mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) mask.floor_() # binarize output = x.div(keep_prob) * mask return output class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) class AffineDropPath(nn.Module): """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks) with a per channel scaling factor (and zero init) See: https://arxiv.org/pdf/2103.17239.pdf """ def __init__(self, num_dim, drop_prob=0.0, init_scale_value=1e-4): super().__init__() self.scale = nn.Parameter( init_scale_value * torch.ones((1, num_dim, 1)), requires_grad=True ) self.drop_prob = drop_prob def forward(self, x): return drop_path(self.scale * x, self.drop_prob, self.training)