|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PreTrainedTokenizer |
|
|
from transformers.tokenization_utils_base import BatchEncoding |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
from rdkit import Chem |
|
|
from rdkit.Chem import Descriptors, AllChem, MACCSkeys |
|
|
from rdkit.ML.Descriptors import MoleculeDescriptors |
|
|
from rdkit import RDLogger |
|
|
from rdkit.Chem import Draw |
|
|
import joblib |
|
|
import numpy as np |
|
|
import os |
|
|
from huggingface_hub import snapshot_download |
|
|
import warnings |
|
|
from sklearn.exceptions import InconsistentVersionWarning |
|
|
from torchvision import models, transforms |
|
|
from PIL import Image |
|
|
warnings.filterwarnings("ignore", category=InconsistentVersionWarning) |
|
|
RDLogger.DisableLog('rdApp.*') |
|
|
|
|
|
class BBBTokenizer(PreTrainedTokenizer): |
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.calc = MoleculeDescriptors.MolecularDescriptorCalculator([i[0] for i in Descriptors.descList]) |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-100M-MLM') |
|
|
self.chemberta = AutoModel.from_pretrained('DeepChem/ChemBERTa-100M-MLM').eval() |
|
|
|
|
|
self.resnet50_backbone = models.resnet50(weights="IMAGENET1K_V1") |
|
|
self.resnet = nn.Sequential(*list(self.resnet50_backbone.children())[:-1]).eval() |
|
|
self.img_preprocess = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize( |
|
|
mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225], |
|
|
) |
|
|
]) |
|
|
|
|
|
self.feature_transformer_tab = None |
|
|
self.feature_transformer_img = None |
|
|
self.feature_transformer_txt = None |
|
|
self.task = None |
|
|
|
|
|
def generate_tab_features(self, smiles): |
|
|
mol = Chem.MolFromSmiles(smiles) |
|
|
|
|
|
if mol is None: |
|
|
return torch.tensor(self.feature_transformer_tab.n_features_in_, dtype=torch.float32) |
|
|
|
|
|
rdkit_2d = np.array(self.calc.CalcDescriptors(mol)) |
|
|
rdkit_2d[np.isinf(rdkit_2d)] = np.nan |
|
|
rdkit_2d = np.nan_to_num(rdkit_2d, nan=0.0, posinf=0.0, neginf=0.0) |
|
|
maccs = np.array(list(MACCSkeys.GenMACCSKeys(mol).ToBitString()), dtype=int) |
|
|
tab_input = np.concatenate([rdkit_2d, maccs]) |
|
|
tab_input = self.feature_transformer_tab.transform(tab_input.reshape(1, -1))[0] |
|
|
tab_input = np.clip(tab_input, -1e5, 1e5) |
|
|
return torch.tensor(tab_input, dtype=torch.float32) |
|
|
|
|
|
def generate_img_features(self, smiles): |
|
|
mol = Chem.MolFromSmiles(smiles) |
|
|
if mol is None: |
|
|
img = Image.new("RGB", (300,300), color=(0,0,0)) |
|
|
else: |
|
|
img = Draw.MolToImage(mol, size=(300, 300)) |
|
|
img = self.img_preprocess(img) |
|
|
with torch.no_grad(): |
|
|
img_input = self.resnet(img.unsqueeze(0)).squeeze(-1).squeeze(-1) |
|
|
img_input = self.feature_transformer_img.transform(img_input.reshape(1, -1))[0] |
|
|
return torch.tensor(img_input, dtype=torch.float32) |
|
|
|
|
|
def generate_txt_features(self, smiles): |
|
|
encoded = self.tokenizer(smiles, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
outputs = self.chemberta(**encoded) |
|
|
hidden_states = outputs.last_hidden_state[0].mean(axis=0).numpy() |
|
|
txt_input = self.feature_transformer_txt.transform(hidden_states.reshape(1, -1))[0] |
|
|
return torch.tensor(txt_input, dtype=torch.float32) |
|
|
|
|
|
def _batch_encode_plus( |
|
|
self, |
|
|
batch_smiles: list[str], |
|
|
task: str = 'classification', |
|
|
return_tensors: str = "pt", |
|
|
**kwargs |
|
|
): |
|
|
if self.task is None or self.task != task: |
|
|
if task == 'classification': |
|
|
model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_cls_tabular.joblib"]) |
|
|
transformer_tab_path = os.path.join(model_dir, "normalize_cls_tabular.joblib") |
|
|
model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_cls_image.joblib"]) |
|
|
transformer_img_path = os.path.join(model_dir, "normalize_cls_image.joblib") |
|
|
model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_cls_text.joblib"]) |
|
|
transformer_txt_path = os.path.join(model_dir, "normalize_cls_text.joblib") |
|
|
self.task = task |
|
|
|
|
|
elif task == 'regression': |
|
|
model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_reg_tabular.joblib"]) |
|
|
transformer_tab_path = os.path.join(model_dir, "normalize_reg_tabular.joblib") |
|
|
model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_reg_image.joblib"]) |
|
|
transformer_img_path = os.path.join(model_dir, "normalize_reg_image.joblib") |
|
|
model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_reg_text.joblib"]) |
|
|
transformer_txt_path = os.path.join(model_dir, "normalize_reg_text.joblib") |
|
|
self.task = task |
|
|
|
|
|
else: |
|
|
raise ValueError('task not defined') |
|
|
return |
|
|
|
|
|
self.feature_transformer_tab = joblib.load(transformer_tab_path) |
|
|
self.feature_transformer_img = joblib.load(transformer_img_path) |
|
|
self.feature_transformer_txt = joblib.load(transformer_txt_path) |
|
|
|
|
|
data_list = [] |
|
|
tab, img, txt = [], [], [] |
|
|
|
|
|
for smiles in batch_smiles: |
|
|
tab.append(self.generate_tab_features(smiles)) |
|
|
img.append(self.generate_img_features(smiles)) |
|
|
txt.append(self.generate_txt_features(smiles)) |
|
|
|
|
|
tab = torch.stack(tab) |
|
|
img = torch.stack(img) |
|
|
txt = torch.stack(txt) |
|
|
|
|
|
output = {} |
|
|
output["tab"] = tab |
|
|
output["img"] = img |
|
|
output["txt"] = txt |
|
|
|
|
|
return BatchEncoding(output, tensor_type=return_tensors) |
|
|
|
|
|
def encode(self, |
|
|
batch_smiles: list[str], |
|
|
task: str = 'classification', |
|
|
return_tensors: str = "pt", |
|
|
**kwargs): |
|
|
return self._batch_encode_plus(batch_smiles, task, return_tensors, **kwargs) |
|
|
|
|
|
def __call__(self, |
|
|
batch_smiles: list[str], |
|
|
task: str = 'classification', |
|
|
return_tensors: str = "pt", |
|
|
**kwargs): |
|
|
return self._batch_encode_plus(batch_smiles, task, return_tensors, **kwargs) |
|
|
|
|
|
def _tokenize(self, text, **kwargs): |
|
|
return [] |
|
|
|
|
|
def save_vocabulary(self, save_directory, filename_prefix=None): |
|
|
return () |
|
|
|
|
|
def get_vocab(self): |
|
|
return {"<pad>":0, "<bos>":1, "<eos>":2, "<unk>":3, "<mask>":4} |
|
|
|
|
|
@property |
|
|
def vocab_size(self): |
|
|
return len(self.get_vocab()) |