Spaces:
Sleeping
Sleeping
| import rdkit | |
| from rdkit import Chem | |
| from rdkit.Chem import Draw | |
| from rdkit import DataStructs | |
| from rdkit.Chem import AllChem | |
| from rdkit.Chem import rdmolfiles | |
| from rdkit.Chem.Draw import IPythonConsole | |
| from molvs import standardize_smiles | |
| import os | |
| import gc | |
| import sys | |
| import time | |
| import json | |
| import math | |
| import random | |
| import argparse | |
| import itertools | |
| import numpy as np | |
| import mxnet as mx | |
| import pandas as pd | |
| import networkx as nx | |
| from scipy import sparse | |
| from mxnet.gluon import nn | |
| from collections import Counter | |
| from mxnet.autograd import Function | |
| from mxnet.gluon.data import Dataset | |
| from mxnet import gluon, autograd, nd | |
| from mxnet.gluon.data import DataLoader | |
| from abc import ABCMeta, abstractmethod | |
| from mxnet.gluon.data.sampler import Sampler | |
| class MoleculeSpec(object): | |
| def __init__(self, file_name='models_folder/atom_types.txt'): | |
| self.atom_types = [] | |
| self.atom_symbols = [] | |
| with open(file_name) as f: | |
| for line in f: | |
| atom_type_i = line.strip('\n').split(',') | |
| self.atom_types.append((atom_type_i[0], int(atom_type_i[1]), int(atom_type_i[2]))) | |
| if atom_type_i[0] not in self.atom_symbols: | |
| self.atom_symbols.append(atom_type_i[0]) | |
| self.bond_orders = [Chem.BondType.AROMATIC, | |
| Chem.BondType.SINGLE, | |
| Chem.BondType.DOUBLE, | |
| Chem.BondType.TRIPLE] | |
| self.max_iter = 120 | |
| def get_atom_type(self, atom): | |
| atom_symbol = atom.GetSymbol() | |
| atom_charge = atom.GetFormalCharge() | |
| atom_hs = atom.GetNumExplicitHs() | |
| return self.atom_types.index((atom_symbol, atom_charge, atom_hs)) | |
| def get_bond_type(self, bond): | |
| return self.bond_orders.index(bond.GetBondType()) | |
| def index_to_atom(self, idx): | |
| atom_symbol, atom_charge, atom_hs = self.atom_types[idx] | |
| a = Chem.Atom(atom_symbol) | |
| a.SetFormalCharge(atom_charge) | |
| a.SetNumExplicitHs(atom_hs) | |
| return a | |
| def index_to_bond(self, mol, begin_id, end_id, idx): | |
| mol.AddBond(begin_id, end_id, self.bond_orders[idx]) | |
| def num_atom_types(self): | |
| return len(self.atom_types) | |
| def num_bond_types(self): | |
| return len(self.bond_orders) | |
| _mol_spec = None | |
| def get_mol_spec(): | |
| global _mol_spec | |
| if _mol_spec is None: | |
| _mol_spec = MoleculeSpec() | |
| return _mol_spec | |
| def get_graph_from_smiles(smiles): | |
| mol = Chem.MolFromSmiles(smiles) | |
| # build graph | |
| atom_types, atom_ranks, bonds, bond_types = [], [], [], [] | |
| for a, r in zip(mol.GetAtoms(), Chem.CanonicalRankAtoms(mol)): | |
| atom_types.append(get_mol_spec().get_atom_type(a)) | |
| atom_ranks.append(r) | |
| for b in mol.GetBonds(): | |
| idx_1, idx_2, bt = b.GetBeginAtomIdx(), b.GetEndAtomIdx(), get_mol_spec().get_bond_type(b) | |
| bonds.append([idx_1, idx_2]) | |
| bond_types.append(bt) | |
| # build nx graph | |
| graph = nx.Graph() | |
| graph.add_nodes_from(range(len(atom_types))) | |
| graph.add_edges_from(bonds) | |
| return graph, atom_types, atom_ranks, bonds, bond_types | |
| def get_graph_from_smiles_list(smiles_list): | |
| graph_list = [] | |
| for smiles in smiles_list: | |
| mol = Chem.MolFromSmiles(smiles) | |
| # build graph | |
| atom_types, bonds, bond_types = [], [], [] | |
| for a in mol.GetAtoms(): | |
| atom_types.append(get_mol_spec().get_atom_type(a)) | |
| for b in mol.GetBonds(): | |
| idx_1, idx_2, bt = b.GetBeginAtomIdx(), b.GetEndAtomIdx(), get_mol_spec().get_bond_type(b) | |
| bonds.append([idx_1, idx_2]) | |
| bond_types.append(bt) | |
| X_0 = np.array(atom_types, dtype=np.int64) | |
| A_0 = np.concatenate([np.array(bonds, dtype=np.int64), | |
| np.array(bond_types, dtype=np.int64)[:, np.newaxis]], | |
| axis=1) | |
| graph_list.append([X_0, A_0]) | |
| return graph_list | |
| def traverse_graph(graph, atom_ranks, current_node=None, step_ids=None, p=0.9, log_p=0.0): | |
| if current_node is None: | |
| next_nodes = range(len(atom_ranks)) | |
| step_ids = [-1, ] * len(next_nodes) | |
| next_node_ranks = atom_ranks | |
| else: | |
| next_nodes = graph.neighbors(current_node) # get neighbor nodes | |
| next_nodes = [n for n in next_nodes if step_ids[n] < 0] # filter visited nodes | |
| next_node_ranks = [atom_ranks[n] for n in next_nodes] # get ranks for neighbors | |
| next_nodes = [n for n, r in sorted(zip(next_nodes, next_node_ranks), key=lambda _x:_x[1])] # sort by rank | |
| # iterate through neighbors | |
| while len(next_nodes) > 0: | |
| if len(next_nodes)==1: | |
| next_node = next_nodes[0] | |
| elif random.random() >= (1 - p): | |
| next_node = next_nodes[0] | |
| log_p += np.log(p) | |
| else: | |
| next_node = next_nodes[random.randint(1, len(next_nodes) - 1)] | |
| log_p += np.log((1.0 - p) / (len(next_nodes) - 1)) | |
| step_ids[next_node] = max(step_ids) + 1 | |
| _, log_p = traverse_graph(graph, atom_ranks, next_node, step_ids, p, log_p) | |
| next_nodes = [n for n in next_nodes if step_ids[n] < 0] # filter visited nodes | |
| return step_ids, log_p | |
| def single_reorder(X_0, A_0, step_ids): | |
| X_0, A_0 = np.copy(X_0), np.copy(A_0) | |
| step_ids = np.array(step_ids, dtype=np.int64) | |
| # sort by step_ids | |
| sorted_ids = np.argsort(step_ids) | |
| X_0 = X_0[sorted_ids] | |
| A_0[:, 0], A_0[:, 1] = step_ids[A_0[:, 0]], step_ids[A_0[:, 1]] | |
| max_b, min_b = np.amax(A_0[:, :2], axis=1), np.amin(A_0[:, :2], axis=1) | |
| A_0 = A_0[np.lexsort([-min_b, max_b]), :] | |
| # separate append and connect | |
| max_b, min_b = np.amax(A_0[:, :2], axis=1), np.amin(A_0[:, :2], axis=1) | |
| is_append = np.concatenate([np.array([True]), max_b[1:] > max_b[:-1]]) | |
| A_0 = np.concatenate([np.where(is_append[:, np.newaxis], | |
| np.stack([min_b, max_b], axis=1), | |
| np.stack([max_b, min_b], axis=1)), | |
| A_0[:, -1:]], axis=1) | |
| return X_0, A_0 | |
| def single_expand(X_0, A_0): | |
| X_0, A_0 = np.copy(X_0), np.copy(A_0) | |
| # expand X | |
| is_append_iter = np.less(A_0[:, 0], A_0[:, 1]).astype(np.int64) | |
| NX = np.cumsum(np.pad(is_append_iter, [[1, 0]], mode='constant', constant_values=1)) | |
| shift = np.cumsum(np.pad(NX, [[1, 0]], mode='constant')[:-1]) | |
| X_index = np.arange(NX.sum(), dtype=np.int64) - np.repeat(shift, NX) | |
| X = X_0[X_index] | |
| # expand A | |
| _, A_index = np.tril_indices(A_0.shape[0]) | |
| A = A_0[A_index, :] | |
| NA = np.arange(A_0.shape[0] + 1) | |
| # get action | |
| # action_type, atom_type, bond_type, append_pos, connect_pos | |
| action_type = 1 - is_append_iter | |
| atom_type = np.where(action_type == 0, X_0[A_0[:, 1]], 0) | |
| bond_type = A_0[:, 2] | |
| append_pos = np.where(action_type == 0, A_0[:, 0], 0) | |
| connect_pos = np.where(action_type == 1, A_0[:, 1], 0) | |
| actions = np.stack([action_type, atom_type, bond_type, append_pos, connect_pos], | |
| axis=1) | |
| last_action = [[2, 0, 0, 0, 0]] | |
| actions = np.append(actions, last_action, axis=0) | |
| action_0 = np.array([X_0[0]], dtype=np.int64) | |
| # }}} | |
| # {{{ Get mask | |
| last_atom_index = shift + NX - 1 | |
| last_atom_mask = np.zeros_like(X) | |
| last_atom_mask[last_atom_index] = np.where( | |
| np.pad(is_append_iter, [[1, 0]], mode='constant', constant_values=1) == 1, | |
| np.ones_like(last_atom_index), | |
| np.ones_like(last_atom_index) * 2) | |
| # }}} | |
| return action_0, X, NX, A, NA, actions, last_atom_mask | |
| def get_d(A, X): | |
| _to_sparse = lambda _A, _X: sparse.coo_matrix((np.ones([_A.shape[0] * 2], dtype=np.int64), | |
| (np.concatenate([_A[:, 0], _A[:, 1]], axis=0), | |
| np.concatenate([_A[:, 1], _A[:, 0]], axis=0))), | |
| shape=[_X.shape[0], ] * 2) | |
| A_sparse = _to_sparse(A, X) | |
| d2 = A_sparse * A_sparse | |
| d3 = d2 * A_sparse | |
| # get D_2 | |
| D_2 = np.stack(d2.nonzero(), axis=1) | |
| D_2 = D_2[D_2[:, 0] < D_2[:, 1], :] | |
| # get D_3 | |
| D_3 = np.stack(d3.nonzero(), axis=1) | |
| D_3 = D_3[D_3[:, 0] < D_3[:, 1], :] | |
| # remove D_1 elements from D_3 | |
| D_3_sparse = _to_sparse(D_3, X) | |
| D_3_sparse = D_3_sparse - D_3_sparse.multiply(A_sparse) | |
| D_3 = np.stack(D_3_sparse.nonzero(), axis=1) | |
| D_3 = D_3[D_3[:, 0] < D_3[:, 1], :] | |
| return D_2, D_3 | |
| def merge_single_0(X_0, A_0, NX_0, NA_0): | |
| # shift_ids | |
| cumsum = np.cumsum(np.pad(NX_0, [[1, 0]], mode='constant')[:-1]) | |
| A_0[:, :2] += np.stack([np.repeat(cumsum, NA_0), ] * 2, axis=1) | |
| # get D | |
| D_0_2, D_0_3 = get_d(A_0, X_0) | |
| # split A | |
| A_split = [] | |
| for i in range(get_mol_spec().num_bond_types): | |
| A_i = A_0[A_0[:, 2] == i, :2] | |
| A_split.append(A_i) | |
| A_split.extend([D_0_2, D_0_3]) | |
| A_0 = A_split | |
| # NX_rep | |
| NX_rep_0 = np.repeat(np.arange(NX_0.shape[0]), NX_0) | |
| return X_0, A_0, NX_0, NX_rep_0 | |
| def merge_single(X, A, | |
| NX, NA, | |
| mol_ids, rep_ids, iw_ids, | |
| action_0, actions, | |
| last_append_mask, | |
| log_p): | |
| X, A, NX, NX_rep = merge_single_0(X, A, NX, NA) | |
| cumsum = np.cumsum(np.pad(NX, [[1, 0]], mode='constant')[:-1]) | |
| actions[:, -2] += cumsum * (actions[:, 0] == 0) | |
| actions[:, -1] += cumsum * (actions[:, 0] == 1) | |
| mol_ids_rep = np.repeat(mol_ids, NX) | |
| rep_ids_rep = np.repeat(rep_ids, NX) | |
| return X, A,\ | |
| mol_ids_rep, rep_ids_rep, iw_ids,\ | |
| last_append_mask,\ | |
| NX, NX_rep,\ | |
| action_0, actions, \ | |
| log_p | |
| def process_single(smiles, k, p): | |
| graph, atom_types, atom_ranks, bonds, bond_types = get_graph_from_smiles(smiles) | |
| # original | |
| X_0 = np.array(atom_types, dtype=np.int64) | |
| A_0 = np.concatenate([np.array(bonds, dtype=np.int64), | |
| np.array(bond_types, dtype=np.int64)[:, np.newaxis]], | |
| axis=1) | |
| X, A = [], [] | |
| NX, NA = [], [] | |
| mol_ids, rep_ids, iw_ids = [], [], [] | |
| action_0, actions = [], [] | |
| last_append_mask = [] | |
| log_p = [] | |
| # random sampling decoding route | |
| for i in range(k): | |
| step_ids_i, log_p_i = traverse_graph(graph, atom_ranks, p=p) | |
| X_i, A_i = single_reorder(X_0, A_0, step_ids_i) | |
| action_0_i, X_i, NX_i, A_i, NA_i, actions_i, last_atom_mask_i = single_expand(X_i, A_i) | |
| # appends | |
| X.append(X_i) | |
| A.append(A_i) | |
| NX.append(NX_i) | |
| NA.append(NA_i) | |
| action_0.append(action_0_i) | |
| actions.append(actions_i) | |
| last_append_mask.append(last_atom_mask_i) | |
| mol_ids.append(np.zeros_like(NX_i, dtype=np.int64)) | |
| rep_ids.append(np.ones_like(NX_i, dtype=np.int64) * i) | |
| iw_ids.append(np.ones_like(NX_i, dtype=np.int64) * i) | |
| log_p.append(log_p_i) | |
| # concatenate | |
| X = np.concatenate(X, axis=0) | |
| A = np.concatenate(A, axis = 0) | |
| NX = np.concatenate(NX, axis = 0) | |
| NA = np.concatenate(NA, axis = 0) | |
| action_0 = np.concatenate(action_0, axis = 0) | |
| actions = np.concatenate(actions, axis = 0) | |
| last_append_mask = np.concatenate(last_append_mask, axis = 0) | |
| mol_ids = np.concatenate(mol_ids, axis = 0) | |
| rep_ids = np.concatenate(rep_ids, axis = 0) | |
| iw_ids = np.concatenate(iw_ids, axis = 0) | |
| log_p = np.array(log_p, dtype=np.float32) | |
| return X, A, NX, NA, mol_ids, rep_ids, iw_ids, action_0, actions, last_append_mask, log_p | |
| # noinspection PyArgumentList | |
| def get_mol_from_graph(X, A, sanitize=True): | |
| try: | |
| mol = Chem.RWMol(Chem.Mol()) | |
| X, A = X.tolist(), A.tolist() | |
| for i, atom_type in enumerate(X): | |
| mol.AddAtom(get_mol_spec().index_to_atom(atom_type)) | |
| for atom_id1, atom_id2, bond_type in A: | |
| get_mol_spec().index_to_bond(mol, atom_id1, atom_id2, bond_type) | |
| except: | |
| return None | |
| if sanitize: | |
| try: | |
| mol = mol.GetMol() | |
| Chem.SanitizeMol(mol) | |
| return mol | |
| except: | |
| return None | |
| else: | |
| return mol | |
| def get_mol_from_graph_list(graph_list, sanitize=True): | |
| mol_list = [get_mol_from_graph(X, A, sanitize) for X, A in graph_list] | |
| return mol_list | |
| class GraphConvFn(Function): | |
| def __init__(self, A): | |
| self.A = A # type: nd.sparse.CSRNDArray | |
| self.A_T = self.A # assume symmetric | |
| super(GraphConvFn, self).__init__() | |
| def forward(self, X): | |
| if self.A is not None: | |
| if len(X.shape) > 2: | |
| X_resized = X.reshape((X.shape[0], -1)) | |
| output = nd.sparse.dot(self.A, X_resized) | |
| output = output.reshape([-1, ] + [X.shape[i] for i in range(1, len(X.shape))]) | |
| else: | |
| output = nd.sparse.dot(self.A, X) | |
| return output | |
| else: | |
| return nd.zeros_like(X) | |
| def backward(self, grad_output): | |
| if self.A is not None: | |
| if len(grad_output.shape) > 2: | |
| grad_output_resized = grad_output.reshape((grad_output.shape[0], -1)) | |
| grad_input = nd.sparse.dot(self.A_T, grad_output_resized) | |
| grad_input = grad_input.reshape([-1] + [grad_output.shape[i] | |
| for i in range(1, len(grad_output.shape))]) | |
| else: | |
| grad_input = nd.sparse.dot(self.A_T, grad_output) | |
| return grad_input | |
| else: | |
| return nd.zeros_like(grad_output) | |
| class EfficientGraphConvFn(Function): | |
| """Save memory by re-computation""" | |
| def __init__(self, A_list): | |
| self.A_list = A_list | |
| super(EfficientGraphConvFn, self).__init__() | |
| def forward(self, X, W): | |
| X_list = [X] | |
| for A in self.A_list: | |
| if A is not None: | |
| X_list.append(nd.sparse.dot(A, X)) | |
| else: | |
| X_list.append(nd.zeros_like(X)) | |
| X_out = nd.concat(*X_list, dim=1) | |
| self.save_for_backward(X, W) | |
| return nd.dot(X_out, W) | |
| def backward(self, grad_output): | |
| X, W = self.saved_tensors | |
| # recompute X_out | |
| X_list = [X, ] | |
| for A in self.A_list: | |
| if A is not None: | |
| X_list.append(nd.sparse.dot(A, X)) | |
| else: | |
| X_list.append(nd.zeros_like(X)) | |
| X_out = nd.concat(*X_list, dim=1) | |
| grad_W = nd.dot(X_out.T, grad_output) | |
| grad_X_out = nd.dot(grad_output, W.T) | |
| grad_X_out_list = nd.split(grad_X_out, num_outputs=len(self.A_list) + 1) | |
| grad_X = [grad_X_out_list[0], ] | |
| for A, grad_X_out in zip(self.A_list, grad_X_out_list[1:]): | |
| if A is not None: | |
| grad_X.append(nd.sparse.dot(A, grad_X_out)) | |
| else: | |
| grad_X.append(nd.zeros_like(grad_X_out)) | |
| grad_X = sum(grad_X) | |
| return grad_X, grad_W | |
| class SegmentSumFn(GraphConvFn): | |
| def __init__(self, idx, num_seg): | |
| # build A | |
| # construct coo | |
| data = nd.ones(idx.shape[0], ctx=idx.context, dtype='int64') | |
| row, col = idx, nd.arange(idx.shape[0], ctx=idx.context, dtype='int64') | |
| shape = (num_seg, int(idx.shape[0])) | |
| sparse = nd.sparse.csr_matrix((data, (row, col)), shape=shape, | |
| ctx=idx.context, dtype='float32') | |
| super(SegmentSumFn, self).__init__(sparse) | |
| sparse = nd.sparse.csr_matrix((data, (col, row)), shape=(shape[1], shape[0]), | |
| ctx=idx.context, dtype='float32') | |
| self.A_T = sparse | |
| def squeeze(input, axis): | |
| assert input.shape[axis] == 1 | |
| new_shape = list(input.shape) | |
| del new_shape[axis] | |
| return input.reshape(new_shape) | |
| def unsqueeze(input, axis): | |
| return nd.expand_dims(input, axis=axis) | |
| def logsumexp(inputs, axis=None, keepdims=False): | |
| """Numerically stable logsumexp. | |
| Args: | |
| inputs: A Variable with any shape. | |
| axis: An integer. | |
| keepdims: A boolean. | |
| Returns: | |
| Equivalent of log(sum(exp(inputs), dim=dim, keepdim=keepdim)). | |
| Adopted from: https://github.com/pytorch/pytorch/issues/2591 | |
| """ | |
| # For a 1-D array x (any array along a single dimension), | |
| # log sum exp(x) = s + log sum exp(x - s) | |
| # with s = max(x) being a common choice. | |
| if axis is None: | |
| inputs = inputs.reshape([-1]) | |
| axis = 0 | |
| s = nd.max(inputs, axis=axis, keepdims=True) | |
| outputs = s + (inputs - s).exp().sum(axis=axis, keepdims=True).log() | |
| if not keepdims: | |
| outputs = nd.sum(outputs, axis=axis, keepdims=False) | |
| return outputs | |
| def get_activation(name): | |
| activation_dict = { | |
| 'relu':nd.relu, | |
| 'tanh':nd.tanh | |
| } | |
| return activation_dict[name] | |
| class Linear_BN(nn.Sequential): | |
| def __init__(self, F_in, F_out): | |
| super(Linear_BN, self).__init__() | |
| self.add(nn.Dense(F_out, in_units=F_in, use_bias=False)) | |
| self.add(BatchNorm(in_channels=F_out)) | |
| class GraphConv(nn.Block): | |
| def __init__(self, Fin, Fout, D): | |
| super(GraphConv, self).__init__() | |
| # model settings | |
| self.Fin = Fin | |
| self.Fout = Fout | |
| self.D = D | |
| # model parameters | |
| self.W = self.params.get('w', shape=(self.Fin * (self.D + 1), self.Fout), | |
| init=None, allow_deferred_init=False) | |
| def forward(self, X, A_list): | |
| try: | |
| assert len(A_list) == self.D | |
| except AssertionError as e: | |
| print(self.D, len(A_list)) | |
| raise e | |
| return EfficientGraphConvFn(A_list)(X, self.W.data(X.context)) | |
| class Policy(nn.Block): | |
| def __init__(self, F_in, F_h, N_A, N_B, k=1): | |
| super(Policy, self).__init__() | |
| self.F_in = F_in # number of input features for each atom | |
| self.F_h = F_h # number of context variables | |
| self.N_A = N_A # number of atom types | |
| self.N_B = N_B # number of bond types | |
| self.k = k # number of softmax used in the mixture | |
| with self.name_scope(): | |
| self.linear_h = Linear_BN(F_in * 2, self.F_h * k) | |
| self.linear_h_t = Linear_BN(F_in, self.F_h * k) | |
| self.linear_x = nn.Dense(self.N_B + self.N_B*self.N_A, in_units=self.F_h) | |
| self.linear_x_t = nn.Dense(1, in_units=self.F_h) | |
| if self.k > 1: | |
| self.linear_pi = nn.Dense(self.k, in_units=self.F_in) | |
| else: | |
| self.linear_pi = None | |
| def forward(self, X, NX, NX_rep, X_end=None): | |
| # segment mean for X | |
| if X_end is None: | |
| X_end = SegmentSumFn(NX_rep, NX.shape[0])(X)/nd.cast(fn.unsqueeze(NX, 1), 'float32') | |
| X = nd.concat(X, X_end[NX_rep, :], dim=1) | |
| X_h = nd.relu(self.linear_h(X)).reshape([-1, self.F_h]) | |
| X_h_end = nd.relu(self.linear_h_t(X_end)).reshape([-1, self.F_h]) | |
| X_x = nd.exp(self.linear_x(X_h)).reshape([-1, self.k, self.N_B + self.N_B*self.N_A]) | |
| X_x_end = nd.exp(self.linear_x_t(X_h_end)).reshape([-1, self.k, 1]) | |
| X_sum = nd.sum(SegmentSumFn(NX_rep, NX.shape[0])(X_x), -1, keepdims=True) + X_x_end | |
| X_sum_gathered = X_sum[NX_rep, :, :] | |
| X_softmax = X_x / X_sum_gathered | |
| X_softmax_end = X_x_end/ X_sum | |
| if self.k > 1: | |
| pi = unsqueeze(nd.softmax(self.linear_pi(X_end), axis=1), -1) | |
| pi_gathered = pi[NX_rep, :, :] | |
| X_softmax = nd.sum(X_softmax * pi_gathered, axis=1) | |
| X_softmax_end = nd.sum(X_softmax_end * pi, axis=1) | |
| else: | |
| X_softmax = squeeze(X_softmax, 1) | |
| X_softmax_end = squeeze(X_softmax_end, 1) | |
| # generate output probabilities | |
| connect, append = X_softmax[:, :self.N_B], X_softmax[:, self.N_B:] | |
| append = append.reshape([-1, self.N_A, self.N_B]) | |
| end = squeeze(X_softmax_end, -1) | |
| return append, connect, end | |
| class BatchNorm(nn.Block): | |
| def __init__(self, in_channels, momentum=0.9, eps=1e-5): | |
| super(BatchNorm, self).__init__() | |
| self.F = in_channels | |
| self.bn_weight = self.params.get('bn_weight', shape=(self.F,), init=mx.init.One(), | |
| allow_deferred_init=False) | |
| self.bn_bias = self.params.get('bn_bias', shape=(self.F,), init=mx.init.Zero(), | |
| allow_deferred_init=False) | |
| self.running_mean = self.params.get('running_mean', grad_req='null', | |
| shape=(self.F,), | |
| init=mx.init.Zero(), | |
| allow_deferred_init=False, | |
| differentiable=False) | |
| self.running_var = self.params.get('running_var', grad_req='null', | |
| shape=(self.F,), | |
| init=mx.init.One(), | |
| allow_deferred_init=False, | |
| differentiable=False) | |
| self.momentum = momentum | |
| self.eps = eps | |
| def forward(self, x): | |
| if autograd.is_training(): | |
| return nd.BatchNorm(x, | |
| gamma=self.bn_weight.data(x.context), | |
| beta=self.bn_bias.data(x.context), | |
| moving_mean=self.running_mean.data(x.context), | |
| moving_var=self.running_var.data(x.context), | |
| eps=self.eps, momentum=self.momentum, | |
| use_global_stats=False) | |
| else: | |
| return nd.BatchNorm(x, | |
| gamma=self.bn_weight.data(x.context), | |
| beta=self.bn_bias.data(x.context), | |
| moving_mean=self.running_mean.data(x.context), | |
| moving_var=self.running_var.data(x.context), | |
| eps=self.eps, momentum=self.momentum, | |
| use_global_stats=True) | |
| class MoleculeGenerator(nn.Block): | |
| __metaclass__ = ABCMeta | |
| def __init__(self, N_A, N_B, D, F_e, F_skip, F_c, Fh_policy, activation, | |
| *args, **kwargs): | |
| super(MoleculeGenerator, self).__init__() | |
| self.N_A = N_A | |
| self.N_B = N_B | |
| self.D = D | |
| self.F_e = F_e | |
| self.F_skip = F_skip | |
| self.F_c = list(F_c) if isinstance(F_c, tuple) else F_c | |
| self.Fh_policy = Fh_policy | |
| self.activation = get_activation(activation) | |
| with self.name_scope(): | |
| # embeddings | |
| self.embedding_atom = nn.Embedding(self.N_A, self.F_e) | |
| self.embedding_mask = nn.Embedding(3, self.F_e) | |
| # graph conv | |
| self._build_graph_conv(*args, **kwargs) | |
| # fully connected | |
| self.dense = nn.Sequential() | |
| for i, (f_in, f_out) in enumerate(zip([self.F_skip, ] + self.F_c[:-1], self.F_c)): | |
| self.dense.add(Linear_BN(f_in, f_out)) | |
| # policy | |
| self.policy_0 = self.params.get('policy_0', shape=[self.N_A, ], | |
| init=mx.init.Zero(), | |
| allow_deferred_init=False) | |
| self.policy_h = Policy(self.F_c[-1], self.Fh_policy, self.N_A, self.N_B) | |
| self.mode = 'loss' | |
| def _build_graph_conv(self, *args, **kwargs): | |
| raise NotImplementedError | |
| def _graph_conv_forward(self, X, A): | |
| raise NotImplementedError | |
| def _policy_0(self, ctx): | |
| policy_0 = nd.exp(self.policy_0.data(ctx)) | |
| policy_0 = policy_0/policy_0.sum() | |
| return policy_0 | |
| def _policy(self, X, A, NX, NX_rep, last_append_mask): | |
| # get initial embedding | |
| X = self.embedding_atom(X) + self.embedding_mask(last_append_mask) | |
| # convolution | |
| X = self._graph_conv_forward(X, A) | |
| # linear | |
| X = self.dense(X) | |
| # policy | |
| append, connect, end = self.policy_h(X, NX, NX_rep) | |
| return append, connect, end | |
| def _likelihood(self, init, append, connect, end, | |
| action_0, actions, iw_ids, log_p_sigma, | |
| batch_size, iw_size): | |
| # decompose action: | |
| action_type, node_type, edge_type, append_pos, connect_pos = \ | |
| actions[:, 0], actions[:, 1], actions[:, 2], actions[:, 3], actions[:, 4] | |
| _log_mask = lambda _x, _mask: _mask * nd.log(_x + 1e-10) + (1- _mask) * nd.zeros_like(_x) | |
| # init | |
| init = init.reshape([batch_size * iw_size, self.N_A]) | |
| index = nd.stack(nd.arange(action_0.shape[0], ctx=action_0.context, dtype='int64'), action_0, axis=0) | |
| loss_init = nd.log(nd.gather_nd(init, index) + 1e-10) | |
| # end | |
| loss_end = _log_mask(end, nd.cast(action_type == 2, 'float32')) | |
| # append | |
| index = nd.stack(append_pos, node_type, edge_type, axis=0) | |
| loss_append = _log_mask(nd.gather_nd(append, index), nd.cast(action_type == 0, 'float32')) | |
| # connect | |
| index = nd.stack(connect_pos, edge_type, axis=0) | |
| loss_connect = _log_mask(nd.gather_nd(connect, index), nd.cast(action_type == 1, 'float32')) | |
| # sum up results | |
| log_p_x = loss_end + loss_append + loss_connect | |
| log_p_x = squeeze(SegmentSumFn(iw_ids, batch_size*iw_size)(unsqueeze(log_p_x, -1)), -1) | |
| log_p_x = log_p_x + loss_init | |
| # reshape | |
| log_p_x = log_p_x.reshape([batch_size, iw_size]) | |
| log_p_sigma = log_p_sigma.reshape([batch_size, iw_size]) | |
| l = log_p_x - log_p_sigma | |
| l = logsumexp(l, axis=1) - math.log(float(iw_size)) | |
| return l | |
| def forward(self, *input): | |
| if self.mode=='loss' or self.mode=='likelihood': | |
| X, A, iw_ids, last_append_mask, \ | |
| NX, NX_rep, action_0, actions, log_p, \ | |
| batch_size, iw_size = input | |
| init = self._policy_0(X.context).tile([batch_size * iw_size, 1]) | |
| append, connect, end = self._policy(X, A, NX, NX_rep, last_append_mask) | |
| l = self._likelihood(init, append, connect, end, action_0, actions, iw_ids, log_p, batch_size, iw_size) | |
| if self.mode=='likelihood': | |
| return l | |
| else: | |
| return -l.mean() | |
| elif self.mode == 'decode_0': | |
| return self._policy_0(input[0]) | |
| elif self.mode == 'decode_step': | |
| X, A, NX, NX_rep, last_append_mask = input | |
| return self._policy(X, A, NX, NX_rep, last_append_mask) | |
| class MoleculeGenerator_RNN(MoleculeGenerator): | |
| __metaclass__ = ABCMeta | |
| def __init__(self, N_A, N_B, D, F_e, F_skip, F_c, Fh_policy, activation, | |
| N_rnn, *args, **kwargs): | |
| super(MoleculeGenerator_RNN, self).__init__(N_A, N_B, D, F_e, F_skip, F_c, Fh_policy, activation, | |
| *args, **kwargs) | |
| self.N_rnn = N_rnn | |
| with self.name_scope(): | |
| self.rnn = gluon.rnn.GRU(hidden_size=self.F_c[-1], | |
| num_layers=self.N_rnn, | |
| layout='NTC', input_size=self.F_c[-1] * 2) | |
| def _rnn_train(self, X, NX, NX_rep, graph_to_rnn, rnn_to_graph, NX_cum): | |
| X_avg = SegmentSumFn(NX_rep, NX.shape[0])(X) / nd.cast(unsqueeze(NX, 1), 'float32') | |
| X_curr = nd.take(X, indices=NX_cum-1) | |
| X = nd.concat(X_avg, X_curr, dim=1) | |
| # rnn | |
| X = nd.take(X, indices=graph_to_rnn) # batch_size, iw_size, length, num_features | |
| batch_size, iw_size, length, num_features = X.shape | |
| X = X.reshape([batch_size*iw_size, length, num_features]) | |
| X = self.rnn(X) | |
| X = X.reshape([batch_size, iw_size, length, -1]) | |
| X = nd.gather_nd(X, indices=rnn_to_graph) | |
| return X | |
| def _rnn_test(self, X, NX, NX_rep, NX_cum, h): | |
| # note: one partition for one molecule | |
| X_avg = SegmentSumFn(NX_rep, NX.shape[0])(X) / nd.cast(unsqueeze(NX, 1), 'float32') | |
| X_curr = nd.take(X, indices=NX_cum - 1) | |
| X = nd.concat(X_avg, X_curr, dim=1) # size: [NX, F_in * 2] | |
| # rnn | |
| X = unsqueeze(X, axis=1) | |
| X, h = self.rnn(X, h) | |
| X = squeeze(X, axis=1) | |
| return X, h | |
| def _policy(self, X, A, NX, NX_rep, last_append_mask, graph_to_rnn, rnn_to_graph, NX_cum): | |
| # get initial embedding | |
| X = self.embedding_atom(X) + self.embedding_mask(last_append_mask) | |
| # convolution | |
| X = self._graph_conv_forward(X, A) | |
| # linear | |
| X = self.dense(X) | |
| # rnn | |
| X_mol = self._rnn_train(X, NX, NX_rep, graph_to_rnn, rnn_to_graph, NX_cum) | |
| # policy | |
| append, connect, end = self.policy_h(X, NX, NX_rep, X_mol) | |
| return append, connect, end | |
| def _decode_step(self, X, A, NX, NX_rep, last_append_mask, NX_cum, h): | |
| # get initial embedding | |
| X = self.embedding_atom(X) + self.embedding_mask(last_append_mask) | |
| # convolution | |
| X = self._graph_conv_forward(X, A) | |
| # linear | |
| X = self.dense(X) | |
| # rnn | |
| X_mol, h = self._rnn_test(X, NX, NX_rep, NX_cum, h) | |
| # policy | |
| append, connect, end = self.policy_h(X, NX, NX_rep, X_mol) | |
| return append, connect, end, h | |
| def forward(self, *input): | |
| if self.mode=='loss' or self.mode=='likelihood': | |
| X, A, iw_ids, last_append_mask, \ | |
| NX, NX_rep, action_0, actions, log_p, \ | |
| batch_size, iw_size, \ | |
| graph_to_rnn, rnn_to_graph, NX_cum = input | |
| init = self._policy_0(X.context).tile([batch_size * iw_size, 1]) | |
| append, connect, end = self._policy(X, A, NX, NX_rep, last_append_mask, graph_to_rnn, rnn_to_graph, NX_cum) | |
| l = self._likelihood(init, append, connect, end, action_0, actions, iw_ids, log_p, batch_size, iw_size) | |
| if self.mode=='likelihood': | |
| return l | |
| else: | |
| return -l.mean() | |
| elif self.mode == 'decode_0': | |
| return self._policy_0(input[0]) | |
| elif self.mode == 'decode_step': | |
| X, A, NX, NX_rep, last_append_mask, NX_cum, h = input | |
| return self._decode_step(X, A, NX, NX_rep, last_append_mask, NX_cum, h) | |
| else: | |
| raise ValueError | |
| class _TwoLayerDense(nn.Block): | |
| def __init__(self, input_size, hidden_size, output_size): | |
| super(_TwoLayerDense, self).__init__() | |
| self.hidden_size = hidden_size | |
| self.output_size = output_size | |
| self.input_size = input_size | |
| with self.name_scope(): | |
| # config 1 | |
| self.input = nn.Dense(self.hidden_size, use_bias=False, in_units=self.input_size) | |
| self.bn_input = BatchNorm(in_channels=hidden_size) | |
| self.output = nn.Dense(self.output_size, use_bias=True, in_units=self.hidden_size) | |
| # config 2 | |
| #self.output = nn.Dense(self.output_size, use_bias=True, in_units=self.input_size) | |
| # config 3 | |
| #self.input1 = nn.Dense(self.hidden_size, use_bias=False, in_units=self.input_size) | |
| #self.bn_input1 = BatchNorm(in_channels=self.hidden_size) | |
| #self.input2 = nn.Dense(self.hidden_size, use_bias=False, in_units=self.hidden_size) | |
| #self.bn_input2 = BatchNorm(in_channels=self.hidden_size) | |
| #self.output = nn.Dense(self.output_size, use_bias=True, in_units=self.hidden_size) | |
| # config 4 | |
| #self.bn_input = BatchNorm(in_channels=self.input_size) | |
| #self.output = nn.Dense(self.output_size, use_bias=True, in_units=self.input_size) | |
| # config 5 | |
| #self.bn_input = BatchNorm(in_channels=1024) | |
| #self.output = nn.Dense(self.output_size, use_bias=True, in_units=1024) | |
| def forward(self, c): | |
| # config 1 | |
| return nd.softmax(self.output(nd.relu(self.bn_input(self.input(c)))), axis=-1) | |
| # config 2 | |
| #return nd.softmax(self.output(c), axis=-1) | |
| # config 3 | |
| #return nd.softmax(self.output(nd.relu(self.bn_input2(self.input2(nd.relu(self.bn_input1(self.input1(c))))))), axis=-1) | |
| # config 4 | |
| #return nd.softmax(self.output(nd.relu(self.bn_input(c))), axis=-1) | |
| # config 5 | |
| #return nd.softmax(self.output(c), axis=-1) | |
| class CMoleculeGenerator_RNN(MoleculeGenerator_RNN): | |
| __metaclass__ = ABCMeta | |
| def __init__(self, N_A, N_B, N_C, D, | |
| F_e, F_skip, F_c, Fh_policy, | |
| activation, N_rnn, | |
| *args, **kwargs): | |
| self.N_C = N_C # number of conditional variables | |
| super(CMoleculeGenerator_RNN, self).__init__(N_A, N_B, D, | |
| F_e, F_skip, F_c, Fh_policy, | |
| activation, N_rnn, | |
| *args, **kwargs) | |
| with self.name_scope(): | |
| self.dense_policy_0 = _TwoLayerDense(self.N_C, self.N_A * 3, self.N_A) | |
| def _graph_conv_forward(self, X, A, c, ids): | |
| raise NotImplementedError | |
| def _policy_0(self, c): | |
| return self.dense_policy_0(c) + 0.0 * self.policy_0.data(c.context) | |
| def _policy(self, X, A, NX, NX_rep, last_append_mask, | |
| graph_to_rnn, rnn_to_graph, NX_cum, | |
| c, ids): | |
| # get initial embedding | |
| X = self.embedding_atom(X) + self.embedding_mask(last_append_mask) | |
| # convolution | |
| X = self._graph_conv_forward(X, A, c, ids) | |
| # linear | |
| X = self.dense(X) | |
| # rnn | |
| X_mol = self._rnn_train(X, NX, NX_rep, graph_to_rnn, rnn_to_graph, NX_cum) | |
| # policy | |
| append, connect, end = self.policy_h(X, NX, NX_rep, X_mol) | |
| return append, connect, end | |
| def _decode_step(self, X, A, NX, NX_rep, last_append_mask, NX_cum, h, c, ids): | |
| # get initial embedding | |
| X = self.embedding_atom(X) + self.embedding_mask(last_append_mask) | |
| # convolution | |
| X = self._graph_conv_forward(X, A, c, ids) | |
| # linear | |
| X = self.dense(X) | |
| # rnn | |
| X_mol, h = self._rnn_test(X, NX, NX_rep, NX_cum, h) | |
| # policy | |
| append, connect, end = self.policy_h(X, NX, NX_rep, X_mol) | |
| return append, connect, end, h | |
| def forward(self, *input): | |
| if self.mode=='loss' or self.mode=='likelihood': | |
| X, A, iw_ids, last_append_mask, \ | |
| NX, NX_rep, action_0, actions, log_p, \ | |
| batch_size, iw_size, \ | |
| graph_to_rnn, rnn_to_graph, NX_cum, \ | |
| c, ids = input | |
| init = nd.tile(unsqueeze(self._policy_0(c), axis=1), [1, iw_size, 1]) | |
| append, connect, end = self._policy(X, A, NX, NX_rep, last_append_mask, | |
| graph_to_rnn, rnn_to_graph, NX_cum, | |
| c, ids) | |
| l = self._likelihood(init, append, connect, end, | |
| action_0, actions, iw_ids, log_p, | |
| batch_size, iw_size) | |
| if self.mode=='likelihood': | |
| return l | |
| else: | |
| return -l.mean() | |
| elif self.mode == 'decode_0': | |
| return self._policy_0(*input) | |
| elif self.mode == 'decode_step': | |
| X, A, NX, NX_rep, last_append_mask, NX_cum, h, c, ids = input | |
| return self._decode_step(X, A, NX, NX_rep, last_append_mask, NX_cum, h, c, ids) | |
| else: | |
| raise ValueError | |
| class CVanillaMolGen_RNN(CMoleculeGenerator_RNN): | |
| def __init__(self, N_A, N_B, N_C, D, | |
| F_e, F_h, F_skip, F_c, Fh_policy, | |
| activation, N_rnn, rename=False): | |
| self.rename = rename | |
| super(CVanillaMolGen_RNN, self).__init__(N_A, N_B, N_C, D, | |
| F_e, F_skip, F_c, Fh_policy, | |
| activation, N_rnn, | |
| F_h) | |
| def _build_graph_conv(self, F_h): | |
| self.F_h = list(F_h) if isinstance(F_h, tuple) else F_h | |
| self.conv, self.bn = [], [] | |
| for i, (f_in, f_out) in enumerate(zip([self.F_e] + self.F_h[:-1], self.F_h)): | |
| conv = GraphConv(f_in, f_out, self.N_B + self.D) | |
| self.conv.append(conv) | |
| self.register_child(conv) | |
| if i != 0: | |
| bn = BatchNorm(in_channels=f_in) | |
| self.register_child(bn) | |
| else: | |
| bn = None | |
| self.bn.append(bn) | |
| self.bn_skip = BatchNorm(in_channels=sum(self.F_h)) | |
| self.linear_skip = Linear_BN(sum(self.F_h), self.F_skip) | |
| # projectors for conditional variable (protein embedding) | |
| self.linear_c = [] | |
| for i, f_out in enumerate(self.F_h): | |
| if self.rename: | |
| linear_c = nn.Dense(f_out, use_bias=False, in_units=self.N_C, prefix='cond_{}'.format(i)) | |
| else: | |
| linear_c = nn.Dense(f_out, use_bias=False, in_units=self.N_C) | |
| self.register_child(linear_c) | |
| self.linear_c.append(linear_c) | |
| def _graph_conv_forward(self, X, A, c, ids): | |
| X_out = [X] | |
| for conv, bn, linear_c in zip(self.conv, self.bn, self.linear_c): | |
| X = X_out[-1] | |
| if bn is not None: | |
| X_out.append(conv(self.activation(bn(X)), A) + linear_c(c)[ids, :]) | |
| else: | |
| X_out.append(conv(X, A) + linear_c(c)[ids, :]) | |
| X_out = nd.concat(*X_out[1:], dim=1) | |
| return self.activation(self.linear_skip(self.activation(self.bn_skip(X_out)))) | |
| def _decode_step(X, A, NX, NA, last_action, finished, | |
| get_init, get_action, | |
| random=False, n_node_types=get_mol_spec().num_atom_types, | |
| n_edge_types=get_mol_spec().num_bond_types): | |
| if X is None: | |
| init = get_init() | |
| if random: | |
| X = [] | |
| for i in range(init.shape[0]): | |
| # init probabilities(for first atom) | |
| p = init[i, :] | |
| # Random sampling using init probability distribution | |
| selected_atom = np.random.choice(np.arange(init.shape[1]), 1, p=p)[0] | |
| X.append(selected_atom) | |
| X = np.array(X, dtype=np.int64) | |
| else: | |
| X = np.argmax(init, axis=1) | |
| A = np.zeros((0, 3), dtype=np.int64) | |
| NX = last_action = np.ones([X.shape[0]], dtype=np.int64) | |
| NA = np.zeros([X.shape[0]], dtype=np.int64) | |
| finished = np.array([False, ] * X.shape[0], dtype=np.bool) | |
| return X, A, NX, NA, last_action, finished | |
| else: | |
| X_u = X[np.repeat(np.logical_not(finished), NX)] | |
| A_u = A[np.repeat(np.logical_not(finished), NA), :] | |
| NX_u = NX[np.logical_not(finished)] | |
| NA_u = NA[np.logical_not(finished)] | |
| last_action_u = last_action[np.logical_not(finished)] | |
| # conv | |
| mol_ids_rep = NX_rep = np.repeat(np.arange(NX_u.shape[0]), NX_u) | |
| rep_ids_rep = np.zeros_like(mol_ids_rep) | |
| if A.shape[0] == 0: | |
| D_2 = D_3 = np.zeros((0, 2), dtype=np.int64) | |
| A_u = [np.zeros((0, 2), dtype=np.int64) for _ in range(get_mol_spec().num_bond_types)] | |
| A_u += [D_2, D_3] | |
| else: | |
| cumsum = np.cumsum(np.pad(NX_u, [[1, 0]], mode='constant')[:-1]) | |
| shift = np.repeat(cumsum, NA_u) | |
| A_u[:, :2] += np.stack([shift, ] * 2, axis=1) | |
| D_2, D_3 = get_d(A_u, X_u) | |
| A_u = [A_u[A_u[:, 2] == _i, :2] for _i in range(n_edge_types)] | |
| A_u += [D_2, D_3] | |
| mask = np.zeros([X_u.shape[0]], dtype=np.int64) | |
| last_append_index = np.cumsum(NX_u) - 1 | |
| mask[last_append_index] = np.where(last_action_u == 1, | |
| np.ones_like(last_append_index, dtype=np.int64), | |
| np.ones_like(last_append_index, dtype=np.int64) * 2) | |
| decode_input = [X_u, A_u, NX_u, NX_rep, mask, mol_ids_rep, rep_ids_rep] | |
| append, connect, end = get_action(decode_input) | |
| if A.shape[0] == 0: | |
| max_index = np.argmax(np.reshape(append, [-1, n_node_types * n_edge_types]), axis=1) | |
| atom_type, bond_type = np.unravel_index(max_index, [n_node_types, n_edge_types]) | |
| X = np.reshape(np.stack([X, atom_type], axis=1), [-1]) | |
| NX = np.array([2, ] * len(finished), dtype=np.int64) | |
| A = np.stack([np.zeros([len(finished), ], dtype=np.int64), | |
| np.ones([len(finished), ], dtype=np.int64), | |
| bond_type], axis=1) | |
| NA = np.ones([len(finished), ], dtype=np.int64) | |
| last_action = np.ones_like(NX, dtype=np.int64) | |
| else: | |
| # process for each molecule | |
| append, connect = np.split(append, np.cumsum(NX_u)), np.split(connect, np.cumsum(NX_u)) | |
| end = end.tolist() | |
| unfinished_ids = np.where(np.logical_not(finished))[0].tolist() | |
| cumsum = np.cumsum(NX) | |
| cumsum_a = np.cumsum(NA) | |
| X_insert = [] | |
| X_insert_ids = [] | |
| A_insert = [] | |
| A_insert_ids = [] | |
| finished_ids = [] | |
| for i, (unfinished_id, append_i, connect_i, end_i) \ | |
| in enumerate(zip(unfinished_ids, append, connect, end)): | |
| if random: | |
| def _rand_id(*_x): | |
| _x_reshaped = [np.reshape(_xi, [-1]) for _xi in _x] | |
| _x_length = np.array([_x_reshape_i.shape[0] for _x_reshape_i in _x_reshaped], | |
| dtype=np.int64) | |
| _begin = np.cumsum(np.pad(_x_length, [[1, 0]], mode='constant')[:-1]) | |
| _end = np.cumsum(_x_length) - 1 | |
| _p = np.concatenate(_x_reshaped) | |
| _p = _p / np.sum(_p) | |
| # Count NaN values | |
| num_nan = np.isnan(_p).sum() | |
| if num_nan > 0: | |
| print(f'Number of NaN values in _p: {num_nan}') | |
| _rand_index = np.random.choice(np.arange(len(_p)), 1)[0] | |
| else: | |
| _rand_index = np.random.choice(np.arange(_p.shape[0]), 1, p=_p)[0] | |
| _p_step = _p[_rand_index] | |
| _x_index = np.where(np.logical_and(_begin <= _rand_index, _end >= _rand_index))[0][0] | |
| _rand_index = _rand_index - _begin[_x_index] | |
| _rand_index = np.unravel_index(_rand_index, _x[_x_index].shape) | |
| return _x_index, _rand_index, _p_step | |
| action_type, action_index, p_step = _rand_id(append_i, connect_i, np.array([end_i])) | |
| else: | |
| _argmax = lambda _x: np.unravel_index(np.argmax(_x), _x.shape) | |
| append_id, append_val = _argmax(append_i), np.max(append_i) | |
| connect_id, connect_val = _argmax(connect_i), np.max(connect_i) | |
| end_val = end_i | |
| if end_val >= append_val and end_val >= connect_val: | |
| action_type = 2 | |
| action_index = None | |
| elif append_val >= connect_val and append_val >= end_val: | |
| action_type = 0 | |
| action_index = append_id | |
| else: | |
| action_type = 1 | |
| action_index = connect_id | |
| if action_type == 2: | |
| # finish growth | |
| finished_ids.append(unfinished_id) | |
| elif action_type == 0: | |
| # append action | |
| append_pos, atom_type, bond_type = action_index | |
| X_insert.append(atom_type) | |
| X_insert_ids.append(unfinished_id) | |
| A_insert.append([append_pos, NX[unfinished_id], bond_type]) | |
| A_insert_ids.append(unfinished_id) | |
| else: | |
| # connect | |
| connect_ps, bond_type = action_index | |
| A_insert.append([NX[unfinished_id] - 1, connect_ps, bond_type]) | |
| A_insert_ids.append(unfinished_id) | |
| if len(A_insert_ids) > 0: | |
| A = np.insert(A, cumsum_a[A_insert_ids], A_insert, axis=0) | |
| NA[A_insert_ids] += 1 | |
| last_action[A_insert_ids] = 0 | |
| if len(X_insert_ids) > 0: | |
| X = np.insert(X, cumsum[X_insert_ids], X_insert, axis=0) | |
| NX[X_insert_ids] += 1 | |
| last_action[X_insert_ids] = 1 | |
| if len(finished_ids) > 0: | |
| finished[finished_ids] = True | |
| # print finished | |
| return X, A, NX, NA, last_action, finished | |
| class Builder(object, metaclass=ABCMeta): | |
| def __init__(self, model_loc, gpu_id=None): | |
| with open(os.path.join(model_loc, 'configs.json')) as f: | |
| configs = json.load(f) | |
| self.mdl = self.__class__._get_model(configs) | |
| self.ctx = mx.gpu(gpu_id) if gpu_id is not None else mx.cpu() | |
| self.mdl.load_parameters(os.path.join(model_loc, 'ckpt.params'), ctx=self.ctx, allow_missing=True) | |
| def _get_model(configs): | |
| raise NotImplementedError | |
| def sample(self, num_samples, *args, **kwargs): | |
| raise NotImplementedError | |
| class CVanilla_RNN_Builder(Builder): | |
| def _get_model(configs): | |
| return CVanillaMolGen_RNN(get_mol_spec().num_atom_types, get_mol_spec().num_bond_types, D=2, **configs) | |
| def sample(self, num_samples, c, output_type='mol', sanitize=True, random=True): | |
| if len(c.shape) == 1: | |
| c = np.stack([c, ]*num_samples, axis=0) | |
| with autograd.predict_mode(): | |
| # step one | |
| finished = [False, ] * num_samples | |
| def get_init(): | |
| self.mdl.mode = 'decode_0' | |
| _c = nd.array(c, dtype='float32', ctx=self.ctx) | |
| init = self.mdl(_c).asnumpy() | |
| return init | |
| outputs = _decode_step(X=None, A=None, NX=None, NA=None, last_action=None, finished=finished, | |
| get_init=get_init, get_action=None, | |
| n_node_types=self.mdl.N_A, n_edge_types=self.mdl.N_B, | |
| random=random) | |
| # If outputs is None | |
| if outputs is None: | |
| return None | |
| X, A, NX, NA, last_action, finished = outputs | |
| count = 1 | |
| h = np.zeros([self.mdl.N_rnn, num_samples, self.mdl.F_c[-1]], dtype=np.float32) | |
| while not np.all(finished) and count < 100: | |
| def get_action(inputs): | |
| self.mdl.mode = 'decode_step' | |
| _h = nd.array(h[:, np.logical_not(finished), :], ctx=self.ctx, dtype='float32') | |
| _c = nd.array(c[np.logical_not(finished), :], ctx=self.ctx, dtype='float32') | |
| _X, _A_sparse, _NX, _NX_rep, _mask, _NX_cum = self.to_nd(inputs) | |
| _append, _connect, _end, _h = self.mdl(_X, _A_sparse, _NX, _NX_rep, _mask, _NX_cum, _h, _c, _NX_rep) | |
| h[:, np.logical_not(finished), :] = _h[0].asnumpy() | |
| return _append.asnumpy(), _connect.asnumpy(), _end.asnumpy() | |
| outputs = _decode_step(X, A, NX, NA, last_action, finished, | |
| get_init=None, get_action=get_action, | |
| n_node_types=self.mdl.N_A, n_edge_types=self.mdl.N_B, | |
| random=random) | |
| X, A, NX, NA, last_action, finished = outputs | |
| count += 1 | |
| graph_list = [] | |
| cumsum_X_ = np.cumsum(np.pad(NX, [[1, 0]], mode='constant')).tolist() | |
| cumsum_A_ = np.cumsum(np.pad(NA, [[1, 0]], mode='constant')).tolist() | |
| for cumsum_A_pre, cumsum_A_post, \ | |
| cumsum_X_pre, cumsum_X_post in zip(cumsum_A_[:-1], cumsum_A_[1:], | |
| cumsum_X_[:-1], cumsum_X_[1:]): | |
| graph_list.append([X[cumsum_X_pre:cumsum_X_post], A[cumsum_A_pre:cumsum_A_post, :]]) | |
| if output_type=='graph': | |
| return graph_list | |
| elif output_type == 'mol': | |
| return get_mol_from_graph_list(graph_list, sanitize) | |
| elif output_type == 'smiles': | |
| mol_list = get_mol_from_graph_list(graph_list, sanitize=True) | |
| smiles_list = [Chem.MolToSmiles(m) if m is not None else None for m in mol_list] | |
| return smiles_list | |
| else: | |
| raise ValueError('Unrecognized output type') | |
| def to_nd(self, inputs): | |
| X, A, NX, NX_rep, mask = inputs[:-2] | |
| NX_cum = np.cumsum(NX) | |
| # convert to ndarray | |
| _to_ndarray = lambda _x: nd.array(_x, self.ctx, 'int64') | |
| X, NX, NX_rep, mask, NX_cum = \ | |
| _to_ndarray(X), _to_ndarray(NX), _to_ndarray(NX_rep), _to_ndarray(mask), _to_ndarray(NX_cum) | |
| A_sparse = [] | |
| for _A_i in A: | |
| if _A_i.shape[0] == 0: | |
| A_sparse.append(None) | |
| else: | |
| # transpose may not be supported in gpu | |
| _A_i = np.concatenate([_A_i, _A_i[:, [1, 0]]], axis=0) | |
| # construct csr matrix ... | |
| _data = np.ones((_A_i.shape[0],), dtype=np.float32) | |
| _row, _col = _A_i[:, 0], _A_i[:, 1] | |
| _A_sparse_i = nd.sparse.csr_matrix((_data, (_row, _col)), | |
| shape=tuple([int(X.shape[0]), ] * 2), | |
| ctx=self.ctx, dtype='float32') | |
| # append to list | |
| A_sparse.append(_A_sparse_i) | |
| return X, A_sparse, NX, NX_rep, mask, NX_cum | |