import torch def gather(x, indices): indices = indices.view(-1, indices.shape[-1]).tolist() out = torch.cat([x[i] for i in indices]) return out def gather_nd(x, indices): newshape = indices.shape[:-1] + x.shape[indices.shape[-1]:] indices = indices.view(-1, indices.shape[-1]).tolist() out = torch.cat([x[tuple(i)] for i in indices]) return out.reshape(newshape) def gen_node_indices(size_list): '''generate node index for extraction of nodes of each graph from batched data''' node_num = [] node_range = [] size_list = [int(i) for i in size_list] for i, n in enumerate(size_list): node_num.extend([i]*n) node_range.extend(list(range(n))) node_num = torch.tensor(node_num) node_range = torch.tensor(node_range) indices = torch.stack([node_num, node_range], axis=1) return indices, node_num, node_range def segment_max(x, size_list): size_list = [int(i) for i in size_list] return torch.stack([torch.max(v, 0).values for v in torch.split(x, size_list)]) def segment_sum(x, size_list): size_list = [int(i) for i in size_list] return torch.stack([torch.sum(v, 0) for v in torch.split(x, size_list)]) def segment_softmax(gate, size_list): segmax = segment_max(gate, size_list) # expand segmax shape to alpha shape segmax_expand = torch.cat([segmax[i].repeat(n,1) for i,n in enumerate(size_list)], dim=0) subtract = gate - segmax_expand exp = torch.exp(subtract) segsum = segment_sum(exp, size_list) # expand segmax shape to alpha shape segsum_expand = torch.cat([segsum[i].repeat(n,1) for i,n in enumerate(size_list)], dim=0) attention = exp / (segsum_expand + 1e-16) return attention def pad_V(V, max_n): N, C = V.shape if max_n > N: zeros = torch.zeros(max_n-N, C) V = torch.cat([V, zeros], dim=0) return V def pad_A(A, max_n): N, L, _ = A.shape if max_n > N: zeros = torch.zeros(N, L, max_n-N) A = torch.cat([A, zeros], dim=-1) zeros = torch.zeros(max_n-N, L, max_n) A = torch.cat([A, zeros], dim=0) return A def pad_prot(P, max_n): N, = P.shape if max_n > N: zeros = torch.zeros(max_n-N) P = torch.cat([P, zeros], dim=0) return P.type(torch.IntTensor) def create_batch(input, pad=False, device=torch.device('cpu')): vl = [] al = [] gsl = [] msl = [] ssl = [] lbl = [] idxs = [] smis = [] for d in input: vl.append(d['V']) al.append(d['A']) gsl.append(d['G']) msl.append(d['mol_size']) ssl.append(d['subgraph_size']) lbl.append(d['label']) idxs.append(d['index']) smis.append(d['smiles']) if gsl[0] is not None: gsl = torch.stack(gsl, dim=0).to(device) if pad: max_n = max(map(lambda x:x.shape[0], vl)) vl1 = [] for v in vl: vl1.append(pad_V(v, max_n)) al1 = [] for a in al: al1.append(pad_A(a, max_n)) return {'V': torch.stack(vl1, dim=0).to(device), 'A': torch.stack(al1, dim=0).to(device), 'G': gsl, 'mol_size': torch.cat(msl, dim=0).to(device), 'subgraph_size': torch.stack(ssl, dim=0).to(device), 'label': torch.stack(lbl, dim=0).to(device), 'index': idxs, 'smiles': smis} return {'V': torch.stack(vl, dim=0).to(device), 'A': torch.stack(al, dim=0).to(device), 'G': gsl, 'mol_size': torch.cat(msl, dim=0).to(device), 'subgraph_size': torch.stack(ssl, dim=0).to(device), 'label': torch.stack(lbl, dim=0).to(device), 'index': idxs, 'smiles': smis} def create_mol_protein_batch(input, pad=False, device=torch.device('cpu'), pr=True): vl = [] al = [] gsl = [] msl = [] ssl = [] prot = [] seq = [] lbl = [] idxs = [] smis = [] fpl = [] for d in input: vl.append(d['V']) al.append(d['A']) gsl.append(d['G']) msl.append(d['mol_size']) ssl.append(d['subgraph_size']) prot.append(d['protein_seq']) seq.append(d['protein']) lbl.append(d['label']) idxs.append(d['index']) smis.append(d['smiles']) if 'fp' in d: fpl.append(d['fp']) if gsl[0] is not None: if pad: gsl = torch.stack(gsl, dim=0).to(device) else: gsl = [torch.unsqueeze(g, 0) for g in gsl] if pad: max_n = max(map(lambda x:x.shape[0], vl)) vl1 = [] if pr: print('\tPadding V to max_n:', max_n) for v in vl: vl1.append(pad_V(v, max_n)) al1 = [] if pr: print('\tPadding A to max_n:', max_n) for a in al: al1.append(pad_A(a, max_n)) max_prot = max(map(lambda x:x.shape[0], prot)) prot1 = [] if pr: print('\tPadding protein_seq to max_n:', max_prot) for p in prot: prot1.append(pad_prot(p, max_prot)) fpt = None if fpl: fpt = torch.stack(fpl, dim=0).to(device) return {'V': torch.stack(vl1, dim=0).to(device), 'A': torch.stack(al1, dim=0).to(device), 'G': gsl, 'fp': fpt, 'mol_size': torch.cat(msl, dim=0).to(device), 'subgraph_size': torch.stack(ssl, dim=0).to(device), 'protein_seq': torch.stack(prot1, dim=0).to(device), 'label': torch.stack(lbl, dim=0).view(-1).to(device), 'index': idxs, 'smiles': smis, 'protein': seq} return {'V': [torch.unsqueeze(v, 0) for v in vl], 'A': [torch.unsqueeze(a, 0) for a in al], 'G': gsl, 'fp': fpt, 'mol_size': torch.cat(msl, dim=0).to(device), 'subgraph_size': [torch.unsqueeze(s, 0) for s in ssl], 'protein_seq': [torch.unsqueeze(p, 0) for p in prot], 'label': torch.stack(lbl, dim=0).view(-1).to(device), 'index': idxs, 'smiles': smis, 'protein': seq} def create_mol_protein_fp_batch(input, pad=False, device=torch.device('cpu'), pr=True): fp = [] prot = [] lbl = [] idxs = [] smis = [] for d in input: fp.append(d['fp']) prot.append(d['protein_seq']) lbl.append(d['label']) idxs.append(d['index']) smis.append(d['smiles']) if pad: max_prot = max(map(lambda x:x.shape[0], prot)) prot1 = [] if pr: print('\tPadding protein_seq to max_n:', max_prot) for p in prot: prot1.append(pad_prot(p, max_prot)) return {'fp': torch.stack(fp, dim=0).to(device), 'protein_seq': torch.stack(prot1, dim=0).to(device), 'label': torch.stack(lbl, dim=0).view(-1).to(device), 'index': idxs, 'smiles': smis} return {'fp': [torch.unsqueeze(f, 0) for f in fp], 'protein_seq': [torch.unsqueeze(p, 0) for p in prot], 'label': torch.stack(lbl, dim=0).view(-1).to(device), 'index': idxs, 'smiles': smis}