CMSSP / code /GNN /featurizer.py
OliXio's picture
Upload 13 files
d5233a9 verified
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem, MACCSkeys, rdMolDescriptors as rdDesc
from utils import *
import torch
import copy
from . import subgraphfp as subfp
PERIODIC_TABLE = Chem.GetPeriodicTable()
POSSIBLE_ATOMS = ['H', 'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br','I', 'B']
HYBRIDS = [ Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2]
CHIRALS = [ Chem.rdchem.ChiralType.CHI_UNSPECIFIED, Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, Chem.rdchem.ChiralType.CHI_OTHER]
BOND_TYPES = [ Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC ]
def one_of_k_encoding(x, allowable_set):
if x not in allowable_set:
raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set))
return list(map(lambda s: x == s, allowable_set))
def one_of_k_encoding_unk(x, allowable_set):
"""Maps inputs not in the allowable set to the last element."""
if x not in allowable_set:
x = allowable_set[-1]
return list(map(lambda s: x == s, allowable_set))
def calc_atom_features_onehot(atom, feature):
'''
Method that computes atom level features from rdkit atom object
'''
atom_features = one_of_k_encoding_unk(atom.GetSymbol(), POSSIBLE_ATOMS)
atom_features += one_of_k_encoding_unk(atom.GetExplicitValence(), list(range(7)))
atom_features += one_of_k_encoding_unk(atom.GetImplicitValence(), list(range(7)))
atom_features += one_of_k_encoding_unk(atom.GetTotalNumHs(), list(range(5)))
atom_features += one_of_k_encoding_unk(atom.GetNumRadicalElectrons(), list(range(5)))
atom_features += one_of_k_encoding_unk(atom.GetTotalDegree(), list(range(7)))
atom_features += one_of_k_encoding_unk(atom.GetFormalCharge(), list(range(-2, 3)))
atom_features += one_of_k_encoding_unk(atom.GetHybridization(), HYBRIDS)
atom_features += one_of_k_encoding_unk(atom.GetIsAromatic(), [False, True])
atom_features += one_of_k_encoding_unk(atom.IsInRing(), [False, True])
atom_features += one_of_k_encoding_unk(atom.GetChiralTag(), CHIRALS)
atom_features += one_of_k_encoding_unk(atom.HasProp('_CIPCode'), ['R', 'S'])
atom_features += [PERIODIC_TABLE.GetRvdw(atom.GetSymbol())]
atom_features += [atom.HasProp('_ChiralityPossible')]
atom_features += [atom.GetAtomicNum()]
atom_features += [atom.GetMass() * 0.01]
atom_features += [atom.GetDegree()]
atom_features += [int(i) for i in list('{0:06b}'.format(feature))]
return atom_features
def calc_adjacent_tensor(bonds, atom_num, with_ring_conj=False):
'''
Method that constructs a AdjecentTensor with many AdjecentMatrics
:param bonds: bonds of a rdkit mol
:param atom_num: the atom number of the rdkit mol
:param with_ring_conj: should the AdjecentTensor contains bond in ring and
is conjugated info
:return: AdjecentTensor A shaped [N, F, N], where N is atom number and F is bond types
'''
bond_types = len(BOND_TYPES)
if with_ring_conj:
bond_types += 2
A = np.zeros([atom_num, bond_types, atom_num])
for bond in bonds:
b, e = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
try:
bond_type = BOND_TYPES.index(bond.GetBondType())
A[b, bond_type, e] = 1
A[e, bond_type, b] = 1
if with_ring_conj:
if bond.IsInRing():
A[b, bond_types-2, e] = 1
A[e, bond_types-2, b] = 1
if bond.GetIsConjugated():
A[b, bond_types-1, e] = 1
A[e, bond_types-1, b] = 1
except:
pass
return A
def calc_data_from_smile(smiles, addh=False, with_ring_conj=False, with_atom_feats=True, with_submol_fp=True, radius=2):
'''
Method that constructs the data of a molecular.
:param smiles: SMILES representation of a molecule
:param addh: should we add all the Hs of the mol
:param with_ring_conj: should the AdjecentTensor contains bond in ring and
is conjugated info
:return: V, A, global_state, mol_size, subgraph_size
'''
mol = Chem.MolFromSmiles(smiles, sanitize=True)
#mol.UpdatePropertyCache(strict=False)
if addh:
mol = Chem.AddHs(mol)
#else:
# mol = Chem.RemoveHs(mol, sanitize=False)
mol_size = torch.IntTensor([mol.GetNumAtoms()])
V = []
if with_atom_feats:
features = rdDesc.GetFeatureInvariants(mol)
submoldict = {}
if with_submol_fp:
atoms, submols = subfp.get_atom_submol_radn(mol, radius, sanitize=True)
submoldict = dict(zip([a.GetIdx() for a in atoms], submols))
for i in range(mol.GetNumAtoms()):
atom_i = mol.GetAtomWithIdx(i)
if with_atom_feats:
atom_i_features = calc_atom_features_onehot(atom_i, features[i])
else:
atom_i_features = []
if with_submol_fp:
submol = submoldict[i]
#print(Chem.MolToSmiles(submol))
submolfp = subfp.gen_fps_from_mol(submol)
atom_i_features.extend(submolfp)
V.append(atom_i_features)
V = torch.FloatTensor(V)
if len(V.shape) != 2:
return None
A = calc_adjacent_tensor(mol.GetBonds(), mol.GetNumAtoms(), with_ring_conj)
A = torch.FloatTensor(A)
return {'V': V, 'A': A, 'mol_size': mol_size}