CMSSP / code /GNN /layers.py
OliXio's picture
Upload 13 files
d5233a9 verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import utils
import pickle
DEVICE = torch.cuda.is_available() and torch.device('cuda') or torch.device('cpu')
class GraphCNNLayer(nn.Module):
def __init__(self, n_feats, adj_chans=4, n_filters=64, bias=True):
super(GraphCNNLayer, self).__init__()
self.n_feats = n_feats
self.adj_chans = adj_chans
self.n_filters = n_filters
self.has_bias = bias
# [C*L, F], C = n_feats, L = adj_chans, F = n_filters; this is for the edge feats
self.weight_e = nn.Parameter(torch.FloatTensor(adj_chans*n_feats, n_filters))
# [C, F], this is for πˆπ•in𝐖0
self.weight_i = nn.Parameter(torch.FloatTensor(n_feats, self.n_filters))
if bias:
self.bias = nn.Parameter(torch.FloatTensor(n_filters))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight_e)
nn.init.xavier_uniform_(self.weight_i)
if self.bias is not None:
self.bias.data.fill_(0.01)
def forward(self, V, A):
'''V node features: [b, N, C], A adjs: [b, N, L, N], L = adj_chans'''
b, N, C = V.shape
b, N, L, _ = A.shape
# formula: 𝐕out = πˆπ•in𝐖0 + GConv(𝐕in, 𝐹) + 𝐛; πˆπ•in = 𝐕in, so πˆπ•in𝐖0 = 𝐕in𝐖0
# A [b, N, L, N] -> [b, N*L, N]
A_reshape = A.view(-1, N*L, N)
# [b, N*L, N] * [b, N, C] -> [b, N*L, C]
n = torch.bmm(A_reshape, V)
# [b, N*L, C] -> [b, N, L*C]
n = n.view(-1, N, L*self.n_feats)
# n [b, N, L*C], W [C*L, F], V [b, N, C], W_I [C, F]
# -> [b, N, F] + [b, N, F] + b
output = torch.matmul(n, self.weight_e) + torch.matmul(V, self.weight_i)
if self.has_bias:
output += self.bias
# output: [b, N, F]
return output
def __repr__(self):
return f'{self.__class__.__name__}(n_feats={self.n_feats},adj_chans={self.adj_chans},n_filters={self.n_filters},bias={self.has_bias}) -> [b, N, {self.n_filters}]'
class GraphResidualCNNLayer(nn.Module):
def __init__(self, n_feats, adj_chans=4, bias=True):
super(GraphResidualCNNLayer, self).__init__()
self.n_feats = n_feats
self.adj_chans = adj_chans
self.has_bias = bias
# [C*L, F], C = n_feats, L = adj_chans
self.weight_layers = nn.ModuleList([nn.Linear(n_feats, n_feats) for _ in range(adj_chans)])
if bias:
self.bias = nn.Parameter(torch.FloatTensor(n_feats))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
if self.bias is not None:
self.bias.data.fill_(0.01)
def forward(self, V, A):
'''V node features: [b, N, C], A adjs: [b, N, L, N], L = adj_chans'''
b, N, C = V.shape
b, N, L, _ = A.shape
for i in range(self.adj_chans):
# [b, N, C] -> [b, N, C]
hs = F.relu(self.weight_layers[i](V))
# [b, N, N]
a = A[:, :, i, :]
a = a.view(-1, N, N)
# [b, N, N] * [b, N, C] -> [b, N, C]
V = V + torch.bmm(a, hs)
if self.has_bias:
V += self.bias
# output: [b, N, C]
return V
def __repr__(self):
return f'{self.__class__.__name__}(n_feats={self.n_feats},adj_chans={self.adj_chans},bias={self.has_bias}) -> [b, N, {self.n_feats}]'
class GraphAttentionLayer(nn.Module):
def __init__(self, n_feats, adj_chans=4, n_filters=64, bias=True, dropout=0., alpha=0.2):
super(GraphAttentionLayer, self).__init__()
self.n_feats = n_feats
self.adj_chans = adj_chans
self.n_filters = n_filters
self.has_bias = bias
self.dropout = dropout
self.alpha = alpha
# [C*L, F], C = n_feats, L = adj_chans, F = n_filters; this is for the edge feats
self.weight_list = nn.ParameterList([nn.Parameter(torch.FloatTensor(n_feats, n_filters)) for _ in range(adj_chans)])
self.a1_list = nn.ParameterList([nn.Parameter(torch.FloatTensor(n_filters, 1)) for _ in range(adj_chans)])
self.a2_list = nn.ParameterList([nn.Parameter(torch.FloatTensor(n_filters, 1)) for _ in range(adj_chans)])
if bias:
self.bias = nn.Parameter(torch.FloatTensor(n_filters))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
for w in self.weight_list:
nn.init.xavier_uniform_(w)
for w in self.a1_list:
nn.init.xavier_uniform_(w)
for w in self.a2_list:
nn.init.xavier_uniform_(w)
if self.bias is not None:
self.bias.data.fill_(0.01)
def forward(self, V, A):
'''V node features: [b, N, C], A adjs: [b, N, L, N], L = adj_chans'''
b, N, C = V.shape
b, N, L, _ = A.shape
output = None
# formula: 𝐕out = πˆπ•in𝐖0 + GConv(𝐕in, 𝐹) + 𝐛; πˆπ•in = 𝐕in, so πˆπ•in𝐖0 = 𝐕in𝐖0
for i in range(self.adj_chans):
# [b, N, 1, N] -> [b, N, N]
adj = A[:, :, i, :].view(-1, N, N)
# [b, N, C] * [C, F] -> [b, N, F]
h = torch.matmul(V, self.weight_list[i])
# [b, N, F] * [F, 1] -> [b, N, 1]
f_1 = torch.matmul(h, self.a1_list[i])
# [b, N, F] * [F, 1] -> [b, N, 1]
f_2 = torch.matmul(h, self.a2_list[i])
# leaky_relu([b, N, 1] + [b, 1, N]) -> [b, N, N]
e = F.leaky_relu(f_1 + f_2.transpose(1, 2), self.alpha)
zero_vec = -9e15 * torch.ones_like(e)
# [b, N, N]
att = torch.where(adj > 0, e, zero_vec)
att = F.softmax(att, dim=1)
att = F.dropout(att, self.dropout, training=self.training)
# [b, N, N] * [b, N, F] -> [b, N, F]
if output is None:
output = torch.matmul(att, h)
else:
output += torch.matmul(att, h)
if self.has_bias:
output += self.bias
# output: [b, N, F]
return output
def __repr__(self):
return f'{self.__class__.__name__}(n_feats={self.n_feats},adj_chans={self.adj_chans},n_filters={self.n_filters},bias={self.has_bias},dropout={self.dropout},alpha={self.alpha}) -> [b, N, {self.n_filters}]'
class GraphNodeCatGlobalFeatures(nn.Module):
def __init__(self, global_feats, out_feats, mols=1, bias=True):
super(GraphNodeCatGlobalFeatures, self).__init__()
self.global_feats = global_feats
self.out_feats = out_feats
self.mols = mols
self.has_bias = bias
self.weights = nn.ParameterList([nn.Parameter(torch.FloatTensor(int(global_feats/mols), out_feats)) for _ in range(mols)])
self.biass = []
if bias:
self.biass = nn.ParameterList([nn.Parameter(torch.FloatTensor(out_feats)) for _ in range(mols)])
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
for weight in self.weights:
nn.init.xavier_uniform_(weight)
for bias in self.biass:
bias.data.fill_(0.01)
def forward(self, V, global_state, graph_size, subgraph_size=None):
# V: [b, N, Ov], global_state: [b, F], subgraph_size: [b, mols]
b, N, Ov = V.shape
O = self.out_feats
if self.mols == 1:
subgraph_size = graph_size.view(-1, 1)
global_state = torch.mm(global_state, self.weights[0])
else:
# global_state: [b, F] view -> [b*mols, F/mols]
global_state_view = global_state.view(b*self.mols, -1)
# split global_state into that of individual mols
idxmols = []
for i in range(self.mols):
idxmols.append(torch.IntTensor(list(range(i, b*self.mols, self.mols))).to(self.weights[0].device))
global_states = []
for i, idx in enumerate(idxmols):
# selected global_state of mols from global_state_view [b*mols, F/mols]. Out shape is [b, F/mols]
gs = global_state_view.index_select(dim=0, index=idx)
# gs: [b, F/mols] * weight: [F/mols, O] -> [b, O]; F = global_feats, O = out_feats
gs = torch.mm(gs, self.weights[i])
if self.has_bias:
gs += self.biass[i]
global_states.append(F.relu(gs))
# convert global_states back to global_state
# [[b, O] ... ] -> [b, mols*O]
global_state = torch.cat(global_states, dim=1)
# [b, mols*O] || [b, O] -> [b, (mols+1)*O]
global_state_new = torch.cat([global_state, torch.zeros(b, O).to(self.weights[0].device)], dim=-1)
# [b*(mols+1), O]
global_state_new = global_state_new.view(-1, O)
repeats = []
for sz in subgraph_size:
repeats.extend(sz.tolist() + [N-sz.sum()])
repeats = torch.tensor(repeats).to(self.weights[0].device)
# repeat form [b*(mols+1), O] -> [b*N, O], the content like [m1_feats, m2_feats, ... mn_feats, pads, ...]
global_state_new = global_state_new.repeat_interleave(repeats, dim=0)
# V view: [b*N, Ov], global_state_new: [b*N, O]
output = torch.cat([V.contiguous().view(-1, Ov), global_state_new], dim=1)
# output: [b, N, Ov+O]
return output.view(-1, N, Ov+O), global_state
def __repr__(self):
return f'{self.__class__.__name__}(global_feats={self.global_feats},out_feats={self.out_feats},bias={self.has_bias}) -> [b, N, {self.global_feats+self.out_feats}], [b, out_feats]'
class MultiHeadGlobalAttention(nn.Module):
'''Input [b, N, C] -> output [b, n_head*C] if concat or else [b, n_head]'''
def __init__(self, n_feats, n_head=5, alpha=0.2, concat=True, bias=True):
super(MultiHeadGlobalAttention, self).__init__()
self.n_feats = n_feats
self.n_head = n_head
self.alpha = alpha
self.concat = concat
self.has_bias = bias
self.weight = nn.Parameter(torch.FloatTensor(n_feats, n_head*n_feats))
self.tune_weight = nn.Parameter(torch.FloatTensor(1, n_head, n_feats))
if bias:
self.bias = nn.Parameter(torch.FloatTensor(n_head*n_feats))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight)
nn.init.xavier_uniform_(self.tune_weight)
if self.bias is not None:
self.bias.data.fill_(0.01)
def forward(self, V, graph_size):
# Gather V of mols in a batch, after this, the pad was removed.
#print(248, V.shape, graph_size)
if V.shape[0] == 1:
Vg = torch.squeeze(V)
graph_size = [graph_size]
else:
Vg = torch.cat([torch.split(v.view(-1, v.shape[-1]), graph_size[i])[0] for i,v in enumerate(torch.split(V, 1))], dim=0)
Vg = torch.matmul(Vg, self.weight)
if self.has_bias:
Vg += self.bias
Vg = Vg.view(-1, self.n_head, self.n_feats)
alpha = torch.mul(self.tune_weight, Vg)
alpha = torch.sum(alpha, dim=-1)
alpha = F.leaky_relu(alpha, self.alpha) # original code is "alpha = tf.nn.leaky_relu(alpha, alpha=0.2)"
alpha = utils.segment_softmax(alpha, graph_size)
#alpha_collect = torch.mean(alpha, dim=-1) # origin code like this. alpha_collect not used?
alpha = alpha.view(-1, self.n_head, 1)
V = torch.mul(Vg, alpha)
if self.concat:
V = utils.segment_sum(V, graph_size)
V = V.view(-1, self.n_head*self.n_feats)
else:
V = torch.mean(V, dim=1)
V = utils.segment_sum(V, graph_size)
return V
def __repr__(self):
if self.concat:
outc = self.n_head*self.n_feats
else:
outc = self.n_head
return f'{self.__class__.__name__}(n_feats={self.n_feats},n_head={self.n_head},alpha={self.alpha},concat={self.concat},bias={self.has_bias}) -> [b, {outc}]'
class GraphEmbedPoolingLayer(nn.Module):
def __init__(self, n_feats, n_filters=1, mask=None, bias=True):
super(GraphEmbedPoolingLayer, self).__init__()
self.n_feats = n_feats
self.n_filters = n_filters
self.mask = mask
self.has_bias = bias
self.emb = nn.Linear(n_feats, n_filters, bias=bias)
def forward(self, V, A):
# [b, N, F]
factors = self.emb(V)
if self.mask is not None:
factors = torch.mul(factors, self.mask)
factors = F.softmax(factors, dim=1)
# [b, N, F] trans -> [b, F, N] * [b, N, C] -> [b, F, C]
result = torch.matmul(factors.transpose(1, 2).contiguous(), V)
if self.n_filters == 1:
return result.view(-1, self.n_feats), A
result_A = A.view(A.shape[0], -1, A.shape[-1])
result_A = torch.matmul(result_A, factors)
result_A = result_A.view(A.shape[0], A.shape[-1], -1)
result_A = torch.matmul(factors.transpose(1, 2).contiguous(), result_A)
result_A = result_A.view(A.shape[0], self.n_filters, A.shape[2], self.n_filters)
return result, result_A
def __repr__(self):
return f'{self.__class__.__name__}(n_feats={self.n_feats},n_filters={self.n_filters},mask={self.mask},bias={self.has_bias}) -> [b, {self.n_filters}, {self.n_feats}], [b, {self.n_filters}, L, {self.n_filters}]'
class GConvBlockWithGF(nn.Module):
def __init__( self,
n_feats,
n_filters,
global_feats,
global_out_feats,
mols=1,
adj_chans=4,
bias=True,
usegat=False):
super(GConvBlockWithGF, self).__init__()
self.n_feats = n_feats
self.n_filters = n_filters
self.global_out_feats = global_out_feats
self.global_feats = global_feats
self.mols = mols
self.adj_chans = adj_chans
self.has_bias = bias
self.usegat = usegat
self.broadcast_global_state = GraphNodeCatGlobalFeatures(global_feats, global_out_feats, mols, bias)
if usegat:
self.graph_conv = GraphAttentionLayer(n_feats+global_out_feats, adj_chans, n_filters)
else:
self.graph_conv = GraphCNNLayer(n_feats+global_out_feats, adj_chans, n_filters, bias)
self.bn_global = nn.BatchNorm1d(global_out_feats*mols)
self.bn_graph = nn.BatchNorm1d(n_filters)
def forward(self, V, A, global_state, graph_size, subgraph_size):
######## transfer global_state #########
# V shape from [b, N, C] to [b, N, C+F], F is n_filters
V, global_state = self.broadcast_global_state(V, global_state, graph_size, subgraph_size)
######## Graph Convolution #########
# V shape from [b, N, C+F] to [b, N, F1], F1 is n_filters
V = self.graph_conv(V, A)
V = self.bn_graph(V.transpose(1, 2).contiguous())
V = F.relu(V.transpose(1, 2))
global_state = F.relu(self.bn_global(global_state))
return V, global_state
def __repr__(self):
return f'{self.__class__.__name__}(n_feats={self.n_feats},n_filters={self.n_filters},global_feats={self.global_feats},global_out_feats={self.global_out_feats},mols={self.mols},adj_chans={self.adj_chans},bias={self.has_bias},usegat={self.usegat}) -> [b, N, {self.n_filters}], [b, {self.global_out_feats*self.mols}]'
class GConvBlockNoGF(nn.Module):
def __init__( self,
n_feats,
n_filters,
mols=1,
adj_chans=4,
bias=True):
super(GConvBlockNoGF, self).__init__()
self.n_feats = n_feats
self.n_filters = n_filters
self.mols = mols
self.adj_chans = adj_chans
self.has_bias = bias
#self.graph_conv = GraphCNNLayer(n_feats+n_filters, adj_chans, n_filters, bias)
self.graph_conv = GraphCNNLayer(n_feats, adj_chans, n_filters, bias)
#self.bn_global = nn.BatchNorm1d(n_filters*mols)
self.bn_graph = nn.BatchNorm1d(n_filters)
def forward(self, V, A):
######## Graph Convolution #########
# V shape from [b, N, C+F] to [b, N, F1], F1 is n_filters
V = self.graph_conv(V, A)
V = self.bn_graph(V.transpose(1, 2).contiguous())
V = F.relu(V.transpose(1, 2))
return V
def __repr__(self):
return f'{self.__class__.__name__}(n_feats={self.n_feats},n_filters={self.n_filters},mols={self.mols},adj_chans={self.adj_chans},bias={self.has_bias}) -> [b, N, {self.n_filters}]'