import torch import torch.nn as nn from transformers import PreTrainedTokenizer from transformers.tokenization_utils_base import BatchEncoding from transformers import AutoTokenizer, AutoModel from rdkit import Chem from rdkit.Chem import Descriptors, AllChem, MACCSkeys from rdkit.ML.Descriptors import MoleculeDescriptors from rdkit import RDLogger from rdkit.Chem import Draw import joblib import numpy as np import os from huggingface_hub import snapshot_download import warnings from sklearn.exceptions import InconsistentVersionWarning from torchvision import models, transforms from PIL import Image warnings.filterwarnings("ignore", category=InconsistentVersionWarning) RDLogger.DisableLog('rdApp.*') class BBBTokenizer(PreTrainedTokenizer): def __init__(self, **kwargs): super().__init__(**kwargs) self.calc = MoleculeDescriptors.MolecularDescriptorCalculator([i[0] for i in Descriptors.descList]) self.tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-100M-MLM') self.chemberta = AutoModel.from_pretrained('DeepChem/ChemBERTa-100M-MLM').eval() self.resnet50_backbone = models.resnet50(weights="IMAGENET1K_V1") self.resnet = nn.Sequential(*list(self.resnet50_backbone.children())[:-1]).eval() self.img_preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ) ]) model_dir = snapshot_download("SaeedLab/TITAN-BBB/classification", allow_patterns=["normalize"]) transformer_tab_path = os.path.join(model_dir, "normalize_tabular.joblib") transformer_img_path = os.path.join(model_dir, "normalize_image.joblib") transformer_txt_path = os.path.join(model_dir, "normalize_text.joblib") self.feature_transformer_tab = joblib.load(transformer_tab_path) self.feature_transformer_img = joblib.load(transformer_img_path) self.feature_transformer_txt = joblib.load(transformer_txt_path) def generate_tab_features(self, smiles): mol = Chem.MolFromSmiles(smiles) if mol is None: return torch.tensor(self.feature_transformer_tab.n_features_in_, dtype=torch.float32) rdkit_2d = np.array(self.calc.CalcDescriptors(mol)) rdkit_2d[np.isinf(rdkit_2d)] = np.nan rdkit_2d = np.nan_to_num(rdkit_2d, nan=0.0, posinf=0.0, neginf=0.0) maccs = np.array(list(MACCSkeys.GenMACCSKeys(mol).ToBitString()), dtype=int) tab_input = np.concatenate([rdkit_2d, maccs]) tab_input = self.feature_transformer_tab.transform(tab_input.reshape(1, -1))[0] return torch.tensor(tab_input, dtype=torch.float32) def generate_img_features(self, smiles): mol = Chem.MolFromSmiles(smiles) if mol is None: img = Image.new("RGB", (300,300), color=(0,0,0)) else: img = Draw.MolToImage(mol, size=(300, 300)) img = self.img_preprocess(img) with torch.no_grad(): img_input = self.resnet(img.unsqueeze(0)).squeeze(-1).squeeze(-1) img_input = self.feature_transformer_img.transform(img_input.reshape(1, -1))[0] return torch.tensor(img_input, dtype=torch.float32) def generate_txt_features(self, smiles): encoded = self.tokenizer(smiles, return_tensors="pt") with torch.no_grad(): outputs = self.chemberta(**encoded) hidden_states = outputs.last_hidden_state[0].mean(axis=0).numpy() txt_input = self.feature_transformer_txt.transform(hidden_states.reshape(1, -1))[0] return torch.tensor(txt_input, dtype=torch.float32) def _batch_encode_plus( self, batch_smiles: list[str], return_tensors: str = "pt", **kwargs ): data_list = [] tab, img, txt = [], [], [] for smiles in batch_smiles: tab.append(self.generate_tab_features(smiles)) img.append(self.generate_img_features(smiles)) txt.append(self.generate_txt_features(smiles)) tab = torch.stack(tab) img = torch.stack(img) txt = torch.stack(txt) output = {} output["tab"] = tab output["img"] = img output["txt"] = txt return BatchEncoding(output, tensor_type=return_tensors) def encode(self, batch_smiles: list[str], return_tensors: str = "pt", **kwargs): return self._batch_encode_plus(batch_smiles, return_tensors, **kwargs) def __call__(self, batch_smiles: list[str], return_tensors: str = "pt", **kwargs): return self._batch_encode_plus(batch_smiles, return_tensors, **kwargs) def _tokenize(self, text, **kwargs): return [] def save_vocabulary(self, save_directory, filename_prefix=None): return () def get_vocab(self): return {"":0, "":1, "":2, "":3, "":4} @property def vocab_size(self): return len(self.get_vocab())