Tong Chen
commited on
Commit
·
d2693e0
1
Parent(s):
f9d1b81
add files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- peptide/ckpt/PepReDi_base.pt +3 -0
- peptide/ckpt/PepReDi_v1.pt +3 -0
- peptide/ckpt/PepReDi_v2.pt +3 -0
- peptide/ckpt/PepReDi_v3.pt +3 -0
- peptide/classifier_ckpt/best_model_half_life.pth +3 -0
- peptide/classifier_ckpt/best_model_hemolysis.json +0 -0
- peptide/classifier_ckpt/best_model_nonfouling.json +0 -0
- peptide/classifier_ckpt/best_model_solubility.json +0 -0
- peptide/classifier_ckpt/binding_affinity_pooled.pt +3 -0
- peptide/classifier_ckpt/binding_affinity_unpooled.pt +3 -0
- peptide/data/test/data-00000-of-00001.arrow +3 -0
- peptide/data/test/dataset_info.json +15 -0
- peptide/data/test/state.json +13 -0
- peptide/data/train/data-00000-of-00001.arrow +3 -0
- peptide/data/train/dataset_info.json +15 -0
- peptide/data/train/state.json +13 -0
- peptide/data/val/data-00000-of-00001.arrow +3 -0
- peptide/data/val/dataset_info.json +15 -0
- peptide/data/val/state.json +13 -0
- peptide/generation.py +213 -0
- peptide/moo.py +284 -0
- peptide/new_coupling.py +226 -0
- peptide/peptide_classifiers.py +568 -0
- peptide/rectified_datasets/v1/dataset_dict.json +1 -0
- peptide/rectified_datasets/v1/test/data-00000-of-00001.arrow +3 -0
- peptide/rectified_datasets/v1/test/dataset_info.json +28 -0
- peptide/rectified_datasets/v1/test/state.json +13 -0
- peptide/rectified_datasets/v1/train/data-00000-of-00001.arrow +3 -0
- peptide/rectified_datasets/v1/train/dataset_info.json +28 -0
- peptide/rectified_datasets/v1/train/state.json +13 -0
- peptide/rectified_datasets/v1/validation/data-00000-of-00001.arrow +3 -0
- peptide/rectified_datasets/v1/validation/dataset_info.json +28 -0
- peptide/rectified_datasets/v1/validation/state.json +13 -0
- peptide/rectified_datasets/v2/dataset_dict.json +1 -0
- peptide/rectified_datasets/v2/test/data-00000-of-00001.arrow +3 -0
- peptide/rectified_datasets/v2/test/dataset_info.json +28 -0
- peptide/rectified_datasets/v2/test/state.json +13 -0
- peptide/rectified_datasets/v2/train/data-00000-of-00001.arrow +3 -0
- peptide/rectified_datasets/v2/train/dataset_info.json +28 -0
- peptide/rectified_datasets/v2/train/state.json +13 -0
- peptide/rectified_datasets/v2/validation/data-00000-of-00001.arrow +3 -0
- peptide/rectified_datasets/v2/validation/dataset_info.json +28 -0
- peptide/rectified_datasets/v2/validation/state.json +13 -0
- peptide/rectified_datasets/v3/dataset_dict.json +1 -0
- peptide/rectified_datasets/v3/test/data-00000-of-00001.arrow +3 -0
- peptide/rectified_datasets/v3/test/dataset_info.json +28 -0
- peptide/rectified_datasets/v3/test/state.json +13 -0
- peptide/rectified_datasets/v3/train/data-00000-of-00001.arrow +3 -0
- peptide/rectified_datasets/v3/train/dataset_info.json +28 -0
- peptide/rectified_datasets/v3/train/state.json +13 -0
peptide/ckpt/PepReDi_base.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3b6437edee514bd7adb9aacea776f3dd97c59ebc7b4928b390eafdd87eaeb8c9
|
| 3 |
+
size 344474053
|
peptide/ckpt/PepReDi_v1.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b72e98ce0752e6c4805179b37f09c2e3824d22db0113e165988ee4017d6f3d39
|
| 3 |
+
size 344457840
|
peptide/ckpt/PepReDi_v2.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6de709e0c6ce9356d37dfc9b6238249c129ee2fd30ad6d03da3b35d5f5a5f8ad
|
| 3 |
+
size 344457840
|
peptide/ckpt/PepReDi_v3.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7456b4d156e922a4a499573df4a0cf228315c5540bc67b842af274f605561f79
|
| 3 |
+
size 344457840
|
peptide/classifier_ckpt/best_model_half_life.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f80f1b20e90ba30503804c738aad4b3bb253424ff2e6e8a86c8e13a2fa1669f9
|
| 3 |
+
size 2623795199
|
peptide/classifier_ckpt/best_model_hemolysis.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
peptide/classifier_ckpt/best_model_nonfouling.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
peptide/classifier_ckpt/best_model_solubility.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
peptide/classifier_ckpt/binding_affinity_pooled.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:91f60e417dfa64277e433b5bc841060d295b43f2d9c19b277b954ce447b44949
|
| 3 |
+
size 211324073
|
peptide/classifier_ckpt/binding_affinity_unpooled.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fc28ae9f09b981b07547a773ca2e07f241cb08b3b8aa901e66627ff153f3aa8b
|
| 3 |
+
size 2731670995
|
peptide/data/test/data-00000-of-00001.arrow
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4f3a4dde7bb38d2ae4aae44265f1beef5df22f36d21f17e6813641e090cc679c
|
| 3 |
+
size 82440
|
peptide/data/test/dataset_info.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"citation": "",
|
| 3 |
+
"description": "",
|
| 4 |
+
"features": {
|
| 5 |
+
"input_ids": {
|
| 6 |
+
"feature": {
|
| 7 |
+
"dtype": "int32",
|
| 8 |
+
"_type": "Value"
|
| 9 |
+
},
|
| 10 |
+
"_type": "List"
|
| 11 |
+
}
|
| 12 |
+
},
|
| 13 |
+
"homepage": "",
|
| 14 |
+
"license": ""
|
| 15 |
+
}
|
peptide/data/test/state.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_data_files": [
|
| 3 |
+
{
|
| 4 |
+
"filename": "data-00000-of-00001.arrow"
|
| 5 |
+
}
|
| 6 |
+
],
|
| 7 |
+
"_fingerprint": "ae4d0541bd157aeb",
|
| 8 |
+
"_format_columns": null,
|
| 9 |
+
"_format_kwargs": {},
|
| 10 |
+
"_format_type": null,
|
| 11 |
+
"_output_all_columns": false,
|
| 12 |
+
"_split": null
|
| 13 |
+
}
|
peptide/data/train/data-00000-of-00001.arrow
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:646801034bcc9683b219f5d3195c4f4ce6551c5ecbc1b1bcbdfd8a7027d8a49e
|
| 3 |
+
size 641784
|
peptide/data/train/dataset_info.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"citation": "",
|
| 3 |
+
"description": "",
|
| 4 |
+
"features": {
|
| 5 |
+
"input_ids": {
|
| 6 |
+
"feature": {
|
| 7 |
+
"dtype": "int32",
|
| 8 |
+
"_type": "Value"
|
| 9 |
+
},
|
| 10 |
+
"_type": "List"
|
| 11 |
+
}
|
| 12 |
+
},
|
| 13 |
+
"homepage": "",
|
| 14 |
+
"license": ""
|
| 15 |
+
}
|
peptide/data/train/state.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_data_files": [
|
| 3 |
+
{
|
| 4 |
+
"filename": "data-00000-of-00001.arrow"
|
| 5 |
+
}
|
| 6 |
+
],
|
| 7 |
+
"_fingerprint": "7b63856b107c2d5c",
|
| 8 |
+
"_format_columns": null,
|
| 9 |
+
"_format_kwargs": {},
|
| 10 |
+
"_format_type": null,
|
| 11 |
+
"_output_all_columns": false,
|
| 12 |
+
"_split": null
|
| 13 |
+
}
|
peptide/data/val/data-00000-of-00001.arrow
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4f3a4dde7bb38d2ae4aae44265f1beef5df22f36d21f17e6813641e090cc679c
|
| 3 |
+
size 82440
|
peptide/data/val/dataset_info.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"citation": "",
|
| 3 |
+
"description": "",
|
| 4 |
+
"features": {
|
| 5 |
+
"input_ids": {
|
| 6 |
+
"feature": {
|
| 7 |
+
"dtype": "int32",
|
| 8 |
+
"_type": "Value"
|
| 9 |
+
},
|
| 10 |
+
"_type": "List"
|
| 11 |
+
}
|
| 12 |
+
},
|
| 13 |
+
"homepage": "",
|
| 14 |
+
"license": ""
|
| 15 |
+
}
|
peptide/data/val/state.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_data_files": [
|
| 3 |
+
{
|
| 4 |
+
"filename": "data-00000-of-00001.arrow"
|
| 5 |
+
}
|
| 6 |
+
],
|
| 7 |
+
"_fingerprint": "ae4d0541bd157aeb",
|
| 8 |
+
"_format_columns": null,
|
| 9 |
+
"_format_kwargs": {},
|
| 10 |
+
"_format_type": null,
|
| 11 |
+
"_output_all_columns": false,
|
| 12 |
+
"_split": null
|
| 13 |
+
}
|
peptide/generation.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from transformers import AutoTokenizer
|
| 7 |
+
|
| 8 |
+
# --- Model Architecture ---
|
| 9 |
+
def modulate(x, shift, scale):
|
| 10 |
+
"""
|
| 11 |
+
Modulates the input tensor x with a shift and scale.
|
| 12 |
+
"""
|
| 13 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 14 |
+
|
| 15 |
+
class TimestepEmbedder(nn.Module):
|
| 16 |
+
"""
|
| 17 |
+
Embeds a continuous scalar timestep t in [0, 1] into a vector representation.
|
| 18 |
+
"""
|
| 19 |
+
def __init__(self, hidden_size):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.mlp = nn.Sequential(
|
| 22 |
+
nn.Linear(1, hidden_size, bias=True),
|
| 23 |
+
nn.SiLU(),
|
| 24 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def forward(self, t):
|
| 28 |
+
# t is shape (batch_size,), needs to be (batch_size, 1) for the Linear layer.
|
| 29 |
+
return self.mlp(t.unsqueeze(-1))
|
| 30 |
+
|
| 31 |
+
class DiTBlock(nn.Module):
|
| 32 |
+
"""
|
| 33 |
+
A single block of the Diffusion Transformer.
|
| 34 |
+
"""
|
| 35 |
+
def __init__(self, hidden_size, n_heads):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 38 |
+
self.attn = nn.MultiheadAttention(hidden_size, n_heads, batch_first=True)
|
| 39 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 40 |
+
self.mlp = nn.Sequential(
|
| 41 |
+
nn.Linear(hidden_size, 4 * hidden_size),
|
| 42 |
+
nn.GELU(),
|
| 43 |
+
nn.Linear(4 * hidden_size, hidden_size)
|
| 44 |
+
)
|
| 45 |
+
self.adaLN_modulation = nn.Sequential(
|
| 46 |
+
nn.SiLU(),
|
| 47 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def forward(self, x, c):
|
| 51 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
| 52 |
+
x_norm1 = modulate(self.norm1(x), shift_msa, scale_msa)
|
| 53 |
+
attn_output, _ = self.attn(x_norm1, x_norm1, x_norm1)
|
| 54 |
+
x = x + gate_msa.unsqueeze(1) * attn_output
|
| 55 |
+
x_norm2 = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
| 56 |
+
mlp_output = self.mlp(x_norm2)
|
| 57 |
+
x = x + gate_mlp.unsqueeze(1) * mlp_output
|
| 58 |
+
return x
|
| 59 |
+
|
| 60 |
+
class MDLM(nn.Module):
|
| 61 |
+
"""
|
| 62 |
+
Masked Diffusion Language Model (MDLM) using a DiT backbone.
|
| 63 |
+
"""
|
| 64 |
+
def __init__(self, vocab_size, seq_len, model_dim, n_heads, n_layers):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.vocab_size = vocab_size
|
| 67 |
+
self.seq_len = seq_len
|
| 68 |
+
self.model_dim = model_dim
|
| 69 |
+
self.mask_token_id = vocab_size # Use vocab_size as the ID for the mask token
|
| 70 |
+
|
| 71 |
+
self.token_embedder = nn.Embedding(vocab_size + 1, model_dim) # +1 for the mask token
|
| 72 |
+
self.pos_embedder = nn.Parameter(torch.randn(1, seq_len, model_dim))
|
| 73 |
+
self.time_embedder = TimestepEmbedder(model_dim)
|
| 74 |
+
|
| 75 |
+
self.transformer_blocks = nn.ModuleList([
|
| 76 |
+
DiTBlock(model_dim, n_heads) for _ in range(n_layers)
|
| 77 |
+
])
|
| 78 |
+
|
| 79 |
+
self.final_norm = nn.LayerNorm(model_dim)
|
| 80 |
+
self.lm_head = nn.Linear(model_dim, vocab_size)
|
| 81 |
+
|
| 82 |
+
def forward(self, x, t):
|
| 83 |
+
seq_len = x.shape[1]
|
| 84 |
+
x_embed = self.token_embedder(x) + self.pos_embedder[:, :seq_len, :]
|
| 85 |
+
t_embed = self.time_embedder(t)
|
| 86 |
+
for block in self.transformer_blocks:
|
| 87 |
+
x_embed = block(x_embed, t_embed)
|
| 88 |
+
x_embed = self.final_norm(x_embed)
|
| 89 |
+
logits = self.lm_head(x_embed)
|
| 90 |
+
return logits
|
| 91 |
+
|
| 92 |
+
# --- Generation Function ---
|
| 93 |
+
|
| 94 |
+
def generate_samples(model, device, num_samples, seq_len, steps, temperature):
|
| 95 |
+
"""
|
| 96 |
+
Generates samples by starting from a random sequence and progressively refining it.
|
| 97 |
+
"""
|
| 98 |
+
model.eval()
|
| 99 |
+
|
| 100 |
+
# Start with a completely random sequence of tokens
|
| 101 |
+
shape = (num_samples, seq_len)
|
| 102 |
+
x = torch.randint(0, model.vocab_size, shape, dtype=torch.long, device=device)
|
| 103 |
+
|
| 104 |
+
# Cosine schedule determines how many tokens we *keep* from the previous step.
|
| 105 |
+
# It goes from 0 (keep none) to seq_len (keep all).
|
| 106 |
+
keep_schedule = torch.cos(torch.linspace(math.pi / 2, 0, steps, device=device)) * seq_len
|
| 107 |
+
keep_schedule = torch.round(keep_schedule).long()
|
| 108 |
+
|
| 109 |
+
with torch.no_grad():
|
| 110 |
+
progress_bar = tqdm(range(steps), desc="Generating Samples")
|
| 111 |
+
for i in progress_bar:
|
| 112 |
+
# Time `t` should go from 0 (pure noise) up to 1 (pure data)
|
| 113 |
+
t_continuous = torch.full((num_samples,), (i) / steps, device=device)
|
| 114 |
+
|
| 115 |
+
logits = model(x, t_continuous)
|
| 116 |
+
|
| 117 |
+
# Apply temperature scaling to control diversity
|
| 118 |
+
scaled_logits = logits / temperature
|
| 119 |
+
probs = torch.nn.functional.softmax(scaled_logits, dim=-1)
|
| 120 |
+
|
| 121 |
+
# Sample a full new sequence from the model's prediction
|
| 122 |
+
sampled_tokens = torch.multinomial(probs.view(-1, model.vocab_size), 1).view(shape)
|
| 123 |
+
|
| 124 |
+
# For the last step, the new sample is our final result
|
| 125 |
+
if i == steps - 1:
|
| 126 |
+
x = sampled_tokens
|
| 127 |
+
break
|
| 128 |
+
|
| 129 |
+
# Determine which tokens from the *newly sampled sequence* to keep, based on confidence
|
| 130 |
+
confidence = torch.gather(probs, 2, sampled_tokens.unsqueeze(-1)).squeeze(-1)
|
| 131 |
+
|
| 132 |
+
# Find the indices of the most confident tokens to keep
|
| 133 |
+
num_to_keep = keep_schedule[i]
|
| 134 |
+
_, indices_to_keep = torch.topk(confidence, num_to_keep, largest=True, dim=-1)
|
| 135 |
+
|
| 136 |
+
# Create a mask for the tokens we are keeping
|
| 137 |
+
keep_mask = torch.zeros_like(x, dtype=torch.bool).scatter_(1, indices_to_keep, True)
|
| 138 |
+
|
| 139 |
+
# The next sequence `x` is a mix:
|
| 140 |
+
# - Where keep_mask is True, we use the new, confident sampled_tokens.
|
| 141 |
+
# - Where keep_mask is False, we keep the tokens from the previous step `x`.
|
| 142 |
+
x = torch.where(keep_mask, sampled_tokens, x)
|
| 143 |
+
|
| 144 |
+
return x
|
| 145 |
+
|
| 146 |
+
# --- Main Execution ---
|
| 147 |
+
|
| 148 |
+
def main(args):
|
| 149 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 150 |
+
print(f"Using device: {device}")
|
| 151 |
+
|
| 152 |
+
print(f"Loading checkpoint from {args.checkpoint}...")
|
| 153 |
+
try:
|
| 154 |
+
checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False)
|
| 155 |
+
model_args = checkpoint['args']
|
| 156 |
+
except FileNotFoundError:
|
| 157 |
+
print(f"Error: Checkpoint file not found at {args.checkpoint}")
|
| 158 |
+
return
|
| 159 |
+
except Exception as e:
|
| 160 |
+
print(f"Error loading checkpoint: {e}")
|
| 161 |
+
return
|
| 162 |
+
|
| 163 |
+
print("Initializing model...")
|
| 164 |
+
model = MDLM(
|
| 165 |
+
vocab_size=model_args.vocab_size,
|
| 166 |
+
seq_len=model_args.seq_len,
|
| 167 |
+
model_dim=model_args.model_dim,
|
| 168 |
+
n_heads=model_args.n_heads,
|
| 169 |
+
n_layers=model_args.n_layers
|
| 170 |
+
).to(device)
|
| 171 |
+
|
| 172 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 173 |
+
print("Model loaded successfully.")
|
| 174 |
+
|
| 175 |
+
gen_len = args.gen_len if args.gen_len is not None else model_args.seq_len
|
| 176 |
+
if gen_len > model_args.seq_len:
|
| 177 |
+
raise ValueError(f"Requested generation length ({gen_len}) is greater than the model's max length ({model_args.seq_len}).")
|
| 178 |
+
print(f"Generating sequences of length {gen_len}.")
|
| 179 |
+
|
| 180 |
+
generated_tokens = generate_samples(
|
| 181 |
+
model=model,
|
| 182 |
+
device=device,
|
| 183 |
+
num_samples=args.num_samples,
|
| 184 |
+
seq_len=gen_len,
|
| 185 |
+
steps=args.gen_steps,
|
| 186 |
+
temperature=args.temperature
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
print("Decoding and saving samples...")
|
| 190 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 191 |
+
|
| 192 |
+
with open(args.output_file, 'w') as f:
|
| 193 |
+
for sample_tokens in generated_tokens:
|
| 194 |
+
sequence = tokenizer.decode(sample_tokens.tolist(), skip_special_tokens=False)
|
| 195 |
+
clean_sequence = sequence.replace(" ", "")[5:-5]
|
| 196 |
+
f.write(clean_sequence + "\n")
|
| 197 |
+
print(clean_sequence)
|
| 198 |
+
|
| 199 |
+
print(f"Generation complete. {args.num_samples} sequences saved to {args.output_file}")
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
if __name__ == "__main__":
|
| 203 |
+
parser = argparse.ArgumentParser(description="Generate samples from a trained ReDi (MDLM) model starting from random noise.")
|
| 204 |
+
|
| 205 |
+
parser.add_argument("--checkpoint", type=str, required=True, help="Path to the model checkpoint file.")
|
| 206 |
+
parser.add_argument("--num_samples", type=int, default=128, help="Number of samples to generate.")
|
| 207 |
+
parser.add_argument("--output_file", type=str, default="./generated_peptides.txt", help="File to save the generated peptide sequences.")
|
| 208 |
+
parser.add_argument("--gen_steps", type=int, default=16, help="Number of steps for the progressive refinement process.")
|
| 209 |
+
parser.add_argument("--gen_len", type=int, default=None, help="Desired length of the generated sequences. Defaults to the model's maximum trained length.")
|
| 210 |
+
parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature. >1 increases diversity, <1 decreases it.")
|
| 211 |
+
|
| 212 |
+
args = parser.parse_args()
|
| 213 |
+
main(args)
|
peptide/moo.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import math
|
| 3 |
+
import random
|
| 4 |
+
from collections import Counter
|
| 5 |
+
import csv
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from transformers import AutoTokenizer
|
| 11 |
+
|
| 12 |
+
from peptide_classifiers import *
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# --- Model Architecture (Must match the trained model) ---
|
| 16 |
+
def modulate(x, shift, scale):
|
| 17 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 18 |
+
|
| 19 |
+
class TimestepEmbedder(nn.Module):
|
| 20 |
+
def __init__(self, hidden_size):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.mlp = nn.Sequential(
|
| 23 |
+
nn.Linear(1, hidden_size, bias=True), nn.SiLU(),
|
| 24 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 25 |
+
)
|
| 26 |
+
def forward(self, t):
|
| 27 |
+
return self.mlp(t.unsqueeze(-1))
|
| 28 |
+
|
| 29 |
+
class DiTBlock(nn.Module):
|
| 30 |
+
def __init__(self, hidden_size, n_heads):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 33 |
+
self.attn = nn.MultiheadAttention(hidden_size, n_heads, batch_first=True)
|
| 34 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 35 |
+
self.mlp = nn.Sequential(
|
| 36 |
+
nn.Linear(hidden_size, 4 * hidden_size), nn.GELU(),
|
| 37 |
+
nn.Linear(4 * hidden_size, hidden_size)
|
| 38 |
+
)
|
| 39 |
+
self.adaLN_modulation = nn.Sequential(
|
| 40 |
+
nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 41 |
+
)
|
| 42 |
+
def forward(self, x, c):
|
| 43 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
| 44 |
+
x_norm1 = modulate(self.norm1(x), shift_msa, scale_msa)
|
| 45 |
+
attn_output, _ = self.attn(x_norm1, x_norm1, x_norm1)
|
| 46 |
+
x = x + gate_msa.unsqueeze(1) * attn_output
|
| 47 |
+
x_norm2 = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
| 48 |
+
mlp_output = self.mlp(x_norm2)
|
| 49 |
+
x = x + gate_mlp.unsqueeze(1) * mlp_output
|
| 50 |
+
return x
|
| 51 |
+
|
| 52 |
+
class MDLM(nn.Module):
|
| 53 |
+
def __init__(self, vocab_size, seq_len, model_dim, n_heads, n_layers):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.vocab_size = vocab_size
|
| 56 |
+
self.seq_len = seq_len
|
| 57 |
+
self.model_dim = model_dim
|
| 58 |
+
self.mask_token_id = vocab_size
|
| 59 |
+
self.token_embedder = nn.Embedding(vocab_size + 1, model_dim)
|
| 60 |
+
self.pos_embedder = nn.Parameter(torch.randn(1, seq_len, model_dim))
|
| 61 |
+
self.time_embedder = TimestepEmbedder(model_dim)
|
| 62 |
+
self.transformer_blocks = nn.ModuleList([DiTBlock(model_dim, n_heads) for _ in range(n_layers)])
|
| 63 |
+
self.final_norm = nn.LayerNorm(model_dim)
|
| 64 |
+
self.lm_head = nn.Linear(model_dim, vocab_size)
|
| 65 |
+
def forward(self, x, t):
|
| 66 |
+
seq_len = x.shape[1]
|
| 67 |
+
x_embed = self.token_embedder(x) + self.pos_embedder[:, :seq_len, :]
|
| 68 |
+
t_embed = self.time_embedder(t)
|
| 69 |
+
for block in self.transformer_blocks:
|
| 70 |
+
x_embed = block(x_embed, t_embed)
|
| 71 |
+
x_embed = self.final_norm(x_embed)
|
| 72 |
+
logits = self.lm_head(x_embed)
|
| 73 |
+
return logits
|
| 74 |
+
|
| 75 |
+
class MOGGenerator:
|
| 76 |
+
def __init__(self, model, device, objectives, args):
|
| 77 |
+
self.model = model
|
| 78 |
+
self.device = device
|
| 79 |
+
self.objectives = objectives
|
| 80 |
+
self.args = args
|
| 81 |
+
self.num_objectives = len(objectives)
|
| 82 |
+
|
| 83 |
+
def _get_scores(self, x_batch):
|
| 84 |
+
"""Calculates the normalized scores for a batch of sequences."""
|
| 85 |
+
scores = []
|
| 86 |
+
for obj_func in self.objectives:
|
| 87 |
+
scores.append(obj_func(x_batch.to(self.device)))
|
| 88 |
+
return torch.stack(scores, dim=0)
|
| 89 |
+
|
| 90 |
+
def _barker_g(self, u):
|
| 91 |
+
"""Barker balancing function."""
|
| 92 |
+
return u / (1 + u)
|
| 93 |
+
|
| 94 |
+
def generate(self):
|
| 95 |
+
"""Main generation loop."""
|
| 96 |
+
shape = (self.args.num_samples, self.args.gen_len + 2)
|
| 97 |
+
x = torch.randint(5, self.model.vocab_size, shape, dtype=torch.long, device=self.device)
|
| 98 |
+
x[:, 0] = 0
|
| 99 |
+
x[:, -1] = 2
|
| 100 |
+
|
| 101 |
+
if args.weights is None:
|
| 102 |
+
weights = torch.full((self.num_objectives,), 1/self.num_objectives, device=self.device).view(-1,1)
|
| 103 |
+
else:
|
| 104 |
+
weights = torch.tensor(self.args.weights, device=self.device).view(-1, 1)
|
| 105 |
+
if len(weights) != self.num_objectives:
|
| 106 |
+
raise ValueError("Number of weights must match number of objectives.")
|
| 107 |
+
print(f"Weights: {weights}")
|
| 108 |
+
|
| 109 |
+
if self.args.min_threshold is not None:
|
| 110 |
+
min_threshold = torch.tensor(self.args.min_threshold, device=self.device)
|
| 111 |
+
else:
|
| 112 |
+
min_threshold = None
|
| 113 |
+
|
| 114 |
+
total_optimization_steps = self.args.optimization_steps * self.args.gen_len
|
| 115 |
+
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
for t in tqdm(range(total_optimization_steps), desc="MOG Generation"):
|
| 118 |
+
# Anneal guidance strength
|
| 119 |
+
eta_t = self.args.eta_min + (self.args.eta_max - self.args.eta_min) * (t / (total_optimization_steps - 1))
|
| 120 |
+
# eta_t = 0.5 * (self.args.eta_min + self.args.eta_max)
|
| 121 |
+
# Choose a random position to mutate
|
| 122 |
+
mut_idx = random.randint(1, self.args.gen_len)
|
| 123 |
+
|
| 124 |
+
# Determine the generation timestep
|
| 125 |
+
# We cycle through the timesteps to ensure all are visited
|
| 126 |
+
generation_step = t % self.args.optimization_steps
|
| 127 |
+
time_t = torch.full((self.args.num_samples,), (generation_step / self.args.optimization_steps), device=self.device)
|
| 128 |
+
|
| 129 |
+
# Get proposal distribution from ReDi model for the chosen position
|
| 130 |
+
logits = self.model(x, time_t)
|
| 131 |
+
probs = F.softmax(logits, dim=-1)
|
| 132 |
+
pos_probs = probs[:, mut_idx, :]
|
| 133 |
+
pos_probs[:, x[:, mut_idx]] = 0 # We don't evalute the same token
|
| 134 |
+
|
| 135 |
+
# Prune candidate vocabulary using top-p sampling
|
| 136 |
+
sorted_probs, sorted_indices = torch.sort(pos_probs, descending=True)
|
| 137 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 138 |
+
remove_mask = cumulative_probs > self.args.top_p
|
| 139 |
+
remove_mask[..., 1:] = remove_mask[..., :-1].clone()
|
| 140 |
+
remove_mask[..., 0] = 0
|
| 141 |
+
|
| 142 |
+
# Get the set of candidate tokens for each sample in the batch
|
| 143 |
+
candidate_tokens_list = []
|
| 144 |
+
for i in range(self.args.num_samples):
|
| 145 |
+
sample_mask = remove_mask[i]
|
| 146 |
+
candidates = sorted_indices[i, ~sample_mask]
|
| 147 |
+
candidate_tokens_list.append(candidates)
|
| 148 |
+
|
| 149 |
+
# Get current scores
|
| 150 |
+
current_scores = self._get_scores(x)
|
| 151 |
+
w_current = torch.exp(eta_t * torch.min(weights * current_scores, dim=0).values)
|
| 152 |
+
|
| 153 |
+
# Evaluate all candidate tokens for each sample
|
| 154 |
+
final_proposal_tokens = []
|
| 155 |
+
for i in range(self.args.num_samples):
|
| 156 |
+
candidates = candidate_tokens_list[i]
|
| 157 |
+
candidates = torch.tensor([token for token in candidates if token not in [0,1,2,3]], device=candidates.device)
|
| 158 |
+
num_candidates = len(candidates)
|
| 159 |
+
|
| 160 |
+
# Create a batch of proposed sequences for the current sample
|
| 161 |
+
x_prop_batch = x[i].repeat(num_candidates, 1)
|
| 162 |
+
x_prop_batch[:, mut_idx] = candidates
|
| 163 |
+
|
| 164 |
+
# Evaluate all proposals
|
| 165 |
+
proposal_scores = self._get_scores(x_prop_batch)
|
| 166 |
+
proposal_s_omega = torch.min(weights * proposal_scores, dim=0).values
|
| 167 |
+
w_proposal = torch.exp(eta_t * proposal_s_omega)
|
| 168 |
+
|
| 169 |
+
# Get ReDi probabilities for the candidates
|
| 170 |
+
redi_probs = pos_probs[i, candidates]
|
| 171 |
+
|
| 172 |
+
# Calculate unnormalized guided probabilities
|
| 173 |
+
tilde_q = redi_probs * self._barker_g(w_proposal / w_current[i])
|
| 174 |
+
|
| 175 |
+
# Normalize and sample the final token
|
| 176 |
+
final_probs = tilde_q / (torch.sum(tilde_q) + 1e-9)
|
| 177 |
+
|
| 178 |
+
index = torch.multinomial(final_probs, 1).item()
|
| 179 |
+
if torch.sum(weights.squeeze(1) * proposal_scores[:, index]) >= torch.sum(weights.squeeze(1) * current_scores[:,i]):
|
| 180 |
+
final_token = candidates[index]
|
| 181 |
+
print(f"Previous Weighted Sum: {torch.sum(weights.squeeze(1) * current_scores[:,i])}")
|
| 182 |
+
print(f"Previous Scores: {current_scores[:,i]}")
|
| 183 |
+
|
| 184 |
+
print(f"New Weighted Sum: {torch.sum(weights.squeeze(1) * proposal_scores[:, index])}")
|
| 185 |
+
print(f"New Scores: {proposal_scores[:,index]}")
|
| 186 |
+
else:
|
| 187 |
+
final_token = x[i][mut_idx]
|
| 188 |
+
# final_token = candidates[index]
|
| 189 |
+
|
| 190 |
+
final_proposal_tokens.append(final_token)
|
| 191 |
+
|
| 192 |
+
# Update the sequences with the chosen tokens
|
| 193 |
+
x[torch.arange(self.args.num_samples), mut_idx] = torch.stack(final_proposal_tokens)
|
| 194 |
+
|
| 195 |
+
scores = self._get_scores(x)
|
| 196 |
+
|
| 197 |
+
return x
|
| 198 |
+
|
| 199 |
+
# --- Main Execution ---
|
| 200 |
+
def main(args):
|
| 201 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 202 |
+
print(f"Using device: {device}")
|
| 203 |
+
|
| 204 |
+
target = args.target
|
| 205 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 206 |
+
target_sequence = tokenizer(target, return_tensors='pt')['input_ids'].to(device)
|
| 207 |
+
|
| 208 |
+
affinity_predictor = load_affinity_predictor('/scratch/pranamlab/tong/ReDi_discrete/peptides/classifier_ckpt/binding_affinity_unpooled.pt', device)
|
| 209 |
+
affinity_model = AffinityModel(affinity_predictor, target_sequence)
|
| 210 |
+
hemolysis_model = HemolysisModel(device=device)
|
| 211 |
+
nonfouling_model = NonfoulingModel(device=device)
|
| 212 |
+
solubility_model = SolubilityModel(device=device)
|
| 213 |
+
halflife_model = HalfLifeModel(device=device)
|
| 214 |
+
|
| 215 |
+
print(f"Loading checkpoint from {args.checkpoint}...")
|
| 216 |
+
try:
|
| 217 |
+
checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False)
|
| 218 |
+
model_args = checkpoint['args']
|
| 219 |
+
except Exception as e:
|
| 220 |
+
print(f"Error loading checkpoint: {e}")
|
| 221 |
+
return
|
| 222 |
+
|
| 223 |
+
print("Initializing model...")
|
| 224 |
+
model = MDLM(
|
| 225 |
+
vocab_size=model_args.vocab_size,
|
| 226 |
+
seq_len=model_args.seq_len,
|
| 227 |
+
model_dim=model_args.model_dim,
|
| 228 |
+
n_heads=model_args.n_heads,
|
| 229 |
+
n_layers=model_args.n_layers
|
| 230 |
+
).to(device)
|
| 231 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 232 |
+
print("Model loaded successfully.")
|
| 233 |
+
|
| 234 |
+
# List of all objective functions
|
| 235 |
+
OBJECTIVE_FUNCTIONS = [hemolysis_model, nonfouling_model, solubility_model, halflife_model, affinity_model]
|
| 236 |
+
|
| 237 |
+
mog_generator = MOGGenerator(model, device, OBJECTIVE_FUNCTIONS, args)
|
| 238 |
+
|
| 239 |
+
hemolysis = []
|
| 240 |
+
nonfouling = []
|
| 241 |
+
solubility = []
|
| 242 |
+
halflife = []
|
| 243 |
+
affinity = []
|
| 244 |
+
|
| 245 |
+
for _ in range(args.num_batches):
|
| 246 |
+
generated_tokens = mog_generator.generate()
|
| 247 |
+
final_scores = mog_generator._get_scores(generated_tokens).detach().cpu().numpy()
|
| 248 |
+
|
| 249 |
+
with open(args.output_file, 'a', newline='') as f:
|
| 250 |
+
writer = csv.writer(f)
|
| 251 |
+
|
| 252 |
+
for i in range(args.num_samples):
|
| 253 |
+
sample_tokens = generated_tokens[i]
|
| 254 |
+
print(sample_tokens)
|
| 255 |
+
sequence_str = tokenizer.decode(sample_tokens.tolist(), skip_special_tokens=False).replace(" ", "")[5:-5]
|
| 256 |
+
|
| 257 |
+
scores = final_scores[:, i]
|
| 258 |
+
|
| 259 |
+
writer.writerow([sequence_str] + scores.tolist())
|
| 260 |
+
|
| 261 |
+
print([sequence_str] + scores.tolist())
|
| 262 |
+
|
| 263 |
+
print("Generation complete.")
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
if __name__ == "__main__":
|
| 268 |
+
parser = argparse.ArgumentParser(description="Multi-Objective Generation with LBP-MOG-ReDi (Single Mutation).")
|
| 269 |
+
|
| 270 |
+
parser.add_argument("--checkpoint", type=str, required=True, help="Path to the trained ReDi model checkpoint.")
|
| 271 |
+
parser.add_argument("--num_samples", type=int, default=10, help="Number of samples to generate.")
|
| 272 |
+
parser.add_argument("--num_batches", type=int, default=10, help="Number of samples to generate.")
|
| 273 |
+
parser.add_argument("--output_file", type=str, default="./mog_peptides.txt", help="File to save the generated sequences.")
|
| 274 |
+
parser.add_argument("--gen_len", type=int, default=50, help="Length of the sequences to generate.")
|
| 275 |
+
parser.add_argument("--optimization_steps", type=int, default=16, help="Number of passes over the sequence.")
|
| 276 |
+
parser.add_argument("--weights", type=float, nargs='+', required=False, help="Weights for the objectives (e.g., 0.5 0.5).")
|
| 277 |
+
parser.add_argument("--min_threshold", type=float, nargs='+', required=False, help="minimum threshold for the objectives (e.g., 0.2 0.2).")
|
| 278 |
+
parser.add_argument("--eta_min", type=float, default=1.0, help="Minimum guidance strength for annealing.")
|
| 279 |
+
parser.add_argument("--eta_max", type=float, default=20.0, help="Maximum guidance strength for annealing.")
|
| 280 |
+
parser.add_argument("--top_p", type=float, default=0.9, help="Top-p for pruning candidate tokens.")
|
| 281 |
+
|
| 282 |
+
parser.add_argument("--target", type=str, required=True)
|
| 283 |
+
args = parser.parse_args()
|
| 284 |
+
main(args)
|
peptide/new_coupling.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from datasets import Dataset, DatasetDict
|
| 10 |
+
|
| 11 |
+
# --- Model Architecture (Must match the trained model) ---
|
| 12 |
+
def modulate(x, shift, scale):
|
| 13 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 14 |
+
|
| 15 |
+
class TimestepEmbedder(nn.Module):
|
| 16 |
+
def __init__(self, hidden_size):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.mlp = nn.Sequential(
|
| 19 |
+
nn.Linear(1, hidden_size, bias=True), nn.SiLU(),
|
| 20 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 21 |
+
)
|
| 22 |
+
def forward(self, t):
|
| 23 |
+
return self.mlp(t.unsqueeze(-1))
|
| 24 |
+
|
| 25 |
+
class DiTBlock(nn.Module):
|
| 26 |
+
def __init__(self, hidden_size, n_heads):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 29 |
+
self.attn = nn.MultiheadAttention(hidden_size, n_heads, batch_first=True)
|
| 30 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 31 |
+
self.mlp = nn.Sequential(
|
| 32 |
+
nn.Linear(hidden_size, 4 * hidden_size), nn.GELU(),
|
| 33 |
+
nn.Linear(4 * hidden_size, hidden_size)
|
| 34 |
+
)
|
| 35 |
+
self.adaLN_modulation = nn.Sequential(
|
| 36 |
+
nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 37 |
+
)
|
| 38 |
+
def forward(self, x, c):
|
| 39 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
| 40 |
+
x_norm1 = modulate(self.norm1(x), shift_msa, scale_msa)
|
| 41 |
+
attn_output, _ = self.attn(x_norm1, x_norm1, x_norm1)
|
| 42 |
+
x = x + gate_msa.unsqueeze(1) * attn_output
|
| 43 |
+
x_norm2 = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
| 44 |
+
mlp_output = self.mlp(x_norm2)
|
| 45 |
+
x = x + gate_mlp.unsqueeze(1) * mlp_output
|
| 46 |
+
return x
|
| 47 |
+
|
| 48 |
+
class MDLM(nn.Module):
|
| 49 |
+
def __init__(self, vocab_size, seq_len, model_dim, n_heads, n_layers):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.vocab_size = vocab_size
|
| 52 |
+
self.seq_len = seq_len
|
| 53 |
+
self.model_dim = model_dim
|
| 54 |
+
self.mask_token_id = vocab_size
|
| 55 |
+
self.token_embedder = nn.Embedding(vocab_size + 1, model_dim)
|
| 56 |
+
self.pos_embedder = nn.Parameter(torch.randn(1, seq_len, model_dim))
|
| 57 |
+
self.time_embedder = TimestepEmbedder(model_dim)
|
| 58 |
+
self.transformer_blocks = nn.ModuleList([DiTBlock(model_dim, n_heads) for _ in range(n_layers)])
|
| 59 |
+
self.final_norm = nn.LayerNorm(model_dim)
|
| 60 |
+
self.lm_head = nn.Linear(model_dim, vocab_size)
|
| 61 |
+
def forward(self, x, t):
|
| 62 |
+
seq_len = x.shape[1]
|
| 63 |
+
x_embed = self.token_embedder(x) + self.pos_embedder[:, :seq_len, :]
|
| 64 |
+
t_embed = self.time_embedder(t)
|
| 65 |
+
for block in self.transformer_blocks:
|
| 66 |
+
x_embed = block(x_embed, t_embed)
|
| 67 |
+
x_embed = self.final_norm(x_embed)
|
| 68 |
+
logits = self.lm_head(x_embed)
|
| 69 |
+
return logits
|
| 70 |
+
|
| 71 |
+
# --- Generation & Utility Functions ---
|
| 72 |
+
|
| 73 |
+
def generate_x1_from_x0(model, device, x0_batch, steps, temperature):
|
| 74 |
+
model.eval()
|
| 75 |
+
x = x0_batch.clone()
|
| 76 |
+
num_samples, seq_len = x.shape
|
| 77 |
+
keep_schedule = torch.cos(torch.linspace(math.pi / 2, 0, steps, device=device)) * seq_len
|
| 78 |
+
keep_schedule = torch.round(keep_schedule).long()
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
for i in range(steps):
|
| 81 |
+
t_continuous = torch.full((num_samples,), 1.0 - (i / steps), device=device)
|
| 82 |
+
logits = model(x, t_continuous)
|
| 83 |
+
scaled_logits = logits / temperature
|
| 84 |
+
probs = torch.nn.functional.softmax(scaled_logits, dim=-1)
|
| 85 |
+
sampled_tokens = torch.multinomial(probs.view(-1, model.vocab_size), 1).view(x.shape)
|
| 86 |
+
if i == steps - 1:
|
| 87 |
+
x = sampled_tokens
|
| 88 |
+
break
|
| 89 |
+
confidence = torch.gather(probs, 2, sampled_tokens.unsqueeze(-1)).squeeze(-1)
|
| 90 |
+
num_to_keep = keep_schedule[i]
|
| 91 |
+
_, indices_to_keep = torch.topk(confidence, num_to_keep, largest=True, dim=-1)
|
| 92 |
+
keep_mask = torch.zeros_like(x, dtype=torch.bool).scatter_(1, indices_to_keep, True)
|
| 93 |
+
x = torch.where(keep_mask, sampled_tokens, x)
|
| 94 |
+
return x
|
| 95 |
+
|
| 96 |
+
def is_sample_valid(sample_x1):
|
| 97 |
+
"""
|
| 98 |
+
Checks if special tokens [0, 1, 2, 3] appear in the middle of the sequence.
|
| 99 |
+
"""
|
| 100 |
+
middle_sequence = sample_x1[1:-1]
|
| 101 |
+
invalid_tokens = {0, 1, 2, 3}
|
| 102 |
+
for token in middle_sequence:
|
| 103 |
+
if token in invalid_tokens:
|
| 104 |
+
return False
|
| 105 |
+
return True
|
| 106 |
+
|
| 107 |
+
def create_prebatched_dataset(dataset, max_tokens_per_batch=500):
|
| 108 |
+
"""
|
| 109 |
+
Groups samples into batches and restructures the dataset.
|
| 110 |
+
Each row in the new dataset is a complete batch.
|
| 111 |
+
"""
|
| 112 |
+
# Group samples by their length
|
| 113 |
+
data_by_length = defaultdict(list)
|
| 114 |
+
for sample in dataset:
|
| 115 |
+
length = len(sample['input_ids_x1'])
|
| 116 |
+
data_by_length[length].append(sample)
|
| 117 |
+
|
| 118 |
+
# Create the actual batches
|
| 119 |
+
batched_data = {'input_ids_x0': [], 'input_ids_x1': []}
|
| 120 |
+
for length, samples in data_by_length.items():
|
| 121 |
+
samples_per_batch = max(1, max_tokens_per_batch // length)
|
| 122 |
+
for i in range(0, len(samples), samples_per_batch):
|
| 123 |
+
batch_samples = samples[i:i + samples_per_batch]
|
| 124 |
+
|
| 125 |
+
batch_x0 = [s['input_ids_x0'] for s in batch_samples]
|
| 126 |
+
batch_x1 = [s['input_ids_x1'] for s in batch_samples]
|
| 127 |
+
|
| 128 |
+
batched_data['input_ids_x0'].append(batch_x0)
|
| 129 |
+
batched_data['input_ids_x1'].append(batch_x1)
|
| 130 |
+
|
| 131 |
+
return Dataset.from_dict(batched_data)
|
| 132 |
+
|
| 133 |
+
# --- Main Execution ---
|
| 134 |
+
|
| 135 |
+
def main(args):
|
| 136 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 137 |
+
print(f"Using device: {device}")
|
| 138 |
+
|
| 139 |
+
print(f"Loading checkpoint from {args.checkpoint}...")
|
| 140 |
+
try:
|
| 141 |
+
checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False)
|
| 142 |
+
model_args = checkpoint['args']
|
| 143 |
+
except Exception as e:
|
| 144 |
+
print(f"Error loading checkpoint: {e}")
|
| 145 |
+
return
|
| 146 |
+
|
| 147 |
+
print("Initializing model...")
|
| 148 |
+
model = MDLM(
|
| 149 |
+
vocab_size=model_args.vocab_size,
|
| 150 |
+
seq_len=model_args.seq_len,
|
| 151 |
+
model_dim=model_args.model_dim,
|
| 152 |
+
n_heads=model_args.n_heads,
|
| 153 |
+
n_layers=model_args.n_layers
|
| 154 |
+
).to(device)
|
| 155 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 156 |
+
print("Model loaded successfully.")
|
| 157 |
+
|
| 158 |
+
all_x0 = []
|
| 159 |
+
all_x1 = []
|
| 160 |
+
|
| 161 |
+
# 1. Generate samples for each length
|
| 162 |
+
for length in range(args.min_len, args.max_len + 1):
|
| 163 |
+
print(f"Generating {args.samples_per_len} valid samples for length {length}...")
|
| 164 |
+
valid_samples_count = 0
|
| 165 |
+
pbar = tqdm(total=args.samples_per_len)
|
| 166 |
+
while valid_samples_count < args.samples_per_len:
|
| 167 |
+
remaining = args.samples_per_len - valid_samples_count
|
| 168 |
+
batch_size = min(args.batch_size, remaining)
|
| 169 |
+
|
| 170 |
+
shape = (batch_size, length)
|
| 171 |
+
x0_batch = torch.randint(0, model.vocab_size, shape, dtype=torch.long, device=device)
|
| 172 |
+
x1_batch = generate_x1_from_x0(model, device, x0_batch, args.gen_steps, args.temperature)
|
| 173 |
+
|
| 174 |
+
# 2. Perform sanity check on each sample
|
| 175 |
+
for x0, x1 in zip(x0_batch, x1_batch):
|
| 176 |
+
if is_sample_valid(x1.tolist()):
|
| 177 |
+
all_x0.append(x0.cpu().tolist())
|
| 178 |
+
all_x1.append(x1.cpu().tolist())
|
| 179 |
+
valid_samples_count += 1
|
| 180 |
+
pbar.update(1)
|
| 181 |
+
if valid_samples_count >= args.samples_per_len:
|
| 182 |
+
break
|
| 183 |
+
pbar.close()
|
| 184 |
+
|
| 185 |
+
# 3. Create dataset and split
|
| 186 |
+
print("Splitting dataset...")
|
| 187 |
+
rectified_data = {'input_ids_x0': all_x0, 'input_ids_x1': all_x1}
|
| 188 |
+
dataset = Dataset.from_dict(rectified_data)
|
| 189 |
+
train_test_split = dataset.train_test_split(test_size=0.2, seed=42)
|
| 190 |
+
valid_test_split = train_test_split['test'].train_test_split(test_size=0.5, seed=42)
|
| 191 |
+
final_dataset_dict = DatasetDict({
|
| 192 |
+
'train': train_test_split['train'],
|
| 193 |
+
'validation': valid_test_split['train'],
|
| 194 |
+
'test': valid_test_split['test']
|
| 195 |
+
})
|
| 196 |
+
|
| 197 |
+
# 4. Pre-batch each split
|
| 198 |
+
print("Pre-batching splits...")
|
| 199 |
+
batched_dataset_dict = DatasetDict()
|
| 200 |
+
for split_name, split_dataset in final_dataset_dict.items():
|
| 201 |
+
print(f"Processing {split_name} split...")
|
| 202 |
+
batched_dataset_dict[split_name] = create_prebatched_dataset(split_dataset)
|
| 203 |
+
|
| 204 |
+
# 5. Save the final dataset
|
| 205 |
+
output_path = f"{args.output_path}/v{args.version}"
|
| 206 |
+
print(f"Saving new batched dataset to {output_path}...")
|
| 207 |
+
batched_dataset_dict.save_to_disk(output_path)
|
| 208 |
+
|
| 209 |
+
print("Rectification complete.")
|
| 210 |
+
print(f"Train on this by updating your training script's dataset path to '{output_path}'.")
|
| 211 |
+
|
| 212 |
+
if __name__ == "__main__":
|
| 213 |
+
parser = argparse.ArgumentParser(description="Generate a rectified dataset with variable lengths and pre-batching.")
|
| 214 |
+
|
| 215 |
+
parser.add_argument("--checkpoint", type=str, required=True)
|
| 216 |
+
parser.add_argument("--output_path", type=str, default="./rectified_datasets")
|
| 217 |
+
parser.add_argument("--version", type=str, default='1')
|
| 218 |
+
parser.add_argument("--samples_per_len", type=int, default=10000)
|
| 219 |
+
parser.add_argument("--min_len", type=int, default=6)
|
| 220 |
+
parser.add_argument("--max_len", type=int, default=49)
|
| 221 |
+
parser.add_argument("--gen_steps", type=int, default=16)
|
| 222 |
+
parser.add_argument("--temperature", type=float, default=1.0)
|
| 223 |
+
parser.add_argument("--batch_size", type=int, default=128)
|
| 224 |
+
|
| 225 |
+
args = parser.parse_args()
|
| 226 |
+
main(args)
|
peptide/peptide_classifiers.py
ADDED
|
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
import time
|
| 6 |
+
from transformers import AutoModel, AutoConfig, AutoTokenizer
|
| 7 |
+
import xgboost as xgb
|
| 8 |
+
import esm
|
| 9 |
+
|
| 10 |
+
class UnpooledBindingPredictor(nn.Module):
|
| 11 |
+
def __init__(self,
|
| 12 |
+
esm_model_name="facebook/esm2_t33_650M_UR50D",
|
| 13 |
+
hidden_dim=512,
|
| 14 |
+
kernel_sizes=[3, 5, 7],
|
| 15 |
+
n_heads=8,
|
| 16 |
+
n_layers=3,
|
| 17 |
+
dropout=0.1,
|
| 18 |
+
freeze_esm=True):
|
| 19 |
+
super().__init__()
|
| 20 |
+
|
| 21 |
+
# Define binding thresholds
|
| 22 |
+
self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
|
| 23 |
+
self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
|
| 24 |
+
|
| 25 |
+
# Load ESM model for computing embeddings on the fly
|
| 26 |
+
self.esm_model = AutoModel.from_pretrained(esm_model_name)
|
| 27 |
+
self.config = AutoConfig.from_pretrained(esm_model_name)
|
| 28 |
+
|
| 29 |
+
# Freeze ESM parameters if needed
|
| 30 |
+
if freeze_esm:
|
| 31 |
+
for param in self.esm_model.parameters():
|
| 32 |
+
param.requires_grad = False
|
| 33 |
+
|
| 34 |
+
# Get ESM hidden size
|
| 35 |
+
esm_dim = self.config.hidden_size
|
| 36 |
+
|
| 37 |
+
# Output channels for CNN layers
|
| 38 |
+
output_channels_per_kernel = 64
|
| 39 |
+
|
| 40 |
+
# CNN layers for handling variable length sequences
|
| 41 |
+
self.protein_conv_layers = nn.ModuleList([
|
| 42 |
+
nn.Conv1d(
|
| 43 |
+
in_channels=esm_dim,
|
| 44 |
+
out_channels=output_channels_per_kernel,
|
| 45 |
+
kernel_size=k,
|
| 46 |
+
padding='same'
|
| 47 |
+
) for k in kernel_sizes
|
| 48 |
+
])
|
| 49 |
+
|
| 50 |
+
self.binder_conv_layers = nn.ModuleList([
|
| 51 |
+
nn.Conv1d(
|
| 52 |
+
in_channels=esm_dim,
|
| 53 |
+
out_channels=output_channels_per_kernel,
|
| 54 |
+
kernel_size=k,
|
| 55 |
+
padding='same'
|
| 56 |
+
) for k in kernel_sizes
|
| 57 |
+
])
|
| 58 |
+
|
| 59 |
+
# Calculate total features after convolution and pooling
|
| 60 |
+
total_features_per_seq = output_channels_per_kernel * len(kernel_sizes) * 2
|
| 61 |
+
|
| 62 |
+
# Project to same dimension after CNN processing
|
| 63 |
+
self.protein_projection = nn.Linear(total_features_per_seq, hidden_dim)
|
| 64 |
+
self.binder_projection = nn.Linear(total_features_per_seq, hidden_dim)
|
| 65 |
+
|
| 66 |
+
self.protein_norm = nn.LayerNorm(hidden_dim)
|
| 67 |
+
self.binder_norm = nn.LayerNorm(hidden_dim)
|
| 68 |
+
|
| 69 |
+
# Cross attention blocks with layer norm
|
| 70 |
+
self.cross_attention_layers = nn.ModuleList([
|
| 71 |
+
nn.ModuleDict({
|
| 72 |
+
'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
|
| 73 |
+
'norm1': nn.LayerNorm(hidden_dim),
|
| 74 |
+
'ffn': nn.Sequential(
|
| 75 |
+
nn.Linear(hidden_dim, hidden_dim * 4),
|
| 76 |
+
nn.ReLU(),
|
| 77 |
+
nn.Dropout(dropout),
|
| 78 |
+
nn.Linear(hidden_dim * 4, hidden_dim)
|
| 79 |
+
),
|
| 80 |
+
'norm2': nn.LayerNorm(hidden_dim)
|
| 81 |
+
}) for _ in range(n_layers)
|
| 82 |
+
])
|
| 83 |
+
|
| 84 |
+
# Prediction heads
|
| 85 |
+
self.shared_head = nn.Sequential(
|
| 86 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
| 87 |
+
nn.ReLU(),
|
| 88 |
+
nn.Dropout(dropout),
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Regression head
|
| 92 |
+
self.regression_head = nn.Linear(hidden_dim, 1)
|
| 93 |
+
|
| 94 |
+
# Classification head (3 classes: tight, medium, loose binding)
|
| 95 |
+
self.classification_head = nn.Linear(hidden_dim, 3)
|
| 96 |
+
|
| 97 |
+
def get_binding_class(self, affinity):
|
| 98 |
+
"""Convert affinity values to class indices
|
| 99 |
+
0: tight binding (>= 7.5)
|
| 100 |
+
1: medium binding (6.0-7.5)
|
| 101 |
+
2: weak binding (< 6.0)
|
| 102 |
+
"""
|
| 103 |
+
if isinstance(affinity, torch.Tensor):
|
| 104 |
+
tight_mask = affinity >= self.tight_threshold
|
| 105 |
+
weak_mask = affinity < self.weak_threshold
|
| 106 |
+
medium_mask = ~(tight_mask | weak_mask)
|
| 107 |
+
|
| 108 |
+
classes = torch.zeros_like(affinity, dtype=torch.long)
|
| 109 |
+
classes[medium_mask] = 1
|
| 110 |
+
classes[weak_mask] = 2
|
| 111 |
+
return classes
|
| 112 |
+
else:
|
| 113 |
+
if affinity >= self.tight_threshold:
|
| 114 |
+
return 0 # tight binding
|
| 115 |
+
elif affinity < self.weak_threshold:
|
| 116 |
+
return 2 # weak binding
|
| 117 |
+
else:
|
| 118 |
+
return 1 # medium binding
|
| 119 |
+
|
| 120 |
+
def compute_embeddings(self, input_ids, attention_mask=None):
|
| 121 |
+
"""Compute ESM embeddings on the fly"""
|
| 122 |
+
esm_outputs = self.esm_model(
|
| 123 |
+
input_ids=input_ids,
|
| 124 |
+
attention_mask=attention_mask,
|
| 125 |
+
return_dict=True
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Get the unpooled last hidden states (batch_size x seq_length x hidden_size)
|
| 129 |
+
return esm_outputs.last_hidden_state
|
| 130 |
+
|
| 131 |
+
def process_sequence(self, unpooled_emb, conv_layers, attention_mask=None):
|
| 132 |
+
"""Process a sequence through CNN layers and pooling"""
|
| 133 |
+
# Transpose for CNN: [batch_size, hidden_size, seq_length]
|
| 134 |
+
x = unpooled_emb.transpose(1, 2)
|
| 135 |
+
|
| 136 |
+
# Apply CNN layers and collect outputs
|
| 137 |
+
conv_outputs = []
|
| 138 |
+
for conv in conv_layers:
|
| 139 |
+
conv_out = F.relu(conv(x))
|
| 140 |
+
conv_outputs.append(conv_out)
|
| 141 |
+
|
| 142 |
+
# Concatenate along channel dimension
|
| 143 |
+
conv_output = torch.cat(conv_outputs, dim=1)
|
| 144 |
+
|
| 145 |
+
# Global pooling (both max and average)
|
| 146 |
+
# If attention mask is provided, use it to create a proper mask for pooling
|
| 147 |
+
if attention_mask is not None:
|
| 148 |
+
# Create a mask for pooling (1 for valid positions, 0 for padding)
|
| 149 |
+
# Expand mask to match conv_output channels
|
| 150 |
+
expanded_mask = attention_mask.unsqueeze(1).expand(-1, conv_output.size(1), -1)
|
| 151 |
+
|
| 152 |
+
# Apply mask (set padding to large negative value for max pooling)
|
| 153 |
+
masked_output = conv_output.clone()
|
| 154 |
+
masked_output = masked_output.masked_fill(expanded_mask == 0, float('-inf'))
|
| 155 |
+
|
| 156 |
+
# Max pooling along sequence dimension
|
| 157 |
+
max_pooled = torch.max(masked_output, dim=2)[0]
|
| 158 |
+
|
| 159 |
+
# Average pooling (sum divided by number of valid positions)
|
| 160 |
+
sum_pooled = torch.sum(conv_output * expanded_mask, dim=2)
|
| 161 |
+
valid_positions = torch.sum(expanded_mask, dim=2)
|
| 162 |
+
valid_positions = torch.clamp(valid_positions, min=1.0) # Avoid division by zero
|
| 163 |
+
avg_pooled = sum_pooled / valid_positions
|
| 164 |
+
else:
|
| 165 |
+
# If no mask, use standard pooling
|
| 166 |
+
max_pooled = torch.max(conv_output, dim=2)[0]
|
| 167 |
+
avg_pooled = torch.mean(conv_output, dim=2)
|
| 168 |
+
|
| 169 |
+
# Concatenate the pooled features
|
| 170 |
+
pooled = torch.cat([max_pooled, avg_pooled], dim=1)
|
| 171 |
+
|
| 172 |
+
return pooled
|
| 173 |
+
|
| 174 |
+
def forward(self, protein_input_ids, binder_input_ids, protein_mask=None, binder_mask=None):
|
| 175 |
+
# Compute embeddings on the fly using the ESM model
|
| 176 |
+
protein_unpooled = self.compute_embeddings(protein_input_ids, protein_mask)
|
| 177 |
+
binder_unpooled = self.compute_embeddings(binder_input_ids, binder_mask)
|
| 178 |
+
|
| 179 |
+
# Process protein and binder sequences through CNN layers
|
| 180 |
+
protein_features = self.process_sequence(protein_unpooled, self.protein_conv_layers, protein_mask)
|
| 181 |
+
binder_features = self.process_sequence(binder_unpooled, self.binder_conv_layers, binder_mask)
|
| 182 |
+
|
| 183 |
+
# Project to same dimension
|
| 184 |
+
protein = self.protein_norm(self.protein_projection(protein_features))
|
| 185 |
+
binder = self.binder_norm(self.binder_projection(binder_features))
|
| 186 |
+
|
| 187 |
+
# Reshape for attention: from [batch_size, hidden_dim] to [1, batch_size, hidden_dim]
|
| 188 |
+
protein = protein.unsqueeze(0)
|
| 189 |
+
binder = binder.unsqueeze(0)
|
| 190 |
+
|
| 191 |
+
# Cross attention layers
|
| 192 |
+
for layer in self.cross_attention_layers:
|
| 193 |
+
# Protein attending to binder
|
| 194 |
+
attended_protein = layer['attention'](
|
| 195 |
+
protein, binder, binder
|
| 196 |
+
)[0]
|
| 197 |
+
protein = layer['norm1'](protein + attended_protein)
|
| 198 |
+
protein = layer['norm2'](protein + layer['ffn'](protein))
|
| 199 |
+
|
| 200 |
+
# Binder attending to protein
|
| 201 |
+
attended_binder = layer['attention'](
|
| 202 |
+
binder, protein, protein
|
| 203 |
+
)[0]
|
| 204 |
+
binder = layer['norm1'](binder + attended_binder)
|
| 205 |
+
binder = layer['norm2'](binder + layer['ffn'](binder))
|
| 206 |
+
|
| 207 |
+
# Remove sequence dimension
|
| 208 |
+
protein_pool = protein.squeeze(0)
|
| 209 |
+
binder_pool = binder.squeeze(0)
|
| 210 |
+
|
| 211 |
+
# Concatenate both representations
|
| 212 |
+
combined = torch.cat([protein_pool, binder_pool], dim=-1)
|
| 213 |
+
|
| 214 |
+
# Shared features
|
| 215 |
+
shared_features = self.shared_head(combined)
|
| 216 |
+
|
| 217 |
+
regression_output = self.regression_head(shared_features)
|
| 218 |
+
# classification_logits = self.classification_head(shared_features)
|
| 219 |
+
|
| 220 |
+
# return regression_output, classification_logits
|
| 221 |
+
return regression_output
|
| 222 |
+
|
| 223 |
+
class ImprovedBindingPredictor(nn.Module):
|
| 224 |
+
def __init__(self,
|
| 225 |
+
esm_dim=1280,
|
| 226 |
+
smiles_dim=1280,
|
| 227 |
+
hidden_dim=512,
|
| 228 |
+
n_heads=8,
|
| 229 |
+
n_layers=5,
|
| 230 |
+
dropout=0.1):
|
| 231 |
+
super().__init__()
|
| 232 |
+
|
| 233 |
+
# Define binding thresholds
|
| 234 |
+
self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
|
| 235 |
+
self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
|
| 236 |
+
|
| 237 |
+
# Project to same dimension
|
| 238 |
+
self.smiles_projection = nn.Linear(smiles_dim, hidden_dim)
|
| 239 |
+
self.protein_projection = nn.Linear(esm_dim, hidden_dim)
|
| 240 |
+
self.protein_norm = nn.LayerNorm(hidden_dim)
|
| 241 |
+
self.smiles_norm = nn.LayerNorm(hidden_dim)
|
| 242 |
+
|
| 243 |
+
# Cross attention blocks with layer norm
|
| 244 |
+
self.cross_attention_layers = nn.ModuleList([
|
| 245 |
+
nn.ModuleDict({
|
| 246 |
+
'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
|
| 247 |
+
'norm1': nn.LayerNorm(hidden_dim),
|
| 248 |
+
'ffn': nn.Sequential(
|
| 249 |
+
nn.Linear(hidden_dim, hidden_dim * 4),
|
| 250 |
+
nn.ReLU(),
|
| 251 |
+
nn.Dropout(dropout),
|
| 252 |
+
nn.Linear(hidden_dim * 4, hidden_dim)
|
| 253 |
+
),
|
| 254 |
+
'norm2': nn.LayerNorm(hidden_dim)
|
| 255 |
+
}) for _ in range(n_layers)
|
| 256 |
+
])
|
| 257 |
+
|
| 258 |
+
# Prediction heads
|
| 259 |
+
self.shared_head = nn.Sequential(
|
| 260 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
| 261 |
+
nn.ReLU(),
|
| 262 |
+
nn.Dropout(dropout),
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# Regression head
|
| 266 |
+
self.regression_head = nn.Linear(hidden_dim, 1)
|
| 267 |
+
|
| 268 |
+
# Classification head (3 classes: tight, medium, loose binding)
|
| 269 |
+
self.classification_head = nn.Linear(hidden_dim, 3)
|
| 270 |
+
|
| 271 |
+
def get_binding_class(self, affinity):
|
| 272 |
+
"""Convert affinity values to class indices
|
| 273 |
+
0: tight binding (>= 7.5)
|
| 274 |
+
1: medium binding (6.0-7.5)
|
| 275 |
+
2: weak binding (< 6.0)
|
| 276 |
+
"""
|
| 277 |
+
if isinstance(affinity, torch.Tensor):
|
| 278 |
+
tight_mask = affinity >= self.tight_threshold
|
| 279 |
+
weak_mask = affinity < self.weak_threshold
|
| 280 |
+
medium_mask = ~(tight_mask | weak_mask)
|
| 281 |
+
|
| 282 |
+
classes = torch.zeros_like(affinity, dtype=torch.long)
|
| 283 |
+
classes[medium_mask] = 1
|
| 284 |
+
classes[weak_mask] = 2
|
| 285 |
+
return classes
|
| 286 |
+
else:
|
| 287 |
+
if affinity >= self.tight_threshold:
|
| 288 |
+
return 0 # tight binding
|
| 289 |
+
elif affinity < self.weak_threshold:
|
| 290 |
+
return 2 # weak binding
|
| 291 |
+
else:
|
| 292 |
+
return 1 # medium binding
|
| 293 |
+
|
| 294 |
+
def forward(self, protein_emb, binder_emb):
|
| 295 |
+
|
| 296 |
+
protein = self.protein_norm(self.protein_projection(protein_emb))
|
| 297 |
+
smiles = self.smiles_norm(self.smiles_projection(binder_emb))
|
| 298 |
+
|
| 299 |
+
protein = protein.transpose(0, 1)
|
| 300 |
+
smiles = smiles.transpose(0, 1)
|
| 301 |
+
|
| 302 |
+
# Cross attention layers
|
| 303 |
+
for layer in self.cross_attention_layers:
|
| 304 |
+
# Protein attending to SMILES
|
| 305 |
+
attended_protein = layer['attention'](
|
| 306 |
+
protein, smiles, smiles
|
| 307 |
+
)[0]
|
| 308 |
+
protein = layer['norm1'](protein + attended_protein)
|
| 309 |
+
protein = layer['norm2'](protein + layer['ffn'](protein))
|
| 310 |
+
|
| 311 |
+
# SMILES attending to protein
|
| 312 |
+
attended_smiles = layer['attention'](
|
| 313 |
+
smiles, protein, protein
|
| 314 |
+
)[0]
|
| 315 |
+
smiles = layer['norm1'](smiles + attended_smiles)
|
| 316 |
+
smiles = layer['norm2'](smiles + layer['ffn'](smiles))
|
| 317 |
+
|
| 318 |
+
# Get sequence-level representations
|
| 319 |
+
protein_pool = torch.mean(protein, dim=0)
|
| 320 |
+
smiles_pool = torch.mean(smiles, dim=0)
|
| 321 |
+
|
| 322 |
+
# Concatenate both representations
|
| 323 |
+
combined = torch.cat([protein_pool, smiles_pool], dim=-1)
|
| 324 |
+
|
| 325 |
+
# Shared features
|
| 326 |
+
shared_features = self.shared_head(combined)
|
| 327 |
+
|
| 328 |
+
regression_output = self.regression_head(shared_features)
|
| 329 |
+
|
| 330 |
+
return regression_output
|
| 331 |
+
|
| 332 |
+
class PooledAffinityModel(nn.Module):
|
| 333 |
+
def __init__(self, affinity_predictor, target_sequence):
|
| 334 |
+
super(PooledAffinityModel, self).__init__()
|
| 335 |
+
self.affinity_predictor = affinity_predictor
|
| 336 |
+
self.target_sequence = target_sequence
|
| 337 |
+
self.esm_model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(self.target_sequence.device)
|
| 338 |
+
for param in self.esm_model.parameters():
|
| 339 |
+
param.requires_grad = False
|
| 340 |
+
|
| 341 |
+
def compute_embeddings(self, input_ids, attention_mask=None):
|
| 342 |
+
"""Compute ESM embeddings on the fly"""
|
| 343 |
+
esm_outputs = self.esm_model(
|
| 344 |
+
input_ids=input_ids,
|
| 345 |
+
attention_mask=attention_mask,
|
| 346 |
+
return_dict=True
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
# Get the unpooled last hidden states (batch_size x seq_length x hidden_size)
|
| 350 |
+
return esm_outputs.last_hidden_state
|
| 351 |
+
|
| 352 |
+
def forward(self, x):
|
| 353 |
+
target_sequence = self.target_sequence.repeat(x.shape[0], 1)
|
| 354 |
+
|
| 355 |
+
protein_emb = self.compute_embeddings(input_ids=target_sequence)
|
| 356 |
+
binder_emb = self.compute_embeddings(input_ids=x)
|
| 357 |
+
return self.affinity_predictor(protein_emb=protein_emb, binder_emb=binder_emb).squeeze(-1)
|
| 358 |
+
|
| 359 |
+
class AffinityModel(nn.Module):
|
| 360 |
+
def __init__(self, affinity_predictor, target_sequence):
|
| 361 |
+
super(AffinityModel, self).__init__()
|
| 362 |
+
self.affinity_predictor = affinity_predictor
|
| 363 |
+
self.target_sequence = target_sequence
|
| 364 |
+
|
| 365 |
+
def forward(self, x):
|
| 366 |
+
target_sequence = self.target_sequence.repeat(x.shape[0], 1)
|
| 367 |
+
affinity = self.affinity_predictor(protein_input_ids=target_sequence, binder_input_ids=x).squeeze(-1)
|
| 368 |
+
return affinity / 10
|
| 369 |
+
|
| 370 |
+
class HemolysisModel:
|
| 371 |
+
def __init__(self, device):
|
| 372 |
+
self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_hemolysis.json')
|
| 373 |
+
|
| 374 |
+
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 375 |
+
self.model.eval()
|
| 376 |
+
|
| 377 |
+
self.device = device
|
| 378 |
+
|
| 379 |
+
def generate_embeddings(self, sequences):
|
| 380 |
+
"""Generate ESM embeddings for protein sequences"""
|
| 381 |
+
with torch.no_grad():
|
| 382 |
+
embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1)
|
| 383 |
+
embeddings = embeddings.cpu().numpy()
|
| 384 |
+
|
| 385 |
+
return embeddings
|
| 386 |
+
|
| 387 |
+
def get_scores(self, input_seqs):
|
| 388 |
+
scores = np.ones(len(input_seqs))
|
| 389 |
+
features = self.generate_embeddings(input_seqs)
|
| 390 |
+
|
| 391 |
+
if len(features) == 0:
|
| 392 |
+
return scores
|
| 393 |
+
|
| 394 |
+
features = np.nan_to_num(features, nan=0.)
|
| 395 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 396 |
+
|
| 397 |
+
features = xgb.DMatrix(features)
|
| 398 |
+
|
| 399 |
+
probs = self.predictor.predict(features)
|
| 400 |
+
# return the probability of it being not hemolytic
|
| 401 |
+
return torch.from_numpy(scores - probs).to(self.device)
|
| 402 |
+
|
| 403 |
+
def __call__(self, input_seqs: list):
|
| 404 |
+
scores = self.get_scores(input_seqs)
|
| 405 |
+
return scores
|
| 406 |
+
|
| 407 |
+
class NonfoulingModel:
|
| 408 |
+
def __init__(self, device):
|
| 409 |
+
# change model path
|
| 410 |
+
self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_nonfouling.json')
|
| 411 |
+
|
| 412 |
+
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 413 |
+
self.model.eval()
|
| 414 |
+
|
| 415 |
+
self.device = device
|
| 416 |
+
|
| 417 |
+
def generate_embeddings(self, sequences):
|
| 418 |
+
"""Generate ESM embeddings for protein sequences"""
|
| 419 |
+
with torch.no_grad():
|
| 420 |
+
embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1)
|
| 421 |
+
embeddings = embeddings.cpu().numpy()
|
| 422 |
+
|
| 423 |
+
return embeddings
|
| 424 |
+
|
| 425 |
+
def get_scores(self, input_seqs):
|
| 426 |
+
scores = np.zeros(len(input_seqs))
|
| 427 |
+
features = self.generate_embeddings(input_seqs)
|
| 428 |
+
|
| 429 |
+
if len(features) == 0:
|
| 430 |
+
return scores
|
| 431 |
+
|
| 432 |
+
features = np.nan_to_num(features, nan=0.)
|
| 433 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 434 |
+
|
| 435 |
+
features = xgb.DMatrix(features)
|
| 436 |
+
|
| 437 |
+
scores = self.predictor.predict(features)
|
| 438 |
+
return torch.from_numpy(scores).to(self.device)
|
| 439 |
+
|
| 440 |
+
def __call__(self, input_seqs: list):
|
| 441 |
+
scores = self.get_scores(input_seqs)
|
| 442 |
+
return scores
|
| 443 |
+
|
| 444 |
+
class SolubilityModel:
|
| 445 |
+
def __init__(self, device):
|
| 446 |
+
# change model path
|
| 447 |
+
self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_solubility.json')
|
| 448 |
+
|
| 449 |
+
self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
|
| 450 |
+
self.model.eval()
|
| 451 |
+
|
| 452 |
+
self.device = device
|
| 453 |
+
|
| 454 |
+
def generate_embeddings(self, sequences):
|
| 455 |
+
"""Generate ESM embeddings for protein sequences"""
|
| 456 |
+
with torch.no_grad():
|
| 457 |
+
embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1)
|
| 458 |
+
embeddings = embeddings.cpu().numpy()
|
| 459 |
+
|
| 460 |
+
return embeddings
|
| 461 |
+
|
| 462 |
+
def get_scores(self, input_seqs: list):
|
| 463 |
+
scores = np.zeros(len(input_seqs))
|
| 464 |
+
features = self.generate_embeddings(input_seqs)
|
| 465 |
+
|
| 466 |
+
if len(features) == 0:
|
| 467 |
+
return scores
|
| 468 |
+
|
| 469 |
+
features = np.nan_to_num(features, nan=0.)
|
| 470 |
+
features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
|
| 471 |
+
|
| 472 |
+
features = xgb.DMatrix(features)
|
| 473 |
+
|
| 474 |
+
scores = self.predictor.predict(features)
|
| 475 |
+
return torch.from_numpy(scores).to(self.device)
|
| 476 |
+
|
| 477 |
+
def __call__(self, input_seqs: list):
|
| 478 |
+
scores = self.get_scores(input_seqs)
|
| 479 |
+
return scores
|
| 480 |
+
|
| 481 |
+
class PeptideCNN(nn.Module):
|
| 482 |
+
def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate):
|
| 483 |
+
super().__init__()
|
| 484 |
+
self.conv1 = nn.Conv1d(input_dim, hidden_dims[0], kernel_size=3, padding=1)
|
| 485 |
+
self.conv2 = nn.Conv1d(hidden_dims[0], hidden_dims[1], kernel_size=5, padding=1)
|
| 486 |
+
self.fc = nn.Linear(hidden_dims[1], output_dim)
|
| 487 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 488 |
+
self.predictor = nn.Linear(output_dim, 1) # For regression/classification
|
| 489 |
+
|
| 490 |
+
self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 491 |
+
self.esm_model.eval()
|
| 492 |
+
|
| 493 |
+
def forward(self, input_ids, attention_mask=None, return_features=False):
|
| 494 |
+
with torch.no_grad():
|
| 495 |
+
x = self.esm_model(input_ids, attention_mask).last_hidden_state
|
| 496 |
+
# x shape: (B, L, input_dim)
|
| 497 |
+
x = x.permute(0, 2, 1) # Reshape to (B, input_dim, L) for Conv1d
|
| 498 |
+
x = nn.functional.relu(self.conv1(x))
|
| 499 |
+
x = self.dropout(x)
|
| 500 |
+
x = nn.functional.relu(self.conv2(x))
|
| 501 |
+
x = self.dropout(x)
|
| 502 |
+
x = x.permute(0, 2, 1) # Reshape back to (B, L, hidden_dims[1])
|
| 503 |
+
|
| 504 |
+
# Global average pooling over the sequence dimension (L)
|
| 505 |
+
x = x.mean(dim=1) # Shape: (B, hidden_dims[1])
|
| 506 |
+
|
| 507 |
+
features = self.fc(x) # features shape: (B, output_dim)
|
| 508 |
+
if return_features:
|
| 509 |
+
return features
|
| 510 |
+
return self.predictor(features) # Output shape: (B, 1)
|
| 511 |
+
|
| 512 |
+
class HalfLifeModel:
|
| 513 |
+
def __init__(self, device):
|
| 514 |
+
input_dim = 1280
|
| 515 |
+
hidden_dims = [input_dim // 2, input_dim // 4]
|
| 516 |
+
output_dim = input_dim // 8
|
| 517 |
+
dropout_rate = 0.3
|
| 518 |
+
self.model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device)
|
| 519 |
+
self.model.load_state_dict(torch.load('./classifier_ckpt/best_model_half_life.pth', map_location=device, weights_only=False))
|
| 520 |
+
self.model.eval()
|
| 521 |
+
|
| 522 |
+
def __call__(self, x):
|
| 523 |
+
prediction = self.model(x, return_features=False)
|
| 524 |
+
half_life = torch.clamp(prediction.squeeze(-1), max=2.0, min=0.0)
|
| 525 |
+
|
| 526 |
+
return half_life / 2
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
def load_bindevaluator(checkpoint_path, device):
|
| 530 |
+
bindevaluator = BindEvaluator.load_from_checkpoint(checkpoint_path, n_layers=8, d_model=128, d_hidden=128, n_head=8, d_k=64, d_v=128, d_inner=64).to(device)
|
| 531 |
+
bindevaluator.eval()
|
| 532 |
+
for param in bindevaluator.parameters():
|
| 533 |
+
param.requires_grad = False
|
| 534 |
+
|
| 535 |
+
return bindevaluator
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def load_pooled_affinity_predictor(checkpoint_path, device):
|
| 540 |
+
"""Load trained model from checkpoint."""
|
| 541 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 542 |
+
|
| 543 |
+
model = ImprovedBindingPredictor().to(device)
|
| 544 |
+
|
| 545 |
+
# Load the trained weights
|
| 546 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 547 |
+
model.eval() # Set to evaluation mode
|
| 548 |
+
|
| 549 |
+
return model
|
| 550 |
+
|
| 551 |
+
def load_affinity_predictor(checkpoint_path, device):
|
| 552 |
+
"""Load trained model from checkpoint."""
|
| 553 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 554 |
+
|
| 555 |
+
model = UnpooledBindingPredictor(
|
| 556 |
+
esm_model_name="facebook/esm2_t33_650M_UR50D",
|
| 557 |
+
hidden_dim=384,
|
| 558 |
+
kernel_sizes=[3, 5, 7],
|
| 559 |
+
n_heads=8,
|
| 560 |
+
n_layers=4,
|
| 561 |
+
dropout=0.14561457009902096,
|
| 562 |
+
freeze_esm=True
|
| 563 |
+
).to(device)
|
| 564 |
+
|
| 565 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 566 |
+
model.eval()
|
| 567 |
+
|
| 568 |
+
return model
|
peptide/rectified_datasets/v1/dataset_dict.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"splits": ["train", "validation", "test"]}
|
peptide/rectified_datasets/v1/test/data-00000-of-00001.arrow
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:087edd095dcd714192f1d4ef341b1894bcee3fb03d5453c2b04d8ce031589318
|
| 3 |
+
size 19749472
|
peptide/rectified_datasets/v1/test/dataset_info.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"citation": "",
|
| 3 |
+
"description": "",
|
| 4 |
+
"features": {
|
| 5 |
+
"input_ids_x0": {
|
| 6 |
+
"feature": {
|
| 7 |
+
"feature": {
|
| 8 |
+
"dtype": "int64",
|
| 9 |
+
"_type": "Value"
|
| 10 |
+
},
|
| 11 |
+
"_type": "List"
|
| 12 |
+
},
|
| 13 |
+
"_type": "List"
|
| 14 |
+
},
|
| 15 |
+
"input_ids_x1": {
|
| 16 |
+
"feature": {
|
| 17 |
+
"feature": {
|
| 18 |
+
"dtype": "int64",
|
| 19 |
+
"_type": "Value"
|
| 20 |
+
},
|
| 21 |
+
"_type": "List"
|
| 22 |
+
},
|
| 23 |
+
"_type": "List"
|
| 24 |
+
}
|
| 25 |
+
},
|
| 26 |
+
"homepage": "",
|
| 27 |
+
"license": ""
|
| 28 |
+
}
|
peptide/rectified_datasets/v1/test/state.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_data_files": [
|
| 3 |
+
{
|
| 4 |
+
"filename": "data-00000-of-00001.arrow"
|
| 5 |
+
}
|
| 6 |
+
],
|
| 7 |
+
"_fingerprint": "118d550fe7101754",
|
| 8 |
+
"_format_columns": null,
|
| 9 |
+
"_format_kwargs": {},
|
| 10 |
+
"_format_type": null,
|
| 11 |
+
"_output_all_columns": false,
|
| 12 |
+
"_split": null
|
| 13 |
+
}
|
peptide/rectified_datasets/v1/train/data-00000-of-00001.arrow
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3b4af46986df1e635f0dff8e3f00c4d4c06aac26161836ca694299d1cc0bd20f
|
| 3 |
+
size 157859216
|
peptide/rectified_datasets/v1/train/dataset_info.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"citation": "",
|
| 3 |
+
"description": "",
|
| 4 |
+
"features": {
|
| 5 |
+
"input_ids_x0": {
|
| 6 |
+
"feature": {
|
| 7 |
+
"feature": {
|
| 8 |
+
"dtype": "int64",
|
| 9 |
+
"_type": "Value"
|
| 10 |
+
},
|
| 11 |
+
"_type": "List"
|
| 12 |
+
},
|
| 13 |
+
"_type": "List"
|
| 14 |
+
},
|
| 15 |
+
"input_ids_x1": {
|
| 16 |
+
"feature": {
|
| 17 |
+
"feature": {
|
| 18 |
+
"dtype": "int64",
|
| 19 |
+
"_type": "Value"
|
| 20 |
+
},
|
| 21 |
+
"_type": "List"
|
| 22 |
+
},
|
| 23 |
+
"_type": "List"
|
| 24 |
+
}
|
| 25 |
+
},
|
| 26 |
+
"homepage": "",
|
| 27 |
+
"license": ""
|
| 28 |
+
}
|
peptide/rectified_datasets/v1/train/state.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_data_files": [
|
| 3 |
+
{
|
| 4 |
+
"filename": "data-00000-of-00001.arrow"
|
| 5 |
+
}
|
| 6 |
+
],
|
| 7 |
+
"_fingerprint": "a5ddb0c42fb68c3f",
|
| 8 |
+
"_format_columns": null,
|
| 9 |
+
"_format_kwargs": {},
|
| 10 |
+
"_format_type": null,
|
| 11 |
+
"_output_all_columns": false,
|
| 12 |
+
"_split": null
|
| 13 |
+
}
|
peptide/rectified_datasets/v1/validation/data-00000-of-00001.arrow
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:409939a66eebfac56a884440065ec2e6cd1b81632fede0d4e2156cd56600a2b8
|
| 3 |
+
size 19725216
|
peptide/rectified_datasets/v1/validation/dataset_info.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"citation": "",
|
| 3 |
+
"description": "",
|
| 4 |
+
"features": {
|
| 5 |
+
"input_ids_x0": {
|
| 6 |
+
"feature": {
|
| 7 |
+
"feature": {
|
| 8 |
+
"dtype": "int64",
|
| 9 |
+
"_type": "Value"
|
| 10 |
+
},
|
| 11 |
+
"_type": "List"
|
| 12 |
+
},
|
| 13 |
+
"_type": "List"
|
| 14 |
+
},
|
| 15 |
+
"input_ids_x1": {
|
| 16 |
+
"feature": {
|
| 17 |
+
"feature": {
|
| 18 |
+
"dtype": "int64",
|
| 19 |
+
"_type": "Value"
|
| 20 |
+
},
|
| 21 |
+
"_type": "List"
|
| 22 |
+
},
|
| 23 |
+
"_type": "List"
|
| 24 |
+
}
|
| 25 |
+
},
|
| 26 |
+
"homepage": "",
|
| 27 |
+
"license": ""
|
| 28 |
+
}
|
peptide/rectified_datasets/v1/validation/state.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_data_files": [
|
| 3 |
+
{
|
| 4 |
+
"filename": "data-00000-of-00001.arrow"
|
| 5 |
+
}
|
| 6 |
+
],
|
| 7 |
+
"_fingerprint": "3a37666e1156a9e6",
|
| 8 |
+
"_format_columns": null,
|
| 9 |
+
"_format_kwargs": {},
|
| 10 |
+
"_format_type": null,
|
| 11 |
+
"_output_all_columns": false,
|
| 12 |
+
"_split": null
|
| 13 |
+
}
|
peptide/rectified_datasets/v2/dataset_dict.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"splits": ["train", "validation", "test"]}
|
peptide/rectified_datasets/v2/test/data-00000-of-00001.arrow
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fd85592cab50715eb0268f061670b7a98d3871a7aa1baefd40c6a23055f489d6
|
| 3 |
+
size 19749472
|
peptide/rectified_datasets/v2/test/dataset_info.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"citation": "",
|
| 3 |
+
"description": "",
|
| 4 |
+
"features": {
|
| 5 |
+
"input_ids_x0": {
|
| 6 |
+
"feature": {
|
| 7 |
+
"feature": {
|
| 8 |
+
"dtype": "int64",
|
| 9 |
+
"_type": "Value"
|
| 10 |
+
},
|
| 11 |
+
"_type": "List"
|
| 12 |
+
},
|
| 13 |
+
"_type": "List"
|
| 14 |
+
},
|
| 15 |
+
"input_ids_x1": {
|
| 16 |
+
"feature": {
|
| 17 |
+
"feature": {
|
| 18 |
+
"dtype": "int64",
|
| 19 |
+
"_type": "Value"
|
| 20 |
+
},
|
| 21 |
+
"_type": "List"
|
| 22 |
+
},
|
| 23 |
+
"_type": "List"
|
| 24 |
+
}
|
| 25 |
+
},
|
| 26 |
+
"homepage": "",
|
| 27 |
+
"license": ""
|
| 28 |
+
}
|
peptide/rectified_datasets/v2/test/state.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_data_files": [
|
| 3 |
+
{
|
| 4 |
+
"filename": "data-00000-of-00001.arrow"
|
| 5 |
+
}
|
| 6 |
+
],
|
| 7 |
+
"_fingerprint": "03f0e67bb58fcf47",
|
| 8 |
+
"_format_columns": null,
|
| 9 |
+
"_format_kwargs": {},
|
| 10 |
+
"_format_type": null,
|
| 11 |
+
"_output_all_columns": false,
|
| 12 |
+
"_split": null
|
| 13 |
+
}
|
peptide/rectified_datasets/v2/train/data-00000-of-00001.arrow
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:26bea1ec7f7db609a4949a6209bc5536003f7ab169f1e03d5a256db81af434bd
|
| 3 |
+
size 157859216
|
peptide/rectified_datasets/v2/train/dataset_info.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"citation": "",
|
| 3 |
+
"description": "",
|
| 4 |
+
"features": {
|
| 5 |
+
"input_ids_x0": {
|
| 6 |
+
"feature": {
|
| 7 |
+
"feature": {
|
| 8 |
+
"dtype": "int64",
|
| 9 |
+
"_type": "Value"
|
| 10 |
+
},
|
| 11 |
+
"_type": "List"
|
| 12 |
+
},
|
| 13 |
+
"_type": "List"
|
| 14 |
+
},
|
| 15 |
+
"input_ids_x1": {
|
| 16 |
+
"feature": {
|
| 17 |
+
"feature": {
|
| 18 |
+
"dtype": "int64",
|
| 19 |
+
"_type": "Value"
|
| 20 |
+
},
|
| 21 |
+
"_type": "List"
|
| 22 |
+
},
|
| 23 |
+
"_type": "List"
|
| 24 |
+
}
|
| 25 |
+
},
|
| 26 |
+
"homepage": "",
|
| 27 |
+
"license": ""
|
| 28 |
+
}
|
peptide/rectified_datasets/v2/train/state.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_data_files": [
|
| 3 |
+
{
|
| 4 |
+
"filename": "data-00000-of-00001.arrow"
|
| 5 |
+
}
|
| 6 |
+
],
|
| 7 |
+
"_fingerprint": "c41975ecd76982be",
|
| 8 |
+
"_format_columns": null,
|
| 9 |
+
"_format_kwargs": {},
|
| 10 |
+
"_format_type": null,
|
| 11 |
+
"_output_all_columns": false,
|
| 12 |
+
"_split": null
|
| 13 |
+
}
|
peptide/rectified_datasets/v2/validation/data-00000-of-00001.arrow
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ff40262e105e2c748e8e5ca8f89b5882af643f099227d65e6c4b13da8d328094
|
| 3 |
+
size 19725216
|
peptide/rectified_datasets/v2/validation/dataset_info.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"citation": "",
|
| 3 |
+
"description": "",
|
| 4 |
+
"features": {
|
| 5 |
+
"input_ids_x0": {
|
| 6 |
+
"feature": {
|
| 7 |
+
"feature": {
|
| 8 |
+
"dtype": "int64",
|
| 9 |
+
"_type": "Value"
|
| 10 |
+
},
|
| 11 |
+
"_type": "List"
|
| 12 |
+
},
|
| 13 |
+
"_type": "List"
|
| 14 |
+
},
|
| 15 |
+
"input_ids_x1": {
|
| 16 |
+
"feature": {
|
| 17 |
+
"feature": {
|
| 18 |
+
"dtype": "int64",
|
| 19 |
+
"_type": "Value"
|
| 20 |
+
},
|
| 21 |
+
"_type": "List"
|
| 22 |
+
},
|
| 23 |
+
"_type": "List"
|
| 24 |
+
}
|
| 25 |
+
},
|
| 26 |
+
"homepage": "",
|
| 27 |
+
"license": ""
|
| 28 |
+
}
|
peptide/rectified_datasets/v2/validation/state.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_data_files": [
|
| 3 |
+
{
|
| 4 |
+
"filename": "data-00000-of-00001.arrow"
|
| 5 |
+
}
|
| 6 |
+
],
|
| 7 |
+
"_fingerprint": "39ddf61d20fce77a",
|
| 8 |
+
"_format_columns": null,
|
| 9 |
+
"_format_kwargs": {},
|
| 10 |
+
"_format_type": null,
|
| 11 |
+
"_output_all_columns": false,
|
| 12 |
+
"_split": null
|
| 13 |
+
}
|
peptide/rectified_datasets/v3/dataset_dict.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"splits": ["train", "validation", "test"]}
|
peptide/rectified_datasets/v3/test/data-00000-of-00001.arrow
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1d6169a4b613ec0ac0c7fdcda92ec031f6eab8a2d0cf451dd744b780ea096825
|
| 3 |
+
size 19749472
|
peptide/rectified_datasets/v3/test/dataset_info.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"citation": "",
|
| 3 |
+
"description": "",
|
| 4 |
+
"features": {
|
| 5 |
+
"input_ids_x0": {
|
| 6 |
+
"feature": {
|
| 7 |
+
"feature": {
|
| 8 |
+
"dtype": "int64",
|
| 9 |
+
"_type": "Value"
|
| 10 |
+
},
|
| 11 |
+
"_type": "List"
|
| 12 |
+
},
|
| 13 |
+
"_type": "List"
|
| 14 |
+
},
|
| 15 |
+
"input_ids_x1": {
|
| 16 |
+
"feature": {
|
| 17 |
+
"feature": {
|
| 18 |
+
"dtype": "int64",
|
| 19 |
+
"_type": "Value"
|
| 20 |
+
},
|
| 21 |
+
"_type": "List"
|
| 22 |
+
},
|
| 23 |
+
"_type": "List"
|
| 24 |
+
}
|
| 25 |
+
},
|
| 26 |
+
"homepage": "",
|
| 27 |
+
"license": ""
|
| 28 |
+
}
|
peptide/rectified_datasets/v3/test/state.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_data_files": [
|
| 3 |
+
{
|
| 4 |
+
"filename": "data-00000-of-00001.arrow"
|
| 5 |
+
}
|
| 6 |
+
],
|
| 7 |
+
"_fingerprint": "f6aed185a066dd98",
|
| 8 |
+
"_format_columns": null,
|
| 9 |
+
"_format_kwargs": {},
|
| 10 |
+
"_format_type": null,
|
| 11 |
+
"_output_all_columns": false,
|
| 12 |
+
"_split": null
|
| 13 |
+
}
|
peptide/rectified_datasets/v3/train/data-00000-of-00001.arrow
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5a7828b0d21a1d8103347e7bdc4438caf3dd9f1d1bf1158fcec31e1de0d9bcea
|
| 3 |
+
size 157859216
|
peptide/rectified_datasets/v3/train/dataset_info.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"citation": "",
|
| 3 |
+
"description": "",
|
| 4 |
+
"features": {
|
| 5 |
+
"input_ids_x0": {
|
| 6 |
+
"feature": {
|
| 7 |
+
"feature": {
|
| 8 |
+
"dtype": "int64",
|
| 9 |
+
"_type": "Value"
|
| 10 |
+
},
|
| 11 |
+
"_type": "List"
|
| 12 |
+
},
|
| 13 |
+
"_type": "List"
|
| 14 |
+
},
|
| 15 |
+
"input_ids_x1": {
|
| 16 |
+
"feature": {
|
| 17 |
+
"feature": {
|
| 18 |
+
"dtype": "int64",
|
| 19 |
+
"_type": "Value"
|
| 20 |
+
},
|
| 21 |
+
"_type": "List"
|
| 22 |
+
},
|
| 23 |
+
"_type": "List"
|
| 24 |
+
}
|
| 25 |
+
},
|
| 26 |
+
"homepage": "",
|
| 27 |
+
"license": ""
|
| 28 |
+
}
|
peptide/rectified_datasets/v3/train/state.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_data_files": [
|
| 3 |
+
{
|
| 4 |
+
"filename": "data-00000-of-00001.arrow"
|
| 5 |
+
}
|
| 6 |
+
],
|
| 7 |
+
"_fingerprint": "448b61e862d72291",
|
| 8 |
+
"_format_columns": null,
|
| 9 |
+
"_format_kwargs": {},
|
| 10 |
+
"_format_type": null,
|
| 11 |
+
"_output_all_columns": false,
|
| 12 |
+
"_split": null
|
| 13 |
+
}
|