import torch import torch.nn as nn import math class MAUCell(nn.Module): def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, tau, cell_mode): super(MAUCell, self).__init__() self.num_hidden = num_hidden # self.padding = (filter_size[0] // 2, filter_size[1] // 2) self.padding = filter_size // 2 self.cell_mode = cell_mode self.d = num_hidden * height * width self.tau = tau self.states = ['residual', 'normal'] if not self.cell_mode in self.states: raise AssertionError self.conv_t = nn.Sequential( nn.Conv2d(in_channel, 3 * num_hidden, kernel_size=filter_size, stride=stride, padding=self.padding), nn.LayerNorm([3 * num_hidden, height, width]) ) self.conv_t_next = nn.Sequential( nn.Conv2d(in_channel, num_hidden, kernel_size=filter_size, stride=stride, padding=self.padding), nn.LayerNorm([num_hidden, height, width]) ) self.conv_s = nn.Sequential( nn.Conv2d(num_hidden, 3 * num_hidden, kernel_size=filter_size, stride=stride, padding=self.padding), nn.LayerNorm([3 * num_hidden, height, width]) ) self.conv_s_next = nn.Sequential( nn.Conv2d(num_hidden, num_hidden, kernel_size=filter_size, stride=stride, padding=self.padding), nn.LayerNorm([num_hidden, height, width]) ) self.softmax = nn.Softmax(dim=0) def forward(self, T_t, S_t, t_att, s_att): s_next = self.conv_s_next(S_t) t_next = self.conv_t_next(T_t) weights_list = [] for i in range(self.tau): weights_list.append((s_att[i] * s_next).sum(dim=(1, 2, 3)) / math.sqrt(self.d)) weights_list = torch.stack(weights_list, dim=0) weights_list = torch.reshape(weights_list, (*weights_list.shape, 1, 1, 1)) weights_list = self.softmax(weights_list) T_trend = t_att * weights_list T_trend = T_trend.sum(dim=0) t_att_gate = torch.sigmoid(t_next) T_fusion = T_t * t_att_gate + (1 - t_att_gate) * T_trend T_concat = self.conv_t(T_fusion) S_concat = self.conv_s(S_t) t_g, t_t, t_s = torch.split(T_concat, self.num_hidden, dim=1) s_g, s_t, s_s = torch.split(S_concat, self.num_hidden, dim=1) T_gate = torch.sigmoid(t_g) S_gate = torch.sigmoid(s_g) T_new = T_gate * t_t + (1 - T_gate) * s_t S_new = S_gate * s_s + (1 - S_gate) * t_s if self.cell_mode == 'residual': S_new = S_new + S_t return T_new, S_new