hazemessam commited on
Commit
dbdea49
·
verified ·
1 Parent(s): 99a62fb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +96 -0
README.md CHANGED
@@ -1,3 +1,99 @@
1
  ---
2
  license: cc-by-nc-sa-4.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: cc-by-nc-sa-4.0
3
  ---
4
+
5
+ # Model Details:
6
+ Ankh3 is a protein language model that is jointly optimized on two objectives:
7
+ * Masked language modeling with multiple masking probabilities
8
+ * Protein sequence completion.
9
+
10
+ 1. Masked Language Modeling:
11
+ - The idea of this task is to intentionally 'corrupt' an input protein sequence by
12
+ masking a certain percentage (X%) of its individual tokens (amino acids),
13
+ and then train the model to reconstruct the original sequence.
14
+
15
+ - Example on a protein sequence before and after corruption:
16
+
17
+ Original protein sequence: MKAYVLINSRGP
18
+
19
+ This sequence will be masked/corrupted using sentinel tokens as shown below:
20
+ Sequence after corruption: M <extra_id_0> A Y <extra_id_1> L I <extra_id_2> S R G <extra_id_3>
21
+
22
+
23
+ The decoder learns to correspond each sentinel token to the actual amino acid that was masked.
24
+ In this example: <extra_id_0> K means that <extra_id_0> corresponds to the "K" amino acid and so on.
25
+
26
+ Decoder output: <extra_id_0> K <extra_id_1> V <extra_id_2> N <extra_id_3> P
27
+
28
+
29
+
30
+ 2. Protein Sequence Completion:
31
+ - The idea of this task is to cut the input sequence into
32
+ two segments, where the first segment is fed to the encoder
33
+ and the decoder is tasked to auto-regressively generate the
34
+ second segment conditioned on the first segment representation
35
+ outputted from the encoder.
36
+
37
+ - Example on protein sequence completion:
38
+
39
+ Original sequence: MKAYVLINSRGP
40
+
41
+ We will pass "MKAYVL" of it to the encoder, and the decoder is trained
42
+ that given the representation of the first part provided by the encoder,
43
+ it should output the second part which is: "INSRGP"
44
+
45
+
46
+
47
+ # How to use:
48
+
49
+ ## For Embedding Extraction:
50
+ ```python
51
+ from transformers import T5ForConditionalGeneration, T5Tokenizer, T5EncoderModel
52
+ import torch
53
+
54
+ # Random sequence from uniprot, most likely Ankh3 saw it during pre-training.
55
+ sequence = "MDTAYPREDTRAPTPSKAGAHTALTLGAPHPPPRDHLIWSVFSTLYLNLCCLGFLALAYSIKARDQKVVGDLEAARRFGSKAKCYNILAAMWTLVPPLLLLGLVVTGALHLARLAKDSAAFFSTKFDDADYD"
56
+
57
+ ckpt = "ElnaggarLab/ankh3-xl"
58
+
59
+ # Make sure that you must use `T5Tokenizer` not `AutoTokenizer`.
60
+ tokenizer = T5Tokenizer.from_pretrained(ckpt)
61
+
62
+ # To use the encoder representation using the NLU prefix:
63
+ encoder_model = T5EncoderModel.from_pretrained(ckpt).eval()
64
+
65
+
66
+ # For extracting embeddings, consider trying the '[S2S]' prefix.
67
+ # Since this prefix was specifically used to denote sequence completion
68
+ # during the model's pre-training, its use can sometimes
69
+ # lead to improved embedding quality.
70
+
71
+ nlu_sequence = "[NLU]" + sequence
72
+ encoded_nlu_sequence = tokenizer(nlu_sequence, add_special_tokens=True, return_tensors="pt", is_split_into_words=False)
73
+
74
+ with torch.no_grad():
75
+ embedding = encoder_model(**encoded_nlu_sequence)
76
+ ```
77
+
78
+ ## For Sequence Completion:
79
+ ```python
80
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
81
+ from transformers.generation import GenerationConfig
82
+ import torch
83
+
84
+ sequence = "MDTAYPREDTRAPTPSKAGAHTALTLGAPHPPPRDHLIWSVFSTLYLNLCCLGFLALAYSIKARDQKVVGDLEAARRFGSKAKCYNILAAMWTLVPPLLLLGLVVTGALHLARLAKDSAAFFSTKFDDADYD"
85
+
86
+ ckpt = "ElnaggarLab/ankh3-xl"
87
+ tokenizer = T5Tokenizer.from_pretrained(ckpt)
88
+ # To use the sequence to sequence task using the S2S prefix:
89
+ model = T5ForConditionalGeneration.from_pretrained(ckpt).eval()
90
+
91
+
92
+ half_length = int(len(sequence) * 0.5)
93
+ s2s_sequence = "[S2S]" + sequence[:half_length]
94
+ encoded_s2s_sequence = tokenizer(s2s_sequence, add_special_tokens=True, return_tensors="pt", is_split_into_words=False)
95
+ # + 1 to account for the start of sequence token.
96
+ gen_config = GenerationConfig(min_length=half_length + 1, max_length=half_length + 1, do_sample=False, num_beams=1)
97
+ generated_sequence = model.generate(encoded_s2s_sequence["input_ids"], gen_config, )
98
+ predicted_sequence = sequence[:half_length] + tokenizer.batch_decode(generated_sequence)[0]
99
+ ```