Spaces:
Sleeping
Sleeping
| import networkx as nx | |
| import numpy as np | |
| import torch | |
| from rdkit import Chem | |
| from torch_geometric.utils import from_smiles | |
| from torch_geometric.data import Data | |
| from deepscreen.data.featurizers.categorical import one_of_k_encoding_unk, one_of_k_encoding | |
| from deepscreen.utils import get_logger | |
| log = get_logger(__name__) | |
| def atom_features(atom, explicit_H=False, use_chirality=True): | |
| """ | |
| Adapted from TransformerCPI 2.0 | |
| """ | |
| symbol = ['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I', 'other'] # 10-dim | |
| degree = [0, 1, 2, 3, 4, 5, 6] # 7-dim | |
| hybridization_type = [Chem.rdchem.HybridizationType.SP, | |
| Chem.rdchem.HybridizationType.SP2, | |
| Chem.rdchem.HybridizationType.SP3, | |
| Chem.rdchem.HybridizationType.SP3D, | |
| Chem.rdchem.HybridizationType.SP3D2, | |
| 'other'] # 6-dim | |
| # 10+7+2+6+1=26 | |
| results = one_of_k_encoding_unk(atom.GetSymbol(), symbol) + \ | |
| one_of_k_encoding(atom.GetDegree(), degree) + \ | |
| [atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] + \ | |
| one_of_k_encoding_unk(atom.GetHybridization(), hybridization_type) + [atom.GetIsAromatic()] | |
| # In case of explicit hydrogen(QM8, QM9), avoid calling `GetTotalNumHs` | |
| # 26+5=31 | |
| if not explicit_H: | |
| results = results + one_of_k_encoding_unk(atom.GetTotalNumHs(), | |
| [0, 1, 2, 3, 4]) | |
| # 31+3=34 | |
| if use_chirality: | |
| try: | |
| results = results + one_of_k_encoding_unk( | |
| atom.GetProp('_CIPCode'), | |
| ['R', 'S']) + [atom.HasProp('_ChiralityPossible')] | |
| except: | |
| results = results + [False, False] + [atom.HasProp('_ChiralityPossible')] | |
| return np.array(results) | |
| def bond_features(bond): | |
| bt = bond.GetBondType() | |
| return np.array( | |
| [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, | |
| bt == Chem.rdchem.BondType.AROMATIC, bond.GetIsConjugated(), bond.IsInRing()]) | |
| def smiles_to_graph_pyg(smiles): | |
| """ | |
| Convert SMILES to graph with the default method defined by PyTorch Geometric | |
| """ | |
| try: | |
| return from_smiles(smiles) | |
| except Exception as e: | |
| log.warning(f"Failed to featurize the following SMILES to graph: {smiles} due to {str(e)}") | |
| return None | |
| def smiles_to_graph(smiles, atom_features: callable = atom_features): | |
| """ | |
| Convert SMILES to graph with custom atom_features | |
| """ | |
| try: | |
| mol = Chem.MolFromSmiles(smiles) | |
| features = [] | |
| for atom in mol.GetAtoms(): | |
| feature = atom_features(atom) | |
| features.append(feature / sum(feature)) | |
| features = np.array(features) | |
| edges = [] | |
| for bond in mol.GetBonds(): | |
| edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]) | |
| g = nx.Graph(edges).to_directed() | |
| if len(edges) == 0: | |
| edge_index = [[0, 0]] | |
| else: | |
| edge_index = [] | |
| for e1, e2 in g.edges: | |
| edge_index.append([e1, e2]) | |
| return Data(x=torch.Tensor(features), | |
| edge_index=torch.LongTensor(edge_index).transpose(0, 1)) | |
| except Exception as e: | |
| log.warning(f"Failed to convert SMILES ({smiles}) to graph due to {str(e)}") | |
| return None | |
| # features = [] | |
| # for atom in mol.GetAtoms(): | |
| # feature = atom_features(atom) | |
| # features.append(feature / sum(feature)) | |
| # | |
| # edge_indices = [] | |
| # for bond in mol.GetBonds(): | |
| # i = bond.GetBeginAtomIdx() | |
| # j = bond.GetEndAtomIdx() | |
| # edge_indices += [[i, j], [j, i]] | |
| # | |
| # edge_index = torch.tensor(edge_indices) | |
| # edge_index = edge_index.t().to(torch.long).view(2, -1) | |
| # | |
| # if edge_index.numel() > 0: # Sort indices. | |
| # perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort() | |
| # edge_index = edge_index[:, perm] | |
| # | |
| def smiles_to_mol_features(smiles, num_atom_feat: callable): | |
| try: | |
| mol = Chem.MolFromSmiles(smiles) | |
| num_atom_feat = len(atom_features(mol.GetAtoms()[0])) | |
| atom_feat = np.zeros((mol.GetNumAtoms(), num_atom_feat)) | |
| for atom in mol.GetAtoms(): | |
| atom_feat[atom.GetIdx(), :] = atom_features(atom) | |
| adj = Chem.GetAdjacencyMatrix(mol) | |
| adj_mat = np.array(adj) | |
| return atom_feat, adj_mat | |
| except Exception as e: | |
| log.warning(f"Failed to featurize the following SMILES to molecular features: {smiles} due to {str(e)}") | |
| return None |