|
|
import torch |
|
|
import warnings |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from typing import List |
|
|
from rdkit import Chem, rdBase, DataStructs |
|
|
import pickle |
|
|
import gzip |
|
|
from rdkit.Chem import AllChem, Descriptors |
|
|
|
|
|
import math |
|
|
import os.path as op |
|
|
from rdkit.Chem import rdMolDescriptors |
|
|
|
|
|
rdBase.DisableLog('rdApp.error') |
|
|
warnings.filterwarnings("ignore", category=DeprecationWarning) |
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
|
|
|
def fingerprints_from_mol(molecule, radius=3, size=2048, hashed=False): |
|
|
""" Create ECFP fingerprint of a molecule """ |
|
|
if hashed: |
|
|
fp_bits = AllChem.GetHashedMorganFingerprint(molecule, radius, nBits=size) |
|
|
else: |
|
|
fp_bits = AllChem.GetMorganFingerprintAsBitVect(molecule, radius, nBits=size) |
|
|
fp_np = np.zeros((1,)) |
|
|
DataStructs.ConvertToNumpyArray(fp_bits, fp_np) |
|
|
return fp_np.reshape(1, -1) |
|
|
|
|
|
def average_agg_tanimoto(stock_vecs, gen_vecs, batch_size=5000, agg='max', device='cuda', p=1): |
|
|
""" |
|
|
For each molecule in gen_vecs finds closest molecule in stock_vecs. |
|
|
Returns average tanimoto score for between these molecules |
|
|
|
|
|
Parameters: |
|
|
stock_vecs: numpy array <n_vectors x dim> |
|
|
gen_vecs: numpy array <n_vectors' x dim> |
|
|
agg: max or mean |
|
|
p: power for averaging: (mean x^p)^(1/p) |
|
|
""" |
|
|
assert agg in ['max', 'mean'], "Can aggregate only max or mean" |
|
|
agg_tanimoto = np.zeros(len(gen_vecs)) |
|
|
total = np.zeros(len(gen_vecs)) |
|
|
for j in range(0, stock_vecs.shape[0], batch_size): |
|
|
x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float() |
|
|
for i in range(0, gen_vecs.shape[0], batch_size): |
|
|
y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float() |
|
|
y_gen = y_gen.transpose(0, 1) |
|
|
tp = torch.mm(x_stock, y_gen) |
|
|
jac = (tp / (x_stock.sum(1, keepdim=True) + y_gen.sum(0, keepdim=True) - tp)).cpu().numpy() |
|
|
jac[np.isnan(jac)] = 1 |
|
|
if p != 1: |
|
|
jac = jac**p |
|
|
if agg == 'max': |
|
|
agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0)) |
|
|
elif agg == 'mean': |
|
|
agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0) |
|
|
total[i:i + y_gen.shape[1]] += jac.shape[0] |
|
|
if agg == 'mean': |
|
|
agg_tanimoto /= total |
|
|
if p != 1: |
|
|
agg_tanimoto = (agg_tanimoto)**(1/p) |
|
|
return np.mean(agg_tanimoto) |
|
|
|
|
|
def get_mol(smiles_or_mol): |
|
|
''' |
|
|
Loads SMILES/molecule into RDKit's object |
|
|
''' |
|
|
if isinstance(smiles_or_mol, str): |
|
|
if len(smiles_or_mol) == 0: |
|
|
return None |
|
|
mol = Chem.MolFromSmiles(smiles_or_mol) |
|
|
if mol is None: |
|
|
return None |
|
|
try: |
|
|
Chem.SanitizeMol(mol) |
|
|
except ValueError: |
|
|
return None |
|
|
return mol |
|
|
return smiles_or_mol |
|
|
|
|
|
def canonic_smiles(smiles_or_mol): |
|
|
mol = get_mol(smiles_or_mol) |
|
|
if mol is None: |
|
|
return None |
|
|
return Chem.MolToSmiles(mol) |
|
|
|
|
|
class SAScorer: |
|
|
def __init__(self, model_path='/home/st512/peptune/scripts/peptide-mdlm-mcts/utils/sascore/SA_score_prediction.pkl.gz', input_type='smiles'): |
|
|
self.clf = pickle.load(gzip.open(model_path, "rb")) |
|
|
self.input_type = 'smiles' |
|
|
|
|
|
def __call__(self, smiles_file): |
|
|
df = pd.read_csv(smiles_file) |
|
|
smiles = df["SMILES"].tolist() |
|
|
scores = self.get_scores(smiles) |
|
|
return scores, scores |
|
|
|
|
|
@staticmethod |
|
|
def _get_descriptors_from_smiles(smiles: List, radius=3, size=4096): |
|
|
""" |
|
|
Add fingerprints together with SAscore and molecular weights |
|
|
""" |
|
|
fps = [] |
|
|
valid_mask = [] |
|
|
for i, smile in enumerate(smiles): |
|
|
mol = Chem.MolFromSmiles(smile) if smile is not None else None |
|
|
valid_mask.append(int(mol is not None)) |
|
|
fp = fingerprints_from_mol(mol, radius, size=size) if mol else np.zeros((1, size)) |
|
|
others = np.array([calculateScore(mol), Descriptors.ExactMolWt(mol)]) if mol else np.zeros(2) |
|
|
prop_np = np.concatenate([others.T, fp.T[:, 0]]) |
|
|
fps.append(prop_np) |
|
|
|
|
|
return fps, valid_mask |
|
|
|
|
|
def get_scores(self, smiles: List, valid_only=False): |
|
|
descriptors, valid_mask = self._get_descriptors_from_smiles(smiles) |
|
|
scores = self.clf.predict_proba(descriptors)[:, 1] |
|
|
if valid_only: |
|
|
return np.float32([scores[i] for i in range(len(scores)) if valid_mask[i]]) |
|
|
return np.float32(scores * np.array(valid_mask)) |
|
|
|
|
|
|
|
|
class Metrics: |
|
|
|
|
|
def __init__(self, prior_path='/scratch/pranamlab/tong/data/smiles/30K_all.csv', n_jobs=100, input_type='smiles'): |
|
|
train_set_cano_smi = pd.read_csv(prior_path)['SMILES'].astype(str).tolist() |
|
|
|
|
|
|
|
|
for smi in train_set_cano_smi: |
|
|
mol = Chem.MolFromSmiles(smi) |
|
|
|
|
|
if mol is None: |
|
|
print(f"Invalid SMILES: {smi}") |
|
|
|
|
|
self.train_set_cano_smi = train_set_cano_smi |
|
|
self.n_jobs = n_jobs |
|
|
|
|
|
self.ref_fps = np.vstack([ |
|
|
fingerprints_from_mol(Chem.MolFromSmiles(smi)) |
|
|
for smi in train_set_cano_smi |
|
|
if Chem.MolFromSmiles(smi) is not None |
|
|
]) |
|
|
|
|
|
|
|
|
def get_metrics(self, generated_path): |
|
|
generated_smi = pd.read_csv(generated_path)['SMILES'].astype(str).tolist() |
|
|
|
|
|
|
|
|
mols = [Chem.MolFromSmiles(smi) if smi else None for smi in generated_smi] |
|
|
is_valid = [1 if mol else 0 for mol in mols] |
|
|
validity = sum(is_valid) / len(is_valid) |
|
|
|
|
|
valid_canon_smiles = [Chem.MolToSmiles(mol) for mol in mols if mol] |
|
|
uniqueness = len(set(valid_canon_smiles)) / len(valid_canon_smiles) |
|
|
|
|
|
uniq_smis = list(set(valid_canon_smiles)) |
|
|
uniq_mols = [Chem.MolFromSmiles(smi) for smi in uniq_smis] |
|
|
|
|
|
fps = np.vstack([fingerprints_from_mol(mol) for mol in uniq_mols]) |
|
|
diversity = 1 - (average_agg_tanimoto(fps, fps, agg='mean', p=1)).mean() |
|
|
|
|
|
snn = average_agg_tanimoto(self.ref_fps, fps, agg='max', p=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"validity\tuniqueness\tdiversity\tsnn") |
|
|
print(f"{validity:.3f}\t{uniqueness:.3f}\t{diversity:.3f}\t{snn:.3f}") |
|
|
return { |
|
|
"validity": validity, |
|
|
"uniqueness": uniqueness, |
|
|
"diversity": diversity, |
|
|
"snn": snn, |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
def readFragmentScores(name='fpscores'): |
|
|
import gzip |
|
|
global _fscores |
|
|
|
|
|
if name == "fpscores": |
|
|
name = op.join(op.dirname(__file__), name) |
|
|
data = pickle.load(gzip.open('%s.pkl.gz' % name)) |
|
|
outDict = {} |
|
|
for i in data: |
|
|
for j in range(1, len(i)): |
|
|
outDict[i[j]] = float(i[0]) |
|
|
_fscores = outDict |
|
|
|
|
|
|
|
|
def numBridgeheadsAndSpiro(mol, ri=None): |
|
|
nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) |
|
|
nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) |
|
|
return nBridgehead, nSpiro |
|
|
|
|
|
|
|
|
def calculateScore(m): |
|
|
if _fscores is None: |
|
|
readFragmentScores() |
|
|
|
|
|
|
|
|
fp = rdMolDescriptors.GetMorganFingerprint(m, 2) |
|
|
fps = fp.GetNonzeroElements() |
|
|
score1 = 0. |
|
|
nf = 0 |
|
|
for bitId, v in fps.items(): |
|
|
nf += v |
|
|
sfp = bitId |
|
|
score1 += _fscores.get(sfp, -4) * v |
|
|
try: |
|
|
score1 /= nf |
|
|
except ZeroDivisionError: |
|
|
score1 = 1 |
|
|
|
|
|
|
|
|
nAtoms = m.GetNumAtoms() |
|
|
nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) |
|
|
ri = m.GetRingInfo() |
|
|
nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) |
|
|
nMacrocycles = 0 |
|
|
for x in ri.AtomRings(): |
|
|
if len(x) > 8: |
|
|
nMacrocycles += 1 |
|
|
|
|
|
sizePenalty = nAtoms**1.005 - nAtoms |
|
|
stereoPenalty = math.log10(nChiralCenters + 1) |
|
|
spiroPenalty = math.log10(nSpiro + 1) |
|
|
bridgePenalty = math.log10(nBridgeheads + 1) |
|
|
macrocyclePenalty = 0. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if nMacrocycles > 0: |
|
|
macrocyclePenalty = math.log10(2) |
|
|
|
|
|
score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
score3 = 0. |
|
|
if nAtoms > len(fps): |
|
|
score3 = math.log(float(nAtoms) / len(fps)) * .5 |
|
|
|
|
|
sascore = score1 + score2 + score3 |
|
|
|
|
|
|
|
|
min = -4.0 |
|
|
max = 2.5 |
|
|
sascore = 11. - (sascore - min + 1) / (max - min) * 9. |
|
|
|
|
|
if sascore > 8.: |
|
|
sascore = 8. + math.log(sascore + 1. - 9.) |
|
|
if sascore > 10.: |
|
|
sascore = 10.0 |
|
|
elif sascore < 1.: |
|
|
sascore = 1.0 |
|
|
|
|
|
return sascore |
|
|
|