Update README.md
Browse files
README.md
CHANGED
|
@@ -5,7 +5,7 @@ Current protein language models (PLMs) learn protein representations mainly base
|
|
| 5 |

|
| 6 |
|
| 7 |
## Example
|
| 8 |
-
The following script shows how to run ProtST with [optimum-intel](https://github.com/huggingface/optimum-intel) optimization on zero-shot classification task.
|
| 9 |
```diff
|
| 10 |
import logging
|
| 11 |
import functools
|
|
@@ -13,14 +13,16 @@ from tqdm import tqdm
|
|
| 13 |
import torch
|
| 14 |
from datasets import load_dataset
|
| 15 |
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
|
|
|
| 16 |
|
| 17 |
logger = logging.getLogger(__name__)
|
| 18 |
|
| 19 |
|
| 20 |
def tokenize_protein(example, protein_tokenizer=None, padding=None):
|
| 21 |
protein_seqs = example["prot_seq"]
|
| 22 |
-
|
| 23 |
-
|
|
|
|
| 24 |
example["protein_input_ids"] = protein_inputs.input_ids
|
| 25 |
example["protein_attention_mask"] = protein_inputs.attention_mask
|
| 26 |
|
|
@@ -33,7 +35,8 @@ def label_embedding(labels, text_tokenizer, text_model, device):
|
|
| 33 |
with torch.inference_mode():
|
| 34 |
for label in labels:
|
| 35 |
label_input_ids = text_tokenizer.encode(label, max_length=128,
|
| 36 |
-
|
|
|
|
| 37 |
label_input_ids = [text_tokenizer.cls_token_id] + label_input_ids
|
| 38 |
label_input_ids = torch.tensor(label_input_ids, dtype=torch.long, device=device).unsqueeze(0)
|
| 39 |
attention_mask = label_input_ids != text_tokenizer.pad_token_id
|
|
@@ -41,7 +44,8 @@ def label_embedding(labels, text_tokenizer, text_model, device):
|
|
| 41 |
|
| 42 |
text_outputs = text_model(label_input_ids, attention_mask=attention_mask)
|
| 43 |
|
| 44 |
-
|
|
|
|
| 45 |
label_feature = torch.cat(label_feature, dim=0)
|
| 46 |
label_feature = label_feature / label_feature.norm(dim=-1, keepdim=True)
|
| 47 |
|
|
@@ -82,10 +86,6 @@ if __name__ == "__main__":
|
|
| 82 |
|
| 83 |
protst_model = AutoModel.from_pretrained("mila-intel/ProtST-esm1b", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
|
| 84 |
protein_model = protst_model.protein_model
|
| 85 |
-
+ import intel_extension_for_pytorch as ipex
|
| 86 |
-
+ from optimum.intel.generation.modeling import jit_trace
|
| 87 |
-
+ protein_model = ipex.optimize(protein_model, dtype=torch.bfloat16, inplace=True)
|
| 88 |
-
+ protein_model = jit_trace(protein_model, "sequence-classification")
|
| 89 |
text_model = protst_model.text_model
|
| 90 |
logit_scale = protst_model.logit_scale
|
| 91 |
logit_scale.requires_grad = False
|
|
@@ -108,4 +108,16 @@ if __name__ == "__main__":
|
|
| 108 |
label_feature = label_embedding(labels, text_tokenizer, text_model, device)
|
| 109 |
zero_shot_eval(logger, device, test_dataset, "localization",
|
| 110 |
protein_model, logit_scale, label_feature)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
```
|
|
|
|
| 5 |

|
| 6 |
|
| 7 |
## Example
|
| 8 |
+
The following script shows how to run ProtST with Gaudi and [optimum-intel](https://github.com/huggingface/optimum-intel) optimization on zero-shot classification task.
|
| 9 |
```diff
|
| 10 |
import logging
|
| 11 |
import functools
|
|
|
|
| 13 |
import torch
|
| 14 |
from datasets import load_dataset
|
| 15 |
from transformers import AutoModel, AutoTokenizer, AutoConfig
|
| 16 |
+
+ import habana_frameworks.torch
|
| 17 |
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
|
| 20 |
|
| 21 |
def tokenize_protein(example, protein_tokenizer=None, padding=None):
|
| 22 |
protein_seqs = example["prot_seq"]
|
| 23 |
+
|
| 24 |
+
- protein_inputs = protein_tokenizer(protein_seqs, padding=padding, add_special_tokens=True)
|
| 25 |
+
+ protein_inputs = protein_tokenizer(protein_seqs, padding="max_length", truncation=True, add_special_tokens=True, max_length=1024)
|
| 26 |
example["protein_input_ids"] = protein_inputs.input_ids
|
| 27 |
example["protein_attention_mask"] = protein_inputs.attention_mask
|
| 28 |
|
|
|
|
| 35 |
with torch.inference_mode():
|
| 36 |
for label in labels:
|
| 37 |
label_input_ids = text_tokenizer.encode(label, max_length=128,
|
| 38 |
+
- truncation=True, add_special_tokens=False)
|
| 39 |
+
+ truncation=True, add_special_tokens=False, padding="max_length")
|
| 40 |
label_input_ids = [text_tokenizer.cls_token_id] + label_input_ids
|
| 41 |
label_input_ids = torch.tensor(label_input_ids, dtype=torch.long, device=device).unsqueeze(0)
|
| 42 |
attention_mask = label_input_ids != text_tokenizer.pad_token_id
|
|
|
|
| 44 |
|
| 45 |
text_outputs = text_model(label_input_ids, attention_mask=attention_mask)
|
| 46 |
|
| 47 |
+
- label_feature.append(text_outputs["text_feature"])
|
| 48 |
+
+ label_feature.append(text_outputs["text_feature"].clone())
|
| 49 |
label_feature = torch.cat(label_feature, dim=0)
|
| 50 |
label_feature = label_feature / label_feature.norm(dim=-1, keepdim=True)
|
| 51 |
|
|
|
|
| 86 |
|
| 87 |
protst_model = AutoModel.from_pretrained("mila-intel/ProtST-esm1b", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
|
| 88 |
protein_model = protst_model.protein_model
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
text_model = protst_model.text_model
|
| 90 |
logit_scale = protst_model.logit_scale
|
| 91 |
logit_scale.requires_grad = False
|
|
|
|
| 108 |
label_feature = label_embedding(labels, text_tokenizer, text_model, device)
|
| 109 |
zero_shot_eval(logger, device, test_dataset, "localization",
|
| 110 |
protein_model, logit_scale, label_feature)
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
Run ProtST on CPU with [optimum-intel](https://github.com/huggingface/optimum-intel) optimization.
|
| 114 |
+
```diff
|
| 115 |
+
...
|
| 116 |
+
protst_model = AutoModel.from_pretrained("mila-intel/ProtST-esm1b", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
|
| 117 |
+
protein_model = protst_model.protein_model
|
| 118 |
+
+ import intel_extension_for_pytorch as ipex
|
| 119 |
+
+ from optimum.intel.generation.modeling import jit_trace
|
| 120 |
+
+ protein_model = ipex.optimize(protein_model, dtype=torch.bfloat16, inplace=True)
|
| 121 |
+
+ protein_model = jit_trace(protein_model, "sequence-classification")
|
| 122 |
+
...
|
| 123 |
```
|