gabrielbianchin commited on
Commit
3a19a3f
·
1 Parent(s): f14c1f8

Clean re-commit using proper Git LFS

Browse files
classification/config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BBBModelForSequenceClassification"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_bbb.BBBConfig",
7
+ "AutoModelForSequenceClassification": "modeling_bbb.BBBModelForSequenceClassification"
8
+ },
9
+ "d_img": 2048,
10
+ "d_tab": 384,
11
+ "d_txt": 768,
12
+ "dropout": 0.1,
13
+ "dtype": "float32",
14
+ "model_type": "bbb-model",
15
+ "problem_type": "single_label_classification",
16
+ "proj_dim": 2048,
17
+ "task": "classification",
18
+ "transformers_version": "4.57.3"
19
+ }
classification/configuration_bbb.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class BBBConfig(PretrainedConfig):
4
+ model_type = "bbb-model"
5
+
6
+ def __init__(
7
+ self,
8
+ d_tab : int = 384,
9
+ d_img: int = 2048,
10
+ d_txt: int = 768,
11
+ proj_dim: int = 2048,
12
+ dropout: float = 0.1,
13
+ task: str = 'classification',
14
+ **kwargs):
15
+
16
+ self.d_tab = d_tab
17
+ self.d_img = d_img
18
+ self.d_txt = d_txt
19
+ self.proj_dim = proj_dim
20
+ self.dropout = dropout
21
+ self.task = task
22
+ super().__init__(**kwargs)
classification/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa5ab19333e753976e402ac0c8c70615fa65065ca2ed3f5e5b6fcb5f54564c58
3
+ size 59853476
classification/model_classification.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:580e8747c45b9db60ff2b9b38384a0089a29d3066d02a50a09c7ab3fab3d3253
3
+ size 59859439
classification/modeling_bbb.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from transformers import PreTrainedModel
6
+ from transformers.modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput
7
+
8
+ from .configuration_bbb import BBBConfig
9
+
10
+ class BBBModelForSequenceClassification(PreTrainedModel):
11
+ config_class = BBBConfig
12
+
13
+ def __init__(self, config: BBBConfig):
14
+ super().__init__(config)
15
+
16
+ self.config = config
17
+
18
+ self.proj_tab = nn.Sequential(
19
+ nn.LayerNorm(config.d_tab),
20
+ nn.Linear(config.d_tab, config.proj_dim),
21
+ nn.ReLU(),
22
+ nn.Dropout(config.dropout)
23
+ )
24
+
25
+ self.proj_img = nn.Sequential(
26
+ nn.LayerNorm(config.d_img),
27
+ nn.Linear(config.d_img, config.proj_dim),
28
+ nn.ReLU(),
29
+ nn.Dropout(config.dropout)
30
+ )
31
+
32
+ self.proj_txt = nn.Sequential(
33
+ nn.LayerNorm(config.d_txt),
34
+ nn.Linear(config.d_txt, config.proj_dim),
35
+ nn.ReLU(),
36
+ nn.Dropout(config.dropout)
37
+ )
38
+
39
+ self.attention_pooling = nn.Sequential(
40
+ nn.Linear(config.proj_dim, config.proj_dim),
41
+ nn.Tanh(),
42
+ nn.Linear(config.proj_dim, 1, bias=False)
43
+ )
44
+
45
+ self.classifier = nn.Sequential(
46
+ nn.Linear(config.proj_dim, config.proj_dim),
47
+ nn.ReLU(),
48
+ nn.Dropout(config.dropout),
49
+ nn.Linear(config.proj_dim, 1)
50
+ )
51
+
52
+ def _init_weights(self, module):
53
+ if isinstance(module, nn.Linear):
54
+ module.weight.data.normal_(mean=0.0, std=1.0)
55
+ if module.bias is not None:
56
+ module.bias.data.zero_()
57
+ elif isinstance(module, nn.Embedding):
58
+ module.weight.data.normal_(mean=0.0, std=1.0)
59
+ if module.padding_idx is not None:
60
+ module.weight.data[module.padding_idx].zero_()
61
+ elif isinstance(module, nn.LayerNorm):
62
+ module.bias.data.zero_()
63
+ module.weight.data.fill_(1.0)
64
+
65
+ def forward(self,
66
+ tab: torch.Tensor = None,
67
+ img: torch.Tensor = None,
68
+ txt: torch.Tensor = None,
69
+ labels: torch.Tensor = None):
70
+
71
+ if tab is None or img is None or txt is None:
72
+ raise ValueError('You have to specify tabular, image, and text features')
73
+
74
+ h_tab = self.proj_tab(tab)
75
+ h_img = self.proj_img(img)
76
+ h_txt = self.proj_txt(txt)
77
+
78
+ embeddings = torch.stack([h_tab, h_img, h_txt], dim=1)
79
+ scores = self.attention_pooling(embeddings)
80
+ weights = F.softmax(scores, dim=1)
81
+ weighted_embeddings = embeddings * weights
82
+ pooled_embeddings = torch.sum(weighted_embeddings, dim=1)
83
+ logits = self.classifier(pooled_embeddings).squeeze(-1)
84
+
85
+ loss = None
86
+ if labels is not None:
87
+ if self.config.task == "regression":
88
+ loss_fct = nn.MSELoss()
89
+ loss = loss_fct(logits.squeeze(-1), labels.squeeze(-1))
90
+ elif self.config.task == "classification":
91
+ loss_fct = nn.BCEWithLogitsLoss()
92
+ loss = loss_fct(logits, labels.float().unsqueeze(-1))
93
+
94
+ return SequenceClassifierOutput(
95
+ loss=loss,
96
+ logits=logits,
97
+ hidden_states=pooled_embeddings,
98
+ attentions=weights.squeeze(-1)
99
+ )
classification/normalize_image.joblib ADDED
Binary file (49.8 kB). View file
 
classification/normalize_tabular.joblib ADDED
Binary file (9.83 kB). View file
 
classification/normalize_text.joblib ADDED
Binary file (19 kB). View file
 
classification/tokenizer_bbb.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedTokenizer
4
+ from transformers.tokenization_utils_base import BatchEncoding
5
+ from transformers import AutoTokenizer, AutoModel
6
+ from rdkit import Chem
7
+ from rdkit.Chem import Descriptors, AllChem, MACCSkeys
8
+ from rdkit.ML.Descriptors import MoleculeDescriptors
9
+ from rdkit import RDLogger
10
+ from rdkit.Chem import Draw
11
+ import joblib
12
+ import numpy as np
13
+ import os
14
+ from huggingface_hub import snapshot_download
15
+ import warnings
16
+ from sklearn.exceptions import InconsistentVersionWarning
17
+ from torchvision import models, transforms
18
+ from PIL import Image
19
+ warnings.filterwarnings("ignore", category=InconsistentVersionWarning)
20
+ RDLogger.DisableLog('rdApp.*')
21
+
22
+ class BBBTokenizer(PreTrainedTokenizer):
23
+ def __init__(self, **kwargs):
24
+ super().__init__(**kwargs)
25
+
26
+ self.calc = MoleculeDescriptors.MolecularDescriptorCalculator([i[0] for i in Descriptors.descList])
27
+
28
+ self.tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-100M-MLM')
29
+ self.chemberta = AutoModel.from_pretrained('DeepChem/ChemBERTa-100M-MLM').eval()
30
+
31
+ self.resnet50_backbone = models.resnet50(weights="IMAGENET1K_V1")
32
+ self.resnet = nn.Sequential(*list(self.resnet50_backbone.children())[:-1]).eval()
33
+ self.img_preprocess = transforms.Compose([
34
+ transforms.Resize((224, 224)),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize(
37
+ mean=[0.485, 0.456, 0.406],
38
+ std=[0.229, 0.224, 0.225],
39
+ )
40
+ ])
41
+
42
+ model_dir = snapshot_download("SaeedLab/TITAN-BBB/classification", allow_patterns=["normalize"])
43
+ transformer_tab_path = os.path.join(model_dir, "normalize_tabular.joblib")
44
+ transformer_img_path = os.path.join(model_dir, "normalize_image.joblib")
45
+ transformer_txt_path = os.path.join(model_dir, "normalize_text.joblib")
46
+
47
+ self.feature_transformer_tab = joblib.load(transformer_tab_path)
48
+ self.feature_transformer_img = joblib.load(transformer_img_path)
49
+ self.feature_transformer_txt = joblib.load(transformer_txt_path)
50
+
51
+ def generate_tab_features(self, smiles):
52
+ mol = Chem.MolFromSmiles(smiles)
53
+
54
+ if mol is None:
55
+ return torch.tensor(self.feature_transformer_tab.n_features_in_, dtype=torch.float32)
56
+
57
+ rdkit_2d = np.array(self.calc.CalcDescriptors(mol))
58
+ rdkit_2d[np.isinf(rdkit_2d)] = np.nan
59
+ rdkit_2d = np.nan_to_num(rdkit_2d, nan=0.0, posinf=0.0, neginf=0.0)
60
+ maccs = np.array(list(MACCSkeys.GenMACCSKeys(mol).ToBitString()), dtype=int)
61
+ tab_input = np.concatenate([rdkit_2d, maccs])
62
+ tab_input = self.feature_transformer_tab.transform(tab_input.reshape(1, -1))[0]
63
+ return torch.tensor(tab_input, dtype=torch.float32)
64
+
65
+ def generate_img_features(self, smiles):
66
+ mol = Chem.MolFromSmiles(smiles)
67
+ if mol is None:
68
+ img = Image.new("RGB", (300,300), color=(0,0,0))
69
+ else:
70
+ img = Draw.MolToImage(mol, size=(300, 300))
71
+ img = self.img_preprocess(img)
72
+ with torch.no_grad():
73
+ img_input = self.resnet(img.unsqueeze(0)).squeeze(-1).squeeze(-1)
74
+ img_input = self.feature_transformer_img.transform(img_input.reshape(1, -1))[0]
75
+ return torch.tensor(img_input, dtype=torch.float32)
76
+
77
+ def generate_txt_features(self, smiles):
78
+ encoded = self.tokenizer(smiles, return_tensors="pt")
79
+ with torch.no_grad():
80
+ outputs = self.chemberta(**encoded)
81
+ hidden_states = outputs.last_hidden_state[0].mean(axis=0).numpy()
82
+ txt_input = self.feature_transformer_txt.transform(hidden_states.reshape(1, -1))[0]
83
+ return torch.tensor(txt_input, dtype=torch.float32)
84
+
85
+ def _batch_encode_plus(
86
+ self,
87
+ batch_smiles: list[str],
88
+ return_tensors: str = "pt",
89
+ **kwargs
90
+ ):
91
+ data_list = []
92
+ tab, img, txt = [], [], []
93
+
94
+ for smiles in batch_smiles:
95
+ tab.append(self.generate_tab_features(smiles))
96
+ img.append(self.generate_img_features(smiles))
97
+ txt.append(self.generate_txt_features(smiles))
98
+
99
+ tab = torch.stack(tab)
100
+ img = torch.stack(img)
101
+ txt = torch.stack(txt)
102
+
103
+ output = {}
104
+ output["tab"] = tab
105
+ output["img"] = img
106
+ output["txt"] = txt
107
+
108
+ return BatchEncoding(output, tensor_type=return_tensors)
109
+
110
+ def encode(self,
111
+ batch_smiles: list[str],
112
+ return_tensors: str = "pt",
113
+ **kwargs):
114
+ return self._batch_encode_plus(batch_smiles, return_tensors, **kwargs)
115
+
116
+ def __call__(self,
117
+ batch_smiles: list[str],
118
+ return_tensors: str = "pt",
119
+ **kwargs):
120
+ return self._batch_encode_plus(batch_smiles, return_tensors, **kwargs)
121
+
122
+ def _tokenize(self, text, **kwargs):
123
+ return []
124
+
125
+ def save_vocabulary(self, save_directory, filename_prefix=None):
126
+ return ()
127
+
128
+ def get_vocab(self):
129
+ return {"<pad>":0, "<bos>":1, "<eos>":2, "<unk>":3, "<mask>":4}
130
+
131
+ @property
132
+ def vocab_size(self):
133
+ return len(self.get_vocab())
classification/tokenizer_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoTokenizer": ["tokenizer_bbb.BBBTokenizer", "tokenizer_bbb.BBBTokenizer"]
4
+ },
5
+ "clean_up_tokenization_spaces": true,
6
+ "cls_token": "<bos>",
7
+ "mask_token": "<mask>",
8
+ "model_max_length": 1000000000000000019884624838656,
9
+ "pad_token": "<pad>",
10
+ "sep_token": "<eos>",
11
+ "tokenizer_class": "BBBTokenizer",
12
+ "unk_token": "<unk>"
13
+ }
convert_weights.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from configuration_bbb import BBBConfig
3
+ from modeling_bbb import BBBModelForSequenceClassification
4
+ import os
5
+
6
+ BASE_ARCH_PARAMS = {
7
+ "d_tab": 384,
8
+ "d_img": 2048,
9
+ "d_txt": 768,
10
+ "proj_dim": 2048
11
+ }
12
+
13
+ def convert_model(checkpoint_path: str, task_name: str, problem_type: str, dropout: float, save_directory: str):
14
+ config_params = BASE_ARCH_PARAMS.copy()
15
+ config_params["task"] = task_name
16
+ config_params["problem_type"] = problem_type
17
+ config_params["dropout"] = dropout
18
+
19
+ config = BBBConfig(**config_params)
20
+
21
+ hf_model = BBBModelForSequenceClassification(config)
22
+ hf_model.eval()
23
+
24
+ if not os.path.exists(checkpoint_path):
25
+ return
26
+
27
+ old_state_dict = torch.load(checkpoint_path, map_location="cpu")
28
+
29
+ new_state_dict = {}
30
+ for key, value in old_state_dict.items():
31
+ if key.startswith("proj") or key.startswith("attention_pooling"):
32
+ new_state_dict[key] = value
33
+
34
+ elif key.startswith("classifier."):
35
+ # The 'fc' layer is already in the correct place
36
+ new_state_dict[key] = value
37
+
38
+ else:
39
+ print(f"[Warning] Unmapped key found: {key}")
40
+ new_state_dict[key] = value
41
+
42
+ print("State dict key names adjusted.")
43
+
44
+ try:
45
+ hf_model.load_state_dict(new_state_dict, strict=True)
46
+ print("State dict loaded successfully into HF")
47
+ except RuntimeError as e:
48
+ print("\n--- ERROR LOADING STATE DICT ---")
49
+ print("Verify that the parameters in BASE_ARCH_PARAMS are correct.")
50
+ print(e)
51
+ return
52
+
53
+ print(f"Saving HF-formatted model to {save_directory}")
54
+ hf_model.save_pretrained(save_directory)
55
+
56
+ if __name__ == "__main__":
57
+ convert_model(
58
+ checkpoint_path="model_classification.pth",
59
+ task_name="classification",
60
+ dropout=0.1,
61
+ problem_type="single_label_classification",
62
+ save_directory="./classification"
63
+ )
regression/config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BBBModelForSequenceClassification"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_bbb.BBBConfig",
7
+ "AutoModelForSequenceClassification": "modeling_bbb.BBBModelForSequenceClassification"
8
+ },
9
+ "d_img": 2048,
10
+ "d_tab": 384,
11
+ "d_txt": 768,
12
+ "dropout": 0.3,
13
+ "dtype": "float32",
14
+ "model_type": "bbb-model",
15
+ "problem_type": "regression",
16
+ "proj_dim": 2048,
17
+ "task": "regression",
18
+ "transformers_version": "4.57.3"
19
+ }
regression/configuration_bbb.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class BBBConfig(PretrainedConfig):
4
+ model_type = "bbb-model"
5
+
6
+ def __init__(
7
+ self,
8
+ d_tab : int = 384,
9
+ d_img: int = 2048,
10
+ d_txt: int = 768,
11
+ proj_dim: int = 2048,
12
+ dropout: float = 0.1,
13
+ task: str = 'classification',
14
+ **kwargs):
15
+
16
+ self.d_tab = d_tab
17
+ self.d_img = d_img
18
+ self.d_txt = d_txt
19
+ self.proj_dim = proj_dim
20
+ self.dropout = dropout
21
+ self.task = task
22
+ super().__init__(**kwargs)
regression/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d281e439bcf72bcccf563020f14589d8388c5ae6c1cfd1ec03610458e8bef64
3
+ size 59853476
regression/model_regression.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:093e549d68ce6137dd0ca89a4742f6717e9a628b97a0835210402d60b2a1266c
3
+ size 59859339
regression/modeling_bbb.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from transformers import PreTrainedModel
6
+ from transformers.modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput
7
+
8
+ from .configuration_bbb import BBBConfig
9
+
10
+ class BBBModelForSequenceClassification(PreTrainedModel):
11
+ config_class = BBBConfig
12
+
13
+ def __init__(self, config: BBBConfig):
14
+ super().__init__(config)
15
+
16
+ self.config = config
17
+
18
+ self.proj_tab = nn.Sequential(
19
+ nn.LayerNorm(config.d_tab),
20
+ nn.Linear(config.d_tab, config.proj_dim),
21
+ nn.ReLU(),
22
+ nn.Dropout(config.dropout)
23
+ )
24
+
25
+ self.proj_img = nn.Sequential(
26
+ nn.LayerNorm(config.d_img),
27
+ nn.Linear(config.d_img, config.proj_dim),
28
+ nn.ReLU(),
29
+ nn.Dropout(config.dropout)
30
+ )
31
+
32
+ self.proj_txt = nn.Sequential(
33
+ nn.LayerNorm(config.d_txt),
34
+ nn.Linear(config.d_txt, config.proj_dim),
35
+ nn.ReLU(),
36
+ nn.Dropout(config.dropout)
37
+ )
38
+
39
+ self.attention_pooling = nn.Sequential(
40
+ nn.Linear(config.proj_dim, config.proj_dim),
41
+ nn.Tanh(),
42
+ nn.Linear(config.proj_dim, 1, bias=False)
43
+ )
44
+
45
+ self.classifier = nn.Sequential(
46
+ nn.Linear(config.proj_dim, config.proj_dim),
47
+ nn.ReLU(),
48
+ nn.Dropout(config.dropout),
49
+ nn.Linear(config.proj_dim, 1)
50
+ )
51
+
52
+ def _init_weights(self, module):
53
+ if isinstance(module, nn.Linear):
54
+ module.weight.data.normal_(mean=0.0, std=1.0)
55
+ if module.bias is not None:
56
+ module.bias.data.zero_()
57
+ elif isinstance(module, nn.Embedding):
58
+ module.weight.data.normal_(mean=0.0, std=1.0)
59
+ if module.padding_idx is not None:
60
+ module.weight.data[module.padding_idx].zero_()
61
+ elif isinstance(module, nn.LayerNorm):
62
+ module.bias.data.zero_()
63
+ module.weight.data.fill_(1.0)
64
+
65
+ def forward(self,
66
+ tab: torch.Tensor = None,
67
+ img: torch.Tensor = None,
68
+ txt: torch.Tensor = None,
69
+ labels: torch.Tensor = None):
70
+
71
+ if tab is None or img is None or txt is None:
72
+ raise ValueError('You have to specify tabular, image, and text features')
73
+
74
+ h_tab = self.proj_tab(tab)
75
+ h_img = self.proj_img(img)
76
+ h_txt = self.proj_txt(txt)
77
+
78
+ embeddings = torch.stack([h_tab, h_img, h_txt], dim=1)
79
+ scores = self.attention_pooling(embeddings)
80
+ weights = F.softmax(scores, dim=1)
81
+ weighted_embeddings = embeddings * weights
82
+ pooled_embeddings = torch.sum(weighted_embeddings, dim=1)
83
+ logits = self.classifier(pooled_embeddings).squeeze(-1)
84
+
85
+ loss = None
86
+ if labels is not None:
87
+ if self.config.task == "regression":
88
+ loss_fct = nn.MSELoss()
89
+ loss = loss_fct(logits.squeeze(-1), labels.squeeze(-1))
90
+ elif self.config.task == "classification":
91
+ loss_fct = nn.BCEWithLogitsLoss()
92
+ loss = loss_fct(logits, labels.float().unsqueeze(-1))
93
+
94
+ return SequenceClassifierOutput(
95
+ loss=loss,
96
+ logits=logits,
97
+ hidden_states=pooled_embeddings,
98
+ attentions=weights.squeeze(-1)
99
+ )
regression/normalize_image.joblib ADDED
Binary file (49.8 kB). View file
 
regression/normalize_tabular.joblib ADDED
Binary file (9.83 kB). View file
 
regression/normalize_text.joblib ADDED
Binary file (19 kB). View file
 
regression/tokenizer_bbb.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedTokenizer
4
+ from transformers.tokenization_utils_base import BatchEncoding
5
+ from transformers import AutoTokenizer, AutoModel
6
+ from rdkit import Chem
7
+ from rdkit.Chem import Descriptors, AllChem, MACCSkeys
8
+ from rdkit.ML.Descriptors import MoleculeDescriptors
9
+ from rdkit import RDLogger
10
+ from rdkit.Chem import Draw
11
+ import joblib
12
+ import numpy as np
13
+ import os
14
+ from huggingface_hub import snapshot_download
15
+ import warnings
16
+ from sklearn.exceptions import InconsistentVersionWarning
17
+ from torchvision import models, transforms
18
+ from PIL import Image
19
+ warnings.filterwarnings("ignore", category=InconsistentVersionWarning)
20
+ RDLogger.DisableLog('rdApp.*')
21
+
22
+ class BBBTokenizer(PreTrainedTokenizer):
23
+ def __init__(self, **kwargs):
24
+ super().__init__(**kwargs)
25
+
26
+ self.calc = MoleculeDescriptors.MolecularDescriptorCalculator([i[0] for i in Descriptors.descList])
27
+
28
+ self.tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-100M-MLM')
29
+ self.chemberta = AutoModel.from_pretrained('DeepChem/ChemBERTa-100M-MLM').eval()
30
+
31
+ self.resnet50_backbone = models.resnet50(weights="IMAGENET1K_V1")
32
+ self.resnet = nn.Sequential(*list(self.resnet50_backbone.children())[:-1]).eval()
33
+ self.img_preprocess = transforms.Compose([
34
+ transforms.Resize((224, 224)),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize(
37
+ mean=[0.485, 0.456, 0.406],
38
+ std=[0.229, 0.224, 0.225],
39
+ )
40
+ ])
41
+
42
+ #model_dir = snapshot_download("SaeedLab/BBBP-Classification", allow_patterns=["transformer.pkl"])
43
+ #transformer_path = os.path.join(model_dir, "transformer.pkl")
44
+
45
+ self.feature_transformer_tab = joblib.load('classification/normalize_tabular.joblib')
46
+ self.feature_transformer_img = joblib.load('classification/normalize_image.joblib')
47
+ self.feature_transformer_txt = joblib.load('classification/normalize_text.joblib')
48
+
49
+ def generate_tab_features(self, smiles):
50
+ mol = Chem.MolFromSmiles(smiles)
51
+
52
+ if mol is None:
53
+ return torch.tensor(self.feature_transformer_tab.n_features_in_, dtype=torch.float32)
54
+
55
+ rdkit_2d = np.array(self.calc.CalcDescriptors(mol))
56
+ rdkit_2d[np.isinf(rdkit_2d)] = np.nan
57
+ rdkit_2d = np.nan_to_num(rdkit_2d, nan=0.0, posinf=0.0, neginf=0.0)
58
+ maccs = np.array(list(MACCSkeys.GenMACCSKeys(mol).ToBitString()), dtype=int)
59
+ tab_input = np.concatenate([rdkit_2d, maccs])
60
+ tab_input = self.feature_transformer_tab.transform(tab_input.reshape(1, -1))[0]
61
+ return torch.tensor(tab_input, dtype=torch.float32)
62
+
63
+ def generate_img_features(self, smiles):
64
+ mol = Chem.MolFromSmiles(smiles)
65
+ if mol is None:
66
+ img = Image.new("RGB", (300,300), color=(0,0,0))
67
+ else:
68
+ img = Draw.MolToImage(mol, size=(300, 300))
69
+ img = self.img_preprocess(img)
70
+ with torch.no_grad():
71
+ img_input = self.resnet(img.unsqueeze(0)).squeeze(-1).squeeze(-1)
72
+ img_input = self.feature_transformer_img.transform(img_input.reshape(1, -1))[0]
73
+ return torch.tensor(img_input, dtype=torch.float32)
74
+
75
+ def generate_txt_features(self, smiles):
76
+ encoded = self.tokenizer(smiles, return_tensors="pt")
77
+ with torch.no_grad():
78
+ outputs = self.chemberta(**encoded)
79
+ hidden_states = outputs.last_hidden_state[0].mean(axis=0).numpy()
80
+ txt_input = self.feature_transformer_txt.transform(hidden_states.reshape(1, -1))[0]
81
+ return torch.tensor(txt_input, dtype=torch.float32)
82
+
83
+ def _batch_encode_plus(
84
+ self,
85
+ batch_smiles: list[str],
86
+ return_tensors: str = "pt",
87
+ **kwargs
88
+ ):
89
+ data_list = []
90
+ tab, img, txt = [], [], []
91
+
92
+ for smiles in batch_smiles:
93
+ tab.append(self.generate_tab_features(smiles))
94
+ img.append(self.generate_img_features(smiles))
95
+ txt.append(self.generate_txt_features(smiles))
96
+
97
+ tab = torch.stack(tab)
98
+ img = torch.stack(img)
99
+ txt = torch.stack(txt)
100
+
101
+ output = {}
102
+ output["tab"] = tab
103
+ output["img"] = img
104
+ output["txt"] = txt
105
+
106
+ return BatchEncoding(output, tensor_type=return_tensors)
107
+
108
+ def encode(self,
109
+ batch_smiles: list[str],
110
+ return_tensors: str = "pt",
111
+ **kwargs):
112
+ return self._batch_encode_plus(batch_smiles, return_tensors, **kwargs)
113
+
114
+ def __call__(self,
115
+ batch_smiles: list[str],
116
+ return_tensors: str = "pt",
117
+ **kwargs):
118
+ return self._batch_encode_plus(batch_smiles, return_tensors, **kwargs)
119
+
120
+ def _tokenize(self, text, **kwargs):
121
+ return []
122
+
123
+ def save_vocabulary(self, save_directory, filename_prefix=None):
124
+ return ()
125
+
126
+ def get_vocab(self):
127
+ return {"<pad>":0, "<bos>":1, "<eos>":2, "<unk>":3, "<mask>":4}
128
+
129
+ @property
130
+ def vocab_size(self):
131
+ return len(self.get_vocab())
regression/tokenizer_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoTokenizer": ["tokenizer_bbb.BBBTokenizer", "tokenizer_bbb.BBBTokenizer"]
4
+ },
5
+ "clean_up_tokenization_spaces": true,
6
+ "cls_token": "<bos>",
7
+ "mask_token": "<mask>",
8
+ "model_max_length": 1000000000000000019884624838656,
9
+ "pad_token": "<pad>",
10
+ "sep_token": "<eos>",
11
+ "tokenizer_class": "BBBTokenizer",
12
+ "unk_token": "<unk>"
13
+ }