|
|
import torch
|
|
|
|
|
|
def gather(x, indices):
|
|
|
indices = indices.view(-1, indices.shape[-1]).tolist()
|
|
|
out = torch.cat([x[i] for i in indices])
|
|
|
|
|
|
return out
|
|
|
|
|
|
def gather_nd(x, indices):
|
|
|
newshape = indices.shape[:-1] + x.shape[indices.shape[-1]:]
|
|
|
indices = indices.view(-1, indices.shape[-1]).tolist()
|
|
|
out = torch.cat([x[tuple(i)] for i in indices])
|
|
|
|
|
|
return out.reshape(newshape)
|
|
|
|
|
|
def gen_node_indices(size_list):
|
|
|
'''generate node index for extraction of nodes of each graph from batched data'''
|
|
|
node_num = []
|
|
|
node_range = []
|
|
|
size_list = [int(i) for i in size_list]
|
|
|
for i, n in enumerate(size_list):
|
|
|
node_num.extend([i]*n)
|
|
|
node_range.extend(list(range(n)))
|
|
|
|
|
|
node_num = torch.tensor(node_num)
|
|
|
node_range = torch.tensor(node_range)
|
|
|
indices = torch.stack([node_num, node_range], axis=1)
|
|
|
return indices, node_num, node_range
|
|
|
|
|
|
def segment_max(x, size_list):
|
|
|
size_list = [int(i) for i in size_list]
|
|
|
return torch.stack([torch.max(v, 0).values for v in torch.split(x, size_list)])
|
|
|
|
|
|
def segment_sum(x, size_list):
|
|
|
size_list = [int(i) for i in size_list]
|
|
|
return torch.stack([torch.sum(v, 0) for v in torch.split(x, size_list)])
|
|
|
|
|
|
def segment_softmax(gate, size_list):
|
|
|
segmax = segment_max(gate, size_list)
|
|
|
|
|
|
segmax_expand = torch.cat([segmax[i].repeat(n,1) for i,n in enumerate(size_list)], dim=0)
|
|
|
subtract = gate - segmax_expand
|
|
|
exp = torch.exp(subtract)
|
|
|
segsum = segment_sum(exp, size_list)
|
|
|
|
|
|
segsum_expand = torch.cat([segsum[i].repeat(n,1) for i,n in enumerate(size_list)], dim=0)
|
|
|
attention = exp / (segsum_expand + 1e-16)
|
|
|
|
|
|
return attention
|
|
|
|
|
|
def pad_V(V, max_n):
|
|
|
N, C = V.shape
|
|
|
if max_n > N:
|
|
|
zeros = torch.zeros(max_n-N, C)
|
|
|
V = torch.cat([V, zeros], dim=0)
|
|
|
return V
|
|
|
|
|
|
def pad_A(A, max_n):
|
|
|
N, L, _ = A.shape
|
|
|
if max_n > N:
|
|
|
zeros = torch.zeros(N, L, max_n-N)
|
|
|
A = torch.cat([A, zeros], dim=-1)
|
|
|
zeros = torch.zeros(max_n-N, L, max_n)
|
|
|
A = torch.cat([A, zeros], dim=0)
|
|
|
|
|
|
return A
|
|
|
|
|
|
def pad_prot(P, max_n):
|
|
|
N, = P.shape
|
|
|
if max_n > N:
|
|
|
zeros = torch.zeros(max_n-N)
|
|
|
P = torch.cat([P, zeros], dim=0)
|
|
|
|
|
|
return P.type(torch.IntTensor)
|
|
|
|
|
|
def create_batch(input, pad=False, device=torch.device('cpu')):
|
|
|
vl = []
|
|
|
al = []
|
|
|
gsl = []
|
|
|
msl = []
|
|
|
ssl = []
|
|
|
lbl = []
|
|
|
idxs = []
|
|
|
smis = []
|
|
|
|
|
|
for d in input:
|
|
|
vl.append(d['V'])
|
|
|
al.append(d['A'])
|
|
|
gsl.append(d['G'])
|
|
|
msl.append(d['mol_size'])
|
|
|
ssl.append(d['subgraph_size'])
|
|
|
lbl.append(d['label'])
|
|
|
idxs.append(d['index'])
|
|
|
smis.append(d['smiles'])
|
|
|
|
|
|
if gsl[0] is not None:
|
|
|
gsl = torch.stack(gsl, dim=0).to(device)
|
|
|
|
|
|
if pad:
|
|
|
max_n = max(map(lambda x:x.shape[0], vl))
|
|
|
vl1 = []
|
|
|
for v in vl:
|
|
|
vl1.append(pad_V(v, max_n))
|
|
|
al1 = []
|
|
|
for a in al:
|
|
|
al1.append(pad_A(a, max_n))
|
|
|
|
|
|
return {'V': torch.stack(vl1, dim=0).to(device),
|
|
|
'A': torch.stack(al1, dim=0).to(device),
|
|
|
'G': gsl,
|
|
|
'mol_size': torch.cat(msl, dim=0).to(device),
|
|
|
'subgraph_size': torch.stack(ssl, dim=0).to(device),
|
|
|
'label': torch.stack(lbl, dim=0).to(device),
|
|
|
'index': idxs,
|
|
|
'smiles': smis}
|
|
|
|
|
|
return {'V': torch.stack(vl, dim=0).to(device),
|
|
|
'A': torch.stack(al, dim=0).to(device),
|
|
|
'G': gsl,
|
|
|
'mol_size': torch.cat(msl, dim=0).to(device),
|
|
|
'subgraph_size': torch.stack(ssl, dim=0).to(device),
|
|
|
'label': torch.stack(lbl, dim=0).to(device),
|
|
|
'index': idxs,
|
|
|
'smiles': smis}
|
|
|
|
|
|
def create_mol_protein_batch(input, pad=False, device=torch.device('cpu'), pr=True):
|
|
|
vl = []
|
|
|
al = []
|
|
|
gsl = []
|
|
|
msl = []
|
|
|
ssl = []
|
|
|
prot = []
|
|
|
seq = []
|
|
|
lbl = []
|
|
|
idxs = []
|
|
|
smis = []
|
|
|
fpl = []
|
|
|
|
|
|
for d in input:
|
|
|
vl.append(d['V'])
|
|
|
al.append(d['A'])
|
|
|
gsl.append(d['G'])
|
|
|
msl.append(d['mol_size'])
|
|
|
ssl.append(d['subgraph_size'])
|
|
|
prot.append(d['protein_seq'])
|
|
|
seq.append(d['protein'])
|
|
|
lbl.append(d['label'])
|
|
|
idxs.append(d['index'])
|
|
|
smis.append(d['smiles'])
|
|
|
if 'fp' in d:
|
|
|
fpl.append(d['fp'])
|
|
|
|
|
|
if gsl[0] is not None:
|
|
|
if pad:
|
|
|
gsl = torch.stack(gsl, dim=0).to(device)
|
|
|
else:
|
|
|
gsl = [torch.unsqueeze(g, 0) for g in gsl]
|
|
|
|
|
|
if pad:
|
|
|
max_n = max(map(lambda x:x.shape[0], vl))
|
|
|
vl1 = []
|
|
|
if pr:
|
|
|
print('\tPadding V to max_n:', max_n)
|
|
|
for v in vl:
|
|
|
vl1.append(pad_V(v, max_n))
|
|
|
|
|
|
al1 = []
|
|
|
if pr:
|
|
|
print('\tPadding A to max_n:', max_n)
|
|
|
for a in al:
|
|
|
al1.append(pad_A(a, max_n))
|
|
|
|
|
|
max_prot = max(map(lambda x:x.shape[0], prot))
|
|
|
prot1 = []
|
|
|
if pr:
|
|
|
print('\tPadding protein_seq to max_n:', max_prot)
|
|
|
for p in prot:
|
|
|
prot1.append(pad_prot(p, max_prot))
|
|
|
|
|
|
fpt = None
|
|
|
if fpl:
|
|
|
fpt = torch.stack(fpl, dim=0).to(device)
|
|
|
|
|
|
return {'V': torch.stack(vl1, dim=0).to(device),
|
|
|
'A': torch.stack(al1, dim=0).to(device),
|
|
|
'G': gsl,
|
|
|
'fp': fpt,
|
|
|
'mol_size': torch.cat(msl, dim=0).to(device),
|
|
|
'subgraph_size': torch.stack(ssl, dim=0).to(device),
|
|
|
'protein_seq': torch.stack(prot1, dim=0).to(device),
|
|
|
'label': torch.stack(lbl, dim=0).view(-1).to(device),
|
|
|
'index': idxs,
|
|
|
'smiles': smis,
|
|
|
'protein': seq}
|
|
|
|
|
|
return {'V': [torch.unsqueeze(v, 0) for v in vl],
|
|
|
'A': [torch.unsqueeze(a, 0) for a in al],
|
|
|
'G': gsl,
|
|
|
'fp': fpt,
|
|
|
'mol_size': torch.cat(msl, dim=0).to(device),
|
|
|
'subgraph_size': [torch.unsqueeze(s, 0) for s in ssl],
|
|
|
'protein_seq': [torch.unsqueeze(p, 0) for p in prot],
|
|
|
'label': torch.stack(lbl, dim=0).view(-1).to(device),
|
|
|
'index': idxs,
|
|
|
'smiles': smis,
|
|
|
'protein': seq}
|
|
|
|
|
|
def create_mol_protein_fp_batch(input, pad=False, device=torch.device('cpu'), pr=True):
|
|
|
fp = []
|
|
|
prot = []
|
|
|
lbl = []
|
|
|
idxs = []
|
|
|
smis = []
|
|
|
|
|
|
for d in input:
|
|
|
fp.append(d['fp'])
|
|
|
prot.append(d['protein_seq'])
|
|
|
lbl.append(d['label'])
|
|
|
idxs.append(d['index'])
|
|
|
smis.append(d['smiles'])
|
|
|
|
|
|
if pad:
|
|
|
max_prot = max(map(lambda x:x.shape[0], prot))
|
|
|
prot1 = []
|
|
|
if pr:
|
|
|
print('\tPadding protein_seq to max_n:', max_prot)
|
|
|
for p in prot:
|
|
|
prot1.append(pad_prot(p, max_prot))
|
|
|
|
|
|
return {'fp': torch.stack(fp, dim=0).to(device),
|
|
|
'protein_seq': torch.stack(prot1, dim=0).to(device),
|
|
|
'label': torch.stack(lbl, dim=0).view(-1).to(device),
|
|
|
'index': idxs,
|
|
|
'smiles': smis}
|
|
|
|
|
|
return {'fp': [torch.unsqueeze(f, 0) for f in fp],
|
|
|
'protein_seq': [torch.unsqueeze(p, 0) for p in prot],
|
|
|
'label': torch.stack(lbl, dim=0).view(-1).to(device),
|
|
|
'index': idxs,
|
|
|
'smiles': smis} |