import torch import torch.nn as nn import torch.nn.functional as F from .graphtrans_module import Normalize, PositionWiseFeedForward, NeighborAttention class Local_Module(nn.Module): def __init__(self, num_hidden, num_in, is_attention, dropout=0.1, scale=30): super(Local_Module, self).__init__() self.num_hidden = num_hidden self.num_in = num_in self.is_attention = is_attention self.scale = scale self.dropout = nn.Dropout(dropout) self.norm = nn.ModuleList([Normalize(num_hidden) for _ in range(2)]) self.W = nn.Sequential(*[ nn.Linear(num_hidden + num_in, num_hidden), nn.LeakyReLU(inplace=True), nn.Linear(num_hidden, num_hidden), nn.LeakyReLU(inplace=True), nn.Linear(num_hidden, num_hidden) ]) self.A = nn.Parameter(torch.empty(size=(num_hidden + num_in, 1))) self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4) def forward(self, h_V, h_E, mask_V=None, mask_attend=None): ''' h_V: [batch, num_nodes, 128] h_E: [batch, num_nodes, K, 128] mask_V: [batch, num_nodes] mask_attend: [batch, num_nodes, K] ''' # Concatenate h_V_i to h_E_ij h_V_expand = h_V.unsqueeze(-2).expand(-1,-1,h_E.size(-2),-1) h_EV = torch.cat([h_V_expand, h_E], -1) # get message h_message = self.W(h_EV) # [4, 312, 30, 384]-->[4, 312, 30, 128] # Attention if self.is_attention == 0: e = F.sigmoid(F.leaky_relu(torch.matmul(h_EV, self.A))).squeeze(-1).exp() # [4, 312, 30, 384]-->[4, 312, 30] e = e / e.sum(-1).unsqueeze(-1) # [4, 312, 30] h_message = h_message * e.unsqueeze(-1) # [4, 312, 30, 128] if mask_attend is not None: h_message = mask_attend.unsqueeze(-1) * h_message # message aggragation dh = torch.sum(h_message, -2) / self.scale # [4, 312, 128] h_V = self.norm[0](h_V + self.dropout(dh)) dh = self.dense(h_V) h_V = self.norm[1](h_V + self.dropout(dh)) if mask_V is not None: mask_V = mask_V.unsqueeze(-1) h_V = mask_V * h_V return h_V class Global_Module(nn.Module): def __init__(self, num_hidden, num_in, num_heads=4, dropout=0.1): super(Global_Module, self).__init__() self.num_heads = num_heads self.num_hidden = num_hidden self.num_in = num_in self.dropout = nn.Dropout(dropout) self.norm = nn.ModuleList([Normalize(num_hidden) for _ in range(2)]) self.attention = NeighborAttention(num_hidden, num_in, num_heads) self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4) def forward(self, h_V, h_E, mask_V=None, mask_attend=None): dh = self.attention(h_V, h_E, mask_attend) h_V = self.norm[0](h_V + self.dropout(dh)) dh = self.dense(h_V) h_V = self.norm[1](h_V + self.dropout(dh)) if mask_V is not None: mask_V = mask_V.unsqueeze(-1) h_V = mask_V * h_V return h_V