Jiqing commited on
Commit
e66b0e5
·
1 Parent(s): 3b48f52

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -22
README.md CHANGED
@@ -18,13 +18,9 @@ logger = logging.getLogger(__name__)
18
 
19
 
20
  def tokenize_protein(example, protein_tokenizer=None, padding=None):
21
- # check https://github.com/huggingface/transformers/blob/41aef33758ae166291d72bc381477f2db84159cf/src/transformers/models/esm/tokenization_esm.py#L100
22
  protein_seqs = example["prot_seq"]
23
 
24
- protein_inputs = protein_tokenizer(
25
- protein_seqs, padding=padding,
26
- add_special_tokens=True, # default is True, no need to add cls and eos manually
27
- ) # results in <cls> + seq + <eos> (no <sep> for ESM)
28
  example["protein_input_ids"] = protein_inputs.input_ids
29
  example["protein_attention_mask"] = protein_inputs.attention_mask
30
 
@@ -42,18 +38,7 @@ def label_embedding(labels, text_tokenizer, text_model, device):
42
  label_input_ids = torch.tensor(label_input_ids, dtype=torch.long, device=device).unsqueeze(0)
43
  attention_mask = label_input_ids != text_tokenizer.pad_token_id
44
 
45
- text_outputs = text_model(
46
- label_input_ids,
47
- attention_mask=attention_mask,
48
- position_ids=None,
49
- head_mask=None,
50
- inputs_embeds=None,
51
- encoder_hidden_states=None,
52
- encoder_attention_mask=None,
53
- output_attentions=None,
54
- output_hidden_states=None,
55
- return_dict=None,
56
- )
57
 
58
  label_feature.append(text_outputs["text_feature"])
59
  label_feature = torch.cat(label_feature, dim=0)
@@ -75,11 +60,7 @@ def zero_shot_eval(logger, device,
75
  protein_input_ids = torch.tensor(data["protein_input_ids"], dtype=torch.long, device=device).unsqueeze(0)
76
  attention_mask = torch.tensor(data["protein_attention_mask"], dtype=torch.long, device=device).unsqueeze(0)
77
 
78
- protein_outputs = protein_model(
79
- protein_input_ids,
80
- attention_mask=attention_mask,
81
- position_ids=None, # it's ok to set `position_ids`` as None: https://github.com/huggingface/transformers/blob/41aef33758ae166291d72bc381477f2db84159cf/src/transformers/models/esm/modeling_esm.py#L195
82
- )
83
 
84
  protein_feature = protein_outputs["protein_feature"]
85
  protein_feature = protein_feature / protein_feature.norm(dim=-1, keepdim=True)
 
18
 
19
 
20
  def tokenize_protein(example, protein_tokenizer=None, padding=None):
 
21
  protein_seqs = example["prot_seq"]
22
 
23
+ protein_inputs = protein_tokenizer(protein_seqs, padding=padding, add_special_tokens=True)
 
 
 
24
  example["protein_input_ids"] = protein_inputs.input_ids
25
  example["protein_attention_mask"] = protein_inputs.attention_mask
26
 
 
38
  label_input_ids = torch.tensor(label_input_ids, dtype=torch.long, device=device).unsqueeze(0)
39
  attention_mask = label_input_ids != text_tokenizer.pad_token_id
40
 
41
+ text_outputs = text_model(label_input_ids, attention_mask=attention_mask)
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  label_feature.append(text_outputs["text_feature"])
44
  label_feature = torch.cat(label_feature, dim=0)
 
60
  protein_input_ids = torch.tensor(data["protein_input_ids"], dtype=torch.long, device=device).unsqueeze(0)
61
  attention_mask = torch.tensor(data["protein_attention_mask"], dtype=torch.long, device=device).unsqueeze(0)
62
 
63
+ protein_outputs = protein_model(protein_input_ids, attention_mask=attention_mask)
 
 
 
 
64
 
65
  protein_feature = protein_outputs["protein_feature"]
66
  protein_feature = protein_feature / protein_feature.norm(dim=-1, keepdim=True)