import torch import torch.nn as nn from transformers import PreTrainedTokenizer from transformers.tokenization_utils_base import BatchEncoding from transformers import AutoTokenizer, AutoModel from rdkit import Chem from rdkit.Chem import Descriptors, AllChem, MACCSkeys from rdkit.ML.Descriptors import MoleculeDescriptors from rdkit import RDLogger from rdkit.Chem import Draw import joblib import numpy as np import os from huggingface_hub import snapshot_download import warnings from sklearn.exceptions import InconsistentVersionWarning from torchvision import models, transforms from PIL import Image warnings.filterwarnings("ignore", category=InconsistentVersionWarning) RDLogger.DisableLog('rdApp.*') class BBBTokenizer(PreTrainedTokenizer): def __init__(self, **kwargs): super().__init__(**kwargs) self.calc = MoleculeDescriptors.MolecularDescriptorCalculator([i[0] for i in Descriptors.descList]) self.tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-100M-MLM') self.chemberta = AutoModel.from_pretrained('DeepChem/ChemBERTa-100M-MLM').eval() self.resnet50_backbone = models.resnet50(weights="IMAGENET1K_V1") self.resnet = nn.Sequential(*list(self.resnet50_backbone.children())[:-1]).eval() self.img_preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ) ]) self.feature_transformer_tab = None self.feature_transformer_img = None self.feature_transformer_txt = None self.task = None def generate_tab_features(self, smiles): mol = Chem.MolFromSmiles(smiles) if mol is None: return torch.tensor(self.feature_transformer_tab.n_features_in_, dtype=torch.float32) rdkit_2d = np.array(self.calc.CalcDescriptors(mol)) rdkit_2d[np.isinf(rdkit_2d)] = np.nan rdkit_2d = np.nan_to_num(rdkit_2d, nan=0.0, posinf=0.0, neginf=0.0) maccs = np.array(list(MACCSkeys.GenMACCSKeys(mol).ToBitString()), dtype=int) tab_input = np.concatenate([rdkit_2d, maccs]) tab_input = self.feature_transformer_tab.transform(tab_input.reshape(1, -1))[0] tab_input = np.clip(tab_input, -1e5, 1e5) return torch.tensor(tab_input, dtype=torch.float32) def generate_img_features(self, smiles): mol = Chem.MolFromSmiles(smiles) if mol is None: img = Image.new("RGB", (300,300), color=(0,0,0)) else: img = Draw.MolToImage(mol, size=(300, 300)) img = self.img_preprocess(img) with torch.no_grad(): img_input = self.resnet(img.unsqueeze(0)).squeeze(-1).squeeze(-1) img_input = self.feature_transformer_img.transform(img_input.reshape(1, -1))[0] return torch.tensor(img_input, dtype=torch.float32) def generate_txt_features(self, smiles): encoded = self.tokenizer(smiles, return_tensors="pt") with torch.no_grad(): outputs = self.chemberta(**encoded) hidden_states = outputs.last_hidden_state[0].mean(axis=0).numpy() txt_input = self.feature_transformer_txt.transform(hidden_states.reshape(1, -1))[0] return torch.tensor(txt_input, dtype=torch.float32) def _batch_encode_plus( self, batch_smiles: list[str], task: str = 'classification', return_tensors: str = "pt", **kwargs ): if self.task is None or self.task != task: if task == 'classification': model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_cls_tabular.joblib"]) transformer_tab_path = os.path.join(model_dir, "normalize_cls_tabular.joblib") model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_cls_image.joblib"]) transformer_img_path = os.path.join(model_dir, "normalize_cls_image.joblib") model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_cls_text.joblib"]) transformer_txt_path = os.path.join(model_dir, "normalize_cls_text.joblib") self.task = task elif task == 'regression': model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_reg_tabular.joblib"]) transformer_tab_path = os.path.join(model_dir, "normalize_reg_tabular.joblib") model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_reg_image.joblib"]) transformer_img_path = os.path.join(model_dir, "normalize_reg_image.joblib") model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_reg_text.joblib"]) transformer_txt_path = os.path.join(model_dir, "normalize_reg_text.joblib") self.task = task else: raise ValueError('task not defined') return self.feature_transformer_tab = joblib.load(transformer_tab_path) self.feature_transformer_img = joblib.load(transformer_img_path) self.feature_transformer_txt = joblib.load(transformer_txt_path) data_list = [] tab, img, txt = [], [], [] for smiles in batch_smiles: tab.append(self.generate_tab_features(smiles)) img.append(self.generate_img_features(smiles)) txt.append(self.generate_txt_features(smiles)) tab = torch.stack(tab) img = torch.stack(img) txt = torch.stack(txt) output = {} output["tab"] = tab output["img"] = img output["txt"] = txt return BatchEncoding(output, tensor_type=return_tensors) def encode(self, batch_smiles: list[str], task: str = 'classification', return_tensors: str = "pt", **kwargs): return self._batch_encode_plus(batch_smiles, task, return_tensors, **kwargs) def __call__(self, batch_smiles: list[str], task: str = 'classification', return_tensors: str = "pt", **kwargs): return self._batch_encode_plus(batch_smiles, task, return_tensors, **kwargs) def _tokenize(self, text, **kwargs): return [] def save_vocabulary(self, save_directory, filename_prefix=None): return () def get_vocab(self): return {"":0, "":1, "":2, "":3, "":4} @property def vocab_size(self): return len(self.get_vocab())