Jiqing commited on
Commit
4a15c97
·
verified ·
1 Parent(s): 3aead7c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +21 -9
README.md CHANGED
@@ -5,7 +5,7 @@ Current protein language models (PLMs) learn protein representations mainly base
5
  ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f0a673f0d40f6aae296b4a/o4F5-Cm-gGdHPpX5rPVKx.png)
6
 
7
  ## Example
8
- The following script shows how to run ProtST with [optimum-intel](https://github.com/huggingface/optimum-intel) optimization on zero-shot classification task.
9
  ```diff
10
  import logging
11
  import functools
@@ -13,14 +13,16 @@ from tqdm import tqdm
13
  import torch
14
  from datasets import load_dataset
15
  from transformers import AutoModel, AutoTokenizer, AutoConfig
 
16
 
17
  logger = logging.getLogger(__name__)
18
 
19
 
20
  def tokenize_protein(example, protein_tokenizer=None, padding=None):
21
  protein_seqs = example["prot_seq"]
22
-
23
- protein_inputs = protein_tokenizer(protein_seqs, padding=padding, add_special_tokens=True)
 
24
  example["protein_input_ids"] = protein_inputs.input_ids
25
  example["protein_attention_mask"] = protein_inputs.attention_mask
26
 
@@ -33,7 +35,8 @@ def label_embedding(labels, text_tokenizer, text_model, device):
33
  with torch.inference_mode():
34
  for label in labels:
35
  label_input_ids = text_tokenizer.encode(label, max_length=128,
36
- truncation=True, add_special_tokens=False)
 
37
  label_input_ids = [text_tokenizer.cls_token_id] + label_input_ids
38
  label_input_ids = torch.tensor(label_input_ids, dtype=torch.long, device=device).unsqueeze(0)
39
  attention_mask = label_input_ids != text_tokenizer.pad_token_id
@@ -41,7 +44,8 @@ def label_embedding(labels, text_tokenizer, text_model, device):
41
 
42
  text_outputs = text_model(label_input_ids, attention_mask=attention_mask)
43
 
44
- label_feature.append(text_outputs["text_feature"])
 
45
  label_feature = torch.cat(label_feature, dim=0)
46
  label_feature = label_feature / label_feature.norm(dim=-1, keepdim=True)
47
 
@@ -82,10 +86,6 @@ if __name__ == "__main__":
82
 
83
  protst_model = AutoModel.from_pretrained("mila-intel/ProtST-esm1b", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
84
  protein_model = protst_model.protein_model
85
- + import intel_extension_for_pytorch as ipex
86
- + from optimum.intel.generation.modeling import jit_trace
87
- + protein_model = ipex.optimize(protein_model, dtype=torch.bfloat16, inplace=True)
88
- + protein_model = jit_trace(protein_model, "sequence-classification")
89
  text_model = protst_model.text_model
90
  logit_scale = protst_model.logit_scale
91
  logit_scale.requires_grad = False
@@ -108,4 +108,16 @@ if __name__ == "__main__":
108
  label_feature = label_embedding(labels, text_tokenizer, text_model, device)
109
  zero_shot_eval(logger, device, test_dataset, "localization",
110
  protein_model, logit_scale, label_feature)
 
 
 
 
 
 
 
 
 
 
 
 
111
  ```
 
5
  ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f0a673f0d40f6aae296b4a/o4F5-Cm-gGdHPpX5rPVKx.png)
6
 
7
  ## Example
8
+ The following script shows how to run ProtST with Gaudi and [optimum-intel](https://github.com/huggingface/optimum-intel) optimization on zero-shot classification task.
9
  ```diff
10
  import logging
11
  import functools
 
13
  import torch
14
  from datasets import load_dataset
15
  from transformers import AutoModel, AutoTokenizer, AutoConfig
16
+ + import habana_frameworks.torch
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
 
21
  def tokenize_protein(example, protein_tokenizer=None, padding=None):
22
  protein_seqs = example["prot_seq"]
23
+
24
+ - protein_inputs = protein_tokenizer(protein_seqs, padding=padding, add_special_tokens=True)
25
+ + protein_inputs = protein_tokenizer(protein_seqs, padding="max_length", truncation=True, add_special_tokens=True, max_length=1024)
26
  example["protein_input_ids"] = protein_inputs.input_ids
27
  example["protein_attention_mask"] = protein_inputs.attention_mask
28
 
 
35
  with torch.inference_mode():
36
  for label in labels:
37
  label_input_ids = text_tokenizer.encode(label, max_length=128,
38
+ - truncation=True, add_special_tokens=False)
39
+ + truncation=True, add_special_tokens=False, padding="max_length")
40
  label_input_ids = [text_tokenizer.cls_token_id] + label_input_ids
41
  label_input_ids = torch.tensor(label_input_ids, dtype=torch.long, device=device).unsqueeze(0)
42
  attention_mask = label_input_ids != text_tokenizer.pad_token_id
 
44
 
45
  text_outputs = text_model(label_input_ids, attention_mask=attention_mask)
46
 
47
+ - label_feature.append(text_outputs["text_feature"])
48
+ + label_feature.append(text_outputs["text_feature"].clone())
49
  label_feature = torch.cat(label_feature, dim=0)
50
  label_feature = label_feature / label_feature.norm(dim=-1, keepdim=True)
51
 
 
86
 
87
  protst_model = AutoModel.from_pretrained("mila-intel/ProtST-esm1b", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
88
  protein_model = protst_model.protein_model
 
 
 
 
89
  text_model = protst_model.text_model
90
  logit_scale = protst_model.logit_scale
91
  logit_scale.requires_grad = False
 
108
  label_feature = label_embedding(labels, text_tokenizer, text_model, device)
109
  zero_shot_eval(logger, device, test_dataset, "localization",
110
  protein_model, logit_scale, label_feature)
111
+ ```
112
+
113
+ Run ProtST on CPU with [optimum-intel](https://github.com/huggingface/optimum-intel) optimization.
114
+ ```diff
115
+ ...
116
+ protst_model = AutoModel.from_pretrained("mila-intel/ProtST-esm1b", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
117
+ protein_model = protst_model.protein_model
118
+ + import intel_extension_for_pytorch as ipex
119
+ + from optimum.intel.generation.modeling import jit_trace
120
+ + protein_model = ipex.optimize(protein_model, dtype=torch.bfloat16, inplace=True)
121
+ + protein_model = jit_trace(protein_model, "sequence-classification")
122
+ ...
123
  ```