File size: 5,618 Bytes
d5233a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem, MACCSkeys, rdMolDescriptors as rdDesc
from utils import *
import torch
import copy
from . import subgraphfp as subfp
PERIODIC_TABLE = Chem.GetPeriodicTable()
POSSIBLE_ATOMS = ['H', 'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br','I', 'B']
HYBRIDS = [ Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2]
CHIRALS = [ Chem.rdchem.ChiralType.CHI_UNSPECIFIED, Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, Chem.rdchem.ChiralType.CHI_OTHER]
BOND_TYPES = [ Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC ]
def one_of_k_encoding(x, allowable_set):
if x not in allowable_set:
raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set))
return list(map(lambda s: x == s, allowable_set))
def one_of_k_encoding_unk(x, allowable_set):
"""Maps inputs not in the allowable set to the last element."""
if x not in allowable_set:
x = allowable_set[-1]
return list(map(lambda s: x == s, allowable_set))
def calc_atom_features_onehot(atom, feature):
'''
Method that computes atom level features from rdkit atom object
'''
atom_features = one_of_k_encoding_unk(atom.GetSymbol(), POSSIBLE_ATOMS)
atom_features += one_of_k_encoding_unk(atom.GetExplicitValence(), list(range(7)))
atom_features += one_of_k_encoding_unk(atom.GetImplicitValence(), list(range(7)))
atom_features += one_of_k_encoding_unk(atom.GetTotalNumHs(), list(range(5)))
atom_features += one_of_k_encoding_unk(atom.GetNumRadicalElectrons(), list(range(5)))
atom_features += one_of_k_encoding_unk(atom.GetTotalDegree(), list(range(7)))
atom_features += one_of_k_encoding_unk(atom.GetFormalCharge(), list(range(-2, 3)))
atom_features += one_of_k_encoding_unk(atom.GetHybridization(), HYBRIDS)
atom_features += one_of_k_encoding_unk(atom.GetIsAromatic(), [False, True])
atom_features += one_of_k_encoding_unk(atom.IsInRing(), [False, True])
atom_features += one_of_k_encoding_unk(atom.GetChiralTag(), CHIRALS)
atom_features += one_of_k_encoding_unk(atom.HasProp('_CIPCode'), ['R', 'S'])
atom_features += [PERIODIC_TABLE.GetRvdw(atom.GetSymbol())]
atom_features += [atom.HasProp('_ChiralityPossible')]
atom_features += [atom.GetAtomicNum()]
atom_features += [atom.GetMass() * 0.01]
atom_features += [atom.GetDegree()]
atom_features += [int(i) for i in list('{0:06b}'.format(feature))]
return atom_features
def calc_adjacent_tensor(bonds, atom_num, with_ring_conj=False):
'''
Method that constructs a AdjecentTensor with many AdjecentMatrics
:param bonds: bonds of a rdkit mol
:param atom_num: the atom number of the rdkit mol
:param with_ring_conj: should the AdjecentTensor contains bond in ring and
is conjugated info
:return: AdjecentTensor A shaped [N, F, N], where N is atom number and F is bond types
'''
bond_types = len(BOND_TYPES)
if with_ring_conj:
bond_types += 2
A = np.zeros([atom_num, bond_types, atom_num])
for bond in bonds:
b, e = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
try:
bond_type = BOND_TYPES.index(bond.GetBondType())
A[b, bond_type, e] = 1
A[e, bond_type, b] = 1
if with_ring_conj:
if bond.IsInRing():
A[b, bond_types-2, e] = 1
A[e, bond_types-2, b] = 1
if bond.GetIsConjugated():
A[b, bond_types-1, e] = 1
A[e, bond_types-1, b] = 1
except:
pass
return A
def calc_data_from_smile(smiles, addh=False, with_ring_conj=False, with_atom_feats=True, with_submol_fp=True, radius=2):
'''
Method that constructs the data of a molecular.
:param smiles: SMILES representation of a molecule
:param addh: should we add all the Hs of the mol
:param with_ring_conj: should the AdjecentTensor contains bond in ring and
is conjugated info
:return: V, A, global_state, mol_size, subgraph_size
'''
mol = Chem.MolFromSmiles(smiles, sanitize=True)
#mol.UpdatePropertyCache(strict=False)
if addh:
mol = Chem.AddHs(mol)
#else:
# mol = Chem.RemoveHs(mol, sanitize=False)
mol_size = torch.IntTensor([mol.GetNumAtoms()])
V = []
if with_atom_feats:
features = rdDesc.GetFeatureInvariants(mol)
submoldict = {}
if with_submol_fp:
atoms, submols = subfp.get_atom_submol_radn(mol, radius, sanitize=True)
submoldict = dict(zip([a.GetIdx() for a in atoms], submols))
for i in range(mol.GetNumAtoms()):
atom_i = mol.GetAtomWithIdx(i)
if with_atom_feats:
atom_i_features = calc_atom_features_onehot(atom_i, features[i])
else:
atom_i_features = []
if with_submol_fp:
submol = submoldict[i]
#print(Chem.MolToSmiles(submol))
submolfp = subfp.gen_fps_from_mol(submol)
atom_i_features.extend(submolfp)
V.append(atom_i_features)
V = torch.FloatTensor(V)
if len(V.shape) != 2:
return None
A = calc_adjacent_tensor(mol.GetBonds(), mol.GetNumAtoms(), with_ring_conj)
A = torch.FloatTensor(A)
return {'V': V, 'A': A, 'mol_size': mol_size}
|