|
|
import numpy as np |
|
|
from rdkit import Chem |
|
|
from rdkit.Chem import AllChem, MACCSkeys, rdMolDescriptors as rdDesc |
|
|
from utils import * |
|
|
import torch |
|
|
import copy |
|
|
from . import subgraphfp as subfp |
|
|
|
|
|
PERIODIC_TABLE = Chem.GetPeriodicTable() |
|
|
POSSIBLE_ATOMS = ['H', 'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br','I', 'B'] |
|
|
HYBRIDS = [ Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, |
|
|
Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2] |
|
|
CHIRALS = [ Chem.rdchem.ChiralType.CHI_UNSPECIFIED, Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, |
|
|
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, Chem.rdchem.ChiralType.CHI_OTHER] |
|
|
BOND_TYPES = [ Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC ] |
|
|
|
|
|
def one_of_k_encoding(x, allowable_set): |
|
|
if x not in allowable_set: |
|
|
raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set)) |
|
|
return list(map(lambda s: x == s, allowable_set)) |
|
|
|
|
|
def one_of_k_encoding_unk(x, allowable_set): |
|
|
"""Maps inputs not in the allowable set to the last element.""" |
|
|
if x not in allowable_set: |
|
|
x = allowable_set[-1] |
|
|
|
|
|
return list(map(lambda s: x == s, allowable_set)) |
|
|
|
|
|
def calc_atom_features_onehot(atom, feature): |
|
|
''' |
|
|
Method that computes atom level features from rdkit atom object |
|
|
''' |
|
|
atom_features = one_of_k_encoding_unk(atom.GetSymbol(), POSSIBLE_ATOMS) |
|
|
atom_features += one_of_k_encoding_unk(atom.GetExplicitValence(), list(range(7))) |
|
|
atom_features += one_of_k_encoding_unk(atom.GetImplicitValence(), list(range(7))) |
|
|
atom_features += one_of_k_encoding_unk(atom.GetTotalNumHs(), list(range(5))) |
|
|
atom_features += one_of_k_encoding_unk(atom.GetNumRadicalElectrons(), list(range(5))) |
|
|
atom_features += one_of_k_encoding_unk(atom.GetTotalDegree(), list(range(7))) |
|
|
atom_features += one_of_k_encoding_unk(atom.GetFormalCharge(), list(range(-2, 3))) |
|
|
atom_features += one_of_k_encoding_unk(atom.GetHybridization(), HYBRIDS) |
|
|
atom_features += one_of_k_encoding_unk(atom.GetIsAromatic(), [False, True]) |
|
|
atom_features += one_of_k_encoding_unk(atom.IsInRing(), [False, True]) |
|
|
atom_features += one_of_k_encoding_unk(atom.GetChiralTag(), CHIRALS) |
|
|
atom_features += one_of_k_encoding_unk(atom.HasProp('_CIPCode'), ['R', 'S']) |
|
|
atom_features += [PERIODIC_TABLE.GetRvdw(atom.GetSymbol())] |
|
|
atom_features += [atom.HasProp('_ChiralityPossible')] |
|
|
atom_features += [atom.GetAtomicNum()] |
|
|
atom_features += [atom.GetMass() * 0.01] |
|
|
atom_features += [atom.GetDegree()] |
|
|
atom_features += [int(i) for i in list('{0:06b}'.format(feature))] |
|
|
|
|
|
return atom_features |
|
|
|
|
|
def calc_adjacent_tensor(bonds, atom_num, with_ring_conj=False): |
|
|
''' |
|
|
Method that constructs a AdjecentTensor with many AdjecentMatrics |
|
|
:param bonds: bonds of a rdkit mol |
|
|
:param atom_num: the atom number of the rdkit mol |
|
|
:param with_ring_conj: should the AdjecentTensor contains bond in ring and |
|
|
is conjugated info |
|
|
:return: AdjecentTensor A shaped [N, F, N], where N is atom number and F is bond types |
|
|
''' |
|
|
bond_types = len(BOND_TYPES) |
|
|
if with_ring_conj: |
|
|
bond_types += 2 |
|
|
|
|
|
A = np.zeros([atom_num, bond_types, atom_num]) |
|
|
|
|
|
for bond in bonds: |
|
|
b, e = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() |
|
|
try: |
|
|
bond_type = BOND_TYPES.index(bond.GetBondType()) |
|
|
A[b, bond_type, e] = 1 |
|
|
A[e, bond_type, b] = 1 |
|
|
if with_ring_conj: |
|
|
if bond.IsInRing(): |
|
|
A[b, bond_types-2, e] = 1 |
|
|
A[e, bond_types-2, b] = 1 |
|
|
if bond.GetIsConjugated(): |
|
|
A[b, bond_types-1, e] = 1 |
|
|
A[e, bond_types-1, b] = 1 |
|
|
except: |
|
|
pass |
|
|
return A |
|
|
|
|
|
def calc_data_from_smile(smiles, addh=False, with_ring_conj=False, with_atom_feats=True, with_submol_fp=True, radius=2): |
|
|
''' |
|
|
Method that constructs the data of a molecular. |
|
|
:param smiles: SMILES representation of a molecule |
|
|
:param addh: should we add all the Hs of the mol |
|
|
:param with_ring_conj: should the AdjecentTensor contains bond in ring and |
|
|
is conjugated info |
|
|
:return: V, A, global_state, mol_size, subgraph_size |
|
|
''' |
|
|
mol = Chem.MolFromSmiles(smiles, sanitize=True) |
|
|
|
|
|
|
|
|
if addh: |
|
|
mol = Chem.AddHs(mol) |
|
|
|
|
|
|
|
|
|
|
|
mol_size = torch.IntTensor([mol.GetNumAtoms()]) |
|
|
|
|
|
V = [] |
|
|
|
|
|
if with_atom_feats: |
|
|
features = rdDesc.GetFeatureInvariants(mol) |
|
|
|
|
|
submoldict = {} |
|
|
if with_submol_fp: |
|
|
atoms, submols = subfp.get_atom_submol_radn(mol, radius, sanitize=True) |
|
|
submoldict = dict(zip([a.GetIdx() for a in atoms], submols)) |
|
|
|
|
|
for i in range(mol.GetNumAtoms()): |
|
|
atom_i = mol.GetAtomWithIdx(i) |
|
|
if with_atom_feats: |
|
|
atom_i_features = calc_atom_features_onehot(atom_i, features[i]) |
|
|
else: |
|
|
atom_i_features = [] |
|
|
|
|
|
if with_submol_fp: |
|
|
submol = submoldict[i] |
|
|
|
|
|
submolfp = subfp.gen_fps_from_mol(submol) |
|
|
atom_i_features.extend(submolfp) |
|
|
|
|
|
V.append(atom_i_features) |
|
|
|
|
|
V = torch.FloatTensor(V) |
|
|
|
|
|
if len(V.shape) != 2: |
|
|
return None |
|
|
|
|
|
A = calc_adjacent_tensor(mol.GetBonds(), mol.GetNumAtoms(), with_ring_conj) |
|
|
A = torch.FloatTensor(A) |
|
|
|
|
|
return {'V': V, 'A': A, 'mol_size': mol_size} |
|
|
|