Update README.md
Browse files
README.md
CHANGED
|
@@ -104,8 +104,10 @@ model = AutoModelForSeq2SeqLM.from_pretrained("Rostlab/ProstT5").to(device)
|
|
| 104 |
# only GPUs support half-precision currently; if you want to run on CPU use full-precision (not recommended, much slower)
|
| 105 |
model.full() if device=='cpu' else model.half()
|
| 106 |
|
| 107 |
-
# prepare your protein sequences/structures as a list.
|
| 108 |
-
|
|
|
|
|
|
|
| 109 |
min_len = min([ len(s) for s in folding_example])
|
| 110 |
max_len = max([ len(s) for s in folding_example])
|
| 111 |
|
|
@@ -116,9 +118,12 @@ sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequen
|
|
| 116 |
sequence_examples = [ "<AA2fold>" + " " + s for s in sequence_examples]
|
| 117 |
|
| 118 |
# tokenize sequences and pad up to the longest sequence in the batch
|
| 119 |
-
ids = tokenizer.batch_encode_plus(sequences_example,
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
-
# Generation configuration
|
| 122 |
gen_kwargs_aa2fold = {
|
| 123 |
"do_sample": True,
|
| 124 |
"num_beams": 3,
|
|
@@ -128,11 +133,11 @@ gen_kwargs_aa2fold = {
|
|
| 128 |
"repetition_penalty" : 1.2,
|
| 129 |
}
|
| 130 |
|
| 131 |
-
# translate from AA to 3Di
|
| 132 |
with torch.no_grad():
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
attention_mask=
|
| 136 |
max_length=max_len, # max length of generated text
|
| 137 |
min_length=min_len, # minimum length of the generated text
|
| 138 |
early_stopping=True, # stop early if end-of-text token is generated
|
|
@@ -140,40 +145,22 @@ with torch.no_grad():
|
|
| 140 |
**gen_kwargs_aa2fold
|
| 141 |
)
|
| 142 |
# Decode and remove white-spaces between tokens
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
```
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
from transformers import T5Tokenizer, AutoModelForSeq2SeqLM
|
| 150 |
-
import torch
|
| 151 |
-
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 152 |
|
| 153 |
-
#
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
# Load the model
|
| 157 |
-
model = AutoModelForSeq2SeqLM.from_pretrained("Rostlab/ProstT5").to(device)
|
| 158 |
-
|
| 159 |
-
# only GPUs support half-precision currently; if you want to run on CPU use full-precision (not recommended, much slower)
|
| 160 |
-
model.full() if device=='cpu' else model.half()
|
| 161 |
-
|
| 162 |
-
# prepare your protein sequences/structures as a list. Amino acid sequences are expected to be upper-case ("PRTEINO" below)
|
| 163 |
-
folding_example = ["prtein", "strctr"]
|
| 164 |
-
min_len = min([ len(s) for s in folding_example])
|
| 165 |
-
max_len = max([ len(s) for s in folding_example])
|
| 166 |
-
|
| 167 |
-
# replace all rare/ambiguous amino acids by X (3Di sequences does not have those) and introduce white-space between all sequences (AAs and 3Di)
|
| 168 |
-
sequence_examples = [" ".join(list(sequence)) for sequence in sequence_examples]
|
| 169 |
-
|
| 170 |
-
# add pre-fixes accordingly. For the translation from 3Di to AAs, you need to prepend "<fold2AA>"
|
| 171 |
-
sequence_examples = [ "<fold2AA>" + " " + s for s in sequence_examples]
|
| 172 |
|
| 173 |
# tokenize sequences and pad up to the longest sequence in the batch
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
-
#
|
| 177 |
gen_kwargs_fold2AA = {
|
| 178 |
"do_sample": True,
|
| 179 |
"top_p" : 0.90,
|
|
@@ -182,11 +169,11 @@ gen_kwargs_fold2AA = {
|
|
| 182 |
"repetition_penalty" : 1.2,
|
| 183 |
}
|
| 184 |
|
| 185 |
-
# translate from 3Di to AA
|
| 186 |
with torch.no_grad():
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
attention_mask=
|
| 190 |
max_length=max_len, # max length of generated text
|
| 191 |
min_length=min_len, # minimum length of the generated text
|
| 192 |
early_stopping=True, # stop early if end-of-text token is generated
|
|
@@ -194,8 +181,9 @@ with torch.no_grad():
|
|
| 194 |
**gen_kwargs_fold2AA
|
| 195 |
)
|
| 196 |
# Decode and remove white-spaces between tokens
|
| 197 |
-
|
| 198 |
-
|
|
|
|
| 199 |
```
|
| 200 |
|
| 201 |
|
|
|
|
| 104 |
# only GPUs support half-precision currently; if you want to run on CPU use full-precision (not recommended, much slower)
|
| 105 |
model.full() if device=='cpu' else model.half()
|
| 106 |
|
| 107 |
+
# prepare your protein sequences/structures as a list.
|
| 108 |
+
# Amino acid sequences are expected to be upper-case ("PRTEINO" below)
|
| 109 |
+
# while 3Di-sequences need to be lower-case.
|
| 110 |
+
sequence_examples = ["PRTEINO", "SEQWENCE"]
|
| 111 |
min_len = min([ len(s) for s in folding_example])
|
| 112 |
max_len = max([ len(s) for s in folding_example])
|
| 113 |
|
|
|
|
| 118 |
sequence_examples = [ "<AA2fold>" + " " + s for s in sequence_examples]
|
| 119 |
|
| 120 |
# tokenize sequences and pad up to the longest sequence in the batch
|
| 121 |
+
ids = tokenizer.batch_encode_plus(sequences_example,
|
| 122 |
+
add_special_tokens=True,
|
| 123 |
+
padding="longest",
|
| 124 |
+
return_tensors='pt').to(device))
|
| 125 |
|
| 126 |
+
# Generation configuration for "folding" (AA-->3Di)
|
| 127 |
gen_kwargs_aa2fold = {
|
| 128 |
"do_sample": True,
|
| 129 |
"num_beams": 3,
|
|
|
|
| 133 |
"repetition_penalty" : 1.2,
|
| 134 |
}
|
| 135 |
|
| 136 |
+
# translate from AA to 3Di (AA-->3Di)
|
| 137 |
with torch.no_grad():
|
| 138 |
+
translations = model.generate(
|
| 139 |
+
ids.input_ids,
|
| 140 |
+
attention_mask=ids.attention_mask,
|
| 141 |
max_length=max_len, # max length of generated text
|
| 142 |
min_length=min_len, # minimum length of the generated text
|
| 143 |
early_stopping=True, # stop early if end-of-text token is generated
|
|
|
|
| 145 |
**gen_kwargs_aa2fold
|
| 146 |
)
|
| 147 |
# Decode and remove white-spaces between tokens
|
| 148 |
+
decoded_translations = tokenizer.batch_decode( translations, skip_special_tokens=True )
|
| 149 |
+
structure_sequences = [ "".join(ts.split(" ")) for ts in decoded_translations ] # predicted 3Di strings
|
|
|
|
| 150 |
|
| 151 |
+
# Now we can use the same model and invert the translation logic
|
| 152 |
+
# to generate an amino acid sequence from the predicted 3Di-sequence (3Di-->AA)
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
+
# add pre-fixes accordingly. For the translation from 3Di to AA (3Di-->AA), you need to prepend "<fold2AA>"
|
| 155 |
+
sequence_examples_backtranslation = [ "<fold2AA>" + " " + s for s in decoded_translations]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
# tokenize sequences and pad up to the longest sequence in the batch
|
| 158 |
+
ids_backtranslation = tokenizer.batch_encode_plus(sequence_examples_backtranslation,
|
| 159 |
+
add_special_tokens=True,
|
| 160 |
+
padding="longest",
|
| 161 |
+
return_tensors='pt').to(device))
|
| 162 |
|
| 163 |
+
# Example generation configuration for "inverse folding" (3Di-->AA)
|
| 164 |
gen_kwargs_fold2AA = {
|
| 165 |
"do_sample": True,
|
| 166 |
"top_p" : 0.90,
|
|
|
|
| 169 |
"repetition_penalty" : 1.2,
|
| 170 |
}
|
| 171 |
|
| 172 |
+
# translate from 3Di to AA (3Di-->AA)
|
| 173 |
with torch.no_grad():
|
| 174 |
+
backtranslations = model.generate(
|
| 175 |
+
ids_backtranslation.input_ids,
|
| 176 |
+
attention_mask=ids_backtranslation.attention_mask,
|
| 177 |
max_length=max_len, # max length of generated text
|
| 178 |
min_length=min_len, # minimum length of the generated text
|
| 179 |
early_stopping=True, # stop early if end-of-text token is generated
|
|
|
|
| 181 |
**gen_kwargs_fold2AA
|
| 182 |
)
|
| 183 |
# Decode and remove white-spaces between tokens
|
| 184 |
+
decoded_backtranslations = tokenizer.batch_decode( backtranslations, skip_special_tokens=True )
|
| 185 |
+
aminoAcid_sequences = [ "".join(ts.split(" ")) for ts in decoded_backtranslations ] # predicted amino acid strings
|
| 186 |
+
|
| 187 |
```
|
| 188 |
|
| 189 |
|