Tong Chen commited on
Commit
d2693e0
·
1 Parent(s): f9d1b81
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. peptide/ckpt/PepReDi_base.pt +3 -0
  2. peptide/ckpt/PepReDi_v1.pt +3 -0
  3. peptide/ckpt/PepReDi_v2.pt +3 -0
  4. peptide/ckpt/PepReDi_v3.pt +3 -0
  5. peptide/classifier_ckpt/best_model_half_life.pth +3 -0
  6. peptide/classifier_ckpt/best_model_hemolysis.json +0 -0
  7. peptide/classifier_ckpt/best_model_nonfouling.json +0 -0
  8. peptide/classifier_ckpt/best_model_solubility.json +0 -0
  9. peptide/classifier_ckpt/binding_affinity_pooled.pt +3 -0
  10. peptide/classifier_ckpt/binding_affinity_unpooled.pt +3 -0
  11. peptide/data/test/data-00000-of-00001.arrow +3 -0
  12. peptide/data/test/dataset_info.json +15 -0
  13. peptide/data/test/state.json +13 -0
  14. peptide/data/train/data-00000-of-00001.arrow +3 -0
  15. peptide/data/train/dataset_info.json +15 -0
  16. peptide/data/train/state.json +13 -0
  17. peptide/data/val/data-00000-of-00001.arrow +3 -0
  18. peptide/data/val/dataset_info.json +15 -0
  19. peptide/data/val/state.json +13 -0
  20. peptide/generation.py +213 -0
  21. peptide/moo.py +284 -0
  22. peptide/new_coupling.py +226 -0
  23. peptide/peptide_classifiers.py +568 -0
  24. peptide/rectified_datasets/v1/dataset_dict.json +1 -0
  25. peptide/rectified_datasets/v1/test/data-00000-of-00001.arrow +3 -0
  26. peptide/rectified_datasets/v1/test/dataset_info.json +28 -0
  27. peptide/rectified_datasets/v1/test/state.json +13 -0
  28. peptide/rectified_datasets/v1/train/data-00000-of-00001.arrow +3 -0
  29. peptide/rectified_datasets/v1/train/dataset_info.json +28 -0
  30. peptide/rectified_datasets/v1/train/state.json +13 -0
  31. peptide/rectified_datasets/v1/validation/data-00000-of-00001.arrow +3 -0
  32. peptide/rectified_datasets/v1/validation/dataset_info.json +28 -0
  33. peptide/rectified_datasets/v1/validation/state.json +13 -0
  34. peptide/rectified_datasets/v2/dataset_dict.json +1 -0
  35. peptide/rectified_datasets/v2/test/data-00000-of-00001.arrow +3 -0
  36. peptide/rectified_datasets/v2/test/dataset_info.json +28 -0
  37. peptide/rectified_datasets/v2/test/state.json +13 -0
  38. peptide/rectified_datasets/v2/train/data-00000-of-00001.arrow +3 -0
  39. peptide/rectified_datasets/v2/train/dataset_info.json +28 -0
  40. peptide/rectified_datasets/v2/train/state.json +13 -0
  41. peptide/rectified_datasets/v2/validation/data-00000-of-00001.arrow +3 -0
  42. peptide/rectified_datasets/v2/validation/dataset_info.json +28 -0
  43. peptide/rectified_datasets/v2/validation/state.json +13 -0
  44. peptide/rectified_datasets/v3/dataset_dict.json +1 -0
  45. peptide/rectified_datasets/v3/test/data-00000-of-00001.arrow +3 -0
  46. peptide/rectified_datasets/v3/test/dataset_info.json +28 -0
  47. peptide/rectified_datasets/v3/test/state.json +13 -0
  48. peptide/rectified_datasets/v3/train/data-00000-of-00001.arrow +3 -0
  49. peptide/rectified_datasets/v3/train/dataset_info.json +28 -0
  50. peptide/rectified_datasets/v3/train/state.json +13 -0
peptide/ckpt/PepReDi_base.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b6437edee514bd7adb9aacea776f3dd97c59ebc7b4928b390eafdd87eaeb8c9
3
+ size 344474053
peptide/ckpt/PepReDi_v1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b72e98ce0752e6c4805179b37f09c2e3824d22db0113e165988ee4017d6f3d39
3
+ size 344457840
peptide/ckpt/PepReDi_v2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6de709e0c6ce9356d37dfc9b6238249c129ee2fd30ad6d03da3b35d5f5a5f8ad
3
+ size 344457840
peptide/ckpt/PepReDi_v3.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7456b4d156e922a4a499573df4a0cf228315c5540bc67b842af274f605561f79
3
+ size 344457840
peptide/classifier_ckpt/best_model_half_life.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f80f1b20e90ba30503804c738aad4b3bb253424ff2e6e8a86c8e13a2fa1669f9
3
+ size 2623795199
peptide/classifier_ckpt/best_model_hemolysis.json ADDED
The diff for this file is too large to render. See raw diff
 
peptide/classifier_ckpt/best_model_nonfouling.json ADDED
The diff for this file is too large to render. See raw diff
 
peptide/classifier_ckpt/best_model_solubility.json ADDED
The diff for this file is too large to render. See raw diff
 
peptide/classifier_ckpt/binding_affinity_pooled.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91f60e417dfa64277e433b5bc841060d295b43f2d9c19b277b954ce447b44949
3
+ size 211324073
peptide/classifier_ckpt/binding_affinity_unpooled.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc28ae9f09b981b07547a773ca2e07f241cb08b3b8aa901e66627ff153f3aa8b
3
+ size 2731670995
peptide/data/test/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f3a4dde7bb38d2ae4aae44265f1beef5df22f36d21f17e6813641e090cc679c
3
+ size 82440
peptide/data/test/dataset_info.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids": {
6
+ "feature": {
7
+ "dtype": "int32",
8
+ "_type": "Value"
9
+ },
10
+ "_type": "List"
11
+ }
12
+ },
13
+ "homepage": "",
14
+ "license": ""
15
+ }
peptide/data/test/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "ae4d0541bd157aeb",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
peptide/data/train/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:646801034bcc9683b219f5d3195c4f4ce6551c5ecbc1b1bcbdfd8a7027d8a49e
3
+ size 641784
peptide/data/train/dataset_info.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids": {
6
+ "feature": {
7
+ "dtype": "int32",
8
+ "_type": "Value"
9
+ },
10
+ "_type": "List"
11
+ }
12
+ },
13
+ "homepage": "",
14
+ "license": ""
15
+ }
peptide/data/train/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "7b63856b107c2d5c",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
peptide/data/val/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f3a4dde7bb38d2ae4aae44265f1beef5df22f36d21f17e6813641e090cc679c
3
+ size 82440
peptide/data/val/dataset_info.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids": {
6
+ "feature": {
7
+ "dtype": "int32",
8
+ "_type": "Value"
9
+ },
10
+ "_type": "List"
11
+ }
12
+ },
13
+ "homepage": "",
14
+ "license": ""
15
+ }
peptide/data/val/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "ae4d0541bd157aeb",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
peptide/generation.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ from tqdm import tqdm
6
+ from transformers import AutoTokenizer
7
+
8
+ # --- Model Architecture ---
9
+ def modulate(x, shift, scale):
10
+ """
11
+ Modulates the input tensor x with a shift and scale.
12
+ """
13
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
14
+
15
+ class TimestepEmbedder(nn.Module):
16
+ """
17
+ Embeds a continuous scalar timestep t in [0, 1] into a vector representation.
18
+ """
19
+ def __init__(self, hidden_size):
20
+ super().__init__()
21
+ self.mlp = nn.Sequential(
22
+ nn.Linear(1, hidden_size, bias=True),
23
+ nn.SiLU(),
24
+ nn.Linear(hidden_size, hidden_size, bias=True),
25
+ )
26
+
27
+ def forward(self, t):
28
+ # t is shape (batch_size,), needs to be (batch_size, 1) for the Linear layer.
29
+ return self.mlp(t.unsqueeze(-1))
30
+
31
+ class DiTBlock(nn.Module):
32
+ """
33
+ A single block of the Diffusion Transformer.
34
+ """
35
+ def __init__(self, hidden_size, n_heads):
36
+ super().__init__()
37
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
38
+ self.attn = nn.MultiheadAttention(hidden_size, n_heads, batch_first=True)
39
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
40
+ self.mlp = nn.Sequential(
41
+ nn.Linear(hidden_size, 4 * hidden_size),
42
+ nn.GELU(),
43
+ nn.Linear(4 * hidden_size, hidden_size)
44
+ )
45
+ self.adaLN_modulation = nn.Sequential(
46
+ nn.SiLU(),
47
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
48
+ )
49
+
50
+ def forward(self, x, c):
51
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
52
+ x_norm1 = modulate(self.norm1(x), shift_msa, scale_msa)
53
+ attn_output, _ = self.attn(x_norm1, x_norm1, x_norm1)
54
+ x = x + gate_msa.unsqueeze(1) * attn_output
55
+ x_norm2 = modulate(self.norm2(x), shift_mlp, scale_mlp)
56
+ mlp_output = self.mlp(x_norm2)
57
+ x = x + gate_mlp.unsqueeze(1) * mlp_output
58
+ return x
59
+
60
+ class MDLM(nn.Module):
61
+ """
62
+ Masked Diffusion Language Model (MDLM) using a DiT backbone.
63
+ """
64
+ def __init__(self, vocab_size, seq_len, model_dim, n_heads, n_layers):
65
+ super().__init__()
66
+ self.vocab_size = vocab_size
67
+ self.seq_len = seq_len
68
+ self.model_dim = model_dim
69
+ self.mask_token_id = vocab_size # Use vocab_size as the ID for the mask token
70
+
71
+ self.token_embedder = nn.Embedding(vocab_size + 1, model_dim) # +1 for the mask token
72
+ self.pos_embedder = nn.Parameter(torch.randn(1, seq_len, model_dim))
73
+ self.time_embedder = TimestepEmbedder(model_dim)
74
+
75
+ self.transformer_blocks = nn.ModuleList([
76
+ DiTBlock(model_dim, n_heads) for _ in range(n_layers)
77
+ ])
78
+
79
+ self.final_norm = nn.LayerNorm(model_dim)
80
+ self.lm_head = nn.Linear(model_dim, vocab_size)
81
+
82
+ def forward(self, x, t):
83
+ seq_len = x.shape[1]
84
+ x_embed = self.token_embedder(x) + self.pos_embedder[:, :seq_len, :]
85
+ t_embed = self.time_embedder(t)
86
+ for block in self.transformer_blocks:
87
+ x_embed = block(x_embed, t_embed)
88
+ x_embed = self.final_norm(x_embed)
89
+ logits = self.lm_head(x_embed)
90
+ return logits
91
+
92
+ # --- Generation Function ---
93
+
94
+ def generate_samples(model, device, num_samples, seq_len, steps, temperature):
95
+ """
96
+ Generates samples by starting from a random sequence and progressively refining it.
97
+ """
98
+ model.eval()
99
+
100
+ # Start with a completely random sequence of tokens
101
+ shape = (num_samples, seq_len)
102
+ x = torch.randint(0, model.vocab_size, shape, dtype=torch.long, device=device)
103
+
104
+ # Cosine schedule determines how many tokens we *keep* from the previous step.
105
+ # It goes from 0 (keep none) to seq_len (keep all).
106
+ keep_schedule = torch.cos(torch.linspace(math.pi / 2, 0, steps, device=device)) * seq_len
107
+ keep_schedule = torch.round(keep_schedule).long()
108
+
109
+ with torch.no_grad():
110
+ progress_bar = tqdm(range(steps), desc="Generating Samples")
111
+ for i in progress_bar:
112
+ # Time `t` should go from 0 (pure noise) up to 1 (pure data)
113
+ t_continuous = torch.full((num_samples,), (i) / steps, device=device)
114
+
115
+ logits = model(x, t_continuous)
116
+
117
+ # Apply temperature scaling to control diversity
118
+ scaled_logits = logits / temperature
119
+ probs = torch.nn.functional.softmax(scaled_logits, dim=-1)
120
+
121
+ # Sample a full new sequence from the model's prediction
122
+ sampled_tokens = torch.multinomial(probs.view(-1, model.vocab_size), 1).view(shape)
123
+
124
+ # For the last step, the new sample is our final result
125
+ if i == steps - 1:
126
+ x = sampled_tokens
127
+ break
128
+
129
+ # Determine which tokens from the *newly sampled sequence* to keep, based on confidence
130
+ confidence = torch.gather(probs, 2, sampled_tokens.unsqueeze(-1)).squeeze(-1)
131
+
132
+ # Find the indices of the most confident tokens to keep
133
+ num_to_keep = keep_schedule[i]
134
+ _, indices_to_keep = torch.topk(confidence, num_to_keep, largest=True, dim=-1)
135
+
136
+ # Create a mask for the tokens we are keeping
137
+ keep_mask = torch.zeros_like(x, dtype=torch.bool).scatter_(1, indices_to_keep, True)
138
+
139
+ # The next sequence `x` is a mix:
140
+ # - Where keep_mask is True, we use the new, confident sampled_tokens.
141
+ # - Where keep_mask is False, we keep the tokens from the previous step `x`.
142
+ x = torch.where(keep_mask, sampled_tokens, x)
143
+
144
+ return x
145
+
146
+ # --- Main Execution ---
147
+
148
+ def main(args):
149
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
150
+ print(f"Using device: {device}")
151
+
152
+ print(f"Loading checkpoint from {args.checkpoint}...")
153
+ try:
154
+ checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False)
155
+ model_args = checkpoint['args']
156
+ except FileNotFoundError:
157
+ print(f"Error: Checkpoint file not found at {args.checkpoint}")
158
+ return
159
+ except Exception as e:
160
+ print(f"Error loading checkpoint: {e}")
161
+ return
162
+
163
+ print("Initializing model...")
164
+ model = MDLM(
165
+ vocab_size=model_args.vocab_size,
166
+ seq_len=model_args.seq_len,
167
+ model_dim=model_args.model_dim,
168
+ n_heads=model_args.n_heads,
169
+ n_layers=model_args.n_layers
170
+ ).to(device)
171
+
172
+ model.load_state_dict(checkpoint['model_state_dict'])
173
+ print("Model loaded successfully.")
174
+
175
+ gen_len = args.gen_len if args.gen_len is not None else model_args.seq_len
176
+ if gen_len > model_args.seq_len:
177
+ raise ValueError(f"Requested generation length ({gen_len}) is greater than the model's max length ({model_args.seq_len}).")
178
+ print(f"Generating sequences of length {gen_len}.")
179
+
180
+ generated_tokens = generate_samples(
181
+ model=model,
182
+ device=device,
183
+ num_samples=args.num_samples,
184
+ seq_len=gen_len,
185
+ steps=args.gen_steps,
186
+ temperature=args.temperature
187
+ )
188
+
189
+ print("Decoding and saving samples...")
190
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
191
+
192
+ with open(args.output_file, 'w') as f:
193
+ for sample_tokens in generated_tokens:
194
+ sequence = tokenizer.decode(sample_tokens.tolist(), skip_special_tokens=False)
195
+ clean_sequence = sequence.replace(" ", "")[5:-5]
196
+ f.write(clean_sequence + "\n")
197
+ print(clean_sequence)
198
+
199
+ print(f"Generation complete. {args.num_samples} sequences saved to {args.output_file}")
200
+
201
+
202
+ if __name__ == "__main__":
203
+ parser = argparse.ArgumentParser(description="Generate samples from a trained ReDi (MDLM) model starting from random noise.")
204
+
205
+ parser.add_argument("--checkpoint", type=str, required=True, help="Path to the model checkpoint file.")
206
+ parser.add_argument("--num_samples", type=int, default=128, help="Number of samples to generate.")
207
+ parser.add_argument("--output_file", type=str, default="./generated_peptides.txt", help="File to save the generated peptide sequences.")
208
+ parser.add_argument("--gen_steps", type=int, default=16, help="Number of steps for the progressive refinement process.")
209
+ parser.add_argument("--gen_len", type=int, default=None, help="Desired length of the generated sequences. Defaults to the model's maximum trained length.")
210
+ parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature. >1 increases diversity, <1 decreases it.")
211
+
212
+ args = parser.parse_args()
213
+ main(args)
peptide/moo.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import random
4
+ from collections import Counter
5
+ import csv
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from tqdm import tqdm
10
+ from transformers import AutoTokenizer
11
+
12
+ from peptide_classifiers import *
13
+
14
+
15
+ # --- Model Architecture (Must match the trained model) ---
16
+ def modulate(x, shift, scale):
17
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
18
+
19
+ class TimestepEmbedder(nn.Module):
20
+ def __init__(self, hidden_size):
21
+ super().__init__()
22
+ self.mlp = nn.Sequential(
23
+ nn.Linear(1, hidden_size, bias=True), nn.SiLU(),
24
+ nn.Linear(hidden_size, hidden_size, bias=True),
25
+ )
26
+ def forward(self, t):
27
+ return self.mlp(t.unsqueeze(-1))
28
+
29
+ class DiTBlock(nn.Module):
30
+ def __init__(self, hidden_size, n_heads):
31
+ super().__init__()
32
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
33
+ self.attn = nn.MultiheadAttention(hidden_size, n_heads, batch_first=True)
34
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
35
+ self.mlp = nn.Sequential(
36
+ nn.Linear(hidden_size, 4 * hidden_size), nn.GELU(),
37
+ nn.Linear(4 * hidden_size, hidden_size)
38
+ )
39
+ self.adaLN_modulation = nn.Sequential(
40
+ nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)
41
+ )
42
+ def forward(self, x, c):
43
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
44
+ x_norm1 = modulate(self.norm1(x), shift_msa, scale_msa)
45
+ attn_output, _ = self.attn(x_norm1, x_norm1, x_norm1)
46
+ x = x + gate_msa.unsqueeze(1) * attn_output
47
+ x_norm2 = modulate(self.norm2(x), shift_mlp, scale_mlp)
48
+ mlp_output = self.mlp(x_norm2)
49
+ x = x + gate_mlp.unsqueeze(1) * mlp_output
50
+ return x
51
+
52
+ class MDLM(nn.Module):
53
+ def __init__(self, vocab_size, seq_len, model_dim, n_heads, n_layers):
54
+ super().__init__()
55
+ self.vocab_size = vocab_size
56
+ self.seq_len = seq_len
57
+ self.model_dim = model_dim
58
+ self.mask_token_id = vocab_size
59
+ self.token_embedder = nn.Embedding(vocab_size + 1, model_dim)
60
+ self.pos_embedder = nn.Parameter(torch.randn(1, seq_len, model_dim))
61
+ self.time_embedder = TimestepEmbedder(model_dim)
62
+ self.transformer_blocks = nn.ModuleList([DiTBlock(model_dim, n_heads) for _ in range(n_layers)])
63
+ self.final_norm = nn.LayerNorm(model_dim)
64
+ self.lm_head = nn.Linear(model_dim, vocab_size)
65
+ def forward(self, x, t):
66
+ seq_len = x.shape[1]
67
+ x_embed = self.token_embedder(x) + self.pos_embedder[:, :seq_len, :]
68
+ t_embed = self.time_embedder(t)
69
+ for block in self.transformer_blocks:
70
+ x_embed = block(x_embed, t_embed)
71
+ x_embed = self.final_norm(x_embed)
72
+ logits = self.lm_head(x_embed)
73
+ return logits
74
+
75
+ class MOGGenerator:
76
+ def __init__(self, model, device, objectives, args):
77
+ self.model = model
78
+ self.device = device
79
+ self.objectives = objectives
80
+ self.args = args
81
+ self.num_objectives = len(objectives)
82
+
83
+ def _get_scores(self, x_batch):
84
+ """Calculates the normalized scores for a batch of sequences."""
85
+ scores = []
86
+ for obj_func in self.objectives:
87
+ scores.append(obj_func(x_batch.to(self.device)))
88
+ return torch.stack(scores, dim=0)
89
+
90
+ def _barker_g(self, u):
91
+ """Barker balancing function."""
92
+ return u / (1 + u)
93
+
94
+ def generate(self):
95
+ """Main generation loop."""
96
+ shape = (self.args.num_samples, self.args.gen_len + 2)
97
+ x = torch.randint(5, self.model.vocab_size, shape, dtype=torch.long, device=self.device)
98
+ x[:, 0] = 0
99
+ x[:, -1] = 2
100
+
101
+ if args.weights is None:
102
+ weights = torch.full((self.num_objectives,), 1/self.num_objectives, device=self.device).view(-1,1)
103
+ else:
104
+ weights = torch.tensor(self.args.weights, device=self.device).view(-1, 1)
105
+ if len(weights) != self.num_objectives:
106
+ raise ValueError("Number of weights must match number of objectives.")
107
+ print(f"Weights: {weights}")
108
+
109
+ if self.args.min_threshold is not None:
110
+ min_threshold = torch.tensor(self.args.min_threshold, device=self.device)
111
+ else:
112
+ min_threshold = None
113
+
114
+ total_optimization_steps = self.args.optimization_steps * self.args.gen_len
115
+
116
+ with torch.no_grad():
117
+ for t in tqdm(range(total_optimization_steps), desc="MOG Generation"):
118
+ # Anneal guidance strength
119
+ eta_t = self.args.eta_min + (self.args.eta_max - self.args.eta_min) * (t / (total_optimization_steps - 1))
120
+ # eta_t = 0.5 * (self.args.eta_min + self.args.eta_max)
121
+ # Choose a random position to mutate
122
+ mut_idx = random.randint(1, self.args.gen_len)
123
+
124
+ # Determine the generation timestep
125
+ # We cycle through the timesteps to ensure all are visited
126
+ generation_step = t % self.args.optimization_steps
127
+ time_t = torch.full((self.args.num_samples,), (generation_step / self.args.optimization_steps), device=self.device)
128
+
129
+ # Get proposal distribution from ReDi model for the chosen position
130
+ logits = self.model(x, time_t)
131
+ probs = F.softmax(logits, dim=-1)
132
+ pos_probs = probs[:, mut_idx, :]
133
+ pos_probs[:, x[:, mut_idx]] = 0 # We don't evalute the same token
134
+
135
+ # Prune candidate vocabulary using top-p sampling
136
+ sorted_probs, sorted_indices = torch.sort(pos_probs, descending=True)
137
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
138
+ remove_mask = cumulative_probs > self.args.top_p
139
+ remove_mask[..., 1:] = remove_mask[..., :-1].clone()
140
+ remove_mask[..., 0] = 0
141
+
142
+ # Get the set of candidate tokens for each sample in the batch
143
+ candidate_tokens_list = []
144
+ for i in range(self.args.num_samples):
145
+ sample_mask = remove_mask[i]
146
+ candidates = sorted_indices[i, ~sample_mask]
147
+ candidate_tokens_list.append(candidates)
148
+
149
+ # Get current scores
150
+ current_scores = self._get_scores(x)
151
+ w_current = torch.exp(eta_t * torch.min(weights * current_scores, dim=0).values)
152
+
153
+ # Evaluate all candidate tokens for each sample
154
+ final_proposal_tokens = []
155
+ for i in range(self.args.num_samples):
156
+ candidates = candidate_tokens_list[i]
157
+ candidates = torch.tensor([token for token in candidates if token not in [0,1,2,3]], device=candidates.device)
158
+ num_candidates = len(candidates)
159
+
160
+ # Create a batch of proposed sequences for the current sample
161
+ x_prop_batch = x[i].repeat(num_candidates, 1)
162
+ x_prop_batch[:, mut_idx] = candidates
163
+
164
+ # Evaluate all proposals
165
+ proposal_scores = self._get_scores(x_prop_batch)
166
+ proposal_s_omega = torch.min(weights * proposal_scores, dim=0).values
167
+ w_proposal = torch.exp(eta_t * proposal_s_omega)
168
+
169
+ # Get ReDi probabilities for the candidates
170
+ redi_probs = pos_probs[i, candidates]
171
+
172
+ # Calculate unnormalized guided probabilities
173
+ tilde_q = redi_probs * self._barker_g(w_proposal / w_current[i])
174
+
175
+ # Normalize and sample the final token
176
+ final_probs = tilde_q / (torch.sum(tilde_q) + 1e-9)
177
+
178
+ index = torch.multinomial(final_probs, 1).item()
179
+ if torch.sum(weights.squeeze(1) * proposal_scores[:, index]) >= torch.sum(weights.squeeze(1) * current_scores[:,i]):
180
+ final_token = candidates[index]
181
+ print(f"Previous Weighted Sum: {torch.sum(weights.squeeze(1) * current_scores[:,i])}")
182
+ print(f"Previous Scores: {current_scores[:,i]}")
183
+
184
+ print(f"New Weighted Sum: {torch.sum(weights.squeeze(1) * proposal_scores[:, index])}")
185
+ print(f"New Scores: {proposal_scores[:,index]}")
186
+ else:
187
+ final_token = x[i][mut_idx]
188
+ # final_token = candidates[index]
189
+
190
+ final_proposal_tokens.append(final_token)
191
+
192
+ # Update the sequences with the chosen tokens
193
+ x[torch.arange(self.args.num_samples), mut_idx] = torch.stack(final_proposal_tokens)
194
+
195
+ scores = self._get_scores(x)
196
+
197
+ return x
198
+
199
+ # --- Main Execution ---
200
+ def main(args):
201
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
202
+ print(f"Using device: {device}")
203
+
204
+ target = args.target
205
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
206
+ target_sequence = tokenizer(target, return_tensors='pt')['input_ids'].to(device)
207
+
208
+ affinity_predictor = load_affinity_predictor('/scratch/pranamlab/tong/ReDi_discrete/peptides/classifier_ckpt/binding_affinity_unpooled.pt', device)
209
+ affinity_model = AffinityModel(affinity_predictor, target_sequence)
210
+ hemolysis_model = HemolysisModel(device=device)
211
+ nonfouling_model = NonfoulingModel(device=device)
212
+ solubility_model = SolubilityModel(device=device)
213
+ halflife_model = HalfLifeModel(device=device)
214
+
215
+ print(f"Loading checkpoint from {args.checkpoint}...")
216
+ try:
217
+ checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False)
218
+ model_args = checkpoint['args']
219
+ except Exception as e:
220
+ print(f"Error loading checkpoint: {e}")
221
+ return
222
+
223
+ print("Initializing model...")
224
+ model = MDLM(
225
+ vocab_size=model_args.vocab_size,
226
+ seq_len=model_args.seq_len,
227
+ model_dim=model_args.model_dim,
228
+ n_heads=model_args.n_heads,
229
+ n_layers=model_args.n_layers
230
+ ).to(device)
231
+ model.load_state_dict(checkpoint['model_state_dict'])
232
+ print("Model loaded successfully.")
233
+
234
+ # List of all objective functions
235
+ OBJECTIVE_FUNCTIONS = [hemolysis_model, nonfouling_model, solubility_model, halflife_model, affinity_model]
236
+
237
+ mog_generator = MOGGenerator(model, device, OBJECTIVE_FUNCTIONS, args)
238
+
239
+ hemolysis = []
240
+ nonfouling = []
241
+ solubility = []
242
+ halflife = []
243
+ affinity = []
244
+
245
+ for _ in range(args.num_batches):
246
+ generated_tokens = mog_generator.generate()
247
+ final_scores = mog_generator._get_scores(generated_tokens).detach().cpu().numpy()
248
+
249
+ with open(args.output_file, 'a', newline='') as f:
250
+ writer = csv.writer(f)
251
+
252
+ for i in range(args.num_samples):
253
+ sample_tokens = generated_tokens[i]
254
+ print(sample_tokens)
255
+ sequence_str = tokenizer.decode(sample_tokens.tolist(), skip_special_tokens=False).replace(" ", "")[5:-5]
256
+
257
+ scores = final_scores[:, i]
258
+
259
+ writer.writerow([sequence_str] + scores.tolist())
260
+
261
+ print([sequence_str] + scores.tolist())
262
+
263
+ print("Generation complete.")
264
+
265
+
266
+
267
+ if __name__ == "__main__":
268
+ parser = argparse.ArgumentParser(description="Multi-Objective Generation with LBP-MOG-ReDi (Single Mutation).")
269
+
270
+ parser.add_argument("--checkpoint", type=str, required=True, help="Path to the trained ReDi model checkpoint.")
271
+ parser.add_argument("--num_samples", type=int, default=10, help="Number of samples to generate.")
272
+ parser.add_argument("--num_batches", type=int, default=10, help="Number of samples to generate.")
273
+ parser.add_argument("--output_file", type=str, default="./mog_peptides.txt", help="File to save the generated sequences.")
274
+ parser.add_argument("--gen_len", type=int, default=50, help="Length of the sequences to generate.")
275
+ parser.add_argument("--optimization_steps", type=int, default=16, help="Number of passes over the sequence.")
276
+ parser.add_argument("--weights", type=float, nargs='+', required=False, help="Weights for the objectives (e.g., 0.5 0.5).")
277
+ parser.add_argument("--min_threshold", type=float, nargs='+', required=False, help="minimum threshold for the objectives (e.g., 0.2 0.2).")
278
+ parser.add_argument("--eta_min", type=float, default=1.0, help="Minimum guidance strength for annealing.")
279
+ parser.add_argument("--eta_max", type=float, default=20.0, help="Maximum guidance strength for annealing.")
280
+ parser.add_argument("--top_p", type=float, default=0.9, help="Top-p for pruning candidate tokens.")
281
+
282
+ parser.add_argument("--target", type=str, required=True)
283
+ args = parser.parse_args()
284
+ main(args)
peptide/new_coupling.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ from collections import defaultdict
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from tqdm import tqdm
9
+ from datasets import Dataset, DatasetDict
10
+
11
+ # --- Model Architecture (Must match the trained model) ---
12
+ def modulate(x, shift, scale):
13
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
14
+
15
+ class TimestepEmbedder(nn.Module):
16
+ def __init__(self, hidden_size):
17
+ super().__init__()
18
+ self.mlp = nn.Sequential(
19
+ nn.Linear(1, hidden_size, bias=True), nn.SiLU(),
20
+ nn.Linear(hidden_size, hidden_size, bias=True),
21
+ )
22
+ def forward(self, t):
23
+ return self.mlp(t.unsqueeze(-1))
24
+
25
+ class DiTBlock(nn.Module):
26
+ def __init__(self, hidden_size, n_heads):
27
+ super().__init__()
28
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
29
+ self.attn = nn.MultiheadAttention(hidden_size, n_heads, batch_first=True)
30
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
31
+ self.mlp = nn.Sequential(
32
+ nn.Linear(hidden_size, 4 * hidden_size), nn.GELU(),
33
+ nn.Linear(4 * hidden_size, hidden_size)
34
+ )
35
+ self.adaLN_modulation = nn.Sequential(
36
+ nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)
37
+ )
38
+ def forward(self, x, c):
39
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
40
+ x_norm1 = modulate(self.norm1(x), shift_msa, scale_msa)
41
+ attn_output, _ = self.attn(x_norm1, x_norm1, x_norm1)
42
+ x = x + gate_msa.unsqueeze(1) * attn_output
43
+ x_norm2 = modulate(self.norm2(x), shift_mlp, scale_mlp)
44
+ mlp_output = self.mlp(x_norm2)
45
+ x = x + gate_mlp.unsqueeze(1) * mlp_output
46
+ return x
47
+
48
+ class MDLM(nn.Module):
49
+ def __init__(self, vocab_size, seq_len, model_dim, n_heads, n_layers):
50
+ super().__init__()
51
+ self.vocab_size = vocab_size
52
+ self.seq_len = seq_len
53
+ self.model_dim = model_dim
54
+ self.mask_token_id = vocab_size
55
+ self.token_embedder = nn.Embedding(vocab_size + 1, model_dim)
56
+ self.pos_embedder = nn.Parameter(torch.randn(1, seq_len, model_dim))
57
+ self.time_embedder = TimestepEmbedder(model_dim)
58
+ self.transformer_blocks = nn.ModuleList([DiTBlock(model_dim, n_heads) for _ in range(n_layers)])
59
+ self.final_norm = nn.LayerNorm(model_dim)
60
+ self.lm_head = nn.Linear(model_dim, vocab_size)
61
+ def forward(self, x, t):
62
+ seq_len = x.shape[1]
63
+ x_embed = self.token_embedder(x) + self.pos_embedder[:, :seq_len, :]
64
+ t_embed = self.time_embedder(t)
65
+ for block in self.transformer_blocks:
66
+ x_embed = block(x_embed, t_embed)
67
+ x_embed = self.final_norm(x_embed)
68
+ logits = self.lm_head(x_embed)
69
+ return logits
70
+
71
+ # --- Generation & Utility Functions ---
72
+
73
+ def generate_x1_from_x0(model, device, x0_batch, steps, temperature):
74
+ model.eval()
75
+ x = x0_batch.clone()
76
+ num_samples, seq_len = x.shape
77
+ keep_schedule = torch.cos(torch.linspace(math.pi / 2, 0, steps, device=device)) * seq_len
78
+ keep_schedule = torch.round(keep_schedule).long()
79
+ with torch.no_grad():
80
+ for i in range(steps):
81
+ t_continuous = torch.full((num_samples,), 1.0 - (i / steps), device=device)
82
+ logits = model(x, t_continuous)
83
+ scaled_logits = logits / temperature
84
+ probs = torch.nn.functional.softmax(scaled_logits, dim=-1)
85
+ sampled_tokens = torch.multinomial(probs.view(-1, model.vocab_size), 1).view(x.shape)
86
+ if i == steps - 1:
87
+ x = sampled_tokens
88
+ break
89
+ confidence = torch.gather(probs, 2, sampled_tokens.unsqueeze(-1)).squeeze(-1)
90
+ num_to_keep = keep_schedule[i]
91
+ _, indices_to_keep = torch.topk(confidence, num_to_keep, largest=True, dim=-1)
92
+ keep_mask = torch.zeros_like(x, dtype=torch.bool).scatter_(1, indices_to_keep, True)
93
+ x = torch.where(keep_mask, sampled_tokens, x)
94
+ return x
95
+
96
+ def is_sample_valid(sample_x1):
97
+ """
98
+ Checks if special tokens [0, 1, 2, 3] appear in the middle of the sequence.
99
+ """
100
+ middle_sequence = sample_x1[1:-1]
101
+ invalid_tokens = {0, 1, 2, 3}
102
+ for token in middle_sequence:
103
+ if token in invalid_tokens:
104
+ return False
105
+ return True
106
+
107
+ def create_prebatched_dataset(dataset, max_tokens_per_batch=500):
108
+ """
109
+ Groups samples into batches and restructures the dataset.
110
+ Each row in the new dataset is a complete batch.
111
+ """
112
+ # Group samples by their length
113
+ data_by_length = defaultdict(list)
114
+ for sample in dataset:
115
+ length = len(sample['input_ids_x1'])
116
+ data_by_length[length].append(sample)
117
+
118
+ # Create the actual batches
119
+ batched_data = {'input_ids_x0': [], 'input_ids_x1': []}
120
+ for length, samples in data_by_length.items():
121
+ samples_per_batch = max(1, max_tokens_per_batch // length)
122
+ for i in range(0, len(samples), samples_per_batch):
123
+ batch_samples = samples[i:i + samples_per_batch]
124
+
125
+ batch_x0 = [s['input_ids_x0'] for s in batch_samples]
126
+ batch_x1 = [s['input_ids_x1'] for s in batch_samples]
127
+
128
+ batched_data['input_ids_x0'].append(batch_x0)
129
+ batched_data['input_ids_x1'].append(batch_x1)
130
+
131
+ return Dataset.from_dict(batched_data)
132
+
133
+ # --- Main Execution ---
134
+
135
+ def main(args):
136
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
137
+ print(f"Using device: {device}")
138
+
139
+ print(f"Loading checkpoint from {args.checkpoint}...")
140
+ try:
141
+ checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False)
142
+ model_args = checkpoint['args']
143
+ except Exception as e:
144
+ print(f"Error loading checkpoint: {e}")
145
+ return
146
+
147
+ print("Initializing model...")
148
+ model = MDLM(
149
+ vocab_size=model_args.vocab_size,
150
+ seq_len=model_args.seq_len,
151
+ model_dim=model_args.model_dim,
152
+ n_heads=model_args.n_heads,
153
+ n_layers=model_args.n_layers
154
+ ).to(device)
155
+ model.load_state_dict(checkpoint['model_state_dict'])
156
+ print("Model loaded successfully.")
157
+
158
+ all_x0 = []
159
+ all_x1 = []
160
+
161
+ # 1. Generate samples for each length
162
+ for length in range(args.min_len, args.max_len + 1):
163
+ print(f"Generating {args.samples_per_len} valid samples for length {length}...")
164
+ valid_samples_count = 0
165
+ pbar = tqdm(total=args.samples_per_len)
166
+ while valid_samples_count < args.samples_per_len:
167
+ remaining = args.samples_per_len - valid_samples_count
168
+ batch_size = min(args.batch_size, remaining)
169
+
170
+ shape = (batch_size, length)
171
+ x0_batch = torch.randint(0, model.vocab_size, shape, dtype=torch.long, device=device)
172
+ x1_batch = generate_x1_from_x0(model, device, x0_batch, args.gen_steps, args.temperature)
173
+
174
+ # 2. Perform sanity check on each sample
175
+ for x0, x1 in zip(x0_batch, x1_batch):
176
+ if is_sample_valid(x1.tolist()):
177
+ all_x0.append(x0.cpu().tolist())
178
+ all_x1.append(x1.cpu().tolist())
179
+ valid_samples_count += 1
180
+ pbar.update(1)
181
+ if valid_samples_count >= args.samples_per_len:
182
+ break
183
+ pbar.close()
184
+
185
+ # 3. Create dataset and split
186
+ print("Splitting dataset...")
187
+ rectified_data = {'input_ids_x0': all_x0, 'input_ids_x1': all_x1}
188
+ dataset = Dataset.from_dict(rectified_data)
189
+ train_test_split = dataset.train_test_split(test_size=0.2, seed=42)
190
+ valid_test_split = train_test_split['test'].train_test_split(test_size=0.5, seed=42)
191
+ final_dataset_dict = DatasetDict({
192
+ 'train': train_test_split['train'],
193
+ 'validation': valid_test_split['train'],
194
+ 'test': valid_test_split['test']
195
+ })
196
+
197
+ # 4. Pre-batch each split
198
+ print("Pre-batching splits...")
199
+ batched_dataset_dict = DatasetDict()
200
+ for split_name, split_dataset in final_dataset_dict.items():
201
+ print(f"Processing {split_name} split...")
202
+ batched_dataset_dict[split_name] = create_prebatched_dataset(split_dataset)
203
+
204
+ # 5. Save the final dataset
205
+ output_path = f"{args.output_path}/v{args.version}"
206
+ print(f"Saving new batched dataset to {output_path}...")
207
+ batched_dataset_dict.save_to_disk(output_path)
208
+
209
+ print("Rectification complete.")
210
+ print(f"Train on this by updating your training script's dataset path to '{output_path}'.")
211
+
212
+ if __name__ == "__main__":
213
+ parser = argparse.ArgumentParser(description="Generate a rectified dataset with variable lengths and pre-batching.")
214
+
215
+ parser.add_argument("--checkpoint", type=str, required=True)
216
+ parser.add_argument("--output_path", type=str, default="./rectified_datasets")
217
+ parser.add_argument("--version", type=str, default='1')
218
+ parser.add_argument("--samples_per_len", type=int, default=10000)
219
+ parser.add_argument("--min_len", type=int, default=6)
220
+ parser.add_argument("--max_len", type=int, default=49)
221
+ parser.add_argument("--gen_steps", type=int, default=16)
222
+ parser.add_argument("--temperature", type=float, default=1.0)
223
+ parser.add_argument("--batch_size", type=int, default=128)
224
+
225
+ args = parser.parse_args()
226
+ main(args)
peptide/peptide_classifiers.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ import pytorch_lightning as pl
5
+ import time
6
+ from transformers import AutoModel, AutoConfig, AutoTokenizer
7
+ import xgboost as xgb
8
+ import esm
9
+
10
+ class UnpooledBindingPredictor(nn.Module):
11
+ def __init__(self,
12
+ esm_model_name="facebook/esm2_t33_650M_UR50D",
13
+ hidden_dim=512,
14
+ kernel_sizes=[3, 5, 7],
15
+ n_heads=8,
16
+ n_layers=3,
17
+ dropout=0.1,
18
+ freeze_esm=True):
19
+ super().__init__()
20
+
21
+ # Define binding thresholds
22
+ self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
23
+ self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
24
+
25
+ # Load ESM model for computing embeddings on the fly
26
+ self.esm_model = AutoModel.from_pretrained(esm_model_name)
27
+ self.config = AutoConfig.from_pretrained(esm_model_name)
28
+
29
+ # Freeze ESM parameters if needed
30
+ if freeze_esm:
31
+ for param in self.esm_model.parameters():
32
+ param.requires_grad = False
33
+
34
+ # Get ESM hidden size
35
+ esm_dim = self.config.hidden_size
36
+
37
+ # Output channels for CNN layers
38
+ output_channels_per_kernel = 64
39
+
40
+ # CNN layers for handling variable length sequences
41
+ self.protein_conv_layers = nn.ModuleList([
42
+ nn.Conv1d(
43
+ in_channels=esm_dim,
44
+ out_channels=output_channels_per_kernel,
45
+ kernel_size=k,
46
+ padding='same'
47
+ ) for k in kernel_sizes
48
+ ])
49
+
50
+ self.binder_conv_layers = nn.ModuleList([
51
+ nn.Conv1d(
52
+ in_channels=esm_dim,
53
+ out_channels=output_channels_per_kernel,
54
+ kernel_size=k,
55
+ padding='same'
56
+ ) for k in kernel_sizes
57
+ ])
58
+
59
+ # Calculate total features after convolution and pooling
60
+ total_features_per_seq = output_channels_per_kernel * len(kernel_sizes) * 2
61
+
62
+ # Project to same dimension after CNN processing
63
+ self.protein_projection = nn.Linear(total_features_per_seq, hidden_dim)
64
+ self.binder_projection = nn.Linear(total_features_per_seq, hidden_dim)
65
+
66
+ self.protein_norm = nn.LayerNorm(hidden_dim)
67
+ self.binder_norm = nn.LayerNorm(hidden_dim)
68
+
69
+ # Cross attention blocks with layer norm
70
+ self.cross_attention_layers = nn.ModuleList([
71
+ nn.ModuleDict({
72
+ 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
73
+ 'norm1': nn.LayerNorm(hidden_dim),
74
+ 'ffn': nn.Sequential(
75
+ nn.Linear(hidden_dim, hidden_dim * 4),
76
+ nn.ReLU(),
77
+ nn.Dropout(dropout),
78
+ nn.Linear(hidden_dim * 4, hidden_dim)
79
+ ),
80
+ 'norm2': nn.LayerNorm(hidden_dim)
81
+ }) for _ in range(n_layers)
82
+ ])
83
+
84
+ # Prediction heads
85
+ self.shared_head = nn.Sequential(
86
+ nn.Linear(hidden_dim * 2, hidden_dim),
87
+ nn.ReLU(),
88
+ nn.Dropout(dropout),
89
+ )
90
+
91
+ # Regression head
92
+ self.regression_head = nn.Linear(hidden_dim, 1)
93
+
94
+ # Classification head (3 classes: tight, medium, loose binding)
95
+ self.classification_head = nn.Linear(hidden_dim, 3)
96
+
97
+ def get_binding_class(self, affinity):
98
+ """Convert affinity values to class indices
99
+ 0: tight binding (>= 7.5)
100
+ 1: medium binding (6.0-7.5)
101
+ 2: weak binding (< 6.0)
102
+ """
103
+ if isinstance(affinity, torch.Tensor):
104
+ tight_mask = affinity >= self.tight_threshold
105
+ weak_mask = affinity < self.weak_threshold
106
+ medium_mask = ~(tight_mask | weak_mask)
107
+
108
+ classes = torch.zeros_like(affinity, dtype=torch.long)
109
+ classes[medium_mask] = 1
110
+ classes[weak_mask] = 2
111
+ return classes
112
+ else:
113
+ if affinity >= self.tight_threshold:
114
+ return 0 # tight binding
115
+ elif affinity < self.weak_threshold:
116
+ return 2 # weak binding
117
+ else:
118
+ return 1 # medium binding
119
+
120
+ def compute_embeddings(self, input_ids, attention_mask=None):
121
+ """Compute ESM embeddings on the fly"""
122
+ esm_outputs = self.esm_model(
123
+ input_ids=input_ids,
124
+ attention_mask=attention_mask,
125
+ return_dict=True
126
+ )
127
+
128
+ # Get the unpooled last hidden states (batch_size x seq_length x hidden_size)
129
+ return esm_outputs.last_hidden_state
130
+
131
+ def process_sequence(self, unpooled_emb, conv_layers, attention_mask=None):
132
+ """Process a sequence through CNN layers and pooling"""
133
+ # Transpose for CNN: [batch_size, hidden_size, seq_length]
134
+ x = unpooled_emb.transpose(1, 2)
135
+
136
+ # Apply CNN layers and collect outputs
137
+ conv_outputs = []
138
+ for conv in conv_layers:
139
+ conv_out = F.relu(conv(x))
140
+ conv_outputs.append(conv_out)
141
+
142
+ # Concatenate along channel dimension
143
+ conv_output = torch.cat(conv_outputs, dim=1)
144
+
145
+ # Global pooling (both max and average)
146
+ # If attention mask is provided, use it to create a proper mask for pooling
147
+ if attention_mask is not None:
148
+ # Create a mask for pooling (1 for valid positions, 0 for padding)
149
+ # Expand mask to match conv_output channels
150
+ expanded_mask = attention_mask.unsqueeze(1).expand(-1, conv_output.size(1), -1)
151
+
152
+ # Apply mask (set padding to large negative value for max pooling)
153
+ masked_output = conv_output.clone()
154
+ masked_output = masked_output.masked_fill(expanded_mask == 0, float('-inf'))
155
+
156
+ # Max pooling along sequence dimension
157
+ max_pooled = torch.max(masked_output, dim=2)[0]
158
+
159
+ # Average pooling (sum divided by number of valid positions)
160
+ sum_pooled = torch.sum(conv_output * expanded_mask, dim=2)
161
+ valid_positions = torch.sum(expanded_mask, dim=2)
162
+ valid_positions = torch.clamp(valid_positions, min=1.0) # Avoid division by zero
163
+ avg_pooled = sum_pooled / valid_positions
164
+ else:
165
+ # If no mask, use standard pooling
166
+ max_pooled = torch.max(conv_output, dim=2)[0]
167
+ avg_pooled = torch.mean(conv_output, dim=2)
168
+
169
+ # Concatenate the pooled features
170
+ pooled = torch.cat([max_pooled, avg_pooled], dim=1)
171
+
172
+ return pooled
173
+
174
+ def forward(self, protein_input_ids, binder_input_ids, protein_mask=None, binder_mask=None):
175
+ # Compute embeddings on the fly using the ESM model
176
+ protein_unpooled = self.compute_embeddings(protein_input_ids, protein_mask)
177
+ binder_unpooled = self.compute_embeddings(binder_input_ids, binder_mask)
178
+
179
+ # Process protein and binder sequences through CNN layers
180
+ protein_features = self.process_sequence(protein_unpooled, self.protein_conv_layers, protein_mask)
181
+ binder_features = self.process_sequence(binder_unpooled, self.binder_conv_layers, binder_mask)
182
+
183
+ # Project to same dimension
184
+ protein = self.protein_norm(self.protein_projection(protein_features))
185
+ binder = self.binder_norm(self.binder_projection(binder_features))
186
+
187
+ # Reshape for attention: from [batch_size, hidden_dim] to [1, batch_size, hidden_dim]
188
+ protein = protein.unsqueeze(0)
189
+ binder = binder.unsqueeze(0)
190
+
191
+ # Cross attention layers
192
+ for layer in self.cross_attention_layers:
193
+ # Protein attending to binder
194
+ attended_protein = layer['attention'](
195
+ protein, binder, binder
196
+ )[0]
197
+ protein = layer['norm1'](protein + attended_protein)
198
+ protein = layer['norm2'](protein + layer['ffn'](protein))
199
+
200
+ # Binder attending to protein
201
+ attended_binder = layer['attention'](
202
+ binder, protein, protein
203
+ )[0]
204
+ binder = layer['norm1'](binder + attended_binder)
205
+ binder = layer['norm2'](binder + layer['ffn'](binder))
206
+
207
+ # Remove sequence dimension
208
+ protein_pool = protein.squeeze(0)
209
+ binder_pool = binder.squeeze(0)
210
+
211
+ # Concatenate both representations
212
+ combined = torch.cat([protein_pool, binder_pool], dim=-1)
213
+
214
+ # Shared features
215
+ shared_features = self.shared_head(combined)
216
+
217
+ regression_output = self.regression_head(shared_features)
218
+ # classification_logits = self.classification_head(shared_features)
219
+
220
+ # return regression_output, classification_logits
221
+ return regression_output
222
+
223
+ class ImprovedBindingPredictor(nn.Module):
224
+ def __init__(self,
225
+ esm_dim=1280,
226
+ smiles_dim=1280,
227
+ hidden_dim=512,
228
+ n_heads=8,
229
+ n_layers=5,
230
+ dropout=0.1):
231
+ super().__init__()
232
+
233
+ # Define binding thresholds
234
+ self.tight_threshold = 7.5 # Kd/Ki/IC50 ≤ ~30nM
235
+ self.weak_threshold = 6.0 # Kd/Ki/IC50 > 1μM
236
+
237
+ # Project to same dimension
238
+ self.smiles_projection = nn.Linear(smiles_dim, hidden_dim)
239
+ self.protein_projection = nn.Linear(esm_dim, hidden_dim)
240
+ self.protein_norm = nn.LayerNorm(hidden_dim)
241
+ self.smiles_norm = nn.LayerNorm(hidden_dim)
242
+
243
+ # Cross attention blocks with layer norm
244
+ self.cross_attention_layers = nn.ModuleList([
245
+ nn.ModuleDict({
246
+ 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
247
+ 'norm1': nn.LayerNorm(hidden_dim),
248
+ 'ffn': nn.Sequential(
249
+ nn.Linear(hidden_dim, hidden_dim * 4),
250
+ nn.ReLU(),
251
+ nn.Dropout(dropout),
252
+ nn.Linear(hidden_dim * 4, hidden_dim)
253
+ ),
254
+ 'norm2': nn.LayerNorm(hidden_dim)
255
+ }) for _ in range(n_layers)
256
+ ])
257
+
258
+ # Prediction heads
259
+ self.shared_head = nn.Sequential(
260
+ nn.Linear(hidden_dim * 2, hidden_dim),
261
+ nn.ReLU(),
262
+ nn.Dropout(dropout),
263
+ )
264
+
265
+ # Regression head
266
+ self.regression_head = nn.Linear(hidden_dim, 1)
267
+
268
+ # Classification head (3 classes: tight, medium, loose binding)
269
+ self.classification_head = nn.Linear(hidden_dim, 3)
270
+
271
+ def get_binding_class(self, affinity):
272
+ """Convert affinity values to class indices
273
+ 0: tight binding (>= 7.5)
274
+ 1: medium binding (6.0-7.5)
275
+ 2: weak binding (< 6.0)
276
+ """
277
+ if isinstance(affinity, torch.Tensor):
278
+ tight_mask = affinity >= self.tight_threshold
279
+ weak_mask = affinity < self.weak_threshold
280
+ medium_mask = ~(tight_mask | weak_mask)
281
+
282
+ classes = torch.zeros_like(affinity, dtype=torch.long)
283
+ classes[medium_mask] = 1
284
+ classes[weak_mask] = 2
285
+ return classes
286
+ else:
287
+ if affinity >= self.tight_threshold:
288
+ return 0 # tight binding
289
+ elif affinity < self.weak_threshold:
290
+ return 2 # weak binding
291
+ else:
292
+ return 1 # medium binding
293
+
294
+ def forward(self, protein_emb, binder_emb):
295
+
296
+ protein = self.protein_norm(self.protein_projection(protein_emb))
297
+ smiles = self.smiles_norm(self.smiles_projection(binder_emb))
298
+
299
+ protein = protein.transpose(0, 1)
300
+ smiles = smiles.transpose(0, 1)
301
+
302
+ # Cross attention layers
303
+ for layer in self.cross_attention_layers:
304
+ # Protein attending to SMILES
305
+ attended_protein = layer['attention'](
306
+ protein, smiles, smiles
307
+ )[0]
308
+ protein = layer['norm1'](protein + attended_protein)
309
+ protein = layer['norm2'](protein + layer['ffn'](protein))
310
+
311
+ # SMILES attending to protein
312
+ attended_smiles = layer['attention'](
313
+ smiles, protein, protein
314
+ )[0]
315
+ smiles = layer['norm1'](smiles + attended_smiles)
316
+ smiles = layer['norm2'](smiles + layer['ffn'](smiles))
317
+
318
+ # Get sequence-level representations
319
+ protein_pool = torch.mean(protein, dim=0)
320
+ smiles_pool = torch.mean(smiles, dim=0)
321
+
322
+ # Concatenate both representations
323
+ combined = torch.cat([protein_pool, smiles_pool], dim=-1)
324
+
325
+ # Shared features
326
+ shared_features = self.shared_head(combined)
327
+
328
+ regression_output = self.regression_head(shared_features)
329
+
330
+ return regression_output
331
+
332
+ class PooledAffinityModel(nn.Module):
333
+ def __init__(self, affinity_predictor, target_sequence):
334
+ super(PooledAffinityModel, self).__init__()
335
+ self.affinity_predictor = affinity_predictor
336
+ self.target_sequence = target_sequence
337
+ self.esm_model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(self.target_sequence.device)
338
+ for param in self.esm_model.parameters():
339
+ param.requires_grad = False
340
+
341
+ def compute_embeddings(self, input_ids, attention_mask=None):
342
+ """Compute ESM embeddings on the fly"""
343
+ esm_outputs = self.esm_model(
344
+ input_ids=input_ids,
345
+ attention_mask=attention_mask,
346
+ return_dict=True
347
+ )
348
+
349
+ # Get the unpooled last hidden states (batch_size x seq_length x hidden_size)
350
+ return esm_outputs.last_hidden_state
351
+
352
+ def forward(self, x):
353
+ target_sequence = self.target_sequence.repeat(x.shape[0], 1)
354
+
355
+ protein_emb = self.compute_embeddings(input_ids=target_sequence)
356
+ binder_emb = self.compute_embeddings(input_ids=x)
357
+ return self.affinity_predictor(protein_emb=protein_emb, binder_emb=binder_emb).squeeze(-1)
358
+
359
+ class AffinityModel(nn.Module):
360
+ def __init__(self, affinity_predictor, target_sequence):
361
+ super(AffinityModel, self).__init__()
362
+ self.affinity_predictor = affinity_predictor
363
+ self.target_sequence = target_sequence
364
+
365
+ def forward(self, x):
366
+ target_sequence = self.target_sequence.repeat(x.shape[0], 1)
367
+ affinity = self.affinity_predictor(protein_input_ids=target_sequence, binder_input_ids=x).squeeze(-1)
368
+ return affinity / 10
369
+
370
+ class HemolysisModel:
371
+ def __init__(self, device):
372
+ self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_hemolysis.json')
373
+
374
+ self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
375
+ self.model.eval()
376
+
377
+ self.device = device
378
+
379
+ def generate_embeddings(self, sequences):
380
+ """Generate ESM embeddings for protein sequences"""
381
+ with torch.no_grad():
382
+ embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1)
383
+ embeddings = embeddings.cpu().numpy()
384
+
385
+ return embeddings
386
+
387
+ def get_scores(self, input_seqs):
388
+ scores = np.ones(len(input_seqs))
389
+ features = self.generate_embeddings(input_seqs)
390
+
391
+ if len(features) == 0:
392
+ return scores
393
+
394
+ features = np.nan_to_num(features, nan=0.)
395
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
396
+
397
+ features = xgb.DMatrix(features)
398
+
399
+ probs = self.predictor.predict(features)
400
+ # return the probability of it being not hemolytic
401
+ return torch.from_numpy(scores - probs).to(self.device)
402
+
403
+ def __call__(self, input_seqs: list):
404
+ scores = self.get_scores(input_seqs)
405
+ return scores
406
+
407
+ class NonfoulingModel:
408
+ def __init__(self, device):
409
+ # change model path
410
+ self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_nonfouling.json')
411
+
412
+ self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
413
+ self.model.eval()
414
+
415
+ self.device = device
416
+
417
+ def generate_embeddings(self, sequences):
418
+ """Generate ESM embeddings for protein sequences"""
419
+ with torch.no_grad():
420
+ embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1)
421
+ embeddings = embeddings.cpu().numpy()
422
+
423
+ return embeddings
424
+
425
+ def get_scores(self, input_seqs):
426
+ scores = np.zeros(len(input_seqs))
427
+ features = self.generate_embeddings(input_seqs)
428
+
429
+ if len(features) == 0:
430
+ return scores
431
+
432
+ features = np.nan_to_num(features, nan=0.)
433
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
434
+
435
+ features = xgb.DMatrix(features)
436
+
437
+ scores = self.predictor.predict(features)
438
+ return torch.from_numpy(scores).to(self.device)
439
+
440
+ def __call__(self, input_seqs: list):
441
+ scores = self.get_scores(input_seqs)
442
+ return scores
443
+
444
+ class SolubilityModel:
445
+ def __init__(self, device):
446
+ # change model path
447
+ self.predictor = xgb.Booster(model_file='./classifier_ckpt/best_model_solubility.json')
448
+
449
+ self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
450
+ self.model.eval()
451
+
452
+ self.device = device
453
+
454
+ def generate_embeddings(self, sequences):
455
+ """Generate ESM embeddings for protein sequences"""
456
+ with torch.no_grad():
457
+ embeddings = self.model(input_ids=sequences).last_hidden_state.mean(dim=1)
458
+ embeddings = embeddings.cpu().numpy()
459
+
460
+ return embeddings
461
+
462
+ def get_scores(self, input_seqs: list):
463
+ scores = np.zeros(len(input_seqs))
464
+ features = self.generate_embeddings(input_seqs)
465
+
466
+ if len(features) == 0:
467
+ return scores
468
+
469
+ features = np.nan_to_num(features, nan=0.)
470
+ features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
471
+
472
+ features = xgb.DMatrix(features)
473
+
474
+ scores = self.predictor.predict(features)
475
+ return torch.from_numpy(scores).to(self.device)
476
+
477
+ def __call__(self, input_seqs: list):
478
+ scores = self.get_scores(input_seqs)
479
+ return scores
480
+
481
+ class PeptideCNN(nn.Module):
482
+ def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate):
483
+ super().__init__()
484
+ self.conv1 = nn.Conv1d(input_dim, hidden_dims[0], kernel_size=3, padding=1)
485
+ self.conv2 = nn.Conv1d(hidden_dims[0], hidden_dims[1], kernel_size=5, padding=1)
486
+ self.fc = nn.Linear(hidden_dims[1], output_dim)
487
+ self.dropout = nn.Dropout(dropout_rate)
488
+ self.predictor = nn.Linear(output_dim, 1) # For regression/classification
489
+
490
+ self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
491
+ self.esm_model.eval()
492
+
493
+ def forward(self, input_ids, attention_mask=None, return_features=False):
494
+ with torch.no_grad():
495
+ x = self.esm_model(input_ids, attention_mask).last_hidden_state
496
+ # x shape: (B, L, input_dim)
497
+ x = x.permute(0, 2, 1) # Reshape to (B, input_dim, L) for Conv1d
498
+ x = nn.functional.relu(self.conv1(x))
499
+ x = self.dropout(x)
500
+ x = nn.functional.relu(self.conv2(x))
501
+ x = self.dropout(x)
502
+ x = x.permute(0, 2, 1) # Reshape back to (B, L, hidden_dims[1])
503
+
504
+ # Global average pooling over the sequence dimension (L)
505
+ x = x.mean(dim=1) # Shape: (B, hidden_dims[1])
506
+
507
+ features = self.fc(x) # features shape: (B, output_dim)
508
+ if return_features:
509
+ return features
510
+ return self.predictor(features) # Output shape: (B, 1)
511
+
512
+ class HalfLifeModel:
513
+ def __init__(self, device):
514
+ input_dim = 1280
515
+ hidden_dims = [input_dim // 2, input_dim // 4]
516
+ output_dim = input_dim // 8
517
+ dropout_rate = 0.3
518
+ self.model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device)
519
+ self.model.load_state_dict(torch.load('./classifier_ckpt/best_model_half_life.pth', map_location=device, weights_only=False))
520
+ self.model.eval()
521
+
522
+ def __call__(self, x):
523
+ prediction = self.model(x, return_features=False)
524
+ half_life = torch.clamp(prediction.squeeze(-1), max=2.0, min=0.0)
525
+
526
+ return half_life / 2
527
+
528
+
529
+ def load_bindevaluator(checkpoint_path, device):
530
+ bindevaluator = BindEvaluator.load_from_checkpoint(checkpoint_path, n_layers=8, d_model=128, d_hidden=128, n_head=8, d_k=64, d_v=128, d_inner=64).to(device)
531
+ bindevaluator.eval()
532
+ for param in bindevaluator.parameters():
533
+ param.requires_grad = False
534
+
535
+ return bindevaluator
536
+
537
+
538
+
539
+ def load_pooled_affinity_predictor(checkpoint_path, device):
540
+ """Load trained model from checkpoint."""
541
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
542
+
543
+ model = ImprovedBindingPredictor().to(device)
544
+
545
+ # Load the trained weights
546
+ model.load_state_dict(checkpoint['model_state_dict'])
547
+ model.eval() # Set to evaluation mode
548
+
549
+ return model
550
+
551
+ def load_affinity_predictor(checkpoint_path, device):
552
+ """Load trained model from checkpoint."""
553
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
554
+
555
+ model = UnpooledBindingPredictor(
556
+ esm_model_name="facebook/esm2_t33_650M_UR50D",
557
+ hidden_dim=384,
558
+ kernel_sizes=[3, 5, 7],
559
+ n_heads=8,
560
+ n_layers=4,
561
+ dropout=0.14561457009902096,
562
+ freeze_esm=True
563
+ ).to(device)
564
+
565
+ model.load_state_dict(checkpoint['model_state_dict'])
566
+ model.eval()
567
+
568
+ return model
peptide/rectified_datasets/v1/dataset_dict.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"splits": ["train", "validation", "test"]}
peptide/rectified_datasets/v1/test/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:087edd095dcd714192f1d4ef341b1894bcee3fb03d5453c2b04d8ce031589318
3
+ size 19749472
peptide/rectified_datasets/v1/test/dataset_info.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids_x0": {
6
+ "feature": {
7
+ "feature": {
8
+ "dtype": "int64",
9
+ "_type": "Value"
10
+ },
11
+ "_type": "List"
12
+ },
13
+ "_type": "List"
14
+ },
15
+ "input_ids_x1": {
16
+ "feature": {
17
+ "feature": {
18
+ "dtype": "int64",
19
+ "_type": "Value"
20
+ },
21
+ "_type": "List"
22
+ },
23
+ "_type": "List"
24
+ }
25
+ },
26
+ "homepage": "",
27
+ "license": ""
28
+ }
peptide/rectified_datasets/v1/test/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "118d550fe7101754",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
peptide/rectified_datasets/v1/train/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b4af46986df1e635f0dff8e3f00c4d4c06aac26161836ca694299d1cc0bd20f
3
+ size 157859216
peptide/rectified_datasets/v1/train/dataset_info.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids_x0": {
6
+ "feature": {
7
+ "feature": {
8
+ "dtype": "int64",
9
+ "_type": "Value"
10
+ },
11
+ "_type": "List"
12
+ },
13
+ "_type": "List"
14
+ },
15
+ "input_ids_x1": {
16
+ "feature": {
17
+ "feature": {
18
+ "dtype": "int64",
19
+ "_type": "Value"
20
+ },
21
+ "_type": "List"
22
+ },
23
+ "_type": "List"
24
+ }
25
+ },
26
+ "homepage": "",
27
+ "license": ""
28
+ }
peptide/rectified_datasets/v1/train/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "a5ddb0c42fb68c3f",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
peptide/rectified_datasets/v1/validation/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:409939a66eebfac56a884440065ec2e6cd1b81632fede0d4e2156cd56600a2b8
3
+ size 19725216
peptide/rectified_datasets/v1/validation/dataset_info.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids_x0": {
6
+ "feature": {
7
+ "feature": {
8
+ "dtype": "int64",
9
+ "_type": "Value"
10
+ },
11
+ "_type": "List"
12
+ },
13
+ "_type": "List"
14
+ },
15
+ "input_ids_x1": {
16
+ "feature": {
17
+ "feature": {
18
+ "dtype": "int64",
19
+ "_type": "Value"
20
+ },
21
+ "_type": "List"
22
+ },
23
+ "_type": "List"
24
+ }
25
+ },
26
+ "homepage": "",
27
+ "license": ""
28
+ }
peptide/rectified_datasets/v1/validation/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "3a37666e1156a9e6",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
peptide/rectified_datasets/v2/dataset_dict.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"splits": ["train", "validation", "test"]}
peptide/rectified_datasets/v2/test/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd85592cab50715eb0268f061670b7a98d3871a7aa1baefd40c6a23055f489d6
3
+ size 19749472
peptide/rectified_datasets/v2/test/dataset_info.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids_x0": {
6
+ "feature": {
7
+ "feature": {
8
+ "dtype": "int64",
9
+ "_type": "Value"
10
+ },
11
+ "_type": "List"
12
+ },
13
+ "_type": "List"
14
+ },
15
+ "input_ids_x1": {
16
+ "feature": {
17
+ "feature": {
18
+ "dtype": "int64",
19
+ "_type": "Value"
20
+ },
21
+ "_type": "List"
22
+ },
23
+ "_type": "List"
24
+ }
25
+ },
26
+ "homepage": "",
27
+ "license": ""
28
+ }
peptide/rectified_datasets/v2/test/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "03f0e67bb58fcf47",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
peptide/rectified_datasets/v2/train/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26bea1ec7f7db609a4949a6209bc5536003f7ab169f1e03d5a256db81af434bd
3
+ size 157859216
peptide/rectified_datasets/v2/train/dataset_info.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids_x0": {
6
+ "feature": {
7
+ "feature": {
8
+ "dtype": "int64",
9
+ "_type": "Value"
10
+ },
11
+ "_type": "List"
12
+ },
13
+ "_type": "List"
14
+ },
15
+ "input_ids_x1": {
16
+ "feature": {
17
+ "feature": {
18
+ "dtype": "int64",
19
+ "_type": "Value"
20
+ },
21
+ "_type": "List"
22
+ },
23
+ "_type": "List"
24
+ }
25
+ },
26
+ "homepage": "",
27
+ "license": ""
28
+ }
peptide/rectified_datasets/v2/train/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "c41975ecd76982be",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
peptide/rectified_datasets/v2/validation/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff40262e105e2c748e8e5ca8f89b5882af643f099227d65e6c4b13da8d328094
3
+ size 19725216
peptide/rectified_datasets/v2/validation/dataset_info.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids_x0": {
6
+ "feature": {
7
+ "feature": {
8
+ "dtype": "int64",
9
+ "_type": "Value"
10
+ },
11
+ "_type": "List"
12
+ },
13
+ "_type": "List"
14
+ },
15
+ "input_ids_x1": {
16
+ "feature": {
17
+ "feature": {
18
+ "dtype": "int64",
19
+ "_type": "Value"
20
+ },
21
+ "_type": "List"
22
+ },
23
+ "_type": "List"
24
+ }
25
+ },
26
+ "homepage": "",
27
+ "license": ""
28
+ }
peptide/rectified_datasets/v2/validation/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "39ddf61d20fce77a",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
peptide/rectified_datasets/v3/dataset_dict.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"splits": ["train", "validation", "test"]}
peptide/rectified_datasets/v3/test/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d6169a4b613ec0ac0c7fdcda92ec031f6eab8a2d0cf451dd744b780ea096825
3
+ size 19749472
peptide/rectified_datasets/v3/test/dataset_info.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids_x0": {
6
+ "feature": {
7
+ "feature": {
8
+ "dtype": "int64",
9
+ "_type": "Value"
10
+ },
11
+ "_type": "List"
12
+ },
13
+ "_type": "List"
14
+ },
15
+ "input_ids_x1": {
16
+ "feature": {
17
+ "feature": {
18
+ "dtype": "int64",
19
+ "_type": "Value"
20
+ },
21
+ "_type": "List"
22
+ },
23
+ "_type": "List"
24
+ }
25
+ },
26
+ "homepage": "",
27
+ "license": ""
28
+ }
peptide/rectified_datasets/v3/test/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "f6aed185a066dd98",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }
peptide/rectified_datasets/v3/train/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a7828b0d21a1d8103347e7bdc4438caf3dd9f1d1bf1158fcec31e1de0d9bcea
3
+ size 157859216
peptide/rectified_datasets/v3/train/dataset_info.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "input_ids_x0": {
6
+ "feature": {
7
+ "feature": {
8
+ "dtype": "int64",
9
+ "_type": "Value"
10
+ },
11
+ "_type": "List"
12
+ },
13
+ "_type": "List"
14
+ },
15
+ "input_ids_x1": {
16
+ "feature": {
17
+ "feature": {
18
+ "dtype": "int64",
19
+ "_type": "Value"
20
+ },
21
+ "_type": "List"
22
+ },
23
+ "_type": "List"
24
+ }
25
+ },
26
+ "homepage": "",
27
+ "license": ""
28
+ }
peptide/rectified_datasets/v3/train/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "448b61e862d72291",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }