Upload README.md with huggingface_hub
Browse files
README.md
CHANGED
|
@@ -78,14 +78,17 @@ To generate a novel sequence of a specific length. DSM uses a progressive denois
|
|
| 78 |
length = 100
|
| 79 |
mask_token = tokenizer.mask_token
|
| 80 |
# optionally, enforce starting with methionine
|
| 81 |
-
|
| 82 |
output = model.mask_diffusion_generate(
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
generated_sequences = model.decode_output(output)
|
| 91 |
print(f"Generated sequence: {generated_sequences[0]}")
|
|
@@ -101,15 +104,18 @@ To fill in masked regions of a template sequence:
|
|
| 101 |
```python
|
| 102 |
# Mask Filling / Inpainting
|
| 103 |
template_sequence = "MA<mask><mask><mask>KEG<mask><mask>STL"
|
| 104 |
-
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
generated_sequences = model.decode_output(output)
|
| 115 |
print(f"Generated sequence: {generated_sequences[0]}")
|
|
@@ -125,9 +131,8 @@ Generated sequence: MAVKFKEGGISTL
|
|
| 125 |
# model_binder = DSM_ppi.from_pretrained("GleghornLab/DSM_650_ppi_lora").to(device).eval()
|
| 126 |
# The lora version from the paper leads to unreliable outputs
|
| 127 |
# Synthyra has generously trained a version through full fine tuning
|
| 128 |
-
from models.modeling_dsm import DSM
|
| 129 |
|
| 130 |
-
|
| 131 |
|
| 132 |
# BBF-14
|
| 133 |
target_seq = "MGTPLWALLGGPWRGTATYEDGTKVTLDYRYTRVSPDRLRADVTYTTPDGTTLEATVDLWKDANGVIRYHATYPDGTSADGTLTQLDADTLLATGTYDDGTKYTVTLTRVAPGSGWHHHHHH"
|
|
@@ -140,10 +145,10 @@ combined_input_str = target_seq + '<eos>' + interactor_template
|
|
| 140 |
|
| 141 |
input_tokens = tokenizer.encode(combined_input_str, add_special_tokens=True, return_tensors='pt').to(device)
|
| 142 |
|
| 143 |
-
output =
|
| 144 |
tokenizer=tokenizer,
|
| 145 |
input_tokens=input_tokens,
|
| 146 |
-
step_divisor=
|
| 147 |
temperature=1.0, # sampling temperature
|
| 148 |
remasking="random", # strategy for remasking tokens not kept
|
| 149 |
preview=False, # set this to True to watch the mask tokens get rilled in real time
|
|
@@ -192,7 +197,7 @@ output = model.mask_diffusion_generate(
|
|
| 192 |
seqa, seqb = model.decode_dual_input(output, seperator='<eos>')
|
| 193 |
# Parse out the generated interactor part based on EOS tokens.
|
| 194 |
# Example: generated_full_seq_str.split(model_binder.tokenizer.eos_token)[1]
|
| 195 |
-
print(f"SeqA: {seqa[0][
|
| 196 |
print(f"SeqB: {seqb[0]}")
|
| 197 |
```
|
| 198 |
|
|
|
|
| 78 |
length = 100
|
| 79 |
mask_token = tokenizer.mask_token
|
| 80 |
# optionally, enforce starting with methionine
|
| 81 |
+
input_tokens = tokenizer.encode('M' + ''.join([mask_token] * (length - 1)), add_special_tokens=True, return_tensors='pt').to(device)
|
| 82 |
output = model.mask_diffusion_generate(
|
| 83 |
+
tokenizer=tokenizer,
|
| 84 |
+
input_tokens=input_tokens,
|
| 85 |
+
step_divisor=100, # lower is slower but better
|
| 86 |
+
temperature=1.0, # sampling temperature
|
| 87 |
+
remasking="random", # strategy for remasking tokens not kept
|
| 88 |
+
preview=False, # set this to True to watch the mask tokens get rilled in real time
|
| 89 |
+
slow=False, # adds a small delay to the real time filling (because it is usually very fast and watching carefully is hard!)
|
| 90 |
+
return_trajectory=False # set this to True to return the trajectory of the generation (what you watch in the preview)
|
| 91 |
+
) # Note: output will be a tuple if return_trajectory is True
|
| 92 |
|
| 93 |
generated_sequences = model.decode_output(output)
|
| 94 |
print(f"Generated sequence: {generated_sequences[0]}")
|
|
|
|
| 104 |
```python
|
| 105 |
# Mask Filling / Inpainting
|
| 106 |
template_sequence = "MA<mask><mask><mask>KEG<mask><mask>STL"
|
| 107 |
+
input_tokens = tokenizer.encode(template_sequence, add_special_tokens=True, return_tensors='pt').to(device)
|
| 108 |
|
| 109 |
+
output = model.mask_diffusion_generate(
|
| 110 |
+
tokenizer=tokenizer,
|
| 111 |
+
input_tokens=input_tokens,
|
| 112 |
+
step_divisor=100, # lower is slower but better
|
| 113 |
+
temperature=1.0, # sampling temperature
|
| 114 |
+
remasking="random", # strategy for remasking tokens not kept
|
| 115 |
+
preview=False, # set this to True to watch the mask tokens get rilled in real time
|
| 116 |
+
slow=False, # adds a small delay to the real time filling (because it is usually very fast and watching carefully is hard!)
|
| 117 |
+
return_trajectory=False # set this to True to return the trajectory of the generation (what you watch in the preview)
|
| 118 |
+
) # Note: output will be a tuple if return_trajectory is True
|
| 119 |
|
| 120 |
generated_sequences = model.decode_output(output)
|
| 121 |
print(f"Generated sequence: {generated_sequences[0]}")
|
|
|
|
| 131 |
# model_binder = DSM_ppi.from_pretrained("GleghornLab/DSM_650_ppi_lora").to(device).eval()
|
| 132 |
# The lora version from the paper leads to unreliable outputs
|
| 133 |
# Synthyra has generously trained a version through full fine tuning
|
|
|
|
| 134 |
|
| 135 |
+
model = DSM.from_pretrained("Synthyra/DSM_ppi_full").to(device).eval()
|
| 136 |
|
| 137 |
# BBF-14
|
| 138 |
target_seq = "MGTPLWALLGGPWRGTATYEDGTKVTLDYRYTRVSPDRLRADVTYTTPDGTTLEATVDLWKDANGVIRYHATYPDGTSADGTLTQLDADTLLATGTYDDGTKYTVTLTRVAPGSGWHHHHHH"
|
|
|
|
| 145 |
|
| 146 |
input_tokens = tokenizer.encode(combined_input_str, add_special_tokens=True, return_tensors='pt').to(device)
|
| 147 |
|
| 148 |
+
output = model.mask_diffusion_generate(
|
| 149 |
tokenizer=tokenizer,
|
| 150 |
input_tokens=input_tokens,
|
| 151 |
+
step_divisor=100, # lower is slower but better
|
| 152 |
temperature=1.0, # sampling temperature
|
| 153 |
remasking="random", # strategy for remasking tokens not kept
|
| 154 |
preview=False, # set this to True to watch the mask tokens get rilled in real time
|
|
|
|
| 197 |
seqa, seqb = model.decode_dual_input(output, seperator='<eos>')
|
| 198 |
# Parse out the generated interactor part based on EOS tokens.
|
| 199 |
# Example: generated_full_seq_str.split(model_binder.tokenizer.eos_token)[1]
|
| 200 |
+
print(f"SeqA: {seqa[0][5:]}") # remove cls token
|
| 201 |
print(f"SeqB: {seqb[0]}")
|
| 202 |
```
|
| 203 |
|