| from rdkit import Chem |
| from rdkit import RDConfig |
| from rdkit.Chem import Descriptors |
| from rdkit.Chem import FragmentCatalog |
| from rdkit import DataStructs |
| from rdkit.Chem import AllChem |
| import os |
| import pandas as pd |
| import numpy as np |
| from tqdm import tqdm |
| import torch |
|
|
|
|
| def smiles_to_fingerprint(smiles, n_bits=2048): |
| mol = Chem.MolFromSmiles(smiles) |
| return AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=n_bits) |
|
|
| def calculate_similarity(smiles1, smiles2): |
| fp1 = smiles_to_fingerprint(smiles1) |
| fp2 = smiles_to_fingerprint(smiles2) |
| return DataStructs.TanimotoSimilarity(fp1, fp2) |
|
|
|
|
| def fingerprints_to_tensor(fps): |
| return torch.tensor([list(fp) for fp in fps], dtype=torch.float32) |
|
|
| def calculate_novelty(new_smiles_list): |
| data = pd.read_csv("./data/sources/zinc250k/zinc250k_selfies.csv") |
| known_smiles_list = data["smiles"].tolist() |
| known_fps = fingerprints_to_tensor([smiles_to_fingerprint(smiles) for smiles in known_smiles_list]) |
| new_fps = fingerprints_to_tensor([smiles_to_fingerprint(smiles) for smiles in new_smiles_list]) |
| |
| dot_product = torch.mm(new_fps, known_fps.t()) |
| norm_new = new_fps.sum(dim=1).unsqueeze(1) |
| norm_known = known_fps.sum(dim=1).unsqueeze(0) |
| similarity_matrix = dot_product / (norm_new + norm_known - dot_product) |
| |
| max_similarities, _ = similarity_matrix.max(dim=1) |
| novelties = 1 - max_similarities |
| |
| return novelties.cpu().numpy() |
|
|
|
|
| def mol_prop(mol, prop): |
| try: |
| mol = Chem.MolFromSmiles(mol) |
| except: |
| return None |
| |
| if mol is None: |
| return None |
| |
| |
| if prop == 'logP': |
| return Descriptors.MolLogP(mol) |
| elif prop == 'weight': |
| return Descriptors.MolWt(mol) |
| elif prop == 'qed': |
| return Descriptors.qed(mol) |
| elif prop == 'TPSA': |
| return Descriptors.TPSA(mol) |
| elif prop == 'HBA': |
| return Descriptors.NumHAcceptors(mol) |
| elif prop == 'HBD': |
| return Descriptors.NumHDonors(mol) |
| elif prop == 'rot_bonds': |
| return Descriptors.NumRotatableBonds(mol) |
| elif prop == 'ring_count': |
| return Descriptors.RingCount(mol) |
| elif prop == 'mr': |
| return Descriptors.MolMR(mol) |
| elif prop == 'balabanJ': |
| return Descriptors.BalabanJ(mol) |
| elif prop == 'hall_kier_alpha': |
| return Descriptors.HallKierAlpha(mol) |
| elif prop == 'logD': |
| return Descriptors.MolLogP(mol) |
| elif prop == 'MR': |
| return Descriptors.MolMR(mol) |
|
|
| |
| elif prop == 'validity': |
| |
| return True |
| |
| |
| elif prop == 'num_single_bonds': |
| return sum([bond.GetBondType() == Chem.rdchem.BondType.SINGLE for bond in mol.GetBonds()]) |
| elif prop == 'num_double_bonds': |
| return sum([bond.GetBondType() == Chem.rdchem.BondType.DOUBLE for bond in mol.GetBonds()]) |
| elif prop == 'num_triple_bonds': |
| return sum([bond.GetBondType() == Chem.rdchem.BondType.TRIPLE for bond in mol.GetBonds()]) |
| elif prop == 'num_aromatic_bonds': |
| return sum([bond.GetBondType() == Chem.rdchem.BondType.AROMATIC for bond in mol.GetBonds()]) |
| elif prop == 'num_rotatable_bonds': |
| return Descriptors.NumRotatableBonds(mol) |
|
|
| |
| |
| elif prop == 'num_carbon': |
| return sum([atom.GetAtomicNum() == 6 for atom in mol.GetAtoms()]) |
| elif prop == 'num_nitrogen': |
| return sum([atom.GetAtomicNum() == 7 for atom in mol.GetAtoms()]) |
| elif prop == 'num_oxygen': |
| return sum([atom.GetAtomicNum() == 8 for atom in mol.GetAtoms()]) |
| elif prop == 'num_fluorine': |
| return sum([atom.GetAtomicNum() == 9 for atom in mol.GetAtoms()]) |
| elif prop == 'num_phosphorus': |
| return sum([atom.GetAtomicNum() == 15 for atom in mol.GetAtoms()]) |
| elif prop == 'num_sulfur': |
| return sum([atom.GetAtomicNum() == 16 for atom in mol.GetAtoms()]) |
| elif prop == 'num_chlorine': |
| return sum([atom.GetAtomicNum() == 17 for atom in mol.GetAtoms()]) |
| elif prop == 'num_bromine': |
| return sum([atom.GetAtomicNum() == 35 for atom in mol.GetAtoms()]) |
| elif prop == 'num_iodine': |
| return sum([atom.GetAtomicNum() == 53 for atom in mol.GetAtoms()]) |
| elif prop == "num_boron": |
| return sum([atom.GetAtomicNum() == 5 for atom in mol.GetAtoms()]) |
| elif prop == "num_silicon": |
| return sum([atom.GetAtomicNum() == 14 for atom in mol.GetAtoms()]) |
| elif prop == "num_selenium": |
| return sum([atom.GetAtomicNum() == 34 for atom in mol.GetAtoms()]) |
| elif prop == "num_tellurium": |
| return sum([atom.GetAtomicNum() == 52 for atom in mol.GetAtoms()]) |
| elif prop == "num_arsenic": |
| return sum([atom.GetAtomicNum() == 33 for atom in mol.GetAtoms()]) |
| elif prop == "num_antimony": |
| return sum([atom.GetAtomicNum() == 51 for atom in mol.GetAtoms()]) |
| elif prop == "num_bismuth": |
| return sum([atom.GetAtomicNum() == 83 for atom in mol.GetAtoms()]) |
| elif prop == "num_polonium": |
| return sum([atom.GetAtomicNum() == 84 for atom in mol.GetAtoms()]) |
| |
| |
| elif prop == "num_benzene_ring": |
| smarts = '[cR1]1[cR1][cR1][cR1][cR1][cR1]1' |
| matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts)) |
| return len(matches) |
| elif prop == "num_hydroxyl": |
| smarts = '[OX2H]' |
| matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts)) |
| return len(matches) |
| elif prop == "num_anhydride": |
| smarts = '[CX3](=[OX1])[OX2][CX3](=[OX1])' |
| matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts)) |
| return len(matches) |
| elif prop == "num_aldehyde": |
| smarts = '[CX3H1](=O)[#6]' |
| matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts)) |
| return len(matches) |
| elif prop == "num_ketone": |
| smarts = '[#6][CX3](=O)[#6]' |
| matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts)) |
| return len(matches) |
| elif prop == "num_carboxyl": |
| smarts = '[CX3](=O)[OX2H1]' |
| matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts)) |
| return len(matches) |
| elif prop == "num_ester": |
| smarts = '[#6][CX3](=O)[OX2H0][#6]' |
| matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts)) |
| return len(matches) |
| elif prop == "num_amide": |
| smarts = '[NX3][CX3](=[OX1])[#6]' |
| matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts)) |
| return len(matches) |
| elif prop == "num_amine": |
| smarts = '[NX3;H2,H1;!$(NC=O)]' |
| matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts)) |
| return len(matches) |
| elif prop == "num_nitro": |
| smarts = '[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]' |
| matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts)) |
| return len(matches) |
| elif prop == "num_halo": |
| smarts = '[F,Cl,Br,I]' |
| matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts)) |
| return len(matches) |
| elif prop == "num_thioether": |
| smarts = '[SX2][CX4]' |
| matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts)) |
| return len(matches) |
| elif prop == "num_nitrile": |
| smarts = '[NX1]#[CX2]' |
| matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts)) |
| return len(matches) |
| elif prop == "num_thiol": |
| smarts = '[#16X2H]' |
| matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts)) |
| return len(matches) |
| elif prop == "num_sulfide": |
| smarts = '[#16X2H0]' |
| matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts)) |
| exception = '[#16X2H0][#16X2H0]' |
| matches_exception = mol.GetSubstructMatches(Chem.MolFromSmarts(exception)) |
| return len(matches) - len(matches_exception) |
| elif prop == "num_disulfide": |
| smarts = '[#16X2H0][#16X2H0]' |
| matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts)) |
| return len(matches) |
| elif prop == "num_sulfoxide": |
| smarts = '[$([#16X3]=[OX1]),$([#16X3+][OX1-])]' |
| matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts)) |
| return len(matches) |
| elif prop == "num_sulfone": |
| smarts = '[$([#16X4](=[OX1])=[OX1]),$([#16X4+2]([OX1-])[OX1-])]' |
| matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts)) |
| return len(matches) |
| elif prop == "num_borane": |
| smarts = '[BX3]' |
| matches = mol.GetSubstructMatches(Chem.MolFromSmarts(smarts)) |
| return len(matches) |
|
|
| else: |
| raise ValueError(f'Property {prop} not supported') |
|
|
| def calculate_basic_property(smiles, prop): |
|
|
| if prop == "heavy" or prop == "light": |
| if mol_prop(smiles, 'weight') > 250: |
| return "heavy" == prop |
| else: |
| return "light" == prop |
| |
| elif prop == "complex" or prop == "simple": |
| if mol_prop(smiles, 'ring_count') > 3: |
| if prop == "complex": |
| return True |
| if mol_prop(smiles, 'rot_bonds') > 3: |
| if prop == "complex": |
| return True |
| if mol_prop(smiles, 'num_carbon') > 20: |
| if prop == "complex": |
| return True |
| else: |
| return False |
| else: |
| if prop == "complex": |
| return False |
| else: |
| return True |
| |
| elif prop in ["toxic", "non-toxic", "toxicity", "non-toxicity"]: |
| pass |
| elif prop in ["high-boiling", "low-boiling", "high boiling point", "low bioling point"]: |
| pass |
| elif prop in ["high-melting", "low-melting", "high melting point", "low melting point"]: |
| pass |
| elif prop in ["water-soluble", "water-insoluble", "soluble in water", "insoluble in water"]: |
| pass |
| else: |
| raise ValueError(f'Property {prop} not supported') |
| |
|
|
|
|
| def eval_moledit_add_component(data, target): |
| """ |
| data: pd |
| target: list(smiles_str) |
| """ |
| valid_molecules = [] |
| successed = [] |
| similarities = [] |
| for idx in tqdm(range(len(data))): |
| raw = data["molecule"][idx] |
| group = data["added_group"][idx] |
| if group == "benzene ring": |
| group = "benzene_ring" |
| target_mol = target[idx] |
| if mol_prop(target_mol, "validity"): |
| valid_molecules.append(target_mol) |
|
|
| if mol_prop(target_mol, "num_" + group) == mol_prop(raw, "num_" + group) + 1: |
| successed.append(1) |
| else: |
| successed.append(0) |
|
|
| similarities.append(calculate_similarity(raw, target_mol)) |
| else: |
| successed.append(0) |
| |
| res = { |
| "success_rate": sum(successed) / len(successed), |
| "similarity": sum(similarities) / len(similarities), |
| "validty": len(valid_molecules) / len(data) |
| } |
|
|
| return res |
|
|
| def eval_moledit_delete_component(data, target): |
| """ |
| data: pd |
| target: list(smiles_str) |
| """ |
| valid_molecules = [] |
| successed = [] |
| similarities = [] |
| for idx in tqdm(range(len(data))): |
| raw = data["molecule"][idx] |
| group = data["removed_group"][idx] |
| if group == "benzene ring": |
| group = "benzene_ring" |
| target_mol = target[idx] |
| if mol_prop(target_mol, "validity"): |
| valid_molecules.append(target_mol) |
|
|
| if mol_prop(target_mol, "num_" + group) == mol_prop(raw, "num_" + group) - 1: |
| successed.append(1) |
| else: |
| successed.append(0) |
|
|
| similarities.append(calculate_similarity(raw, target_mol)) |
| else: |
| successed.append(0) |
| |
| res = { |
| "success_rate": sum(successed) / len(successed), |
| "similarity": sum(similarities) / len(similarities), |
| "validty": len(valid_molecules) / len(data) |
| } |
|
|
| return res |
|
|
| def eval_moledit_sub_component(data, target): |
| """ |
| data: pd |
| target: list(smiles_str) |
| """ |
| valid_molecules = [] |
| successed = [] |
| similarities = [] |
| for idx in tqdm(range(len(data))): |
| raw = data["molecule"][idx] |
| added_group = data["added_group"][idx] |
| removed_group = data["removed_group"][idx] |
| if added_group == "benzene ring": |
| added_group = "benzene_ring" |
| if removed_group == "benzene ring": |
| removed_group = "benzene_ring" |
|
|
| target_mol = target[idx] |
|
|
| if mol_prop(target_mol, "validity"): |
| valid_molecules.append(target_mol) |
|
|
| if mol_prop(target_mol, "num_" + removed_group) == mol_prop(raw, "num_" + removed_group) - 1 and mol_prop(target_mol, "num_" + added_group) == mol_prop(raw, "num_" + added_group) + 1: |
| successed.append(1) |
| else: |
| successed.append(0) |
|
|
| similarities.append(calculate_similarity(raw, target_mol)) |
| else: |
| successed.append(0) |
| |
| res = { |
| "success_rate": sum(successed) / len(successed), |
| "similarity": sum(similarities) / len(similarities), |
| "validty": len(valid_molecules) / len(data) |
| } |
|
|
| return res |
|
|
|
|
| def eval_molopt_logP(data, target): |
| """ |
| data: pd |
| target: list(smiles_str) |
| """ |
| valid_molecules = [] |
| successed = [] |
| similarities = [] |
| for idx in tqdm(range(len(data))): |
| raw = data["molecule"][idx] |
| target_mol = target[idx] |
| instruction = data["Instruction"][idx] |
| if mol_prop(target_mol, "validity"): |
| valid_molecules.append(target_mol) |
| similarities.append(calculate_similarity(raw, target_mol)) |
| if "lower" in instruction or "decrease" in instruction: |
| if mol_prop(target_mol, "logP") < mol_prop(raw, "logP"): |
| successed.append(1) |
| else: |
| successed.append(0) |
| else: |
| if mol_prop(target_mol, "logP") > mol_prop(raw, "logP"): |
| successed.append(1) |
| else: |
| successed.append(0) |
| else: |
| successed.append(0) |
| |
| res = { |
| "success_rate": sum(successed) / len(successed), |
| "similarity": sum(similarities) / len(similarities), |
| "validty": len(valid_molecules) / len(data) |
| } |
|
|
| return res |
|
|
|
|
| def eval_molopt_MR(data, target): |
| """ |
| data: pd |
| target: list(smiles_str) |
| """ |
| valid_molecules = [] |
| successed = [] |
| similarities = [] |
| for idx in tqdm(range(len(data))): |
| raw = data["molecule"][idx] |
| target_mol = target[idx] |
| instruction = data["Instruction"][idx] |
| if mol_prop(target_mol, "validity"): |
| valid_molecules.append(target_mol) |
| similarities.append(calculate_similarity(raw, target_mol)) |
| if "lower" in instruction or "decrease" in instruction: |
| if mol_prop(target_mol, "MR") < mol_prop(raw, "MR"): |
| successed.append(1) |
| else: |
| successed.append(0) |
| else: |
| if mol_prop(target_mol, "MR") > mol_prop(raw, "MR"): |
| successed.append(1) |
| else: |
| successed.append(0) |
| else: |
| successed.append(0) |
| |
| res = { |
| "success_rate": sum(successed) / len(successed), |
| "similarity": sum(similarities) / len(similarities), |
| "validty": len(valid_molecules) / len(data) |
| } |
|
|
| return res |
|
|
| def eval_molopt_QED(data, target): |
| """ |
| data: pd |
| target: list(smiles_str) |
| """ |
| valid_molecules = [] |
| successed = [] |
| similarities = [] |
| for idx in tqdm(range(len(data))): |
| raw = data["molecule"][idx] |
| target_mol = target[idx] |
| instruction = data["Instruction"][idx] |
| if mol_prop(target_mol, "validity"): |
| valid_molecules.append(target_mol) |
| similarities.append(calculate_similarity(raw, target_mol)) |
| if "lower" in instruction or "decrease" in instruction: |
| if mol_prop(target_mol, "qed") < mol_prop(raw, "qed"): |
| successed.append(1) |
| else: |
| successed.append(0) |
| else: |
| if mol_prop(target_mol, "qed") > mol_prop(raw, "qed"): |
| successed.append(1) |
| else: |
| successed.append(0) |
| else: |
| successed.append(0) |
| |
| res = { |
| "success_rate": sum(successed) / len(successed), |
| "similarity": sum(similarities) / len(similarities), |
| "validty": len(valid_molecules) / len(data) |
| } |
|
|
| return res |