Jiqing commited on
Commit
59de2c5
·
1 Parent(s): 401c15c

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +118 -0
README.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```python
2
+ import logging
3
+ import functools
4
+ from tqdm import tqdm
5
+ import torch
6
+ from datasets import load_dataset
7
+ from transformers import AutoModel, AutoTokenizer, AutoConfig
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def tokenize_protein(example, protein_tokenizer=None, padding=None):
13
+ # check https://github.com/huggingface/transformers/blob/41aef33758ae166291d72bc381477f2db84159cf/src/transformers/models/esm/tokenization_esm.py#L100
14
+ protein_seqs = example["prot_seq"]
15
+
16
+ protein_inputs = protein_tokenizer(
17
+ protein_seqs, padding=padding,
18
+ add_special_tokens=True, # default is True, no need to add cls and eos manually
19
+ ) # results in <cls> + seq + <eos> (no <sep> for ESM)
20
+ example["protein_input_ids"] = protein_inputs.input_ids
21
+ example["protein_attention_mask"] = protein_inputs.attention_mask
22
+
23
+ return example
24
+
25
+
26
+ def label_embedding(labels, text_tokenizer, text_model, device):
27
+ # embed label descriptions
28
+ label_feature = []
29
+ with torch.inference_mode():
30
+ for label in labels:
31
+ label_input_ids = text_tokenizer.encode(label, max_length=128,
32
+ truncation=True, add_special_tokens=False)
33
+ label_input_ids = [text_tokenizer.cls_token_id] + label_input_ids
34
+ label_input_ids = torch.tensor(label_input_ids, dtype=torch.long, device=device).unsqueeze(0)
35
+ attention_mask = label_input_ids != text_tokenizer.pad_token_id
36
+
37
+ text_outputs = text_model(
38
+ label_input_ids,
39
+ attention_mask=attention_mask,
40
+ position_ids=None,
41
+ head_mask=None,
42
+ inputs_embeds=None,
43
+ encoder_hidden_states=None,
44
+ encoder_attention_mask=None,
45
+ output_attentions=None,
46
+ output_hidden_states=None,
47
+ return_dict=None,
48
+ )
49
+
50
+ label_feature.append(text_outputs["text_feature"])
51
+ label_feature = torch.cat(label_feature, dim=0)
52
+ label_feature = label_feature / label_feature.norm(dim=-1, keepdim=True)
53
+
54
+ return label_feature
55
+
56
+ def zero_shot_eval(logger, device,
57
+ test_dataset, target_field, protein_model, logit_scale, label_feature):
58
+
59
+ # get prediction and target
60
+ test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
61
+ preds, targets = [], []
62
+ with torch.inference_mode():
63
+ for data in tqdm(test_dataloader):
64
+ target = data[target_field]
65
+ targets.append(target)
66
+
67
+ protein_input_ids = torch.tensor(data["protein_input_ids"], dtype=torch.long, device=device).unsqueeze(0)
68
+ attention_mask = torch.tensor(data["protein_attention_mask"], dtype=torch.long, device=device).unsqueeze(0)
69
+
70
+ protein_outputs = protein_model(
71
+ protein_input_ids,
72
+ attention_mask=attention_mask,
73
+ position_ids=None, # it's ok to set `position_ids`` as None: https://github.com/huggingface/transformers/blob/41aef33758ae166291d72bc381477f2db84159cf/src/transformers/models/esm/modeling_esm.py#L195
74
+ )
75
+
76
+ protein_feature = protein_outputs["protein_feature"]
77
+ protein_feature = protein_feature / protein_feature.norm(dim=-1, keepdim=True)
78
+ pred = logit_scale * protein_feature @ label_feature.t()
79
+ preds.append(pred)
80
+
81
+ preds = torch.cat(preds, dim=0)
82
+ targets = torch.tensor(targets, dtype=torch.long, device=device)
83
+ accuracy = (preds.argmax(dim=-1) == targets).float().mean().item()
84
+ logger.warning("Zero-shot accuracy: %.6f" % accuracy)
85
+
86
+
87
+ if __name__ == "__main__":
88
+ # get datasets
89
+ raw_datasets = load_dataset("Jiqing/ProtST-SubcellularLocalization", cache_dir="~/.cache/huggingface/datasets", split='test') # cache_dir defaults to "~/.cache/huggingface/datasets"
90
+
91
+ #device = torch.device("cuda:0")
92
+ device = torch.device("cpu")
93
+
94
+ protst_model = AutoModel.from_pretrained("Jiqing/ProtST-esm1b", trust_remote_code=True, torch_dtype=torch.bfloat16)
95
+ protein_model = protst_model.protein_model
96
+ text_model = protst_model.text_model
97
+ logit_scale = protst_model.logit_scale
98
+ logit_scale.requires_grad = False
99
+ logit_scale = logit_scale.to(device)
100
+ logit_scale = logit_scale.exp()
101
+
102
+ protein_tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
103
+ text_tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
104
+
105
+ func_tokenize_protein = functools.partial(tokenize_protein, protein_tokenizer=protein_tokenizer, padding=False)
106
+ test_dataset = raw_datasets.map(
107
+ func_tokenize_protein, batched=False,
108
+ remove_columns=["prot_seq"],
109
+ desc="Running tokenize_proteins on dataset",
110
+ )
111
+
112
+ labels = load_dataset("Jiqing/subloc_template", cache_dir="~/.cache/huggingface/datasets")["train"]["name"]
113
+
114
+ text_tokenizer.encode(labels[0], max_length=128, truncation=True, add_special_tokens=False)
115
+ label_feature = label_embedding(labels, text_tokenizer, text_model, device)
116
+ zero_shot_eval(logger, device, test_dataset, "localization",
117
+ protein_model, logit_scale, label_feature)
118
+ ```