Spaces:
Sleeping
Sleeping
| import copy | |
| import torch | |
| from io import BytesIO | |
| from openbabel import openbabel | |
| from torch_geometric.utils import to_networkx | |
| from torch_geometric.data import Data | |
| from torch_scatter import scatter | |
| from rdkit import Chem | |
| from rdkit.Chem.rdchem import Mol, HybridizationType, BondType | |
| from rdkit.Chem.rdchem import BondType as BT | |
| BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())} | |
| BOND_NAMES = {i: t for i, t in enumerate(BT.names.keys())} | |
| def rdmol_to_data(mol, smiles=None): | |
| assert mol.GetNumConformers() == 1 | |
| N = mol.GetNumAtoms() | |
| pos = torch.tensor(mol.GetConformer(0).GetPositions(), dtype=torch.float32) | |
| atomic_number = [] | |
| aromatic = [] | |
| sp = [] | |
| sp2 = [] | |
| sp3 = [] | |
| num_hs = [] | |
| for atom in mol.GetAtoms(): | |
| atomic_number.append(atom.GetAtomicNum()) | |
| aromatic.append(1 if atom.GetIsAromatic() else 0) | |
| hybridization = atom.GetHybridization() | |
| sp.append(1 if hybridization == HybridizationType.SP else 0) | |
| sp2.append(1 if hybridization == HybridizationType.SP2 else 0) | |
| sp3.append(1 if hybridization == HybridizationType.SP3 else 0) | |
| z = torch.tensor(atomic_number, dtype=torch.long) | |
| row, col, edge_type = [], [], [] | |
| for bond in mol.GetBonds(): | |
| start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() | |
| row += [start, end] | |
| col += [end, start] | |
| edge_type += 2 * [BOND_TYPES[bond.GetBondType()]] | |
| edge_index = torch.tensor([row, col], dtype=torch.long) | |
| edge_type = torch.tensor(edge_type) | |
| perm = (edge_index[0] * N + edge_index[1]).argsort() | |
| edge_index = edge_index[:, perm] | |
| edge_type = edge_type[perm] | |
| row, col = edge_index | |
| hs = (z == 1).to(torch.float32) | |
| num_hs = scatter(hs[row], col, dim_size=N, reduce='sum').tolist() | |
| if smiles is None: | |
| smiles = Chem.MolToSmiles(Chem.RemoveHs(mol)) | |
| data = Data(atom_type=z, pos=pos, edge_index=edge_index, edge_type=edge_type, | |
| rdmol=copy.deepcopy(mol), smiles=smiles) | |
| data.nx = to_networkx(data, to_undirected=True) | |
| return data | |
| def generated_to_xyz(data): | |
| ptable = Chem.GetPeriodicTable() | |
| num_atoms = data.ligand_context_element.size(0) | |
| xyz = "%d\n\n" % (num_atoms, ) | |
| for i in range(num_atoms): | |
| symb = ptable.GetElementSymbol(data.ligand_context_element[i].item()) | |
| x, y, z = data.ligand_context_pos[i].clone().cpu().tolist() | |
| xyz += "%s %.8f %.8f %.8f\n" % (symb, x, y, z) | |
| return xyz | |
| def generated_to_sdf(data): | |
| xyz = generated_to_xyz(data) | |
| obConversion = openbabel.OBConversion() | |
| obConversion.SetInAndOutFormats("xyz", "sdf") | |
| mol = openbabel.OBMol() | |
| obConversion.ReadString(mol, xyz) | |
| sdf = obConversion.WriteString(mol) | |
| return sdf | |
| def sdf_to_rdmol(sdf): | |
| stream = BytesIO(sdf.encode()) | |
| suppl = Chem.ForwardSDMolSupplier(stream) | |
| for mol in suppl: | |
| return mol | |
| return None | |
| def generated_to_rdmol(data): | |
| sdf = generated_to_sdf(data) | |
| return sdf_to_rdmol(sdf) | |
| def filter_rd_mol(rdmol): | |
| ring_info = rdmol.GetRingInfo() | |
| ring_info.AtomRings() | |
| rings = [set(r) for r in ring_info.AtomRings()] | |
| # 3-3 ring intersection | |
| for i, ring_a in enumerate(rings): | |
| if len(ring_a) != 3:continue | |
| for j, ring_b in enumerate(rings): | |
| if i <= j: continue | |
| inter = ring_a.intersection(ring_b) | |
| if (len(ring_b) == 3) and (len(inter) > 0): | |
| return False | |
| return True |