protgpt3 commited on
Commit
376da1e
·
verified ·
1 Parent(s): 32fba5f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +45 -3
README.md CHANGED
@@ -76,10 +76,50 @@ Load the model and tokenizer:
76
  ```python
77
  import torch
78
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  model_id = "protgpt3/ProtGPT3-MSA" # Replace with the final checkpoint name
81
 
82
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
 
 
83
  model = AutoModelForCausalLM.from_pretrained(
84
  model_id,
85
  torch_dtype=torch.bfloat16,
@@ -97,6 +137,7 @@ Use the `<no_gap>` modality token for unaligned sequences. Separate homologous s
97
  ```python
98
  import torch
99
 
 
100
  homologs = [
101
  "MKTAYIAKQRQISFVKSHFSRQDILD",
102
  "MKTVYIAKQRQISFVKSHFSRQDILD",
@@ -104,7 +145,7 @@ homologs = [
104
  # Add up to 15 homologous protein sequences
105
  ]
106
 
107
- prompt = "<no_gap>" + "<s>".join(homologs) + "<s>"
108
 
109
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
110
 
@@ -130,13 +171,14 @@ Use the `<gap>` modality token for aligned sequences. Gap characters may be incl
130
  ```python
131
  import torch
132
 
 
133
  aligned_homologs = [
134
  "MKTAYIAKQRQI--SFVKSHFSRQDILD",
135
  "MKTVYIAKQRQI--SFVKSHFSRQDILD",
136
  "MKTAYIAKQRQINNSFVKSHFSRQNILD",
137
  ]
138
 
139
- prompt = "<gap>" + "<s>".join(aligned_homologs) + "<s>"
140
 
141
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
142
 
 
76
  ```python
77
  import torch
78
  from transformers import AutoTokenizer, AutoModelForCausalLM
79
+ import random
80
+ import re
81
+
82
+ # ---- Intialise useful methods to prompt ProtGPT3-MSA ----
83
+ def process_style(seq: str, gap: bool):
84
+ """Remove gaps, uppercase insertions, drop X."""
85
+ if gap:
86
+ # keep gaps
87
+ return re.sub(r"[X]", "", seq.upper())
88
+ else:
89
+ # remove gaps
90
+ return re.sub(r"[X]", "", seq.replace("-", "").upper())
91
+
92
+ def build_prompt(
93
+ sequences: List[str],
94
+ gap: bool = False,
95
+ ) -> str:
96
+ """Build prompt for ProtGPT3-MSA"""
97
+
98
+ random.shuffle(sequences)
99
+
100
+ direction = "1" # change this to "2" for reversed C-to-N generation
101
+
102
+ if gap:
103
+ gap_token = "<gap>"
104
+ assert all(len(s) == len(sequences[0]) for s in sequences), "Sequences in the prompt have different len(), but should be aligned, either align them or use no_gap mode"
105
+ else:
106
+ gap_token = "<no_gap>"
107
+
108
+ tokens: List[str] = ["<|bos|>", direction, gap_token]
109
+ for seq in sequences:
110
+ tokens.append("<s>")
111
+ tokens.extend(list(process_style(seq,gap=gap)))
112
+
113
+ # Match train-time separator before continuation
114
+ tokens.append("<s>")
115
+ return " ".join(tokens)
116
+ ## --------------------------------------
117
 
118
  model_id = "protgpt3/ProtGPT3-MSA" # Replace with the final checkpoint name
119
 
120
+ # Load tokenizer for generation
121
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True,add_bos_token=True, add_eos_token=False)
122
+
123
  model = AutoModelForCausalLM.from_pretrained(
124
  model_id,
125
  torch_dtype=torch.bfloat16,
 
137
  ```python
138
  import torch
139
 
140
+
141
  homologs = [
142
  "MKTAYIAKQRQISFVKSHFSRQDILD",
143
  "MKTVYIAKQRQISFVKSHFSRQDILD",
 
145
  # Add up to 15 homologous protein sequences
146
  ]
147
 
148
+ prompt = build_prompt(sequences=homologs)
149
 
150
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
151
 
 
171
  ```python
172
  import torch
173
 
174
+ # must have the same length and be aligned
175
  aligned_homologs = [
176
  "MKTAYIAKQRQI--SFVKSHFSRQDILD",
177
  "MKTVYIAKQRQI--SFVKSHFSRQDILD",
178
  "MKTAYIAKQRQINNSFVKSHFSRQNILD",
179
  ]
180
 
181
+ prompt = build_prompt(sequences=homologs, gap=True)
182
 
183
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
184