gabrielbianchin commited on
Commit
413a158
·
1 Parent(s): 5fcbd53
Files changed (1) hide show
  1. tokenizer_bbb.py +135 -0
tokenizer_bbb.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedTokenizer
4
+ from transformers.tokenization_utils_base import BatchEncoding
5
+ from transformers import AutoTokenizer, AutoModel
6
+ from rdkit import Chem
7
+ from rdkit.Chem import Descriptors, AllChem, MACCSkeys
8
+ from rdkit.ML.Descriptors import MoleculeDescriptors
9
+ from rdkit import RDLogger
10
+ from rdkit.Chem import Draw
11
+ import joblib
12
+ import numpy as np
13
+ import os
14
+ from huggingface_hub import snapshot_download
15
+ import warnings
16
+ from sklearn.exceptions import InconsistentVersionWarning
17
+ from torchvision import models, transforms
18
+ from PIL import Image
19
+ warnings.filterwarnings("ignore", category=InconsistentVersionWarning)
20
+ RDLogger.DisableLog('rdApp.*')
21
+
22
+ class BBBTokenizer(PreTrainedTokenizer):
23
+ def __init__(self, **kwargs):
24
+ super().__init__(**kwargs)
25
+
26
+ self.calc = MoleculeDescriptors.MolecularDescriptorCalculator([i[0] for i in Descriptors.descList])
27
+
28
+ self.tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-100M-MLM')
29
+ self.chemberta = AutoModel.from_pretrained('DeepChem/ChemBERTa-100M-MLM').eval()
30
+
31
+ self.resnet50_backbone = models.resnet50(weights="IMAGENET1K_V1")
32
+ self.resnet = nn.Sequential(*list(self.resnet50_backbone.children())[:-1]).eval()
33
+ self.img_preprocess = transforms.Compose([
34
+ transforms.Resize((224, 224)),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize(
37
+ mean=[0.485, 0.456, 0.406],
38
+ std=[0.229, 0.224, 0.225],
39
+ )
40
+ ])
41
+
42
+ model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_reg_tabular.joblib"])
43
+ transformer_tab_path = os.path.join(model_dir, "normalize_cls_tabular.joblib")
44
+ model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_reg_image.joblib"])
45
+ transformer_img_path = os.path.join(model_dir, "normalize_cls_image.joblib")
46
+ model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_reg_text.joblib"])
47
+ transformer_txt_path = os.path.join(model_dir, "normalize_cls_text.joblib")
48
+
49
+ self.feature_transformer_tab = joblib.load(transformer_tab_path)
50
+ self.feature_transformer_img = joblib.load(transformer_img_path)
51
+ self.feature_transformer_txt = joblib.load(transformer_txt_path)
52
+
53
+ def generate_tab_features(self, smiles):
54
+ mol = Chem.MolFromSmiles(smiles)
55
+
56
+ if mol is None:
57
+ return torch.tensor(self.feature_transformer_tab.n_features_in_, dtype=torch.float32)
58
+
59
+ rdkit_2d = np.array(self.calc.CalcDescriptors(mol))
60
+ rdkit_2d[np.isinf(rdkit_2d)] = np.nan
61
+ rdkit_2d = np.nan_to_num(rdkit_2d, nan=0.0, posinf=0.0, neginf=0.0)
62
+ maccs = np.array(list(MACCSkeys.GenMACCSKeys(mol).ToBitString()), dtype=int)
63
+ tab_input = np.concatenate([rdkit_2d, maccs])
64
+ tab_input = self.feature_transformer_tab.transform(tab_input.reshape(1, -1))[0]
65
+ return torch.tensor(tab_input, dtype=torch.float32)
66
+
67
+ def generate_img_features(self, smiles):
68
+ mol = Chem.MolFromSmiles(smiles)
69
+ if mol is None:
70
+ img = Image.new("RGB", (300,300), color=(0,0,0))
71
+ else:
72
+ img = Draw.MolToImage(mol, size=(300, 300))
73
+ img = self.img_preprocess(img)
74
+ with torch.no_grad():
75
+ img_input = self.resnet(img.unsqueeze(0)).squeeze(-1).squeeze(-1)
76
+ img_input = self.feature_transformer_img.transform(img_input.reshape(1, -1))[0]
77
+ return torch.tensor(img_input, dtype=torch.float32)
78
+
79
+ def generate_txt_features(self, smiles):
80
+ encoded = self.tokenizer(smiles, return_tensors="pt")
81
+ with torch.no_grad():
82
+ outputs = self.chemberta(**encoded)
83
+ hidden_states = outputs.last_hidden_state[0].mean(axis=0).numpy()
84
+ txt_input = self.feature_transformer_txt.transform(hidden_states.reshape(1, -1))[0]
85
+ return torch.tensor(txt_input, dtype=torch.float32)
86
+
87
+ def _batch_encode_plus(
88
+ self,
89
+ batch_smiles: list[str],
90
+ return_tensors: str = "pt",
91
+ **kwargs
92
+ ):
93
+ data_list = []
94
+ tab, img, txt = [], [], []
95
+
96
+ for smiles in batch_smiles:
97
+ tab.append(self.generate_tab_features(smiles))
98
+ img.append(self.generate_img_features(smiles))
99
+ txt.append(self.generate_txt_features(smiles))
100
+
101
+ tab = torch.stack(tab)
102
+ img = torch.stack(img)
103
+ txt = torch.stack(txt)
104
+
105
+ output = {}
106
+ output["tab"] = tab
107
+ output["img"] = img
108
+ output["txt"] = txt
109
+
110
+ return BatchEncoding(output, tensor_type=return_tensors)
111
+
112
+ def encode(self,
113
+ batch_smiles: list[str],
114
+ return_tensors: str = "pt",
115
+ **kwargs):
116
+ return self._batch_encode_plus(batch_smiles, return_tensors, **kwargs)
117
+
118
+ def __call__(self,
119
+ batch_smiles: list[str],
120
+ return_tensors: str = "pt",
121
+ **kwargs):
122
+ return self._batch_encode_plus(batch_smiles, return_tensors, **kwargs)
123
+
124
+ def _tokenize(self, text, **kwargs):
125
+ return []
126
+
127
+ def save_vocabulary(self, save_directory, filename_prefix=None):
128
+ return ()
129
+
130
+ def get_vocab(self):
131
+ return {"<pad>":0, "<bos>":1, "<eos>":2, "<unk>":3, "<mask>":4}
132
+
133
+ @property
134
+ def vocab_size(self):
135
+ return len(self.get_vocab())