InPeerReview's picture
Upload 161 files
226675b verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer
from torch.nn.functional import linear, softmax
from torch.nn.parameter import Parameter
from torch.nn import Linear
from torch.nn.init import xavier_uniform_
from torch.nn.init import constant_
class Conv2d_BN(nn.Sequential):
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
groups=1, bn_weight_init=1, bias=False,
norm_cfg=dict(type='BN', requires_grad=True)):
super().__init__()
self.inp_channel = a
self.out_channel = b
self.ks = ks
self.pad = pad
self.stride = stride
self.dilation = dilation
self.groups = groups
# self.bias = bias
self.add_module('c', nn.Conv2d(
a, b, ks, stride, pad, dilation, groups, bias=bias))
bn = build_norm_layer(norm_cfg, b)[1]
nn.init.constant_(bn.weight, bn_weight_init)
nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn)
class Conv1d_BN(nn.Sequential):
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
groups=1, bn_weight_init=1, bias=False,
norm_cfg=dict(type='BN1d', requires_grad=True)):
super().__init__()
self.inp_channel = a
self.out_channel = b
self.ks = ks
self.pad = pad
self.stride = stride
self.dilation = dilation
self.groups = groups
# self.bias = bias
self.add_module('c', nn.Conv1d(
a, b, ks, stride, pad, dilation, groups, bias=bias))
bn = build_norm_layer(norm_cfg, b)[1]
nn.init.constant_(bn.weight, bn_weight_init)
nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn)
class DEACA_attention(torch.nn.Module):
def __init__(self, dim, num_heads,
activation=nn.ReLU, ):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = head_dim = dim // num_heads
self.scaling = float(head_dim) ** -0.5
self.to_q_row = Conv1d_BN(dim, dim, 1)
self.to_q_col = Conv1d_BN(dim, dim, 1)
self.to_k_row = Conv2d_BN(dim, dim, 1)
self.to_k_col = Conv2d_BN(dim, dim, 1)
self.to_v = Conv2d_BN(dim, dim, 1)
self.proj = torch.nn.Sequential(activation(), Conv1d_BN(
dim, dim, bn_weight_init=0))
self.proj_encode_row = torch.nn.Sequential(activation(), Conv1d_BN(
dim, dim, bn_weight_init=0))
self.proj_encode_column = torch.nn.Sequential(activation(), Conv1d_BN(
dim, dim, bn_weight_init=0))
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.pwconv = Conv2d_BN(head_dim, head_dim, ks=1)
self.sigmoid = nn.Sigmoid()
def forward(self, q_row, q_col, k_row, k_col, v):
# q [4,5,256]
# k [4,128,128,256]
# v [4,128,128,256]
_, tgt_len, _ = q_row.shape
B, H, W, C = v.shape # [4,128,128,256]
q_row = self.to_q_row(q_row.transpose(1, 2)) # [4,256,5]
q_col = self.to_q_col(q_col.transpose(1, 2)) # [4,256,5]
k_row = self.to_k_row(k_row.permute(0, 3, 1, 2)) # [4,256,128,128]
k_col = self.to_k_col(k_col.permute(0, 3, 1, 2)) # [4,256,128,128]
v = self.to_v(v.permute(0, 3, 1, 2)) # [4,256,128,128]
q_row = q_row.permute(2, 0, 1) # [5,4,256]
q_col = q_col.permute(2, 0, 1) # [5,4,256]
k_row = k_row.mean(-1).permute(2, 0, 1) # [128,4,256]
k_col = k_col.mean(-2).permute(2, 0, 1) # [128,4,256]
q_row = q_row.contiguous().view(tgt_len, B * self.num_heads, self.head_dim).transpose(0, 1) # [32,5,32]
q_col = q_col.contiguous().view(tgt_len, B * self.num_heads, self.head_dim).transpose(0, 1) # [32,5,32]
k_row = k_row.contiguous().view(-1, B * self.num_heads, self.head_dim).transpose(0, 1) # [32,128,32]
k_col = k_col.contiguous().view(-1, B * self.num_heads, self.head_dim).transpose(0, 1) # [32,128,32]
v = v.contiguous().permute(1,2,0,3).reshape(H, W, B * self.num_heads, self.head_dim).permute(2,0,1,3) # [32,128,128,32]
v_avg = self.sigmoid(self.pwconv(self.avg_pool(v.permute(0,3,1,2)))).squeeze(-1).permute(0,2,1) # [32, 1, 32]
k_row = k_row * v_avg
k_col = k_col * v_avg
v_row = v.mean(2) # [32,128,32]
v_col = v.mean(1) # [32,128,32]
attn_row = torch.matmul(q_row, k_row.transpose(1, 2)) * self.scaling # [32,5,128]
attn_row = attn_row.softmax(dim=-1)
xx_row = torch.matmul(attn_row, v_row) # [32,5,32]
xx_row = self.proj_encode_row(xx_row.permute(0, 2, 1).reshape(B, self.dim, tgt_len)) # [4,256,5]
attn_col = torch.matmul(q_col, k_col.transpose(1, 2)) * self.scaling
attn_col = attn_col.softmax(dim=-1)
xx_col = torch.matmul(attn_col, v_col) # [32,5,32]
xx_col = self.proj_encode_column(xx_col.permute(0, 2, 1).reshape(B, self.dim, tgt_len)) # [4,256,5]
xx = xx_row.add(xx_col)
xx = self.proj(xx)
return xx.squeeze(-1).permute(2,0,1)
class ChannelAttention(nn.Module):
def __init__(self, in_channels, reduction_ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc1 = nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False)
self.relu = nn.ReLU()
self.fc2 = nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc2(self.relu(self.fc1(self.avg_pool(x))))
max_out = self.fc2(self.relu(self.fc1(self.max_pool(x))))
out = avg_out + max_out
return self.sigmoid(out)
class DEACA_attention_v3(torch.nn.Module):
# avg + max pool
def __init__(self, embed_dim, num_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = head_dim = embed_dim // num_heads
self.scaling = float(head_dim) ** -0.5
self.in_proj_weight = Parameter(torch.empty(5 * embed_dim, embed_dim))
self.in_proj_bias = Parameter(torch.empty(5 * embed_dim))
self.proj_encode_row = Linear(embed_dim, embed_dim, bias=True)
self.proj_encode_col = Linear(embed_dim, embed_dim, bias=True)
self.out_proj = Linear(embed_dim, embed_dim, bias=True)
self.avg_pool = torch.nn.AdaptiveAvgPool2d(1)
self.conv = torch.nn.Conv2d(head_dim, head_dim, 1)
self.activate = torch.nn.ReLU()
self.sigmoid = torch.nn.Sigmoid()
self._reset_parameters()
def _reset_parameters(self):
xavier_uniform_(self.in_proj_weight)
if self.in_proj_bias is not None:
constant_(self.in_proj_bias, 0.)
constant_(self.out_proj.bias, 0.)
def forward(self, query_row, query_col, key_row, key_col, value):
bsz, tgt_len, embed_dim = query_row.size()
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = self.in_proj_bias
_start = 0
_end = embed_dim
_w = self.in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q_row = linear(query_row, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = self.in_proj_bias
_start = embed_dim * 1
_end = embed_dim * 2
_w = self.in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q_col = linear(query_col, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = self.in_proj_bias
_start = embed_dim * 2
_end = embed_dim * 3
_w = self.in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k_row = linear(key_row, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = self.in_proj_bias
_start = embed_dim * 3
_end = embed_dim * 4
_w = self.in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k_col = linear(key_col, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = self.in_proj_bias
_start = embed_dim * 4
_end = None
_w = self.in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
v = linear(value, _w, _b)
# q [4,5,256]
# k [4,128,128,256]
# v [4,128,128,256]
_k_row, _k_col = k_row, k_col
_, tgt_len, _ = q_row.shape
B, H, W, C = v.shape # [4,128,128,256]
q_row = q_row.transpose(0, 1) # [5,4,256]
q_col = q_col.transpose(0, 1) # [5,4,256]
k_row = k_row.mean(1).transpose(0, 1) # [128,4,256]
k_col = k_col.mean(2).transpose(0, 1) # [128,4,256]
q_row = q_row.contiguous().view(tgt_len, B * self.num_heads, self.head_dim).transpose(0, 1) # [32,5,32]
q_col = q_col.contiguous().view(tgt_len, B * self.num_heads, self.head_dim).transpose(0, 1) # [32,5,32]
k_row = k_row.contiguous().view(-1, B * self.num_heads, self.head_dim).transpose(0, 1) # [32,128,32]
k_col = k_col.contiguous().view(-1, B * self.num_heads, self.head_dim).transpose(0, 1) # [32,128,32]
_k_row = _k_row.contiguous().permute(1,2,0,3).reshape(H, W, B * self.num_heads, self.head_dim).permute(2,0,1,3) # [32,128,128,32]
_k_col = _k_col.contiguous().permute(1,2,0,3).reshape(H, W, B * self.num_heads, self.head_dim).permute(2,0,1,3) # [32,128,128,32]
v = v.contiguous().permute(1,2,0,3).reshape(H, W, B * self.num_heads, self.head_dim).permute(2,0,1,3) # [32,128,128,32]
v_avg = self.sigmoid(self.conv(self.avg_pool(v.permute(0,3,1,2)))).squeeze(-1).permute(0,2,1) # [32, 1, 32]
k_row = k_row * v_avg # [32,128,32]
k_col = k_col * v_avg # [32,128,32]
v_row = v.mean(1)
v_col = v.mean(2)
attn_row = torch.matmul(q_row, k_row.transpose(1, 2)) * self.scaling # [32,5,128]
attn_row = attn_row.softmax(dim=-1)
xx_row = torch.matmul(attn_row, v_row) # [32,5,32]
xx_row = self.proj_encode_row(xx_row.permute(0, 2, 1).reshape(B, self.embed_dim, tgt_len).permute(2,0,1)) # [4,256,5]
attn_col = torch.matmul(q_col, k_col.transpose(1, 2)) * self.scaling
attn_col = attn_col.softmax(dim=-1)
xx_col = torch.matmul(attn_col, v_col) # [32,5,32]
xx_col = self.proj_encode_col(xx_col.permute(0, 2, 1).reshape(B, self.embed_dim, tgt_len).permute(2,0,1)) # [4,256,5]
xx = xx_row.add(xx_col)
xx = self.out_proj(xx)
return xx
# class SEBlock(nn.Module):
# def __init__(self, in_channels, reduction=16):
# super(SEBlock, self).__init__()
# self.avg_pool = nn.AdaptiveAvgPool2d(1)
# self.fc = nn.Sequential(
# nn.Linear(in_channels, in_channels // reduction),
# nn.ReLU(inplace=True),
# nn.Linear(in_channels // reduction, in_channels),
# nn.Sigmoid()
# )
# def forward(self, x):
# b, c, _, _ = x.size()
# y = self.avg_pool(x).view(b, c)
# y = self.fc(y).view(b, c, 1, 1)
# return y
class rcda_rebuild(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.num_heads = num_heads
self.in_proj_weight = Parameter(torch.empty(5 * embed_dim, embed_dim))
self.in_proj_bias = Parameter(torch.empty(5 * embed_dim))
self.out_proj = Linear(embed_dim, embed_dim, bias=True)
self.avg_pool = torch.nn.AdaptiveAvgPool2d(1)
# self.conv = torch.nn.Conv2d(embed_dim, embed_dim, 1)
self.lin = torch.nn.Linear(embed_dim, embed_dim)
# self.activate = torch.nn.ReLU()
self.sigmoid = torch.nn.Sigmoid()
# self.senet = SEBlock(embed_dim)
self._reset_parameters()
def _reset_parameters(self):
xavier_uniform_(self.in_proj_weight)
if self.in_proj_bias is not None:
constant_(self.in_proj_bias, 0.)
constant_(self.out_proj.bias, 0.)
def forward(self, query_row, query_col, key_row, key_col, value):
v_avg = self.sigmoid(self.lin(self.avg_pool(value.permute(0,3,1,2)).squeeze(-1).squeeze(-1))).unsqueeze(-1).permute(2,0,1)
# v_avg = self.senet(value.permute(0,3,1,2)).squeeze(-1).permute(2,0,1)
bsz, tgt_len, embed_dim = query_row.size()
src_len_row = key_row.size()[2]
src_len_col = key_col.size()[1]
num_heads = self.num_heads
head_dim = embed_dim // num_heads
scaling = float(head_dim) ** -0.5
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = self.in_proj_bias
_start = 0
_end = embed_dim
_w = self.in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q_row = linear(query_row, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = self.in_proj_bias
_start = embed_dim * 1
_end = embed_dim * 2
_w = self.in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q_col = linear(query_col, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = self.in_proj_bias
_start = embed_dim * 2
_end = embed_dim * 3
_w = self.in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k_row = linear(key_row, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = self.in_proj_bias
_start = embed_dim * 3
_end = embed_dim * 4
_w = self.in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k_col = linear(key_col, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = self.in_proj_bias
_start = embed_dim * 4
_end = None
_w = self.in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
v = linear(value, _w, _b)
q_row = q_row.transpose(0, 1)
q_col = q_col.transpose(0, 1)
k_row = k_row.mean(1).transpose(0, 1) * v_avg
k_col = k_col.mean(2).transpose(0, 1) * v_avg
q_row = q_row * scaling
q_col = q_col * scaling
q_row = q_row.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
q_col = q_col.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
k_row = k_row.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
k_col = k_col.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
v = v.contiguous().permute(1,2,0,3).reshape(src_len_col,src_len_row, bsz*num_heads, head_dim).permute(2,0,1,3)
attn_output_weights_row = torch.bmm(q_row, k_row.transpose(1, 2))
attn_output_weights_col = torch.bmm(q_col, k_col.transpose(1, 2))
attn_output_weights_col = softmax(attn_output_weights_col, dim=-1)
attn_output_weights_row = softmax(attn_output_weights_row, dim=-1)
b_ein, q_ein, h_ein = attn_output_weights_col.shape
b_ein, h_ein, w_ein, c_ein = v.shape
attn_output_col = torch.matmul(attn_output_weights_col, v.reshape(b_ein, h_ein, w_ein * c_ein)).reshape(b_ein, q_ein, w_ein, c_ein)
attn_output = torch.matmul(attn_output_weights_row[:, :, None, :], attn_output_col).squeeze(-2).permute(1, 0, 2).reshape(tgt_len, bsz, embed_dim)
### the following code base on einsum get the same results
# attn_output_col = torch.einsum("bqh,bhwc->bqwc", attn_output_weights_col, v)
# attn_output = torch.einsum("bqw,bqwc->qbc", attn_output_weights_row, attn_output_col).reshape(tgt_len, bsz,embed_dim)
attn_output = linear(attn_output, self.out_proj.weight, self.out_proj.bias)
return attn_output, torch.einsum("bqw,bqh->qbhw",attn_output_weights_row,attn_output_weights_col).reshape(tgt_len,bsz,num_heads,src_len_col,src_len_row).mean(2)
if __name__ == '__main__':
dim = 256
num_heads = 8
act_layer = nn.ReLU
attn = DEACA_attention(dim, num_heads=num_heads, activation=act_layer)
# q [4,5,256]
# k [4,128,128,256]
# v [4,128,128,256]
q_row = torch.randn((4,5,256))
q_col = torch.randn((4,5,256))
k_row = torch.randn((4,128,128,256))
k_col = torch.randn((4,128,128,256))
v = torch.randn((4,128,128,256))
out = attn(q_row, q_col, k_row, k_col, v)
print(out.shape)