LYL1015's picture
test
eea83e8
import torch
import torch.nn as nn
import warnings
import math
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
class cross_attention(nn.Module):
def __init__(self, dim, num_heads, dropout=0.):
super(cross_attention, self).__init__()
if dim % num_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (dim, num_heads)
)
self.num_heads = num_heads
self.attention_head_size = int(dim / num_heads)
self.query = Depth_conv(in_ch=dim, out_ch=dim)
self.key = Depth_conv(in_ch=dim, out_ch=dim)
self.value = Depth_conv(in_ch=dim, out_ch=dim)
self.dropout = nn.Dropout(dropout)
def transpose_for_scores(self, x):
'''
new_x_shape = x.size()[:-1] + (
self.num_heads,
self.attention_head_size,
)
print(new_x_shape)
x = x.view(*new_x_shape)
'''
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, ctx):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(ctx)
mixed_value_layer = self.value(ctx)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = nn.Softmax(dim=-1)(attention_scores)
attention_probs = self.dropout(attention_probs)
ctx_layer = torch.matmul(attention_probs, value_layer)
ctx_layer = ctx_layer.permute(0, 2, 1, 3).contiguous()
return ctx_layer
class Depth_conv(nn.Module):
def __init__(self, in_ch, out_ch):
super(Depth_conv, self).__init__()
self.depth_conv = nn.Conv2d(
in_channels=in_ch,
out_channels=in_ch,
kernel_size=(3, 3),
stride=(1, 1),
padding=1,
groups=in_ch
)
self.point_conv = nn.Conv2d(
in_channels=in_ch,
out_channels=out_ch,
kernel_size=(1, 1),
stride=(1, 1),
padding=0,
groups=1
)
def forward(self, input):
out = self.depth_conv(input)
out = self.point_conv(out)
return out
class Dilated_Resblock(nn.Module):
def __init__(self, in_channels, out_channels):
super(Dilated_Resblock, self).__init__()
sequence = list()
sequence += [
nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1),
padding=1, dilation=(1, 1)),
nn.LeakyReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), stride=(1, 1),
padding=2, dilation=(2, 2)),
nn.LeakyReLU(),
nn.Conv2d(out_channels, in_channels, kernel_size=(3, 3), stride=(1, 1),
padding=3, dilation=(3, 3)),
nn.LeakyReLU(),
nn.Conv2d(out_channels, in_channels, kernel_size=(3, 3), stride=(1, 1),
padding=2, dilation=(2, 2)),
nn.LeakyReLU(),
nn.Conv2d(out_channels, in_channels, kernel_size=(3, 3), stride=(1, 1),
padding=1, dilation=(1, 1))
]
self.model = nn.Sequential(*sequence)
def forward(self, x):
out = self.model(x) + x
return out
class HFRM(nn.Module):
def __init__(self, in_channels, out_channels):
super(HFRM, self).__init__()
self.conv_head = Depth_conv(in_channels, out_channels)
self.dilated_block_LH = Dilated_Resblock(out_channels, out_channels)
self.dilated_block_HL = Dilated_Resblock(out_channels, out_channels)
self.cross_attention0 = cross_attention(out_channels, num_heads=8)
self.dilated_block_HH = Dilated_Resblock(out_channels, out_channels)
self.conv_HH = nn.Conv2d(out_channels*2, out_channels, kernel_size=3, stride=1, padding=1)
self.cross_attention1 = cross_attention(out_channels, num_heads=8)
self.conv_tail = Depth_conv(out_channels, in_channels)
def forward(self, x):
b, c, h, w = x.shape
residual = x
x = self.conv_head(x)
x_HL, x_LH, x_HH = x[:b//3, ...], x[b//3:2*b//3, ...], x[2*b//3:, ...]
x_HH_LH = self.cross_attention0(x_LH, x_HH)
x_HH_HL = self.cross_attention1(x_HL, x_HH)
x_HL = self.dilated_block_HL(x_HL)
x_LH = self.dilated_block_LH(x_LH)
x_HH = self.dilated_block_HH(self.conv_HH(torch.cat((x_HH_LH, x_HH_HL), dim=1)))
out = self.conv_tail(torch.cat((x_HL, x_LH, x_HH), dim=0))
return out + residual