molecular / molecule /predict.py
ivanm151's picture
init
6796365
from .model import ModelWrapper
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem, DataStructs
import shap
def smiles_to_ecfp(smiles, radius=2, n_bits=1024):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return np.zeros(n_bits)
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
arr = np.zeros(n_bits, dtype=int)
DataStructs.ConvertToNumpyArray(fp, arr)
return arr
models = [
ModelWrapper("solubility.pth"),
ModelWrapper("logp.pth"),
ModelWrapper("clintox.pth"),
ModelWrapper("fdaapprov.pth"),
ModelWrapper("cardiotoxicity.pth"),
]
def solubility(X):
try:
X = smiles_to_ecfp(X)
X = np.asarray(X, dtype=float)
return models[0].model.predict([X]).item()
except Exception as e:
print(e)
return 0
def logp(X):
try:
X = smiles_to_ecfp(X)
X = np.asarray(X, dtype=float)
return models[1].model.predict([X]).item()
except Exception as e:
print(e)
return 0
def clintox(X):
try:
X = smiles_to_ecfp(X)
X = np.asarray(X, dtype=float)
return models[2].model.predict([X]).item()
except Exception as e:
print(e)
return 0
def fdaapprov(X):
try:
X = smiles_to_ecfp(X)
X = np.asarray(X, dtype=float)
return models[3].model.predict([X]).item()
except Exception as e:
print(e)
return 0
def cardiotoxicity(X):
try:
X = smiles_to_ecfp(X)
X = np.asarray(X, dtype=float)
return models[4].model.predict([X]).item()
except Exception as e:
print(e)
return 0
def solubility_shap(X, model_wrapper=models[0]):
"""
Возвращает предсказание растворимости + данные для фронтенда:
atom_shap
"""
try:
# 1. Morgan FP + bitInfo
mol = Chem.MolFromSmiles(X)
if mol is None:
return {"pred": 0, "atom_shap": [], "fp": [], "bitInfo": {}, "shap_values_bits": []}
bitInfo = {}
fp_vect = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=1024, bitInfo=bitInfo)
fp = np.zeros(1024, dtype=int)
DataStructs.ConvertToNumpyArray(fp_vect, fp)
# 2. Предсказание модели
X_input = np.asarray(fp, dtype=float).reshape(1,-1)
pred = model_wrapper.model.predict(X_input).item()
# 3. SHAP
if not hasattr(model_wrapper, "shap_explainer"):
# создаем explainer один раз
model_wrapper.shap_explainer = shap.TreeExplainer(model_wrapper.model)
shap_vals_bits = model_wrapper.shap_explainer.shap_values(X_input)[0]
# 4. Mapping SHAP -> атомы
atom_scores = np.zeros(mol.GetNumAtoms(), dtype=float)
for bit, val in enumerate(shap_vals_bits):
if bit in bitInfo:
atoms = [a for (a,r) in bitInfo[bit]]
for a in atoms:
atom_scores[a] += val
return {
"pred": pred,
"atom_shap": atom_scores.tolist()
}
except Exception as e:
print(e)
return {"pred": 0, "atom_shap": []}
def logp_shap(X, model_wrapper=models[1]):
"""
Возвращает предсказание растворимости + данные для фронтенда:
atom_shap
"""
try:
# 1. Morgan FP + bitInfo
mol = Chem.MolFromSmiles(X)
if mol is None:
return {"pred": 0, "atom_shap": [], "fp": [], "bitInfo": {}, "shap_values_bits": []}
bitInfo = {}
fp_vect = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=1024, bitInfo=bitInfo)
fp = np.zeros(1024, dtype=int)
DataStructs.ConvertToNumpyArray(fp_vect, fp)
# 2. Предсказание модели
X_input = np.asarray(fp, dtype=float).reshape(1,-1)
pred = model_wrapper.model.predict(X_input).item()
# 3. SHAP
if not hasattr(model_wrapper, "shap_explainer"):
# создаем explainer один раз
model_wrapper.shap_explainer = shap.TreeExplainer(model_wrapper.model)
shap_vals_bits = model_wrapper.shap_explainer.shap_values(X_input)[0]
# 4. Mapping SHAP -> атомы
atom_scores = np.zeros(mol.GetNumAtoms(), dtype=float)
for bit, val in enumerate(shap_vals_bits):
if bit in bitInfo:
atoms = [a for (a,r) in bitInfo[bit]]
for a in atoms:
atom_scores[a] += val
return {
"pred": pred,
"atom_shap": atom_scores.tolist()
}
except Exception as e:
print(e)
return {"pred": 0, "atom_shap": []}
def clintox_shap(X, model_wrapper=models[2]):
"""
Возвращает предсказание растворимости + данные для фронтенда:
atom_shap
"""
try:
# 1. Morgan FP + bitInfo
mol = Chem.MolFromSmiles(X)
if mol is None:
return {"pred": 0, "atom_shap": [], "fp": [], "bitInfo": {}, "shap_values_bits": []}
bitInfo = {}
fp_vect = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=1024, bitInfo=bitInfo)
fp = np.zeros(1024, dtype=int)
DataStructs.ConvertToNumpyArray(fp_vect, fp)
# 2. Предсказание модели
X_input = np.asarray(fp, dtype=float).reshape(1,-1)
pred = model_wrapper.model.predict(X_input).item()
# 3. SHAP
if not hasattr(model_wrapper, "shap_explainer"):
# создаем explainer один раз
model_wrapper.shap_explainer = shap.TreeExplainer(model_wrapper.model)
shap_vals_bits = model_wrapper.shap_explainer.shap_values(X_input)[0]
# 4. Mapping SHAP -> атомы
atom_scores = np.zeros(mol.GetNumAtoms(), dtype=float)
for bit, val in enumerate(shap_vals_bits):
if bit in bitInfo:
atoms = [a for (a,r) in bitInfo[bit]]
for a in atoms:
atom_scores[a] += val
return {
"pred": pred,
"atom_shap": atom_scores.tolist()
}
except Exception as e:
print(e)
return {"pred": 0, "atom_shap": []}
def fdaapprov_shap(X, model_wrapper=models[3]):
"""
Возвращает предсказание растворимости + данные для фронтенда:
atom_shap
"""
try:
# 1. Morgan FP + bitInfo
mol = Chem.MolFromSmiles(X)
if mol is None:
return {"pred": 0, "atom_shap": [], "fp": [], "bitInfo": {}, "shap_values_bits": []}
bitInfo = {}
fp_vect = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=1024, bitInfo=bitInfo)
fp = np.zeros(1024, dtype=int)
DataStructs.ConvertToNumpyArray(fp_vect, fp)
# 2. Предсказание модели
X_input = np.asarray(fp, dtype=float).reshape(1,-1)
pred = model_wrapper.model.predict(X_input).item()
# 3. SHAP
if not hasattr(model_wrapper, "shap_explainer"):
# создаем explainer один раз
model_wrapper.shap_explainer = shap.TreeExplainer(model_wrapper.model)
shap_vals_bits = model_wrapper.shap_explainer.shap_values(X_input)[0]
# 4. Mapping SHAP -> атомы
atom_scores = np.zeros(mol.GetNumAtoms(), dtype=float)
for bit, val in enumerate(shap_vals_bits):
if bit in bitInfo:
atoms = [a for (a,r) in bitInfo[bit]]
for a in atoms:
atom_scores[a] += val
return {
"pred": pred,
"atom_shap": atom_scores.tolist()
}
except Exception as e:
print(e)
return {"pred": 0, "atom_shap": []}
def cardiotoxicity_shap(X, model_wrapper=models[4]):
"""
Возвращает предсказание растворимости + данные для фронтенда:
atom_shap
"""
try:
# 1. Morgan FP + bitInfo
mol = Chem.MolFromSmiles(X)
if mol is None:
return {"pred": 0, "atom_shap": [], "fp": [], "bitInfo": {}, "shap_values_bits": []}
bitInfo = {}
fp_vect = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=1024, bitInfo=bitInfo)
fp = np.zeros(1024, dtype=int)
DataStructs.ConvertToNumpyArray(fp_vect, fp)
# 2. Предсказание модели
X_input = np.asarray(fp, dtype=float).reshape(1,-1)
pred = model_wrapper.model.predict(X_input).item()
# 3. SHAP
if not hasattr(model_wrapper, "shap_explainer"):
# создаем explainer один раз
model_wrapper.shap_explainer = shap.TreeExplainer(model_wrapper.model)
shap_vals_bits = model_wrapper.shap_explainer.shap_values(X_input)[0]
# 4. Mapping SHAP -> атомы
atom_scores = np.zeros(mol.GetNumAtoms(), dtype=float)
for bit, val in enumerate(shap_vals_bits):
if bit in bitInfo:
atoms = [a for (a,r) in bitInfo[bit]]
for a in atoms:
atom_scores[a] += val
return {
"pred": pred,
"atom_shap": atom_scores.tolist()
}
except Exception as e:
print(e)
return {"pred": 0, "atom_shap": []}
property_predictors = {
"solubility": solubility,
"logp": logp,
"clintox": clintox,
"fdaapprov": fdaapprov,
"cardiotoxicity": cardiotoxicity,
}
property_predictors_shap = {
"solubility": solubility_shap,
"logp": logp_shap,
"clintox": clintox_shap,
"fdaapprov": fdaapprov_shap,
"cardiotoxicity": cardiotoxicity_shap,
}
def predict(X, shap=False):
props = {}
try:
if shap:
for property in property_predictors_shap.keys():
props[property] = property_predictors_shap[property](X)
return props
else:
for property in property_predictors.keys():
props[property] = property_predictors[property](X)
return props
except Exception as e:
print(e)
return None