|
|
import math |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torch.nn as nn |
|
|
|
|
|
from .conv import ConvModule |
|
|
from ..builder import MODELS |
|
|
|
|
|
|
|
|
@MODELS.register_module() |
|
|
class TransformerBlock(nn.Module): |
|
|
""" |
|
|
Adapted from https://github.com/happyharrycn/actionformer_release/blob/main/libs/modeling/blocks.py#L644 |
|
|
|
|
|
Originally modified from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_channels, |
|
|
n_head, |
|
|
n_ds_strides=(1, 1), |
|
|
n_out=None, |
|
|
n_hidden=None, |
|
|
act_layer=nn.GELU, |
|
|
attn_pdrop=0.0, |
|
|
proj_pdrop=0.0, |
|
|
path_pdrop=0.0, |
|
|
mha_win_size=-1, |
|
|
): |
|
|
super().__init__() |
|
|
assert len(n_ds_strides) == 2 |
|
|
|
|
|
|
|
|
self.ln1 = nn.LayerNorm(in_channels) |
|
|
self.ln2 = nn.LayerNorm(in_channels) |
|
|
|
|
|
|
|
|
if mha_win_size > 1: |
|
|
self.attn = LocalMaskedMHCA( |
|
|
in_channels, |
|
|
n_head, |
|
|
window_size=mha_win_size, |
|
|
n_qx_stride=n_ds_strides[0], |
|
|
n_kv_stride=n_ds_strides[1], |
|
|
attn_pdrop=attn_pdrop, |
|
|
proj_pdrop=proj_pdrop, |
|
|
) |
|
|
else: |
|
|
self.attn = MaskedMHCA( |
|
|
in_channels, |
|
|
n_head, |
|
|
n_qx_stride=n_ds_strides[0], |
|
|
n_kv_stride=n_ds_strides[1], |
|
|
attn_pdrop=attn_pdrop, |
|
|
proj_pdrop=proj_pdrop, |
|
|
) |
|
|
|
|
|
|
|
|
if n_ds_strides[0] > 1: |
|
|
kernel_size, stride, padding = n_ds_strides[0] + 1, n_ds_strides[0], (n_ds_strides[0] + 1) // 2 |
|
|
self.pool_skip = nn.MaxPool1d(kernel_size, stride=stride, padding=padding) |
|
|
else: |
|
|
self.pool_skip = nn.Identity() |
|
|
|
|
|
|
|
|
if n_hidden is None: |
|
|
n_hidden = 4 * in_channels |
|
|
if n_out is None: |
|
|
n_out = in_channels |
|
|
|
|
|
self.mlp = nn.Sequential( |
|
|
nn.Conv1d(in_channels, n_hidden, 1), |
|
|
act_layer(), |
|
|
nn.Dropout(proj_pdrop, inplace=True), |
|
|
nn.Conv1d(n_hidden, n_out, 1), |
|
|
nn.Dropout(proj_pdrop, inplace=True), |
|
|
) |
|
|
|
|
|
|
|
|
if path_pdrop > 0.0: |
|
|
self.drop_path_attn = AffineDropPath(in_channels, drop_prob=path_pdrop) |
|
|
self.drop_path_mlp = AffineDropPath(n_out, drop_prob=path_pdrop) |
|
|
else: |
|
|
self.drop_path_attn = nn.Identity() |
|
|
self.drop_path_mlp = nn.Identity() |
|
|
|
|
|
def forward(self, x, mask): |
|
|
|
|
|
out, out_mask = self.attn(self.ln1(x.permute(0, 2, 1)).permute(0, 2, 1), mask) |
|
|
out_mask_float = out_mask.to(out.dtype) |
|
|
out = self.pool_skip(x) * out_mask_float.unsqueeze(1) + self.drop_path_attn(out) |
|
|
|
|
|
out = out + self.drop_path_mlp( |
|
|
self.mlp(self.ln2(out.permute(0, 2, 1)).permute(0, 2, 1)) * out_mask_float.unsqueeze(1) |
|
|
) |
|
|
return out, out_mask |
|
|
|
|
|
|
|
|
@MODELS.register_module() |
|
|
class MaskedMHCA(nn.Module): |
|
|
""" |
|
|
Multi Head Conv Attention with mask |
|
|
|
|
|
Add a depthwise convolution within a standard MHA |
|
|
The extra conv op can be used to |
|
|
(1) encode relative position information (replace position encoding); |
|
|
(2) downsample the features if needed; |
|
|
(3) match the feature channels |
|
|
|
|
|
Note: With current implementation, the downsample feature will be aligned |
|
|
to every s+1 time step, where s is the downsampling stride. This allows us |
|
|
to easily interpolate the corresponding positional embeddings. |
|
|
|
|
|
Modified from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
n_embd, |
|
|
n_head, |
|
|
n_qx_stride=1, |
|
|
n_kv_stride=1, |
|
|
attn_pdrop=0.0, |
|
|
proj_pdrop=0.0, |
|
|
): |
|
|
super().__init__() |
|
|
assert n_embd % n_head == 0 |
|
|
self.n_embd = n_embd |
|
|
self.n_head = n_head |
|
|
self.n_channels = n_embd // n_head |
|
|
self.scale = 1.0 / math.sqrt(self.n_channels) |
|
|
|
|
|
|
|
|
assert (n_qx_stride == 1) or (n_qx_stride % 2 == 0) |
|
|
assert (n_kv_stride == 1) or (n_kv_stride % 2 == 0) |
|
|
self.n_qx_stride = n_qx_stride |
|
|
self.n_kv_stride = n_kv_stride |
|
|
|
|
|
|
|
|
kernel_size = self.n_qx_stride + 1 if self.n_qx_stride > 1 else 3 |
|
|
stride, padding = self.n_kv_stride, kernel_size // 2 |
|
|
self.query_conv = ConvModule( |
|
|
self.n_embd, |
|
|
self.n_embd, |
|
|
kernel_size=kernel_size, |
|
|
stride=stride, |
|
|
padding=padding, |
|
|
conv_cfg=dict(groups=n_embd, bias=False), |
|
|
) |
|
|
self.query_norm = nn.LayerNorm(self.n_embd) |
|
|
|
|
|
|
|
|
kernel_size = self.n_kv_stride + 1 if self.n_kv_stride > 1 else 3 |
|
|
stride, padding = self.n_kv_stride, kernel_size // 2 |
|
|
self.key_conv = ConvModule( |
|
|
self.n_embd, |
|
|
self.n_embd, |
|
|
kernel_size=kernel_size, |
|
|
stride=stride, |
|
|
padding=padding, |
|
|
conv_cfg=dict(groups=n_embd, bias=False), |
|
|
) |
|
|
self.key_norm = nn.LayerNorm(self.n_embd) |
|
|
|
|
|
self.value_conv = ConvModule( |
|
|
self.n_embd, |
|
|
self.n_embd, |
|
|
kernel_size=kernel_size, |
|
|
stride=stride, |
|
|
padding=padding, |
|
|
conv_cfg=dict(groups=n_embd, bias=False), |
|
|
) |
|
|
self.value_norm = nn.LayerNorm(self.n_embd) |
|
|
|
|
|
|
|
|
|
|
|
self.key = nn.Conv1d(self.n_embd, self.n_embd, 1) |
|
|
self.query = nn.Conv1d(self.n_embd, self.n_embd, 1) |
|
|
self.value = nn.Conv1d(self.n_embd, self.n_embd, 1) |
|
|
|
|
|
|
|
|
self.attn_drop = nn.Dropout(attn_pdrop) |
|
|
self.proj_drop = nn.Dropout(proj_pdrop) |
|
|
|
|
|
|
|
|
self.proj = nn.Conv1d(self.n_embd, self.n_embd, 1) |
|
|
|
|
|
def forward(self, x, mask): |
|
|
|
|
|
|
|
|
B, C, T = x.size() |
|
|
|
|
|
|
|
|
q, qx_mask = self.query_conv(x, mask) |
|
|
q = self.query_norm(q.permute(0, 2, 1)).permute(0, 2, 1) |
|
|
|
|
|
k, kv_mask = self.key_conv(x, mask) |
|
|
k = self.key_norm(k.permute(0, 2, 1)).permute(0, 2, 1) |
|
|
v, _ = self.value_conv(x, mask) |
|
|
v = self.value_norm(v.permute(0, 2, 1)).permute(0, 2, 1) |
|
|
|
|
|
|
|
|
q = self.query(q) |
|
|
k = self.key(k) |
|
|
v = self.value(v) |
|
|
|
|
|
|
|
|
|
|
|
k = k.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) |
|
|
q = q.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) |
|
|
v = v.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) |
|
|
|
|
|
|
|
|
att = (q * self.scale) @ k.transpose(-2, -1) |
|
|
|
|
|
att = att.masked_fill(torch.logical_not(kv_mask[:, None, None, :]), float("-inf")) |
|
|
|
|
|
att = F.softmax(att, dim=-1) |
|
|
att = self.attn_drop(att) |
|
|
|
|
|
out = att @ (v * kv_mask[:, None, :, None].to(v.dtype)) |
|
|
|
|
|
out = out.transpose(2, 3).contiguous().view(B, C, -1) |
|
|
|
|
|
|
|
|
out = self.proj_drop(self.proj(out)) * qx_mask.unsqueeze(1).to(out.dtype) |
|
|
return out, qx_mask |
|
|
|
|
|
|
|
|
@MODELS.register_module() |
|
|
class LocalMaskedMHCA(nn.Module): |
|
|
""" |
|
|
Local Multi Head Conv Attention with mask |
|
|
|
|
|
Add a depthwise convolution within a standard MHA |
|
|
The extra conv op can be used to |
|
|
(1) encode relative position information (replace position encoding); |
|
|
(2) downsample the features if needed; |
|
|
(3) match the feature channels |
|
|
|
|
|
Note: With current implementation, the downsample feature will be aligned |
|
|
to every s+1 time step, where s is the downsampling stride. This allows us |
|
|
to easily interpolate the corresponding positional embeddings. |
|
|
|
|
|
The implementation is fairly tricky, code reference from |
|
|
https://github.com/huggingface/transformers/blob/master/src/transformers/models/longformer/modeling_longformer.py |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
n_embd, |
|
|
n_head, |
|
|
window_size, |
|
|
n_qx_stride=1, |
|
|
n_kv_stride=1, |
|
|
attn_pdrop=0.0, |
|
|
proj_pdrop=0.0, |
|
|
use_rel_pe=False, |
|
|
): |
|
|
super().__init__() |
|
|
assert n_embd % n_head == 0 |
|
|
self.n_embd = n_embd |
|
|
self.n_head = n_head |
|
|
self.n_channels = n_embd // n_head |
|
|
self.scale = 1.0 / math.sqrt(self.n_channels) |
|
|
self.window_size = window_size |
|
|
self.window_overlap = window_size // 2 |
|
|
|
|
|
assert self.window_size > 1 and self.n_head >= 1 |
|
|
self.use_rel_pe = use_rel_pe |
|
|
|
|
|
|
|
|
assert (n_qx_stride == 1) or (n_qx_stride % 2 == 0) |
|
|
assert (n_kv_stride == 1) or (n_kv_stride % 2 == 0) |
|
|
self.n_qx_stride = n_qx_stride |
|
|
self.n_kv_stride = n_kv_stride |
|
|
|
|
|
|
|
|
kernel_size = self.n_qx_stride + 1 if self.n_qx_stride > 1 else 3 |
|
|
stride, padding = self.n_kv_stride, kernel_size // 2 |
|
|
self.query_conv = ConvModule( |
|
|
self.n_embd, |
|
|
self.n_embd, |
|
|
kernel_size=kernel_size, |
|
|
stride=stride, |
|
|
padding=padding, |
|
|
conv_cfg=dict(groups=n_embd, bias=False), |
|
|
) |
|
|
self.query_norm = nn.LayerNorm(self.n_embd) |
|
|
|
|
|
|
|
|
kernel_size = self.n_kv_stride + 1 if self.n_kv_stride > 1 else 3 |
|
|
stride, padding = self.n_kv_stride, kernel_size // 2 |
|
|
self.key_conv = ConvModule( |
|
|
self.n_embd, |
|
|
self.n_embd, |
|
|
kernel_size=kernel_size, |
|
|
stride=stride, |
|
|
padding=padding, |
|
|
conv_cfg=dict(groups=n_embd, bias=False), |
|
|
) |
|
|
self.key_norm = nn.LayerNorm(self.n_embd) |
|
|
|
|
|
self.value_conv = ConvModule( |
|
|
self.n_embd, |
|
|
self.n_embd, |
|
|
kernel_size=kernel_size, |
|
|
stride=stride, |
|
|
padding=padding, |
|
|
conv_cfg=dict(groups=n_embd, bias=False), |
|
|
) |
|
|
self.value_norm = nn.LayerNorm(self.n_embd) |
|
|
|
|
|
|
|
|
|
|
|
self.key = nn.Conv1d(self.n_embd, self.n_embd, 1) |
|
|
self.query = nn.Conv1d(self.n_embd, self.n_embd, 1) |
|
|
self.value = nn.Conv1d(self.n_embd, self.n_embd, 1) |
|
|
|
|
|
|
|
|
self.attn_drop = nn.Dropout(attn_pdrop) |
|
|
self.proj_drop = nn.Dropout(proj_pdrop) |
|
|
|
|
|
|
|
|
self.proj = nn.Conv1d(self.n_embd, self.n_embd, 1) |
|
|
|
|
|
@staticmethod |
|
|
def _chunk(x, window_overlap): |
|
|
"""convert into overlapping chunks. Chunk size = 2w, overlap size = w""" |
|
|
|
|
|
|
|
|
x = x.view( |
|
|
x.size(0), |
|
|
x.size(1) // (window_overlap * 2), |
|
|
window_overlap * 2, |
|
|
x.size(2), |
|
|
) |
|
|
|
|
|
|
|
|
chunk_size = list(x.size()) |
|
|
chunk_size[1] = chunk_size[1] * 2 - 1 |
|
|
chunk_stride = list(x.stride()) |
|
|
chunk_stride[1] = chunk_stride[1] // 2 |
|
|
|
|
|
|
|
|
return x.as_strided(size=chunk_size, stride=chunk_stride) |
|
|
|
|
|
@staticmethod |
|
|
def _pad_and_transpose_last_two_dims(x, padding): |
|
|
"""pads rows and then flips rows and columns""" |
|
|
|
|
|
x = nn.functional.pad(x, padding) |
|
|
x = x.view(*x.size()[:-2], x.size(-1), x.size(-2)) |
|
|
return x |
|
|
|
|
|
@staticmethod |
|
|
def _mask_invalid_locations(input_tensor, affected_seq_len): |
|
|
beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0]) |
|
|
beginning_mask = beginning_mask_2d[None, :, None, :] |
|
|
ending_mask = beginning_mask.flip(dims=(1, 3)) |
|
|
beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] |
|
|
beginning_mask = beginning_mask.expand(beginning_input.size()) |
|
|
|
|
|
beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) |
|
|
ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] |
|
|
ending_mask = ending_mask.expand(ending_input.size()) |
|
|
|
|
|
ending_input.masked_fill_(ending_mask == 1, -float("inf")) |
|
|
|
|
|
@staticmethod |
|
|
def _pad_and_diagonalize(x): |
|
|
""" |
|
|
shift every row 1 step right, converting columns into diagonals. |
|
|
Example:: |
|
|
chunked_hidden_states: [ 0.4983, 2.6918, -0.0071, 1.0492, |
|
|
-1.8348, 0.7672, 0.2986, 0.0285, |
|
|
-0.7584, 0.4206, -0.0405, 0.1599, |
|
|
2.0514, -1.1600, 0.5372, 0.2629 ] |
|
|
window_overlap = num_rows = 4 |
|
|
(pad & diagonalize) => |
|
|
[ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000 |
|
|
0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000 |
|
|
0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000 |
|
|
0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ] |
|
|
""" |
|
|
total_num_heads, num_chunks, window_overlap, hidden_dim = x.size() |
|
|
|
|
|
x = nn.functional.pad(x, (0, window_overlap + 1)) |
|
|
|
|
|
x = x.view(total_num_heads, num_chunks, -1) |
|
|
|
|
|
x = x[:, :, :-window_overlap] |
|
|
x = x.view(total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim) |
|
|
x = x[:, :, :, :-1] |
|
|
return x |
|
|
|
|
|
def _sliding_chunks_query_key_matmul(self, query, key, num_heads, window_overlap): |
|
|
""" |
|
|
Matrix multiplication of query and key tensors using with a sliding window attention pattern. This implementation splits the input into overlapping chunks of size 2w with an overlap of size w (window_overlap) |
|
|
""" |
|
|
|
|
|
bnh, seq_len, head_dim = query.size() |
|
|
batch_size = bnh // num_heads |
|
|
assert seq_len % (window_overlap * 2) == 0 |
|
|
assert query.size() == key.size() |
|
|
|
|
|
chunks_count = seq_len // window_overlap - 1 |
|
|
|
|
|
|
|
|
chunk_query = self._chunk(query, window_overlap) |
|
|
chunk_key = self._chunk(key, window_overlap) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
diagonal_chunked_attention_scores = torch.einsum("bcxd,bcyd->bcxy", (chunk_query, chunk_key)) |
|
|
|
|
|
|
|
|
|
|
|
diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims( |
|
|
diagonal_chunked_attention_scores, padding=(0, 0, 0, 1) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty( |
|
|
(batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
diagonal_attention_scores[:, :-1, :, window_overlap:] = diagonal_chunked_attention_scores[ |
|
|
:, :, :window_overlap, : window_overlap + 1 |
|
|
] |
|
|
diagonal_attention_scores[:, -1, :, window_overlap:] = diagonal_chunked_attention_scores[ |
|
|
:, -1, window_overlap:, : window_overlap + 1 |
|
|
] |
|
|
|
|
|
diagonal_attention_scores[:, 1:, :, :window_overlap] = diagonal_chunked_attention_scores[ |
|
|
:, :, -(window_overlap + 1) : -1, window_overlap + 1 : |
|
|
] |
|
|
|
|
|
diagonal_attention_scores[:, 0, 1:window_overlap, 1:window_overlap] = diagonal_chunked_attention_scores[ |
|
|
:, 0, : window_overlap - 1, 1 - window_overlap : |
|
|
] |
|
|
|
|
|
|
|
|
diagonal_attention_scores = diagonal_attention_scores.view( |
|
|
batch_size, num_heads, seq_len, 2 * window_overlap + 1 |
|
|
).transpose(2, 1) |
|
|
|
|
|
self._mask_invalid_locations(diagonal_attention_scores, window_overlap) |
|
|
return diagonal_attention_scores |
|
|
|
|
|
def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, num_heads, window_overlap): |
|
|
""" |
|
|
Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the |
|
|
same shape as `attn_probs` |
|
|
""" |
|
|
bnh, seq_len, head_dim = value.size() |
|
|
batch_size = bnh // num_heads |
|
|
assert seq_len % (window_overlap * 2) == 0 |
|
|
assert attn_probs.size(3) == 2 * window_overlap + 1 |
|
|
chunks_count = seq_len // window_overlap - 1 |
|
|
|
|
|
|
|
|
chunked_attn_probs = attn_probs.transpose(1, 2).reshape( |
|
|
batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1 |
|
|
) |
|
|
|
|
|
|
|
|
padded_value = nn.functional.pad(value, (0, 0, window_overlap, window_overlap), value=-1) |
|
|
|
|
|
|
|
|
chunked_value_size = (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim) |
|
|
chunked_value_stride = padded_value.stride() |
|
|
chunked_value_stride = ( |
|
|
chunked_value_stride[0], |
|
|
window_overlap * chunked_value_stride[1], |
|
|
chunked_value_stride[1], |
|
|
chunked_value_stride[2], |
|
|
) |
|
|
chunked_value = padded_value.as_strided(size=chunked_value_size, stride=chunked_value_stride) |
|
|
|
|
|
chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs) |
|
|
|
|
|
context = torch.einsum("bcwd,bcdh->bcwh", (chunked_attn_probs, chunked_value)) |
|
|
return context.view(batch_size, num_heads, seq_len, head_dim) |
|
|
|
|
|
def forward(self, x, mask): |
|
|
|
|
|
|
|
|
B, C, T = x.size() |
|
|
|
|
|
|
|
|
|
|
|
q, qx_mask = self.query_conv(x, mask) |
|
|
q = self.query_norm(q.permute(0, 2, 1)).permute(0, 2, 1) |
|
|
|
|
|
k, kv_mask = self.key_conv(x, mask) |
|
|
k = self.key_norm(k.permute(0, 2, 1)).permute(0, 2, 1) |
|
|
v, _ = self.value_conv(x, mask) |
|
|
v = self.value_norm(v.permute(0, 2, 1)).permute(0, 2, 1) |
|
|
|
|
|
|
|
|
|
|
|
q = self.query(q) |
|
|
k = self.key(k) |
|
|
v = self.value(v) |
|
|
|
|
|
q = q.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) |
|
|
k = k.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) |
|
|
v = v.view(B, self.n_head, self.n_channels, -1).transpose(2, 3) |
|
|
|
|
|
q = q.view(B * self.n_head, -1, self.n_channels).contiguous() |
|
|
k = k.view(B * self.n_head, -1, self.n_channels).contiguous() |
|
|
v = v.view(B * self.n_head, -1, self.n_channels).contiguous() |
|
|
|
|
|
|
|
|
q *= self.scale |
|
|
|
|
|
att = self._sliding_chunks_query_key_matmul(q, k, self.n_head, self.window_overlap) |
|
|
|
|
|
|
|
|
if self.use_rel_pe: |
|
|
att += self.rel_pe |
|
|
|
|
|
inverse_kv_mask = torch.logical_not(kv_mask[:, None, :, None].view(B, -1, 1)) |
|
|
|
|
|
float_inverse_kv_mask = inverse_kv_mask.type_as(q).masked_fill(inverse_kv_mask, -1e4) |
|
|
|
|
|
diagonal_mask = self._sliding_chunks_query_key_matmul( |
|
|
float_inverse_kv_mask.new_ones(size=float_inverse_kv_mask.size()), |
|
|
float_inverse_kv_mask, |
|
|
1, |
|
|
self.window_overlap, |
|
|
) |
|
|
att += diagonal_mask |
|
|
|
|
|
|
|
|
att = nn.functional.softmax(att, dim=-1) |
|
|
|
|
|
att = att.masked_fill(torch.logical_not(kv_mask[:, :, None, None]), 0.0) |
|
|
att = self.attn_drop(att) |
|
|
|
|
|
|
|
|
|
|
|
out = self._sliding_chunks_matmul_attn_probs_value(att, v, self.n_head, self.window_overlap) |
|
|
|
|
|
out = out.transpose(2, 3).contiguous().view(B, C, -1) |
|
|
|
|
|
out = self.proj_drop(self.proj(out)) * qx_mask.unsqueeze(1).to(out.dtype) |
|
|
return out, qx_mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def drop_path(x, drop_prob=0.0, training=False): |
|
|
""" |
|
|
Stochastic Depth per sample. |
|
|
""" |
|
|
if drop_prob == 0.0 or not training: |
|
|
return x |
|
|
keep_prob = 1 - drop_prob |
|
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
|
|
mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) |
|
|
mask.floor_() |
|
|
output = x.div(keep_prob) * mask |
|
|
return output |
|
|
|
|
|
|
|
|
class DropPath(nn.Module): |
|
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" |
|
|
|
|
|
def __init__(self, drop_prob=None): |
|
|
super(DropPath, self).__init__() |
|
|
self.drop_prob = drop_prob |
|
|
|
|
|
def forward(self, x): |
|
|
return drop_path(x, self.drop_prob, self.training) |
|
|
|
|
|
|
|
|
class AffineDropPath(nn.Module): |
|
|
""" |
|
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks) with a per channel scaling factor (and zero init) |
|
|
See: https://arxiv.org/pdf/2103.17239.pdf |
|
|
""" |
|
|
|
|
|
def __init__(self, num_dim, drop_prob=0.0, init_scale_value=1e-4): |
|
|
super().__init__() |
|
|
self.scale = nn.Parameter(init_scale_value * torch.ones((1, num_dim, 1)), requires_grad=True) |
|
|
self.drop_prob = drop_prob |
|
|
|
|
|
def forward(self, x): |
|
|
return drop_path(self.scale * x, self.drop_prob, self.training) |
|
|
|