import torch import warnings import numpy as np import pandas as pd from typing import List from rdkit import Chem, rdBase, DataStructs import pickle import gzip from rdkit.Chem import AllChem, Descriptors # from utils.utils import mapper import math import os.path as op from rdkit.Chem import rdMolDescriptors rdBase.DisableLog('rdApp.error') warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=FutureWarning) def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False): """ Create ECFP fingerprint of a molecule """ if hashed: fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size) else: fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size) fp_np = np.zeros((1,)) DataStructs.ConvertToNumpyArray(fp_bits, fp_np) return fp_np.reshape(1, -1) def average_agg_tanimoto(stock_vecs, gen_vecs, batch_size=5000, agg='max', device='cuda', p=1): """ For each molecule in gen_vecs finds closest molecule in stock_vecs. Returns average tanimoto score for between these molecules Parameters: stock_vecs: numpy array gen_vecs: numpy array agg: max or mean p: power for averaging: (mean x^p)^(1/p) """ assert agg in ['max', 'mean'], "Can aggregate only max or mean" agg_tanimoto = np.zeros(len(gen_vecs)) total = np.zeros(len(gen_vecs)) for j in range(0, stock_vecs.shape[0], batch_size): x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float() for i in range(0, gen_vecs.shape[0], batch_size): y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float() y_gen = y_gen.transpose(0, 1) tp = torch.mm(x_stock, y_gen) jac = (tp / (x_stock.sum(1, keepdim=True) + y_gen.sum(0, keepdim=True) - tp)).cpu().numpy() jac[np.isnan(jac)] = 1 if p != 1: jac = jac**p if agg == 'max': agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0)) elif agg == 'mean': agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0) total[i:i + y_gen.shape[1]] += jac.shape[0] if agg == 'mean': agg_tanimoto /= total if p != 1: agg_tanimoto = (agg_tanimoto)**(1/p) return np.mean(agg_tanimoto) def get_mol(smiles_or_mol): ''' Loads SMILES/molecule into RDKit's object ''' if isinstance(smiles_or_mol, str): if len(smiles_or_mol) == 0: return None mol = Chem.MolFromSmiles(smiles_or_mol) if mol is None: return None try: Chem.SanitizeMol(mol) except ValueError: return None return mol return smiles_or_mol def canonic_smiles(smiles_or_mol): mol = get_mol(smiles_or_mol) if mol is None: return None return Chem.MolToSmiles(mol) class SAScorer: def __init__(self, model_path='/home/st512/peptune/scripts/peptide-mdlm-mcts/utils/sascore/SA_score_prediction.pkl.gz', input_type='smiles'): self.clf = pickle.load(gzip.open(model_path, "rb")) self.input_type = 'smiles' def __call__(self, smiles_file): df = pd.read_csv(smiles_file) smiles = df["SMILES"].tolist() scores = self.get_scores(smiles) return scores, scores @staticmethod def _get_descriptors_from_smiles(smiles: List, radius=3, size=4096): # """ Add fingerprints together with SAscore and molecular weights """ fps = [] valid_mask = [] for i, smile in enumerate(smiles): mol = Chem.MolFromSmiles(smile) if smile is not None else None valid_mask.append(int(mol is not None)) fp = fingerprints_from_mol(mol, radius, size=size) if mol else np.zeros((1, size)) others = np.array([calculateScore(mol), Descriptors.ExactMolWt(mol)]) if mol else np.zeros(2) prop_np = np.concatenate([others.T, fp.T[:, 0]]) fps.append(prop_np) return fps, valid_mask def get_scores(self, smiles: List, valid_only=False): descriptors, valid_mask = self._get_descriptors_from_smiles(smiles) scores = self.clf.predict_proba(descriptors)[:, 1] if valid_only: # filter by valid mask return np.float32([scores[i] for i in range(len(scores)) if valid_mask[i]]) return np.float32(scores * np.array(valid_mask)) class Metrics: def __init__(self, prior_path='/scratch/pranamlab/tong/data/smiles/30K_all.csv', n_jobs=100, input_type='smiles'): train_set_cano_smi = pd.read_csv(prior_path)['SMILES'].astype(str).tolist() #print(train_set_cano_smi[:5]) for smi in train_set_cano_smi: mol = Chem.MolFromSmiles(smi) if mol is None: print(f"Invalid SMILES: {smi}") self.train_set_cano_smi = train_set_cano_smi self.n_jobs = n_jobs #self.input_type = 'helm' if input_type != 'smiles' else 'smiles' self.ref_fps = np.vstack([ fingerprints_from_mol(Chem.MolFromSmiles(smi)) for smi in train_set_cano_smi if Chem.MolFromSmiles(smi) is not None ]) #self.ref_fps = np.vstack([fingerprints_from_mol(Chem.MolFromSmiles(smi)) for smi in train_set_cano_smi]) def get_metrics(self, generated_path): generated_smi = pd.read_csv(generated_path)['SMILES'].astype(str).tolist() #generated_smi = pd.read_csv(generated_path, usecols=['Generated SMILES'])['Generated SMILES'].tolist() #mols = [Chem.MolFromSmiles(smi) if smi else None for smi in generated_smi] mols = [Chem.MolFromSmiles(smi) if smi else None for smi in generated_smi] is_valid = [1 if mol else 0 for mol in mols] validity = sum(is_valid) / len(is_valid) valid_canon_smiles = [Chem.MolToSmiles(mol) for mol in mols if mol] uniqueness = len(set(valid_canon_smiles)) / len(valid_canon_smiles) uniq_smis = list(set(valid_canon_smiles)) uniq_mols = [Chem.MolFromSmiles(smi) for smi in uniq_smis] fps = np.vstack([fingerprints_from_mol(mol) for mol in uniq_mols]) diversity = 1 - (average_agg_tanimoto(fps, fps, agg='mean', p=1)).mean() snn = average_agg_tanimoto(self.ref_fps, fps, agg='max', p=1) # gen_smiles = mapper(self.n_jobs)(canonic_smiles, valid_canon_smiles) # gen_smiles_set = set(gen_smiles) - {None} # train_set = set(self.train_set_cano_smi) # novelty = len(gen_smiles_set - train_set) / len(gen_smiles_set) # print(f"validity\tuniqueness\tdiversity\tsnn\tnovelty") # print(f"{validity:.3f}\t{uniqueness:.3f}\t{diversity:.3f}\t{snn:.3f}\t{novelty:.3f}") print(f"validity\tuniqueness\tdiversity\tsnn") print(f"{validity:.3f}\t{uniqueness:.3f}\t{diversity:.3f}\t{snn:.3f}") return { "validity": validity, "uniqueness": uniqueness, "diversity": diversity, "snn": snn, # "structural novelty" # "novelty": novelty, } def readFragmentScores(name='fpscores'): import gzip global _fscores # generate the full path filename: if name == "fpscores": name = op.join(op.dirname(__file__), name) data = pickle.load(gzip.open('%s.pkl.gz' % name)) outDict = {} for i in data: for j in range(1, len(i)): outDict[i[j]] = float(i[0]) _fscores = outDict def numBridgeheadsAndSpiro(mol, ri=None): nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) return nBridgehead, nSpiro def calculateScore(m): if _fscores is None: readFragmentScores() # fragment score fp = rdMolDescriptors.GetMorganFingerprint(m, 2) # <- 2 is the *radius* of the circular fingerprint fps = fp.GetNonzeroElements() score1 = 0. nf = 0 for bitId, v in fps.items(): nf += v sfp = bitId score1 += _fscores.get(sfp, -4) * v try: score1 /= nf except ZeroDivisionError: # where nf is 0 score1 = 1 # features score nAtoms = m.GetNumAtoms() nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) ri = m.GetRingInfo() nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) nMacrocycles = 0 for x in ri.AtomRings(): if len(x) > 8: nMacrocycles += 1 sizePenalty = nAtoms**1.005 - nAtoms stereoPenalty = math.log10(nChiralCenters + 1) spiroPenalty = math.log10(nSpiro + 1) bridgePenalty = math.log10(nBridgeheads + 1) macrocyclePenalty = 0. # --------------------------------------- # This differs from the paper, which defines: # macrocyclePenalty = math.log10(nMacrocycles+1) # This form generates better results when 2 or more macrocycles are present if nMacrocycles > 0: macrocyclePenalty = math.log10(2) score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty # correction for the fingerprint density # not in the original publication, added in version 1.1 # to make highly symmetrical molecules easier to synthetise score3 = 0. if nAtoms > len(fps): score3 = math.log(float(nAtoms) / len(fps)) * .5 sascore = score1 + score2 + score3 # need to transform "raw" value into scale between 1 and 10 min = -4.0 max = 2.5 sascore = 11. - (sascore - min + 1) / (max - min) * 9. # smooth the 10-end if sascore > 8.: sascore = 8. + math.log(sascore + 1. - 9.) if sascore > 10.: sascore = 10.0 elif sascore < 1.: sascore = 1.0 return sascore