Update README.md
Browse files
README.md
CHANGED
|
@@ -18,13 +18,9 @@ logger = logging.getLogger(__name__)
|
|
| 18 |
|
| 19 |
|
| 20 |
def tokenize_protein(example, protein_tokenizer=None, padding=None):
|
| 21 |
-
# check https://github.com/huggingface/transformers/blob/41aef33758ae166291d72bc381477f2db84159cf/src/transformers/models/esm/tokenization_esm.py#L100
|
| 22 |
protein_seqs = example["prot_seq"]
|
| 23 |
|
| 24 |
-
protein_inputs = protein_tokenizer(
|
| 25 |
-
protein_seqs, padding=padding,
|
| 26 |
-
add_special_tokens=True, # default is True, no need to add cls and eos manually
|
| 27 |
-
) # results in <cls> + seq + <eos> (no <sep> for ESM)
|
| 28 |
example["protein_input_ids"] = protein_inputs.input_ids
|
| 29 |
example["protein_attention_mask"] = protein_inputs.attention_mask
|
| 30 |
|
|
@@ -42,18 +38,7 @@ def label_embedding(labels, text_tokenizer, text_model, device):
|
|
| 42 |
label_input_ids = torch.tensor(label_input_ids, dtype=torch.long, device=device).unsqueeze(0)
|
| 43 |
attention_mask = label_input_ids != text_tokenizer.pad_token_id
|
| 44 |
|
| 45 |
-
text_outputs = text_model(
|
| 46 |
-
label_input_ids,
|
| 47 |
-
attention_mask=attention_mask,
|
| 48 |
-
position_ids=None,
|
| 49 |
-
head_mask=None,
|
| 50 |
-
inputs_embeds=None,
|
| 51 |
-
encoder_hidden_states=None,
|
| 52 |
-
encoder_attention_mask=None,
|
| 53 |
-
output_attentions=None,
|
| 54 |
-
output_hidden_states=None,
|
| 55 |
-
return_dict=None,
|
| 56 |
-
)
|
| 57 |
|
| 58 |
label_feature.append(text_outputs["text_feature"])
|
| 59 |
label_feature = torch.cat(label_feature, dim=0)
|
|
@@ -75,11 +60,7 @@ def zero_shot_eval(logger, device,
|
|
| 75 |
protein_input_ids = torch.tensor(data["protein_input_ids"], dtype=torch.long, device=device).unsqueeze(0)
|
| 76 |
attention_mask = torch.tensor(data["protein_attention_mask"], dtype=torch.long, device=device).unsqueeze(0)
|
| 77 |
|
| 78 |
-
protein_outputs = protein_model(
|
| 79 |
-
protein_input_ids,
|
| 80 |
-
attention_mask=attention_mask,
|
| 81 |
-
position_ids=None, # it's ok to set `position_ids`` as None: https://github.com/huggingface/transformers/blob/41aef33758ae166291d72bc381477f2db84159cf/src/transformers/models/esm/modeling_esm.py#L195
|
| 82 |
-
)
|
| 83 |
|
| 84 |
protein_feature = protein_outputs["protein_feature"]
|
| 85 |
protein_feature = protein_feature / protein_feature.norm(dim=-1, keepdim=True)
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
def tokenize_protein(example, protein_tokenizer=None, padding=None):
|
|
|
|
| 21 |
protein_seqs = example["prot_seq"]
|
| 22 |
|
| 23 |
+
protein_inputs = protein_tokenizer(protein_seqs, padding=padding, add_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
| 24 |
example["protein_input_ids"] = protein_inputs.input_ids
|
| 25 |
example["protein_attention_mask"] = protein_inputs.attention_mask
|
| 26 |
|
|
|
|
| 38 |
label_input_ids = torch.tensor(label_input_ids, dtype=torch.long, device=device).unsqueeze(0)
|
| 39 |
attention_mask = label_input_ids != text_tokenizer.pad_token_id
|
| 40 |
|
| 41 |
+
text_outputs = text_model(label_input_ids, attention_mask=attention_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
label_feature.append(text_outputs["text_feature"])
|
| 44 |
label_feature = torch.cat(label_feature, dim=0)
|
|
|
|
| 60 |
protein_input_ids = torch.tensor(data["protein_input_ids"], dtype=torch.long, device=device).unsqueeze(0)
|
| 61 |
attention_mask = torch.tensor(data["protein_attention_mask"], dtype=torch.long, device=device).unsqueeze(0)
|
| 62 |
|
| 63 |
+
protein_outputs = protein_model(protein_input_ids, attention_mask=attention_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
protein_feature = protein_outputs["protein_feature"]
|
| 66 |
protein_feature = protein_feature / protein_feature.norm(dim=-1, keepdim=True)
|