protgpt3 commited on
Commit
cc73d30
·
verified ·
1 Parent(s): 4145707

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +46 -7
README.md CHANGED
@@ -69,7 +69,9 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
69
 
70
  model_id = "protgpt3/ProtGPT3-1OB" # Replace with the final checkpoint name
71
 
72
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
 
 
73
  model = AutoModelForCausalLM.from_pretrained(
74
  model_id,
75
  torch_dtype=torch.bfloat16,
@@ -80,12 +82,12 @@ model = AutoModelForCausalLM.from_pretrained(
80
  model.eval()
81
  ```
82
 
83
- Generate a protein sequence:
84
 
85
  ```python
86
  import torch
87
 
88
- prompt = "" # Optionally provide an amino-acid prefix or model-specific direction token
89
 
90
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
91
 
@@ -97,17 +99,20 @@ with torch.no_grad():
97
  temperature=0.8,
98
  top_p=0.9,
99
  eos_token_id=tokenizer.eos_token_id,
100
- pad_token_id=tokenizer.eos_token_id,
101
  )
102
 
103
  sequence = tokenizer.decode(output_ids[0], skip_special_tokens=True)
104
- print(sequence)
105
  ```
106
 
107
- Generate from an amino-acid prefix:
108
 
109
  ```python
110
- prefix = "MKT"
 
 
 
111
 
112
  inputs = tokenizer(prefix, return_tensors="pt").to(model.device)
113
 
@@ -126,6 +131,40 @@ sequence = tokenizer.decode(output_ids[0], skip_special_tokens=True)
126
  print(sequence)
127
  ```
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  ## Training Details
130
 
131
  ### Training Data
 
69
 
70
  model_id = "protgpt3/ProtGPT3-1OB" # Replace with the final checkpoint name
71
 
72
+ # Load tokenizer for generation
73
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True,add_bos_token=True, add_eos_token=False)
74
+
75
  model = AutoModelForCausalLM.from_pretrained(
76
  model_id,
77
  torch_dtype=torch.bfloat16,
 
82
  model.eval()
83
  ```
84
 
85
+ ### Generate a protein sequence
86
 
87
  ```python
88
  import torch
89
 
90
+ prompt = "" # Optionally provide an amino-acid prefix or model-specific direction
91
 
92
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
93
 
 
99
  temperature=0.8,
100
  top_p=0.9,
101
  eos_token_id=tokenizer.eos_token_id,
102
+ pad_token_id=tokenizer.pad_token_id,
103
  )
104
 
105
  sequence = tokenizer.decode(output_ids[0], skip_special_tokens=True)
106
+ print(sequence) # output includes directional token "1" or "2" to denote if sequence was generated N-to-C or C-to-N
107
  ```
108
 
109
+ ### Generate from an amino-acid prefix
110
 
111
  ```python
112
+ import torch
113
+
114
+ # forward N-to-C generation with special token "1"
115
+ prefix = "1MKT" # use special token "2" instead of "1" for reverse C-to-N generation
116
 
117
  inputs = tokenizer(prefix, return_tensors="pt").to(model.device)
118
 
 
131
  print(sequence)
132
  ```
133
 
134
+ ### Batch generation
135
+
136
+ ```python
137
+ import torch
138
+
139
+ prompts = [
140
+ "",
141
+ "1MKT", # N-to-C generation
142
+ "2MAV", # C-to-N generation
143
+ ]
144
+
145
+ inputs = tokenizer(
146
+ prompts,
147
+ return_tensors="pt",
148
+ padding=True,
149
+ ).to(model.device)
150
+
151
+ with torch.no_grad():
152
+ output_ids = model.generate(
153
+ **inputs,
154
+ max_new_tokens=256,
155
+ do_sample=True,
156
+ temperature=0.8,
157
+ top_p=0.9,
158
+ eos_token_id=tokenizer.eos_token_id,
159
+ pad_token_id=tokenizer.bos_token_id,
160
+ )
161
+
162
+ sequences = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
163
+
164
+ for sequence in sequences:
165
+ print(sequence)
166
+ ```
167
+
168
  ## Training Details
169
 
170
  ### Training Data