|
|
import math
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import numpy as np
|
|
|
import utils
|
|
|
import pickle
|
|
|
|
|
|
DEVICE = torch.cuda.is_available() and torch.device('cuda') or torch.device('cpu')
|
|
|
|
|
|
class GraphCNNLayer(nn.Module):
|
|
|
def __init__(self, n_feats, adj_chans=4, n_filters=64, bias=True):
|
|
|
super(GraphCNNLayer, self).__init__()
|
|
|
self.n_feats = n_feats
|
|
|
self.adj_chans = adj_chans
|
|
|
self.n_filters = n_filters
|
|
|
self.has_bias = bias
|
|
|
|
|
|
|
|
|
self.weight_e = nn.Parameter(torch.FloatTensor(adj_chans*n_feats, n_filters))
|
|
|
|
|
|
self.weight_i = nn.Parameter(torch.FloatTensor(n_feats, self.n_filters))
|
|
|
|
|
|
if bias:
|
|
|
self.bias = nn.Parameter(torch.FloatTensor(n_filters))
|
|
|
else:
|
|
|
self.register_parameter('bias', None)
|
|
|
|
|
|
self.reset_parameters()
|
|
|
|
|
|
def reset_parameters(self):
|
|
|
nn.init.xavier_uniform_(self.weight_e)
|
|
|
nn.init.xavier_uniform_(self.weight_i)
|
|
|
|
|
|
if self.bias is not None:
|
|
|
self.bias.data.fill_(0.01)
|
|
|
|
|
|
def forward(self, V, A):
|
|
|
'''V node features: [b, N, C], A adjs: [b, N, L, N], L = adj_chans'''
|
|
|
b, N, C = V.shape
|
|
|
b, N, L, _ = A.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
A_reshape = A.view(-1, N*L, N)
|
|
|
|
|
|
n = torch.bmm(A_reshape, V)
|
|
|
|
|
|
n = n.view(-1, N, L*self.n_feats)
|
|
|
|
|
|
|
|
|
|
|
|
output = torch.matmul(n, self.weight_e) + torch.matmul(V, self.weight_i)
|
|
|
|
|
|
if self.has_bias:
|
|
|
output += self.bias
|
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
def __repr__(self):
|
|
|
return f'{self.__class__.__name__}(n_feats={self.n_feats},adj_chans={self.adj_chans},n_filters={self.n_filters},bias={self.has_bias}) -> [b, N, {self.n_filters}]'
|
|
|
|
|
|
class GraphResidualCNNLayer(nn.Module):
|
|
|
def __init__(self, n_feats, adj_chans=4, bias=True):
|
|
|
super(GraphResidualCNNLayer, self).__init__()
|
|
|
self.n_feats = n_feats
|
|
|
self.adj_chans = adj_chans
|
|
|
self.has_bias = bias
|
|
|
|
|
|
|
|
|
self.weight_layers = nn.ModuleList([nn.Linear(n_feats, n_feats) for _ in range(adj_chans)])
|
|
|
|
|
|
if bias:
|
|
|
self.bias = nn.Parameter(torch.FloatTensor(n_feats))
|
|
|
else:
|
|
|
self.register_parameter('bias', None)
|
|
|
|
|
|
self.reset_parameters()
|
|
|
|
|
|
def reset_parameters(self):
|
|
|
if self.bias is not None:
|
|
|
self.bias.data.fill_(0.01)
|
|
|
|
|
|
def forward(self, V, A):
|
|
|
'''V node features: [b, N, C], A adjs: [b, N, L, N], L = adj_chans'''
|
|
|
b, N, C = V.shape
|
|
|
b, N, L, _ = A.shape
|
|
|
|
|
|
for i in range(self.adj_chans):
|
|
|
|
|
|
hs = F.relu(self.weight_layers[i](V))
|
|
|
|
|
|
a = A[:, :, i, :]
|
|
|
a = a.view(-1, N, N)
|
|
|
|
|
|
V = V + torch.bmm(a, hs)
|
|
|
|
|
|
if self.has_bias:
|
|
|
V += self.bias
|
|
|
|
|
|
|
|
|
return V
|
|
|
|
|
|
def __repr__(self):
|
|
|
return f'{self.__class__.__name__}(n_feats={self.n_feats},adj_chans={self.adj_chans},bias={self.has_bias}) -> [b, N, {self.n_feats}]'
|
|
|
|
|
|
class GraphAttentionLayer(nn.Module):
|
|
|
def __init__(self, n_feats, adj_chans=4, n_filters=64, bias=True, dropout=0., alpha=0.2):
|
|
|
super(GraphAttentionLayer, self).__init__()
|
|
|
self.n_feats = n_feats
|
|
|
self.adj_chans = adj_chans
|
|
|
self.n_filters = n_filters
|
|
|
self.has_bias = bias
|
|
|
self.dropout = dropout
|
|
|
self.alpha = alpha
|
|
|
|
|
|
|
|
|
self.weight_list = nn.ParameterList([nn.Parameter(torch.FloatTensor(n_feats, n_filters)) for _ in range(adj_chans)])
|
|
|
self.a1_list = nn.ParameterList([nn.Parameter(torch.FloatTensor(n_filters, 1)) for _ in range(adj_chans)])
|
|
|
self.a2_list = nn.ParameterList([nn.Parameter(torch.FloatTensor(n_filters, 1)) for _ in range(adj_chans)])
|
|
|
|
|
|
if bias:
|
|
|
self.bias = nn.Parameter(torch.FloatTensor(n_filters))
|
|
|
else:
|
|
|
self.register_parameter('bias', None)
|
|
|
|
|
|
self.reset_parameters()
|
|
|
|
|
|
def reset_parameters(self):
|
|
|
for w in self.weight_list:
|
|
|
nn.init.xavier_uniform_(w)
|
|
|
for w in self.a1_list:
|
|
|
nn.init.xavier_uniform_(w)
|
|
|
for w in self.a2_list:
|
|
|
nn.init.xavier_uniform_(w)
|
|
|
if self.bias is not None:
|
|
|
self.bias.data.fill_(0.01)
|
|
|
|
|
|
def forward(self, V, A):
|
|
|
'''V node features: [b, N, C], A adjs: [b, N, L, N], L = adj_chans'''
|
|
|
b, N, C = V.shape
|
|
|
b, N, L, _ = A.shape
|
|
|
|
|
|
output = None
|
|
|
|
|
|
|
|
|
for i in range(self.adj_chans):
|
|
|
|
|
|
adj = A[:, :, i, :].view(-1, N, N)
|
|
|
|
|
|
|
|
|
h = torch.matmul(V, self.weight_list[i])
|
|
|
|
|
|
f_1 = torch.matmul(h, self.a1_list[i])
|
|
|
|
|
|
f_2 = torch.matmul(h, self.a2_list[i])
|
|
|
|
|
|
|
|
|
e = F.leaky_relu(f_1 + f_2.transpose(1, 2), self.alpha)
|
|
|
|
|
|
zero_vec = -9e15 * torch.ones_like(e)
|
|
|
|
|
|
att = torch.where(adj > 0, e, zero_vec)
|
|
|
att = F.softmax(att, dim=1)
|
|
|
att = F.dropout(att, self.dropout, training=self.training)
|
|
|
|
|
|
if output is None:
|
|
|
output = torch.matmul(att, h)
|
|
|
else:
|
|
|
output += torch.matmul(att, h)
|
|
|
|
|
|
if self.has_bias:
|
|
|
output += self.bias
|
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
def __repr__(self):
|
|
|
return f'{self.__class__.__name__}(n_feats={self.n_feats},adj_chans={self.adj_chans},n_filters={self.n_filters},bias={self.has_bias},dropout={self.dropout},alpha={self.alpha}) -> [b, N, {self.n_filters}]'
|
|
|
|
|
|
class GraphNodeCatGlobalFeatures(nn.Module):
|
|
|
def __init__(self, global_feats, out_feats, mols=1, bias=True):
|
|
|
super(GraphNodeCatGlobalFeatures, self).__init__()
|
|
|
self.global_feats = global_feats
|
|
|
self.out_feats = out_feats
|
|
|
self.mols = mols
|
|
|
self.has_bias = bias
|
|
|
|
|
|
self.weights = nn.ParameterList([nn.Parameter(torch.FloatTensor(int(global_feats/mols), out_feats)) for _ in range(mols)])
|
|
|
|
|
|
self.biass = []
|
|
|
if bias:
|
|
|
self.biass = nn.ParameterList([nn.Parameter(torch.FloatTensor(out_feats)) for _ in range(mols)])
|
|
|
else:
|
|
|
self.register_parameter('bias', None)
|
|
|
|
|
|
self.reset_parameters()
|
|
|
|
|
|
def reset_parameters(self):
|
|
|
for weight in self.weights:
|
|
|
nn.init.xavier_uniform_(weight)
|
|
|
for bias in self.biass:
|
|
|
bias.data.fill_(0.01)
|
|
|
|
|
|
def forward(self, V, global_state, graph_size, subgraph_size=None):
|
|
|
|
|
|
b, N, Ov = V.shape
|
|
|
O = self.out_feats
|
|
|
if self.mols == 1:
|
|
|
subgraph_size = graph_size.view(-1, 1)
|
|
|
global_state = torch.mm(global_state, self.weights[0])
|
|
|
else:
|
|
|
|
|
|
global_state_view = global_state.view(b*self.mols, -1)
|
|
|
|
|
|
|
|
|
idxmols = []
|
|
|
for i in range(self.mols):
|
|
|
idxmols.append(torch.IntTensor(list(range(i, b*self.mols, self.mols))).to(self.weights[0].device))
|
|
|
|
|
|
global_states = []
|
|
|
for i, idx in enumerate(idxmols):
|
|
|
|
|
|
gs = global_state_view.index_select(dim=0, index=idx)
|
|
|
|
|
|
gs = torch.mm(gs, self.weights[i])
|
|
|
|
|
|
if self.has_bias:
|
|
|
gs += self.biass[i]
|
|
|
|
|
|
global_states.append(F.relu(gs))
|
|
|
|
|
|
|
|
|
|
|
|
global_state = torch.cat(global_states, dim=1)
|
|
|
|
|
|
|
|
|
global_state_new = torch.cat([global_state, torch.zeros(b, O).to(self.weights[0].device)], dim=-1)
|
|
|
|
|
|
global_state_new = global_state_new.view(-1, O)
|
|
|
|
|
|
repeats = []
|
|
|
for sz in subgraph_size:
|
|
|
repeats.extend(sz.tolist() + [N-sz.sum()])
|
|
|
repeats = torch.tensor(repeats).to(self.weights[0].device)
|
|
|
|
|
|
|
|
|
global_state_new = global_state_new.repeat_interleave(repeats, dim=0)
|
|
|
|
|
|
|
|
|
output = torch.cat([V.contiguous().view(-1, Ov), global_state_new], dim=1)
|
|
|
|
|
|
|
|
|
return output.view(-1, N, Ov+O), global_state
|
|
|
|
|
|
def __repr__(self):
|
|
|
return f'{self.__class__.__name__}(global_feats={self.global_feats},out_feats={self.out_feats},bias={self.has_bias}) -> [b, N, {self.global_feats+self.out_feats}], [b, out_feats]'
|
|
|
|
|
|
class MultiHeadGlobalAttention(nn.Module):
|
|
|
'''Input [b, N, C] -> output [b, n_head*C] if concat or else [b, n_head]'''
|
|
|
def __init__(self, n_feats, n_head=5, alpha=0.2, concat=True, bias=True):
|
|
|
super(MultiHeadGlobalAttention, self).__init__()
|
|
|
|
|
|
self.n_feats = n_feats
|
|
|
self.n_head = n_head
|
|
|
self.alpha = alpha
|
|
|
self.concat = concat
|
|
|
self.has_bias = bias
|
|
|
|
|
|
self.weight = nn.Parameter(torch.FloatTensor(n_feats, n_head*n_feats))
|
|
|
self.tune_weight = nn.Parameter(torch.FloatTensor(1, n_head, n_feats))
|
|
|
|
|
|
if bias:
|
|
|
self.bias = nn.Parameter(torch.FloatTensor(n_head*n_feats))
|
|
|
else:
|
|
|
self.register_parameter('bias', None)
|
|
|
|
|
|
self.reset_parameters()
|
|
|
|
|
|
def reset_parameters(self):
|
|
|
nn.init.xavier_uniform_(self.weight)
|
|
|
nn.init.xavier_uniform_(self.tune_weight)
|
|
|
if self.bias is not None:
|
|
|
self.bias.data.fill_(0.01)
|
|
|
|
|
|
def forward(self, V, graph_size):
|
|
|
|
|
|
|
|
|
if V.shape[0] == 1:
|
|
|
Vg = torch.squeeze(V)
|
|
|
graph_size = [graph_size]
|
|
|
else:
|
|
|
Vg = torch.cat([torch.split(v.view(-1, v.shape[-1]), graph_size[i])[0] for i,v in enumerate(torch.split(V, 1))], dim=0)
|
|
|
|
|
|
Vg = torch.matmul(Vg, self.weight)
|
|
|
if self.has_bias:
|
|
|
Vg += self.bias
|
|
|
Vg = Vg.view(-1, self.n_head, self.n_feats)
|
|
|
|
|
|
alpha = torch.mul(self.tune_weight, Vg)
|
|
|
alpha = torch.sum(alpha, dim=-1)
|
|
|
alpha = F.leaky_relu(alpha, self.alpha)
|
|
|
alpha = utils.segment_softmax(alpha, graph_size)
|
|
|
|
|
|
|
|
|
alpha = alpha.view(-1, self.n_head, 1)
|
|
|
V = torch.mul(Vg, alpha)
|
|
|
|
|
|
if self.concat:
|
|
|
V = utils.segment_sum(V, graph_size)
|
|
|
V = V.view(-1, self.n_head*self.n_feats)
|
|
|
else:
|
|
|
V = torch.mean(V, dim=1)
|
|
|
V = utils.segment_sum(V, graph_size)
|
|
|
|
|
|
return V
|
|
|
|
|
|
def __repr__(self):
|
|
|
if self.concat:
|
|
|
outc = self.n_head*self.n_feats
|
|
|
else:
|
|
|
outc = self.n_head
|
|
|
return f'{self.__class__.__name__}(n_feats={self.n_feats},n_head={self.n_head},alpha={self.alpha},concat={self.concat},bias={self.has_bias}) -> [b, {outc}]'
|
|
|
|
|
|
class GraphEmbedPoolingLayer(nn.Module):
|
|
|
def __init__(self, n_feats, n_filters=1, mask=None, bias=True):
|
|
|
super(GraphEmbedPoolingLayer, self).__init__()
|
|
|
self.n_feats = n_feats
|
|
|
self.n_filters = n_filters
|
|
|
self.mask = mask
|
|
|
self.has_bias = bias
|
|
|
|
|
|
self.emb = nn.Linear(n_feats, n_filters, bias=bias)
|
|
|
|
|
|
def forward(self, V, A):
|
|
|
|
|
|
factors = self.emb(V)
|
|
|
|
|
|
if self.mask is not None:
|
|
|
factors = torch.mul(factors, self.mask)
|
|
|
|
|
|
factors = F.softmax(factors, dim=1)
|
|
|
|
|
|
result = torch.matmul(factors.transpose(1, 2).contiguous(), V)
|
|
|
|
|
|
if self.n_filters == 1:
|
|
|
return result.view(-1, self.n_feats), A
|
|
|
|
|
|
result_A = A.view(A.shape[0], -1, A.shape[-1])
|
|
|
result_A = torch.matmul(result_A, factors)
|
|
|
result_A = result_A.view(A.shape[0], A.shape[-1], -1)
|
|
|
result_A = torch.matmul(factors.transpose(1, 2).contiguous(), result_A)
|
|
|
result_A = result_A.view(A.shape[0], self.n_filters, A.shape[2], self.n_filters)
|
|
|
|
|
|
return result, result_A
|
|
|
|
|
|
def __repr__(self):
|
|
|
return f'{self.__class__.__name__}(n_feats={self.n_feats},n_filters={self.n_filters},mask={self.mask},bias={self.has_bias}) -> [b, {self.n_filters}, {self.n_feats}], [b, {self.n_filters}, L, {self.n_filters}]'
|
|
|
|
|
|
class GConvBlockWithGF(nn.Module):
|
|
|
def __init__( self,
|
|
|
n_feats,
|
|
|
n_filters,
|
|
|
global_feats,
|
|
|
global_out_feats,
|
|
|
mols=1,
|
|
|
adj_chans=4,
|
|
|
bias=True,
|
|
|
usegat=False):
|
|
|
|
|
|
super(GConvBlockWithGF, self).__init__()
|
|
|
|
|
|
self.n_feats = n_feats
|
|
|
self.n_filters = n_filters
|
|
|
self.global_out_feats = global_out_feats
|
|
|
self.global_feats = global_feats
|
|
|
self.mols = mols
|
|
|
self.adj_chans = adj_chans
|
|
|
self.has_bias = bias
|
|
|
self.usegat = usegat
|
|
|
|
|
|
self.broadcast_global_state = GraphNodeCatGlobalFeatures(global_feats, global_out_feats, mols, bias)
|
|
|
if usegat:
|
|
|
self.graph_conv = GraphAttentionLayer(n_feats+global_out_feats, adj_chans, n_filters)
|
|
|
else:
|
|
|
self.graph_conv = GraphCNNLayer(n_feats+global_out_feats, adj_chans, n_filters, bias)
|
|
|
|
|
|
self.bn_global = nn.BatchNorm1d(global_out_feats*mols)
|
|
|
self.bn_graph = nn.BatchNorm1d(n_filters)
|
|
|
|
|
|
def forward(self, V, A, global_state, graph_size, subgraph_size):
|
|
|
|
|
|
|
|
|
V, global_state = self.broadcast_global_state(V, global_state, graph_size, subgraph_size)
|
|
|
|
|
|
|
|
|
|
|
|
V = self.graph_conv(V, A)
|
|
|
V = self.bn_graph(V.transpose(1, 2).contiguous())
|
|
|
V = F.relu(V.transpose(1, 2))
|
|
|
|
|
|
global_state = F.relu(self.bn_global(global_state))
|
|
|
|
|
|
return V, global_state
|
|
|
|
|
|
def __repr__(self):
|
|
|
return f'{self.__class__.__name__}(n_feats={self.n_feats},n_filters={self.n_filters},global_feats={self.global_feats},global_out_feats={self.global_out_feats},mols={self.mols},adj_chans={self.adj_chans},bias={self.has_bias},usegat={self.usegat}) -> [b, N, {self.n_filters}], [b, {self.global_out_feats*self.mols}]'
|
|
|
|
|
|
class GConvBlockNoGF(nn.Module):
|
|
|
def __init__( self,
|
|
|
n_feats,
|
|
|
n_filters,
|
|
|
mols=1,
|
|
|
adj_chans=4,
|
|
|
bias=True):
|
|
|
|
|
|
super(GConvBlockNoGF, self).__init__()
|
|
|
|
|
|
self.n_feats = n_feats
|
|
|
self.n_filters = n_filters
|
|
|
self.mols = mols
|
|
|
self.adj_chans = adj_chans
|
|
|
self.has_bias = bias
|
|
|
|
|
|
|
|
|
self.graph_conv = GraphCNNLayer(n_feats, adj_chans, n_filters, bias)
|
|
|
|
|
|
|
|
|
self.bn_graph = nn.BatchNorm1d(n_filters)
|
|
|
|
|
|
def forward(self, V, A):
|
|
|
|
|
|
|
|
|
V = self.graph_conv(V, A)
|
|
|
V = self.bn_graph(V.transpose(1, 2).contiguous())
|
|
|
V = F.relu(V.transpose(1, 2))
|
|
|
|
|
|
return V
|
|
|
|
|
|
def __repr__(self):
|
|
|
return f'{self.__class__.__name__}(n_feats={self.n_feats},n_filters={self.n_filters},mols={self.mols},adj_chans={self.adj_chans},bias={self.has_bias}) -> [b, N, {self.n_filters}]' |