yinuozhang commited on
Commit
c1bbdd6
·
1 Parent(s): 9c11751
app.py ADDED
@@ -0,0 +1,1398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import xgboost as xgb
8
+ from transformers import AutoTokenizer, AutoModel, AutoConfig, EsmModel, EsmTokenizer
9
+ import plotly.graph_objects as go
10
+ from pathlib import Path
11
+ import json
12
+ import time
13
+ from typing import List, Dict, Any, Tuple, Optional
14
+
15
+ # Try to import RDKit for SMILES support
16
+ try:
17
+ from rdkit import Chem
18
+ from rdkit.Chem import Descriptors, AllChem
19
+ RDKIT_AVAILABLE = True
20
+ except ImportError:
21
+ RDKIT_AVAILABLE = False
22
+ print("RDKit not available. SMILES input will be disabled.")
23
+ import re
24
+
25
+ AA_RE = re.compile(r'^[ACDEFGHIKLMNPQRSTVWYBXZJUO\-]+$', re.IGNORECASE)
26
+
27
+ def is_aa_sequence_like(s: str) -> bool:
28
+ s = s.strip().replace(" ", "")
29
+ if not s:
30
+ return False
31
+ # Very lenient: allow AA letters + optional '-' for readability
32
+ return bool(AA_RE.fullmatch(s)) and any(c.isalpha() for c in s)
33
+
34
+ def is_smiles_like(s: str) -> bool:
35
+ s = s.strip()
36
+ if not s:
37
+ return False
38
+ # Heuristic: SMILES often contains these symbols; also reject if it looks like pure AA
39
+ maybe_smiles_chars = set("=#()[]+\\/-@1234567890")
40
+ return (any(ch in maybe_smiles_chars for ch in s) or not is_aa_sequence_like(s)) and len(s) >= 2
41
+
42
+ # ==================== Model Classes ====================
43
+
44
+ # --- add this utility somewhere above UnifiedPeptidePredictor ---
45
+ def load_cnn_weights_safely(model: nn.Module, ckpt_path: Path, device: torch.device):
46
+ """
47
+ Load a CNN checkpoint that might include old ESM weights, DDP prefixes, or different wrappers.
48
+ Strips unknown prefixes and ignores non-matching keys gracefully.
49
+ """
50
+ ckpt = torch.load(ckpt_path, map_location=device)
51
+
52
+ # 1) Extract a state dict from various formats
53
+ if isinstance(ckpt, dict) and any(k in ckpt for k in ["state_dict", "model_state_dict", "weights"]):
54
+ sd = ckpt.get("state_dict") or ckpt.get("model_state_dict") or ckpt.get("weights")
55
+ elif isinstance(ckpt, dict):
56
+ # Probably already a state_dict
57
+ sd = ckpt
58
+ else:
59
+ # Possibly a full pickled model; try to read its state_dict
60
+ try:
61
+ sd = ckpt.state_dict()
62
+ except Exception as e:
63
+ raise RuntimeError(f"Unsupported checkpoint format at {ckpt_path}: {type(ckpt)}") from e
64
+
65
+ # 2) Normalize keys: strip DDP 'module.' and drop old ESM-containing parameters
66
+ cleaned = {}
67
+ for k, v in sd.items():
68
+ k2 = k
69
+ if k2.startswith("module."):
70
+ k2 = k2[len("module."):]
71
+ # drop anything from the embedded ESM or other now-missing submodules
72
+ if k2.startswith("esm_model.") or k2.startswith("esm.") or k2.startswith("backbone.esm."):
73
+ continue
74
+ cleaned[k2] = v
75
+
76
+ # 3) Load non-strictly so extra/missing heads don't crash
77
+ missing, unexpected = model.load_state_dict(cleaned, strict=False)
78
+
79
+ # Optional: log what happened so you can verify
80
+ if unexpected:
81
+ print(f"[load_cnn_weights_safely] Unexpected keys ignored: {sorted(unexpected)[:6]}{'...' if len(unexpected)>6 else ''}")
82
+ if missing:
83
+ print(f"[load_cnn_weights_safely] Missing keys not found in checkpoint: {sorted(missing)[:6]}{'...' if len(missing)>6 else ''}")
84
+
85
+
86
+ # ====== PeptideCLM SMILES featurizer ======
87
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
88
+ from transformers import AutoModelForMaskedLM
89
+
90
+ class PeptideCLMFeaturizer:
91
+ """
92
+ Mean-pool hidden states from PeptideCLM-23M-all for SMILES tokens produced by SMILES_SPE_Tokenizer.
93
+ Use the SAME tokenizer files, max_length, and pooling you used in training your XGB models.
94
+ """
95
+ def __init__(self, vocab_path: str, splits_path: str, device: torch.device, max_length: int = 256):
96
+ self.device = device
97
+ self.max_length = max_length
98
+ self.tok = SMILES_SPE_Tokenizer(vocab_path, splits_path)
99
+ self.model = AutoModelForMaskedLM.from_pretrained("aaronfeller/PeptideCLM-23M-all").roformer.to(device).eval()
100
+
101
+ @torch.no_grad()
102
+ def embed_list(self, smiles_list: list[str]) -> np.ndarray:
103
+ feats = []
104
+ for s in smiles_list:
105
+ toks = self.tok(s, return_tensors="pt", truncation=True, padding=True)
106
+ toks = {k: v.to(self.device) for k, v in toks.items()}
107
+ out = self.model(**toks).last_hidden_state # [1, L, H]
108
+ mask = toks["attention_mask"].unsqueeze(-1) # [1, L, 1]
109
+ pooled = (out * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
110
+ feats.append(pooled.squeeze(0).float().cpu().numpy())
111
+ return np.stack(feats, axis=0) # [N, H]
112
+
113
+
114
+ class UnpooledBindingPredictor(nn.Module):
115
+ """Binding affinity predictor with cross-attention mechanism"""
116
+ def __init__(self,
117
+ esm_model_name="facebook/esm2_t33_650M_UR50D",
118
+ hidden_dim=512,
119
+ kernel_sizes=[3, 5, 7],
120
+ n_heads=8,
121
+ n_layers=3,
122
+ dropout=0.1,
123
+ freeze_esm=True):
124
+ super().__init__()
125
+
126
+ # Use these everywhere for consistency
127
+ self.tight_threshold = 7.5
128
+ self.weak_threshold = 6.0
129
+
130
+ self.esm_model = AutoModel.from_pretrained(esm_model_name)
131
+ self.config = AutoConfig.from_pretrained(esm_model_name)
132
+ if freeze_esm:
133
+ for p in self.esm_model.parameters():
134
+ p.requires_grad = False
135
+
136
+ esm_dim = self.config.hidden_size
137
+ out_ch = 64
138
+ self.protein_conv_layers = nn.ModuleList([
139
+ nn.Conv1d(esm_dim, out_ch, k, padding='same') for k in kernel_sizes
140
+ ])
141
+ self.binder_conv_layers = nn.ModuleList([
142
+ nn.Conv1d(esm_dim, out_ch, k, padding='same') for k in kernel_sizes
143
+ ])
144
+ total = out_ch * len(kernel_sizes) * 2
145
+
146
+ self.protein_projection = nn.Linear(total, hidden_dim)
147
+ self.binder_projection = nn.Linear(total, hidden_dim)
148
+ self.protein_norm = nn.LayerNorm(hidden_dim)
149
+ self.binder_norm = nn.LayerNorm(hidden_dim)
150
+
151
+ self.cross_attention_layers = nn.ModuleList([
152
+ nn.ModuleDict({
153
+ 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
154
+ 'norm1': nn.LayerNorm(hidden_dim),
155
+ 'ffn': nn.Sequential(
156
+ nn.Linear(hidden_dim, hidden_dim * 4),
157
+ nn.ReLU(),
158
+ nn.Dropout(dropout),
159
+ nn.Linear(hidden_dim * 4, hidden_dim),
160
+ ),
161
+ 'norm2': nn.LayerNorm(hidden_dim),
162
+ }) for _ in range(n_layers)
163
+ ])
164
+
165
+ self.shared_head = nn.Sequential(
166
+ nn.Linear(hidden_dim * 2, hidden_dim),
167
+ nn.ReLU(),
168
+ nn.Dropout(dropout),
169
+ )
170
+ self.regression_head = nn.Linear(hidden_dim, 1)
171
+ self.classification_head = nn.Linear(hidden_dim, 3)
172
+
173
+ def get_binding_class(self, affinity: torch.Tensor | float) -> torch.LongTensor | int:
174
+ """
175
+ 0: tight (>= tight_threshold)
176
+ 1: medium [weak_threshold, tight_threshold)
177
+ 2: weak (< weak_threshold)
178
+ """
179
+ if isinstance(affinity, torch.Tensor):
180
+ tight = affinity >= self.tight_threshold
181
+ weak = affinity < self.weak_threshold
182
+ medium = ~(tight | weak)
183
+ classes = torch.zeros_like(affinity, dtype=torch.long)
184
+ classes[medium] = 1
185
+ classes[weak] = 2
186
+ return classes
187
+ else:
188
+ if affinity >= self.tight_threshold:
189
+ return 0
190
+ elif affinity < self.weak_threshold:
191
+ return 2
192
+ else:
193
+ return 1
194
+
195
+ def compute_embeddings(self, input_ids, attention_mask=None):
196
+ out = self.esm_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
197
+ return out.last_hidden_state
198
+
199
+ def process_sequence(self, unpooled_emb, conv_layers, attention_mask=None):
200
+ x = unpooled_emb.transpose(1, 2) # [B, C_in=E, L]
201
+ conv_outputs = [F.relu(conv(x)) for conv in conv_layers] # list of [B, C_out, L]
202
+ conv_output = torch.cat(conv_outputs, dim=1) # [B, sumC, L]
203
+ if attention_mask is not None:
204
+ mask = attention_mask.unsqueeze(1).expand(-1, conv_output.size(1), -1)
205
+ masked = conv_output.masked_fill(mask == 0, float('-inf'))
206
+ max_pooled = masked.max(dim=2)[0]
207
+ sum_pooled = (conv_output * mask).sum(dim=2)
208
+ denom = mask.sum(dim=2).clamp(min=1.0)
209
+ avg_pooled = sum_pooled / denom
210
+ else:
211
+ max_pooled = conv_output.max(dim=2)[0]
212
+ avg_pooled = conv_output.mean(dim=2)
213
+ return torch.cat([max_pooled, avg_pooled], dim=1) # [B, 2*sumC]
214
+
215
+ def forward(self, protein_input_ids, binder_input_ids, protein_mask=None, binder_mask=None):
216
+ protein_unpooled = self.compute_embeddings(protein_input_ids, protein_mask)
217
+ binder_unpooled = self.compute_embeddings(binder_input_ids, binder_mask)
218
+ protein_features = self.process_sequence(protein_unpooled, self.protein_conv_layers, protein_mask)
219
+ binder_features = self.process_sequence(binder_unpooled, self.binder_conv_layers, binder_mask)
220
+ protein = self.protein_norm(self.protein_projection(protein_features))
221
+ binder = self.binder_norm(self.binder_projection(binder_features))
222
+
223
+ # make them "sequence length 1" for MHA (L,B,D)
224
+ protein = protein.unsqueeze(0).transpose(0,1)
225
+ binder = binder.unsqueeze(0).transpose(0,1)
226
+ for layer in self.cross_attention_layers:
227
+ p_attn = layer['attention'](protein, binder, binder)[0]
228
+ protein = layer['norm1'](protein + p_attn)
229
+ protein = layer['norm2'](protein + layer['ffn'](protein))
230
+ b_attn = layer['attention'](binder, protein, protein)[0]
231
+ binder = layer['norm1'](binder + b_attn)
232
+ binder = layer['norm2'](binder + layer['ffn'](binder))
233
+
234
+ protein_pool = protein.mean(dim=0).squeeze(0)
235
+ binder_pool = binder.mean(dim=0).squeeze(0)
236
+ shared = self.shared_head(torch.cat([protein_pool, binder_pool], dim=-1))
237
+ reg = self.regression_head(shared) # [1]
238
+ logits= self.classification_head(shared) # [3]
239
+ return reg, logits
240
+
241
+
242
+ # ------- SMILES + Protein binding model (reg + 3-class) -------
243
+ class ImprovedBindingPredictor(nn.Module):
244
+ def __init__(self, esm_dim=1280, smiles_dim=768, hidden_dim=512, n_heads=8, n_layers=3, dropout=0.1):
245
+ super().__init__()
246
+ self.tight_threshold = 7.5
247
+ self.weak_threshold = 6.0
248
+
249
+ self.smiles_projection = nn.Linear(smiles_dim, hidden_dim)
250
+ self.protein_projection = nn.Linear(esm_dim, hidden_dim)
251
+ self.protein_norm = nn.LayerNorm(hidden_dim)
252
+ self.smiles_norm = nn.LayerNorm(hidden_dim)
253
+
254
+ self.cross_attention_layers = nn.ModuleList([
255
+ nn.ModuleDict({
256
+ 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout),
257
+ 'norm1': nn.LayerNorm(hidden_dim),
258
+ 'ffn': nn.Sequential(
259
+ nn.Linear(hidden_dim, hidden_dim * 4),
260
+ nn.ReLU(),
261
+ nn.Dropout(dropout),
262
+ nn.Linear(hidden_dim * 4, hidden_dim),
263
+ ),
264
+ 'norm2': nn.LayerNorm(hidden_dim),
265
+ }) for _ in range(n_layers)
266
+ ])
267
+
268
+ self.shared_head = nn.Sequential(
269
+ nn.Linear(hidden_dim * 2, hidden_dim),
270
+ nn.ReLU(),
271
+ nn.Dropout(dropout),
272
+ )
273
+ self.regression_head = nn.Linear(hidden_dim, 1)
274
+ self.classification_head = nn.Linear(hidden_dim, 3)
275
+
276
+ def get_binding_class(self, affinity):
277
+ """Convert affinity values to class indices
278
+ 0: tight binding (>= 7.5)
279
+ 1: medium binding (6.0-7.5)
280
+ 2: weak binding (< 6.0)
281
+ """
282
+ if isinstance(affinity, torch.Tensor):
283
+ tight_mask = affinity >= self.tight_threshold
284
+ weak_mask = affinity < self.weak_threshold
285
+ medium_mask = ~(tight_mask | weak_mask)
286
+
287
+ classes = torch.zeros_like(affinity, dtype=torch.long)
288
+ classes[medium_mask] = 1
289
+ classes[weak_mask] = 2
290
+ return classes
291
+ else:
292
+ if affinity >= self.tight_threshold:
293
+ return 0 # tight binding
294
+ elif affinity < self.weak_threshold:
295
+ return 2 # weak binding
296
+ else:
297
+ return 1 # medium binding
298
+
299
+ def forward(self, protein_emb: torch.Tensor, smiles_emb: torch.Tensor):
300
+ # protein_emb: [1, E], smiles_emb: [1, H]
301
+ protein = self.protein_norm(self.protein_projection(protein_emb)) # [1, D]
302
+ smiles = self.smiles_norm(self.smiles_projection(smiles_emb)) # [1, D]
303
+
304
+ # Treat as "sequence length"=1 tokens; mha still works (QKV dims match)
305
+ protein = protein.unsqueeze(0) # [1, 1, D] -> (L, B, D) expected, we’ll keep batch in 2nd dim:
306
+ smiles = smiles.unsqueeze(0) # [1, 1, D]
307
+ protein = protein.transpose(0, 1) # [B=1, L=1, D] -> MHA wants [L, B, D]
308
+ smiles = smiles.transpose(0, 1)
309
+
310
+ for layer in self.cross_attention_layers:
311
+ attn_p = layer['attention'](protein, smiles, smiles)[0]
312
+ protein = layer['norm1'](protein + attn_p)
313
+ protein = layer['norm2'](protein + layer['ffn'](protein))
314
+
315
+ attn_s = layer['attention'](smiles, protein, protein)[0]
316
+ smiles = layer['norm1'](smiles + attn_s)
317
+ smiles = layer['norm2'](smiles + layer['ffn'](smiles))
318
+
319
+ # pool over L (it's 1, so mean==squeeze)
320
+ protein_pool = protein.mean(dim=0).squeeze(0) # [D]
321
+ smiles_pool = smiles.mean(dim=0).squeeze(0) # [D]
322
+
323
+ combined = torch.cat([protein_pool, smiles_pool], dim=-1) # [2D]
324
+ shared = self.shared_head(combined)
325
+ reg = self.regression_head(shared) # scalar pKd/pKi
326
+ logits = self.classification_head(shared) # 3-class
327
+ return reg, logits
328
+
329
+
330
+ class PeptideCNN(nn.Module):
331
+ """CNN model for single peptide property prediction"""
332
+ def __init__(self, input_dim=1280, hidden_dims=None, output_dim=160, dropout_rate=0.3):
333
+ super().__init__()
334
+ if hidden_dims is None:
335
+ hidden_dims = [input_dim // 2, input_dim // 4]
336
+
337
+ self.conv1 = nn.Conv1d(input_dim, hidden_dims[0], kernel_size=3, padding=1)
338
+ self.conv2 = nn.Conv1d(hidden_dims[0], hidden_dims[1], kernel_size=5, padding=1)
339
+ self.fc = nn.Linear(hidden_dims[1], output_dim)
340
+ self.dropout = nn.Dropout(dropout_rate)
341
+ self.predictor = nn.Linear(output_dim, 1)
342
+
343
+ def forward(self, esm_embeddings, return_features=False):
344
+ x = esm_embeddings.permute(0, 2, 1)
345
+ x = F.relu(self.conv1(x))
346
+ x = self.dropout(x)
347
+ x = F.relu(self.conv2(x))
348
+ x = self.dropout(x)
349
+ x = x.permute(0, 2, 1)
350
+ x = x.mean(dim=1)
351
+ features = self.fc(x)
352
+ if return_features:
353
+ return features
354
+ return self.predictor(features)
355
+
356
+
357
+ # ==================== Data Management ====================
358
+
359
+ class TrainingDataManager:
360
+ """Manage training data statistics and distributions"""
361
+ def __init__(self, data_dir="training_data"):
362
+ self.data_dir = Path(__file__).resolve().parent / data_dir
363
+ self.data_dir.mkdir(exist_ok=True)
364
+ self.statistics = self.load_statistics()
365
+
366
+ def _load_half_life_csv(self):
367
+ csv_path = self.data_dir / "half_life_smiles.csv"
368
+ if not csv_path.exists():
369
+ return None
370
+ try:
371
+ df = pd.read_csv(csv_path)
372
+ if "log_hour" in df.columns:
373
+ vals = pd.to_numeric(df["log_hour"], errors="coerce").dropna().to_numpy()
374
+ else:
375
+ if "half_life_hours" not in df.columns:
376
+ if "half_life" in df.columns:
377
+ df["half_life_hours"] = pd.to_numeric(df["half_life"], errors="coerce") / 3600.0
378
+ else:
379
+ raise ValueError("CSV must contain 'log_hour' or 'half_life_hours' (or 'half_life').")
380
+ hh = pd.to_numeric(df["half_life_hours"], errors="coerce")
381
+ vals = np.log10(hh.replace(0, np.nan)).dropna().to_numpy()
382
+ if len(vals) == 0:
383
+ return None
384
+ return {
385
+ "values": vals,
386
+ "unit": "log10(hours)",
387
+ "threshold": float(np.median(vals)), # median on log scale
388
+ "kind": "continuous",
389
+ }
390
+ except Exception as e:
391
+ print(f"[TrainingDataManager] half-life load error: {e}")
392
+ return None
393
+
394
+ def _load_binary_pair(self, prefix: str):
395
+ """
396
+ Load binary labels from <prefix>-positive.npz and <prefix>-negative.npz
397
+ Returns: {'values': y, 'unit': 'Class (0=neg, 1=pos)', 'kind': 'binary', 'n_pos': int, 'n_neg': int}
398
+ or None if missing.
399
+ """
400
+ pos_path = self.data_dir / f"{prefix}-positive.npz"
401
+ neg_path = self.data_dir / f"{prefix}-negative.npz"
402
+ if not pos_path.exists() or not neg_path.exists():
403
+ return None
404
+ try:
405
+ with np.load(pos_path) as pos:
406
+ pos_data = pos["arr_0"]
407
+ with np.load(neg_path) as neg:
408
+ neg_data = neg["arr_0"]
409
+ y = np.concatenate(
410
+ [np.ones(len(pos_data), dtype=int), np.zeros(len(neg_data), dtype=int)],
411
+ axis=0
412
+ )
413
+ return {
414
+ "values": y,
415
+ "unit": "Class (0=neg, 1=pos)",
416
+ "kind": "binary",
417
+ "n_pos": int(len(pos_data)),
418
+ "n_neg": int(len(neg_data)),
419
+ }
420
+ except Exception as e:
421
+ print(f"[TrainingDataManager] binary load error for '{prefix}': {e}")
422
+ return None
423
+
424
+ def load_statistics(self):
425
+ """Load pre-computed statistics for each property"""
426
+ stats = {
427
+ 'hemolysis': {
428
+ 'values': np.random.beta(2, 5, 1000),
429
+ 'description': 'Probability of peptide disrupting red blood cell membranes.',
430
+ 'unit': 'Probability',
431
+ 'threshold': 0.5,
432
+ 'download_link': '#'
433
+ },
434
+ 'solubility': {
435
+ 'values': np.random.normal(5, 2, 1000),
436
+ 'description': 'Probability of peptide remaining dissolved in aqueous conditions.',
437
+ 'unit': 'Probability',
438
+ 'threshold': 0.5,
439
+ 'download_link': '#'
440
+ },
441
+ 'binding_affinity': {
442
+ 'values': np.random.normal(7, 1.5, 1000),
443
+ 'description': 'Protein-peptide binding affinity',
444
+ 'unit': 'Probability',
445
+ 'threshold': 7.5,
446
+ 'download_link': '#'
447
+ },
448
+ 'half_life (smiles)': {
449
+ # will be overwritten below if CSV exists
450
+ 'values': np.random.lognormal(2, 1, 1000),
451
+ 'description': 'Serum half-life from clinical and preclinical studies',
452
+ 'unit': 'Hours',
453
+ 'threshold': 2.0, # hours (default fallback)
454
+ 'download_link': '#'
455
+ },
456
+ 'nonfouling': {
457
+ 'values': np.random.lognormal(4, 1, 1000),
458
+ 'description': 'A nonfouling peptide resists nonspecific interactions and protein adsorption.',
459
+ 'unit': 'Probability',
460
+ 'threshold': 0.5,
461
+ 'download_link': '#'
462
+ },
463
+ 'permeability': {
464
+ 'values': np.random.normal(-4, 1, 1000),
465
+ 'description': 'Cell membrane permeability measurements',
466
+ 'unit': 'Probability of peptide penetrating the cell membrane.',
467
+ 'threshold': 0.5,
468
+ 'download_link': '#'
469
+ }
470
+ }
471
+
472
+ # Overlay real half-life
473
+ hl = self._load_half_life_csv()
474
+ if hl is not None:
475
+ stats["half_life"].update(hl)
476
+
477
+ # Overlay real solubility from sol-* (binary)
478
+ sol = self._load_binary_pair("sol")
479
+ if sol is not None:
480
+ stats["solubility"].update(sol)
481
+
482
+ # Overlay real non-fouling from nf-* (binary)
483
+ nf = self._load_binary_pair("nf")
484
+ if nf is not None:
485
+ stats["nonfouling"].update(nf)
486
+
487
+ hemo = self._load_binary_pair("hemo")
488
+ if hemo is not None:
489
+ stats["hemolysis"].update(hemo)
490
+
491
+ return stats
492
+
493
+
494
+ def get_distribution_plot(self, property_name, current_value=None):
495
+ if property_name not in self.statistics:
496
+ return None
497
+ s = self.statistics[property_name]
498
+ vals = np.asarray(s["values"])
499
+ kind = s.get("kind", "continuous")
500
+
501
+ if kind == "binary":
502
+ n0 = int((vals == 0).sum())
503
+ n1 = int((vals == 1).sum())
504
+ total = max(n0 + n1, 1)
505
+ fig = go.Figure()
506
+ fig.add_trace(go.Bar(x=["Negative (0)", "Positive (1)"], y=[n0, n1]))
507
+ fig.update_layout(
508
+ title=f"{property_name.replace('_',' ').title()} — Class Balance",
509
+ xaxis_title="Class",
510
+ yaxis_title="Count",
511
+ height=400,
512
+ showlegend=False,
513
+ annotations=[
514
+ dict(x="Negative (0)", y=n0, text=f"{n0} ({n0/total:.1%})", showarrow=False, yshift=8),
515
+ dict(x="Positive (1)", y=n1, text=f"{n1} ({n1/total:.1%})", showarrow=False, yshift=8),
516
+ ],
517
+ )
518
+ return fig
519
+
520
+ # continuous
521
+ fig = go.Figure()
522
+ fig.add_trace(go.Histogram(x=vals, nbinsx=50, name="Training Data"))
523
+ if "threshold" in s and s["threshold"] is not None:
524
+ fig.add_vline(
525
+ x=s["threshold"], line_dash="dash", line_color="red",
526
+ annotation_text=f"Threshold: {s['threshold']:.3f}"
527
+ )
528
+ if current_value is not None:
529
+ fig.add_vline(
530
+ x=current_value, line_dash="solid", line_color="green", line_width=3,
531
+ annotation_text=f"Your Result: {current_value:.3f}"
532
+ )
533
+ fig.update_layout(
534
+ title=f"{property_name.replace('_', ' ').title()} Distribution",
535
+ xaxis_title=s.get("unit", ""),
536
+ yaxis_title="Count",
537
+ height=400,
538
+ showlegend=False,
539
+ )
540
+ return fig
541
+
542
+ def get_property_info(self, property_name):
543
+ if property_name not in self.statistics:
544
+ return None
545
+ s = self.statistics[property_name]
546
+ vals = np.asarray(s["values"])
547
+ kind = s.get("kind", "continuous")
548
+
549
+ info = {
550
+ "description": s.get("description", ""),
551
+ "unit": s.get("unit", ""),
552
+ "n_samples": int(len(vals)),
553
+ "mean": float(np.mean(vals)),
554
+ "std": float(np.std(vals)),
555
+ "min": float(np.min(vals)),
556
+ "max": float(np.max(vals)),
557
+ "percentiles": {},
558
+ }
559
+
560
+ if kind == "binary":
561
+ info["n_neg"] = int((vals == 0).sum())
562
+ info["n_pos"] = int((vals == 1).sum())
563
+ else:
564
+ pct = np.percentile(vals, [10, 25, 50, 75, 90])
565
+ info["percentiles"] = {
566
+ "10%": float(pct[0]),
567
+ "25%": float(pct[1]),
568
+ "50% (median)": float(pct[2]),
569
+ "75%": float(pct[3]),
570
+ "90%": float(pct[4]),
571
+ }
572
+ return info
573
+
574
+
575
+
576
+ def _base_stat_key(model_key: str) -> str:
577
+ # strip modality suffixes to find stats in TrainingDataManager
578
+ for suf in ("_seq", "_smiles"):
579
+ if model_key.endswith(suf):
580
+ return model_key[:-len(suf)]
581
+ return model_key
582
+ # ==================== Unified Predictor ====================
583
+
584
+ class UnifiedPeptidePredictor:
585
+ """Main predictor handling all model types"""
586
+
587
+ def __init__(self, model_dir="models"):
588
+ self.model_dir = Path(model_dir)
589
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
590
+
591
+ # Initialize tokenizer and ESM model
592
+ print("Loading ESM model...")
593
+ self.tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
594
+ self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
595
+ self.esm_model.to(self.device)
596
+ self.esm_model.eval()
597
+
598
+ self.tokenizer_dir = Path("tokenizer")
599
+ self.smiles_featurizer = PeptideCLMFeaturizer(
600
+ vocab_path=f"{self.tokenizer_dir}/new_vocab.txt",
601
+ splits_path=f"{self.tokenizer_dir}/new_splits.txt",
602
+ device=self.device,
603
+ )
604
+ # Model registry
605
+ self.models = {}
606
+ self.model_configs = self.get_model_configs()
607
+
608
+ # Data manager
609
+ self.data_manager = TrainingDataManager()
610
+ self._protein_cache = {}
611
+ # Load models
612
+ self.load_all_models()
613
+
614
+ def get_model_configs(self):
615
+ """Define model configurations"""
616
+ return {
617
+ 'hemolysis_seq': {
618
+ 'type': 'xgboost',
619
+ 'input': 'sequence',
620
+ 'path': 'best_model_hemolysis.json',
621
+ 'inverse_score': False,
622
+ 'unit': 'Probability',
623
+ 'display_name': '🩸 Hemolysis',
624
+ 'positive_label': 'Non-hemolytic',
625
+ 'negative_label': 'Hemolytic'
626
+ },
627
+ 'hemolysis_smiles': {
628
+ 'type': 'xgboost',
629
+ 'input': 'smiles',
630
+ 'path': 'hemolysis-xgboost_smiles.json',
631
+ 'inverse_score': False,
632
+ 'unit': 'Probability',
633
+ 'display_name': '🩸 Hemolysis',
634
+ 'positive_label': 'Non-hemolytic',
635
+ 'negative_label': 'Hemolytic'
636
+ },
637
+ 'solubility_seq': {
638
+ 'type': 'xgboost',
639
+ 'input': 'sequence',
640
+ 'path': 'best_model_solubility.json',
641
+ 'unit': 'Probability',
642
+ 'display_name': '💧 Solubility',
643
+ 'positive_label': 'Soluble',
644
+ 'negative_label': 'Insoluble'
645
+ },
646
+ 'solubility_smiles': {
647
+ 'type': 'xgboost',
648
+ 'input': 'smiles',
649
+ 'path': 'solubility-xgboost_smiles.json',
650
+ 'unit': 'Probability',
651
+ 'display_name': '💧 Solubility',
652
+ 'positive_label': 'Soluble',
653
+ 'negative_label': 'Insoluble'
654
+ },
655
+ 'permeability_smiles': {
656
+ 'type': 'xgboost',
657
+ 'input': 'smiles',
658
+ 'path': 'permeability-xgboost_smiles.json',
659
+ 'unit': 'Probability',
660
+ 'display_name': '🪣 Permeability',
661
+ 'positive_label': 'Permeable',
662
+ 'negative_label': 'Impermeable'
663
+ },
664
+ 'half_life_seq': {
665
+ 'type': 'pytorch_cnn',
666
+ 'input': 'sequence',
667
+ 'path': 'best_model_half_life.pth',
668
+ 'transform': lambda x: 10**x,
669
+ 'unit': 'hours',
670
+ 'display_name': '⏱️ Half-life',
671
+ 'positive_label': 'Stable',
672
+ 'negative_label': 'Unstable'
673
+ },
674
+ 'nonfouling_seq': {
675
+ 'type': 'xgboost',
676
+ 'input': 'sequence',
677
+ 'path': 'best_model_nonfouling.json',
678
+ 'unit': 'Probability',
679
+ 'display_name': '👯 Non-Fouling',
680
+ 'positive_label': 'Non-toxic',
681
+ 'negative_label': 'Toxic'
682
+ },
683
+ 'nonfouling_smiles': {
684
+ 'type': 'xgboost',
685
+ 'input': 'smiles',
686
+ 'path': 'nonfouling-xgboost_smiles.json',
687
+ 'unit': 'Probability',
688
+ 'display_name': '👯 Non-Fouling',
689
+ 'positive_label': 'Stable',
690
+ 'negative_label': 'Unstable'
691
+ },
692
+ 'binding_affinity': {
693
+ 'type': 'binding',
694
+ 'input': 'dual_sequence',
695
+ 'path': 'binding_affinity_unpooled.pt',
696
+ 'unit': 'Probability',
697
+ 'display_name': '🔗 Binding Affinity'
698
+ },
699
+ 'binding_affinity_smiles': {
700
+ 'type': 'binding_smiles',
701
+ 'input': 'sequence+smiles',
702
+ 'path': 'binding-affinity_smiles.pt',
703
+ 'unit': 'Probability',
704
+ 'display_name': '🔗 Binding Affinity (SMILES)'
705
+ },
706
+ }
707
+
708
+ def load_all_models(self):
709
+ """Load all available models"""
710
+ for name, config in self.model_configs.items():
711
+ model_path = self.model_dir / config['path']
712
+
713
+ if not model_path.exists():
714
+ print(f"Warning: Model {name} not found at {model_path}")
715
+ continue
716
+
717
+ try:
718
+ if config['type'] == 'xgboost':
719
+ self.models[name] = xgb.Booster(model_file=str(model_path))
720
+
721
+ elif config['type'] == 'pytorch_cnn':
722
+ model = PeptideCNN().to(self.device)
723
+ ckpt_path = model_path # Path from config
724
+ load_cnn_weights_safely(model, ckpt_path, self.device)
725
+ model.eval()
726
+ self.models[name] = model
727
+
728
+ elif config['type'] == 'binding':
729
+ checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
730
+ model = UnpooledBindingPredictor(
731
+ hidden_dim=384,
732
+ kernel_sizes=[3, 5, 7],
733
+ n_heads=8,
734
+ n_layers=4,
735
+ dropout=0.14561457009902096,
736
+ freeze_esm=True
737
+ ).to(self.device)
738
+ model.load_state_dict(checkpoint['model_state_dict'])
739
+ model.eval()
740
+ self.models[name] = model
741
+ elif config['type'] == 'binding_smiles':
742
+ ckpt = torch.load(model_path, map_location=self.device, weights_only=False)
743
+ model = ImprovedBindingPredictor(
744
+ esm_dim=1280, smiles_dim=768, hidden_dim=512, n_heads=8, n_layers=3, dropout=0.1
745
+ ).to(self.device)
746
+ model.load_state_dict(ckpt['model_state_dict'])
747
+ model.eval()
748
+ self.models[name] = model
749
+
750
+ print(f"✓ Loaded {name}")
751
+
752
+ except Exception as e:
753
+ print(f"Error loading {name}: {e}")
754
+
755
+ def _protein_embed_mean(self, protein_seq: str) -> torch.Tensor:
756
+ """Mean-pool ESM last_hidden_state -> [1, 1280]"""
757
+ toks = self.tokenizer(protein_seq, return_tensors="pt", padding=True, truncation=True, max_length=1024)
758
+ toks = {k: v.to(self.device) for k, v in toks.items()}
759
+ with torch.no_grad():
760
+ out = self.esm_model(**toks).last_hidden_state # [1, L, E]
761
+ mask = toks['attention_mask'].unsqueeze(-1) # [1, L, 1]
762
+ pooled = (out * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) # [1, E]
763
+ return pooled
764
+
765
+ def _get_protein_vec(self, protein_seq: str) -> torch.Tensor:
766
+ key = protein_seq.strip()
767
+ if key in self._protein_cache:
768
+ return self._protein_cache[key]
769
+ vec = self._protein_embed_mean(key)
770
+ self._protein_cache[key] = vec
771
+ return vec
772
+
773
+ def _smiles_embed_mean(self, smiles: str) -> torch.Tensor:
774
+ vec = self.smiles_featurizer.embed_list([smiles])[0] # np [H]
775
+ return torch.from_numpy(vec).to(self.device).unsqueeze(0) # [1, H]
776
+
777
+ def predict_property(self, model, config, value: str, input_type: str):
778
+ """
779
+ value: either AA sequence (Sequence mode) or SMILES (SMILES mode)
780
+ """
781
+ if config['type'] == 'xgboost':
782
+ if input_type == 'SMILES':
783
+ if config.get('input') != 'smiles':
784
+ raise RuntimeError(f"Model {config['display_name']} expects sequence, not SMILES.")
785
+ feats = self._features_from_smiles_peptclm(value)[None, ...] # [1, D]
786
+ else:
787
+ if config.get('input') == 'smiles':
788
+ raise RuntimeError(f"Model {config['display_name']} expects SMILES, not sequence.")
789
+ # ESM mean-pooled features
790
+ toks = self.tokenizer(value, return_tensors="pt", padding=True, truncation=True, max_length=512)
791
+ toks = {k: v.to(self.device) for k, v in toks.items()}
792
+ with torch.no_grad():
793
+ out = self.esm_model(**toks).last_hidden_state
794
+ mask = toks["attention_mask"].unsqueeze(-1)
795
+ pooled = (out * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
796
+ feats = pooled.float().cpu().numpy() # [1, 1280]
797
+ # Optional safety check
798
+ expected = model.num_features()
799
+ if feats.shape[1] != expected:
800
+ raise RuntimeError(f"Feature dim mismatch: got {feats.shape[1]}, booster expects {expected}")
801
+ dmat = xgb.DMatrix(feats)
802
+ pred = model.predict(dmat)[0]
803
+ if config.get('inverse_score', False):
804
+ pred = 1 - pred
805
+ return float(pred)
806
+
807
+ elif config['type'] == 'pytorch_cnn':
808
+ if input_type == 'SMILES':
809
+ raise RuntimeError(f"{config['display_name']} (CNN) expects AA sequence, not SMILES.")
810
+ toks = self.tokenizer(value, return_tensors="pt", padding=True, truncation=True, max_length=512)
811
+ toks = {k: v.to(self.device) for k, v in toks.items()}
812
+ with torch.no_grad():
813
+ out = self.esm_model(**toks).last_hidden_state
814
+ y = model(out).squeeze().item()
815
+ if 'transform' in config:
816
+ y = config['transform'](y)
817
+ return float(y)
818
+
819
+ else:
820
+ raise NotImplementedError(config['type'])
821
+
822
+ def predict_sequence_property(self, model, config, sequence):
823
+ """Predict property from sequence"""
824
+ inputs = self.tokenizer(
825
+ sequence,
826
+ return_tensors="pt",
827
+ padding=True,
828
+ truncation=True,
829
+ max_length=512
830
+ )
831
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
832
+
833
+ with torch.no_grad():
834
+ outputs = self.esm_model(**inputs)
835
+ embeddings = outputs.last_hidden_state
836
+
837
+ if config['type'] == 'xgboost':
838
+ attention_mask = inputs['attention_mask']
839
+ masked_embeddings = embeddings * attention_mask.unsqueeze(-1)
840
+ sum_embeddings = masked_embeddings.sum(dim=1)
841
+ seq_lengths = attention_mask.sum(dim=1, keepdim=True)
842
+ mean_embeddings = sum_embeddings / seq_lengths
843
+ features = mean_embeddings.cpu().numpy()
844
+
845
+ dmatrix = xgb.DMatrix(features)
846
+ prediction = model.predict(dmatrix)[0]
847
+
848
+ if config.get('inverse_score', False):
849
+ prediction = 1 - prediction
850
+
851
+ elif config['type'] == 'pytorch_cnn':
852
+ prediction = model(embeddings).squeeze().item()
853
+
854
+ if 'transform' in config:
855
+ prediction = config['transform'](prediction)
856
+
857
+ return prediction
858
+
859
+ def predict_binding(self, model, protein_seq, binder_seq, prefer_thresholds: bool = True):
860
+ """Predict (affinity, class_label). If prefer_thresholds=True, label is derived from model.tight/weak thresholds."""
861
+ protein_tokens = self.tokenizer(
862
+ protein_seq, return_tensors="pt",
863
+ padding="max_length", max_length=1024, truncation=True
864
+ )
865
+ binder_tokens = self.tokenizer(
866
+ binder_seq, return_tensors="pt",
867
+ padding="max_length", max_length=1024, truncation=True
868
+ )
869
+ protein_ids = protein_tokens['input_ids'].to(self.device)
870
+ protein_mask= protein_tokens['attention_mask'].to(self.device)
871
+ binder_ids = binder_tokens['input_ids'].to(self.device)
872
+ binder_mask = binder_tokens['attention_mask'].to(self.device)
873
+
874
+ with torch.no_grad():
875
+ reg, logits = model(protein_ids, binder_ids, protein_mask, binder_mask)
876
+ affinity = float(reg.squeeze().item())
877
+ # 1) threshold-based class:
878
+ cls_by_thr = int(model.get_binding_class(affinity))
879
+ # 2) logits-based class:
880
+ cls_by_logit = int(torch.argmax(logits, dim=-1).item())
881
+
882
+ class_names = ['Tight', 'Medium', 'Weak']
883
+ # choose which one you want to show
884
+ cls_idx = cls_by_thr if prefer_thresholds else cls_by_logit
885
+
886
+ # decorate with explicit cutoffs for UI clarity
887
+ if cls_idx == 0:
888
+ label = f"Tight (≥ {model.tight_threshold:.1f})"
889
+ elif cls_idx == 1:
890
+ label = f"Medium ({model.weak_threshold:.1f}–{model.tight_threshold:.1f})"
891
+ else:
892
+ label = f"Weak (< {model.weak_threshold:.1f})"
893
+
894
+ return affinity, label
895
+
896
+
897
+ def predict_binding_smiles(self, model, protein_seq: str, smiles_str: str, prefer_thresholds: bool = True) -> tuple[float, str]:
898
+ prot_vec = self._get_protein_vec(protein_seq) # [1, 1280]
899
+ smiles_vec = self._smiles_embed_mean(smiles_str) # [1, 768]
900
+ with torch.no_grad():
901
+ reg, logits = model(prot_vec, smiles_vec)
902
+ affinity = float(reg.squeeze().item())
903
+ cls_by_thr = int(model.get_binding_class(affinity))
904
+ cls_by_logit = int(torch.argmax(logits, dim=-1).item())
905
+
906
+ cls_idx = cls_by_thr if prefer_thresholds else cls_by_logit
907
+
908
+ if cls_idx == 0:
909
+ label = f"Tight (≥ {model.tight_threshold:.1f})"
910
+ elif cls_idx == 1:
911
+ label = f"Medium ({model.weak_threshold:.1f}–{model.tight_threshold:.1f})"
912
+ else:
913
+ label = f"Weak (< {model.weak_threshold:.1f})"
914
+ return affinity, label
915
+
916
+
917
+ def _features_from_smiles_peptclm(self, s: str) -> np.ndarray:
918
+ return self.smiles_featurizer.embed_list([s])[0]
919
+
920
+
921
+ # ==================== Gradio Interface ====================
922
+
923
+ # Global predictor
924
+ predictor = None
925
+
926
+ def initialize():
927
+ """Initialize the predictor"""
928
+ global predictor
929
+ if predictor is None:
930
+ predictor = UnifiedPeptidePredictor(model_dir="models")
931
+ return predictor
932
+
933
+
934
+ def predict_properties(
935
+ input_text: str,
936
+ input_type: str,
937
+ protein_seq: str,
938
+ # Individual property checkboxes
939
+ hemolysis: bool,
940
+ solubility: bool,
941
+ permeability: bool,
942
+ half_life: bool,
943
+ nonfouling: bool,
944
+ binding_affinity: bool,
945
+ progress=gr.Progress()
946
+ ):
947
+ """Main prediction function"""
948
+
949
+ if not input_text or not input_text.strip():
950
+ return None, "⚠️ Please provide an input sequence"
951
+
952
+ lines = [s.strip() for s in input_text.split("\n") if s.strip()]
953
+
954
+ if input_type == "SMILES":
955
+ bad = [s for s in lines if not is_smiles_like(s)]
956
+ if bad:
957
+ return None, f"⚠️ You selected SMILES but {len(bad)} input line(s) don't look like SMILES. Example bad line: {bad[0][:60]}"
958
+ if binding_affinity and not protein_seq:
959
+ return None, "⚠️ For SMILES binding, please provide a protein sequence in the 'Protein Sequence' box."
960
+ else:
961
+ bad = [s for s in lines if not is_aa_sequence_like(s)]
962
+ if bad:
963
+ return None, f"⚠️ You selected Sequence but {len(bad)} input line(s) don't look like amino-acid sequences. Example bad line: {bad[0][:60]}"
964
+ pred = initialize()
965
+ results = []
966
+
967
+ # Collect selected properties
968
+ selected_properties = []
969
+
970
+ # Map UI checkboxes to your internal model keys
971
+ checkbox_to_keys = {
972
+ 'hemolysis': ['hemolysis_seq', 'hemolysis_smiles'],
973
+ 'solubility': ['solubility_seq', 'solubility_smiles'],
974
+ 'permeability': ['permeability_smiles'], # only smiles in your current config
975
+ 'half_life': ['half_life_seq', 'binding_affinity_smiles'],
976
+ 'nonfouling': ['nonfouling_seq', 'nonfouling_smiles'], # adjust if you have a real cytotox model
977
+ }
978
+ selected_properties = []
979
+ for ui_name, is_selected in {
980
+ 'hemolysis': hemolysis,
981
+ 'solubility': solubility,
982
+ 'permeability': permeability,
983
+ 'half_life': half_life,
984
+ 'nonfouling': nonfouling,
985
+ }.items():
986
+ if not is_selected:
987
+ continue
988
+ # choose the variant that matches the current input type
989
+ keys = checkbox_to_keys.get(ui_name, [])
990
+ for key in keys:
991
+ if key in pred.model_configs:
992
+ expected_input = pred.model_configs[key].get('input', 'sequence')
993
+ if (input_type == 'SMILES' and expected_input == 'smiles') or \
994
+ (input_type == 'Sequence' and expected_input == 'sequence'):
995
+ if key in pred.models:
996
+ selected_properties.append(key)
997
+
998
+ # Process sequences for regular properties
999
+ if selected_properties:
1000
+ sequences = [s.strip() for s in input_text.split('\n') if s.strip()]
1001
+
1002
+ for seq_idx, seq in enumerate(sequences):
1003
+ progress((seq_idx + 1) / len(sequences), f"Processing sequence {seq_idx + 1}/{len(sequences)}")
1004
+
1005
+ for prop in selected_properties:
1006
+ config = pred.model_configs[prop]
1007
+ model = pred.models[prop]
1008
+
1009
+ try:
1010
+ value = pred.predict_property(model, config, seq, input_type)
1011
+
1012
+ stat_key = _base_stat_key(prop)
1013
+ threshold = pred.data_manager.statistics.get(stat_key, {}).get('threshold')
1014
+ if threshold is not None:
1015
+ # which direction?
1016
+ if stat_key in ['hemolysis']: # lower is better
1017
+ label = config['positive_label'] if value < threshold else config['negative_label']
1018
+ else: # higher is better by default for these examples
1019
+ label = config['positive_label'] if value > threshold else config['negative_label']
1020
+ else:
1021
+ label = ""
1022
+
1023
+ # Create clickable property name
1024
+ prop_display = f'<a href="#" onclick="show_distribution(\'{prop}\', {value})">{config["display_name"]}</a>'
1025
+
1026
+ results.append({
1027
+ 'Sequence': seq[:30] + '...' if len(seq) > 30 else seq,
1028
+ 'Property': config["display_name"],
1029
+ 'Prediction': label,
1030
+ 'Value': f"{value:.3f}",
1031
+ 'Unit': config['unit']
1032
+ })
1033
+ except Exception as e:
1034
+ print(f"Error predicting {prop}: {e}")
1035
+
1036
+ # Handle binding affinity separately
1037
+ if binding_affinity and input_text:
1038
+ # Sequence–Sequence binding
1039
+ if input_type == "Sequence":
1040
+ if 'binding_affinity' in pred.models:
1041
+ progress(0.9, "Predicting binding affinity (sequence) ...")
1042
+ if not protein_seq:
1043
+ return None, "⚠️ Please provide a protein sequence for binding prediction."
1044
+ try:
1045
+ binder_seqs = [s.strip() for s in input_text.split('\n') if s.strip()]
1046
+ for binder_seq in binder_seqs:
1047
+ affinity, binding_class = pred.predict_binding(
1048
+ pred.models['binding_affinity'],
1049
+ protein_seq,
1050
+ binder_seq
1051
+ )
1052
+ results.append({
1053
+ 'Sequence': f"Protein–{binder_seq[:20]}...",
1054
+ 'Property': pred.model_configs['binding_affinity']['display_name'],
1055
+ 'Prediction': binding_class, # e.g., Tight/Medium/Weak
1056
+ 'Value': f"{affinity:.3f}",
1057
+ 'Unit': pred.model_configs['binding_affinity']['unit']
1058
+ })
1059
+ except Exception as e:
1060
+ print(f"Error in sequence binding prediction: {e}")
1061
+
1062
+ # Sequence + SMILES binding
1063
+ else: # input_type == "SMILES"
1064
+ if 'binding_affinity_smiles' not in pred.models:
1065
+ return None, "⚠️ SMILES binding model not loaded. Please add the checkpoint to models/ and restart."
1066
+ if not protein_seq:
1067
+ return None, "⚠️ For SMILES binding, please provide a protein sequence."
1068
+ # quick AA check for protein_seq
1069
+ if not is_aa_sequence_like(protein_seq):
1070
+ return None, "⚠️ The provided protein sequence does not look like an amino-acid sequence."
1071
+ progress(0.9, "Predicting binding affinity (SMILES) ...")
1072
+ try:
1073
+ smiles_list = [s.strip() for s in input_text.split('\n') if s.strip()]
1074
+ for smi in smiles_list:
1075
+ affinity, label = pred.predict_binding_smiles(
1076
+ pred.models['binding_affinity_smiles'],
1077
+ protein_seq,
1078
+ smi
1079
+ )
1080
+ results.append({
1081
+ 'Sequence': f"Protein–{smi[:20]}...",
1082
+ 'Property': pred.model_configs['binding_affinity_smiles']['display_name'],
1083
+ 'Prediction': label, # Tight (≥7.5) / Medium (6.0–7.5) / Weak (<6.0)
1084
+ 'Value': f"{affinity:.3f}",
1085
+ 'Unit': pred.model_configs['binding_affinity_smiles']['unit'],
1086
+ })
1087
+ except Exception as e:
1088
+ print(f"Error in SMILES binding prediction: {e}")
1089
+
1090
+ if not results:
1091
+ return None, "⚠️ Please select at least one property to predict"
1092
+
1093
+ # Create summary
1094
+ n_sequences = len(set(r['Sequence'] for r in results))
1095
+ n_properties = len(set(r['Property'] for r in results))
1096
+
1097
+ status = f"✅ Completed {len(results)} predictions ({n_sequences} sequence(s), {n_properties} properties)"
1098
+ if binding_affinity:
1099
+ status += " \n**Binding class cutoffs:** Tight ≥ 7.5, Medium 6.0–7.5, Weak < 6.0"
1100
+
1101
+ return pd.DataFrame(results), status
1102
+
1103
+
1104
+ def show_distribution(property_name, predicted_value=None):
1105
+ """Show distribution plot + info for selected property."""
1106
+ pred = initialize()
1107
+ if not property_name:
1108
+ return None, "Select a property to view its distribution."
1109
+
1110
+ # Get the first property if a list was passed
1111
+ prop = property_name[0] if isinstance(property_name, list) else property_name
1112
+
1113
+ # Generate the plot (works for both binary & continuous)
1114
+ fig = pred.data_manager.get_distribution_plot(prop, predicted_value)
1115
+
1116
+ # Build info panel with correct fields per kind
1117
+ stats = pred.data_manager.statistics.get(prop, {})
1118
+ kind = stats.get("kind", "continuous")
1119
+ info = pred.data_manager.get_property_info(prop)
1120
+
1121
+ if not info:
1122
+ return fig, "No information available for this property."
1123
+
1124
+ title = prop.replace('_', ' ').title()
1125
+
1126
+ if kind == "binary":
1127
+ n_pos = info.get("n_pos", int((stats.get("values") == 1).sum() if "values" in stats else 0))
1128
+ n_neg = info.get("n_neg", int((stats.get("values") == 0).sum() if "values" in stats else 0))
1129
+ total = max(n_pos + n_neg, 1)
1130
+ info_text = f"""
1131
+ ### {title} Information
1132
+
1133
+ **Description:** {info.get('description','')}
1134
+
1135
+ **Statistics (Binary):**
1136
+ - Samples: {info['n_samples']:,}
1137
+ - Positives (1): {n_pos:,} ({n_pos/total:.1%})
1138
+ - Negatives (0): {n_neg:,} ({n_neg/total:.1%})
1139
+ """
1140
+ else:
1141
+ p = info.get("percentiles", {})
1142
+ info_text = f"""
1143
+ ### {title} Information
1144
+
1145
+ **Description:** {info.get('description','')}
1146
+
1147
+ **Statistics:**
1148
+ - Samples: {info['n_samples']:,}
1149
+ - Mean: {info['mean']:.3f} {info['unit']}
1150
+ - Std Dev: {info['std']:.3f}
1151
+ - Range: [{info['min']:.3f}, {info['max']:.3f}]
1152
+
1153
+ **Percentiles:**
1154
+ - 10%: {p.get('10%', float('nan')):.3f}
1155
+ - 25%: {p.get('25%', float('nan')):.3f}
1156
+ - 50% (median): {p.get('50% (median)', float('nan')):.3f}
1157
+ - 75%: {p.get('75%', float('nan')):.3f}
1158
+ - 90%: {p.get('90%', float('nan')):.3f}
1159
+ """
1160
+
1161
+ return fig, info_text
1162
+
1163
+
1164
+ def load_example(example_name):
1165
+ """Load example sequences"""
1166
+ examples = {
1167
+ "T7": ("HAIYPRH", ""),
1168
+ "Protein-Peptide": ("MVHLTPEEKSAVTALWGKVNVDEVGGEALGRLLVVYPWTQRFFESFGDLST", "GIVEQCCTSICSLYQLENYCN")
1169
+ }
1170
+
1171
+ if example_name in examples:
1172
+ if example_name == "Protein-Peptide":
1173
+ return examples[example_name][1], examples[example_name][0] # Binder, Protein
1174
+ else:
1175
+ return examples[example_name][0], ""
1176
+ return "", ""
1177
+
1178
+
1179
+ # ==================== Gradio App ====================
1180
+
1181
+ custom_css = """
1182
+ .gradio-container {
1183
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
1184
+ }
1185
+
1186
+ .gr-button-primary {
1187
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
1188
+ border: none !important;
1189
+ color: white !important;
1190
+ }
1191
+
1192
+ .gr-button-primary:hover {
1193
+ transform: translateY(-1px);
1194
+ box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
1195
+ }
1196
+
1197
+ h1 {
1198
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
1199
+ -webkit-background-clip: text;
1200
+ -webkit-text-fill-color: transparent;
1201
+ font-size: 2.5em !important;
1202
+ text-align: center;
1203
+ margin-bottom: 10px !important;
1204
+ }
1205
+
1206
+ table {
1207
+ font-size: 14px !important;
1208
+ }
1209
+ .property-result:hover {
1210
+ background: #f0f0f0;
1211
+ cursor: pointer;
1212
+ }
1213
+ """
1214
+
1215
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="indigo")) as demo:
1216
+
1217
+ # Header
1218
+ gr.Markdown(
1219
+ """
1220
+ # ☄️ PeptiVerse
1221
+ ### Peptide Property Predictions
1222
+ """
1223
+ )
1224
+
1225
+ with gr.Tabs():
1226
+ # Main Prediction Tab
1227
+ with gr.TabItem("🔬 Predict"):
1228
+ with gr.Row():
1229
+ # Input Section
1230
+ with gr.Column(scale=1):
1231
+ with gr.Group():
1232
+ gr.Markdown("### 📝 Input")
1233
+
1234
+ input_type = gr.Radio(
1235
+ ["Sequence", "SMILES"],
1236
+ label="Input Type",
1237
+ value="Sequence"
1238
+ )
1239
+
1240
+ input_text = gr.Textbox(
1241
+ label="Peptide Sequence(s) / Binder",
1242
+ placeholder="Enter amino acid sequence(s), one per line",
1243
+ lines=6
1244
+ )
1245
+
1246
+ protein_seq = gr.Textbox(
1247
+ label="Protein Sequence (for binding prediction)",
1248
+ placeholder="Enter protein sequence for binding affinity prediction",
1249
+ lines=3,
1250
+ visible=False
1251
+ )
1252
+
1253
+ gr.Markdown("**Examples:**")
1254
+ example_dropdown = gr.Dropdown(
1255
+ choices=["T7","Protein-Peptide"],
1256
+ label="Load Example",
1257
+ interactive=True
1258
+ )
1259
+
1260
+ file_input = gr.File(
1261
+ label="Or Upload File",
1262
+ file_types=[".txt", ".fasta", ".fa"],
1263
+ visible=False
1264
+ )
1265
+
1266
+ # Property Selection
1267
+ with gr.Column(scale=1):
1268
+ with gr.Group():
1269
+ gr.Markdown("### ⚙️ Select Properties")
1270
+
1271
+ with gr.Accordion("Sequence Properties", open=True):
1272
+ hemolysis = gr.Checkbox(label="🩸 Hemolysis ↓", value=True)
1273
+ solubility = gr.Checkbox(label="💧 Solubility ↑", value=True)
1274
+ permeability = gr.Checkbox(label="🪣 Permeability ↑", value=False)
1275
+ half_life = gr.Checkbox(label="⏱️ Half-life ↑", value=False)
1276
+ nonfouling = gr.Checkbox(label="👯 Non-Fouling ↑", value=False)
1277
+ with gr.Accordion("Binding Prediction", open=False):
1278
+ binding_affinity = gr.Checkbox(label="🔗 Binding Affinity ↑", value=False)
1279
+ gr.Markdown("*Requires protein sequence input*")
1280
+ # Distribution Analysis Tab
1281
+ with gr.TabItem("📊 Distributions"):
1282
+ with gr.Row():
1283
+ with gr.Column(scale=1):
1284
+ property_selector = gr.Dropdown(
1285
+ choices=["hemolysis", "solubility", "permeability", "half_life (smiles)",
1286
+ "nonfouling", "binding_affinity"],
1287
+ label="Select Property",
1288
+ value="hemolysis"
1289
+ )
1290
+ test_value = gr.Number(label="Test Value among Distribution", value=None)
1291
+ show_dist_btn = gr.Button("Show Distribution")
1292
+
1293
+ with gr.Column(scale=2):
1294
+ dist_plot_tab = gr.Plot(label="Score Distribution")
1295
+ dist_info_tab = gr.Markdown()
1296
+
1297
+ # Data Documentation Tab
1298
+ with gr.TabItem("📚 Documentation"):
1299
+ file_path = "description.md"
1300
+ try:
1301
+ with open(file_path, "r", encoding="utf-8") as f:
1302
+ markdown_content = f.read()
1303
+ except FileNotFoundError:
1304
+ print(f"Error: The file '{file_path}' was not found.")
1305
+ except Exception as e:
1306
+ print(f"An error occurred: {e}")
1307
+ gr.Markdown(
1308
+ markdown_content
1309
+ )
1310
+
1311
+ # Action Buttons
1312
+ with gr.Row():
1313
+ clear_btn = gr.Button("🗑️ Clear", variant="secondary")
1314
+ predict_btn = gr.Button("🚀 Predict Properties", variant="primary", scale=2)
1315
+
1316
+ # Status
1317
+ status_output = gr.Markdown("")
1318
+
1319
+ # Results Section
1320
+ with gr.Group():
1321
+ gr.Markdown("### 📊 Results")
1322
+ gr.Markdown("*Click on property names to view distribution plots*")
1323
+
1324
+ results_df = gr.Dataframe(
1325
+ headers=["Sequence", "Property", "Prediction", "Value", "Unit"],
1326
+ datatype=["str", "str", "str", "str", "str"],
1327
+ interactive=False
1328
+ )
1329
+
1330
+ # Hidden components for distribution modal
1331
+ with gr.Row(visible=False) as distribution_row:
1332
+ with gr.Column():
1333
+ selected_property = gr.Textbox(visible=False)
1334
+ dist_plot_modal = gr.Plot() # <-- renamed
1335
+ dist_info_modal = gr.Markdown() # <-- renamed
1336
+ close_btn = gr.Button("Close")
1337
+
1338
+ # Footer
1339
+ gr.Markdown(
1340
+ """
1341
+ ---
1342
+ <div style='text-align: center; color: #6b7280;'>
1343
+ <p>Models: ESM2-650M embeddings + XGBoost/CNN classifiers</p>
1344
+ <p style='font-size: 0.9em;'>Click on property names in results to view training data distributions</p>
1345
+ </div>
1346
+ """
1347
+ )
1348
+
1349
+ # Event Handlers
1350
+ def update_visibility(binding_checked):
1351
+ return gr.update(visible=binding_checked)
1352
+
1353
+ binding_affinity.change(
1354
+ update_visibility,
1355
+ inputs=[binding_affinity],
1356
+ outputs=[protein_seq]
1357
+ )
1358
+
1359
+ example_dropdown.change(
1360
+ load_example,
1361
+ inputs=[example_dropdown],
1362
+ outputs=[input_text, protein_seq]
1363
+ )
1364
+
1365
+ predict_btn.click(
1366
+ predict_properties,
1367
+ inputs=[
1368
+ input_text, input_type, protein_seq,
1369
+ hemolysis, solubility, permeability,
1370
+ half_life, nonfouling,
1371
+ binding_affinity
1372
+ ],
1373
+ outputs=[results_df, status_output]
1374
+ )
1375
+
1376
+ clear_btn.click(
1377
+ lambda: ("", "", None, ""),
1378
+ outputs=[input_text, protein_seq, results_df, status_output]
1379
+ )
1380
+
1381
+ # Add JavaScript for clickable property names
1382
+ demo.load(js="""
1383
+ function show_distribution(property, value) {
1384
+ // This would open a modal with the distribution
1385
+ console.log('Show distribution for', property, 'with value', value);
1386
+ }
1387
+ """)
1388
+ show_dist_btn.click(
1389
+ show_distribution,
1390
+ inputs=[property_selector, test_value],
1391
+ outputs=[dist_plot_tab, dist_info_tab]
1392
+ )
1393
+
1394
+ if __name__ == "__main__":
1395
+ print("Initializing models...")
1396
+ initialize()
1397
+ print("Ready!")
1398
+ demo.launch(share=True)
description.md ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Data Sources and Methods
2
+
3
+ ### Training Data Collection
4
+
5
+ Our models are trained on curated datasets from multiple sources:
6
+
7
+ #### Hemolysis Dataset
8
+ - **Primary Source:** [peptideBERT](https://pubs.acs.org/doi/abs/10.1021/acs.jpclett.3c02398)
9
+ - **Secondary Source:** the Database of Antimicrobial Activity and Structure of Peptides (DBAASPv3)
10
+ - **Size:** 9,316 peptides, with 19.6% being positive (hemolytic) and 80.4% being negative (nonhemolytic)
11
+ - **Description:** Probability of peptide disrupting red blood cell membranes.
12
+ - **Download:** [hemolysis_training_data.csv](#)
13
+
14
+ #### Solubility Dataset
15
+ - **Primary Source:** [peptideBERT](https://pubs.acs.org/doi/abs/10.1021/acs.jpclett.3c02398)
16
+ - **Secondary Source:** PROSO-II
17
+ - **Size:** 18,453 sequences, with 47.6% being labeled as positives and 52.4% being labeled as negatives
18
+ - **Description:** Probability of peptide remaining dissolved in aqueous conditions.
19
+ - **Download:** [solubility_training_data.csv](#)
20
+
21
+ #### Non-Fouling Dataset
22
+ - **Primary Source:** [peptideBERT](https://pubs.acs.org/doi/abs/10.1021/acs.jpclett.3c02398)
23
+ - **Secondary Source:** [Classifying antimicrobial and multifunctional peptides with Bayesian network models](https://doi.org/10.1002/pep2.24079)
24
+ - **Size:** 3,600 positive, 13,585 negative
25
+ - **Description:** A nonfouling peptide resists nonspecific interactions and protein adsorption.
26
+ - **Download:** [solubility_training_data.csv](#)
27
+
28
+ #### Permeability Dataset
29
+ - **Primary Source:** [PepLand](https://arxiv.org/abs/2311.04419)
30
+ - **Secondary Source:** CycPeptMPDB
31
+ - **Size:** 1162 positive and negative for nanonical samples each (22 relevant cell-penetrating peptide databases by compiling literature on existing cell-penetrating peptide prediction models ); CycPeptMPDB provides extra 7334 cyclic peptides
32
+ - **Description:** Probability of peptide penetrating the cell membrane.
33
+ - **Download:** [binding_affinity_training_data.csv](#)
34
+
35
+ #### Half-life Dataset
36
+ - **Primary Source:** [Thpdb2](https://doi.org/10.1016/j.drudis.2024.104047), [PepTherDia](https://doi.org/10.1016/j.drudis.2021.02.019), [peplife](https://www.nature.com/articles/srep36617)
37
+ - **Size:** 105 wt, 275 wt+noncanonical, human-only
38
+ - **Clean-ups:** Data are all transformed into log\(hour\)
39
+ - **Download:** [binding_affinity_training_data.csv](#)
40
+
41
+
42
+ #### Binding Affinity Dataset
43
+ - **Primary Source:** [PepLand](https://arxiv.org/abs/2311.04419)
44
+ - **Size:** 1,781 protein-peptide complexes, canonical and non-canonical
45
+ - **Description:** Binding probability normalized in PepLand already. It's a combination of IC50/EC50.
46
+ - **Quality:** Binding class cutoffs: Tight ≥ 7.5, Medium 6.0–7.5, Weak < 6.0
47
+ - **Download:** [binding_affinity_training_data.csv](#)
48
+
49
+ ### Model Architecture
50
+
51
+ - **Sequence Embeddings:** ESM-2 650M parameter model
52
+ - **XGBoost Models:** Gradient boosting on pooled ESM embeddings
53
+ - **CNN Models:** 1D convolutional networks with attention mechanisms
54
+ - **Binding Model:** Cross-attention between protein and peptide representations
55
+
56
+ ### Citation
57
+
58
+ If you use this tool, please cite:
59
+ ```
60
+ @article{peptiprop2024,
61
+ title={PeptiProp: Unified Platform for Peptide Property Prediction},
62
+ author={Your Name et al.},
63
+ journal={Journal Name},
64
+ year={2024}
65
+ }
66
+ ```
67
+
68
+ ### Contact
69
+
70
+ For questions or collaborations: [contact@example.com](mailto:contact@example.com)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ pandas>=2.0.0
3
+ numpy>=1.24.0
4
+ plotly>=5.14.0
5
+ torch>=2.0.0
6
+ transformers==4.46.0
7
+ scikit-learn>=1.3.0
8
+ biopython>=1.81
9
+ rdkit>=2023.3.1
10
+ seaborn
11
+ SmielsPE
tokenizer/__init__.py ADDED
File without changes
tokenizer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (136 Bytes). View file
 
tokenizer/__pycache__/my_tokenizers.cpython-310.pyc ADDED
Binary file (16.2 kB). View file
 
tokenizer/my_tokenizers.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import os
3
+ import re
4
+ from typing import List, Optional
5
+ from transformers import PreTrainedTokenizer
6
+ from SmilesPE.tokenizer import SPE_Tokenizer
7
+ import torch
8
+
9
+ def load_vocab(vocab_file):
10
+ """Loads a vocabulary file into a dictionary."""
11
+ vocab = collections.OrderedDict()
12
+ with open(vocab_file, "r", encoding="utf-8") as reader:
13
+ tokens = reader.readlines()
14
+ for index, token in enumerate(tokens):
15
+ token = token.rstrip("\n")
16
+ vocab[token] = index
17
+ return vocab
18
+
19
+ class Atomwise_Tokenizer(object):
20
+ """Run atom-level SMILES tokenization"""
21
+
22
+ def __init__(self):
23
+ """ Constructs a atom-level Tokenizer.
24
+ """
25
+ # self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
26
+ self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
27
+
28
+ self.regex = re.compile(self.regex_pattern)
29
+
30
+ def tokenize(self, text):
31
+ """ Basic Tokenization of a SMILES.
32
+ """
33
+ tokens = [token for token in self.regex.findall(text)]
34
+ return tokens
35
+
36
+ class SMILES_SPE_Tokenizer(PreTrainedTokenizer):
37
+ r"""
38
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
39
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
40
+ should refer to the superclass for more information regarding methods.
41
+ Args:
42
+ vocab_file (:obj:`string`):
43
+ File containing the vocabulary.
44
+ spe_file (:obj:`string`):
45
+ File containing the trained SMILES Pair Encoding vocabulary.
46
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
47
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
48
+ token instead.
49
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
50
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
51
+ for sequence classification or for a text and a question for question answering.
52
+ It is also used as the last token of a sequence built with special tokens.
53
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
54
+ The token used for padding, for example when batching sequences of different lengths.
55
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
56
+ The classifier token which is used when doing sequence classification (classification of the whole
57
+ sequence instead of per-token classification). It is the first token of the sequence when built with
58
+ special tokens.
59
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
60
+ The token used for masking values. This is the token used when training this model with masked language
61
+ modeling. This is the token which the model will try to predict.
62
+ """
63
+
64
+ def __init__(self, vocab_file, spe_file,
65
+ unk_token="[UNK]",
66
+ sep_token="[SEP]",
67
+ pad_token="[PAD]",
68
+ cls_token="[CLS]",
69
+ mask_token="[MASK]",
70
+ **kwargs):
71
+ if not os.path.isfile(vocab_file):
72
+ raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file))
73
+ if not os.path.isfile(spe_file):
74
+ raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file))
75
+
76
+ self.vocab = load_vocab(vocab_file)
77
+ self.spe_vocab = open(spe_file, 'r', encoding='utf-8')
78
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
79
+ self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab)
80
+
81
+ super().__init__(
82
+ unk_token=unk_token,
83
+ sep_token=sep_token,
84
+ pad_token=pad_token,
85
+ cls_token=cls_token,
86
+ mask_token=mask_token,
87
+ **kwargs)
88
+
89
+ @property
90
+ def vocab_size(self):
91
+ return len(self.vocab)
92
+
93
+ def get_vocab(self):
94
+ return dict(self.vocab, **self.added_tokens_encoder)
95
+
96
+ def _tokenize(self, text):
97
+ return self.spe_tokenizer.tokenize(text).split(' ')
98
+
99
+ def _convert_token_to_id(self, token):
100
+ """ Converts a token (str) in an id using the vocab. """
101
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
102
+
103
+ # changed encode and decode functions
104
+ def encode(self, token_array):
105
+ token_ids = []
106
+ token_ids.append(2)
107
+ for token in token_array:
108
+ id = self._convert_token_to_id(token)
109
+ token_ids.append(id)
110
+ token_ids.append(3)
111
+ token_ids = torch.tensor([token_ids])
112
+ attn_mask = torch.ones_like(token_ids)
113
+ return {'input_ids': token_ids, 'attention_mask': attn_mask}
114
+
115
+ def decode(self, token_ids, skip_special_tokens=True):
116
+ token_ids = token_ids.squeeze(0).cpu().tolist()
117
+ token_array = []
118
+ for idx in token_ids:
119
+ if idx == 3: # Stop decoding when token ID 3 is encountered
120
+ break
121
+ if skip_special_tokens and idx in self.all_special_ids:
122
+ continue
123
+ token = self._convert_id_to_token(idx)
124
+ token_array.append(token)
125
+ sequence = "".join(token_array)
126
+ return sequence
127
+
128
+ def batch_decode(self, batch_token_ids, skip_special_tokens=True):
129
+ sequences = []
130
+ for token_ids in batch_token_ids:
131
+ sequences.append(self.decode(token_ids))
132
+ return sequences
133
+
134
+ def get_token_split(self, token_ids):
135
+ if isinstance(token_ids, torch.Tensor):
136
+ token_ids = token_ids.cpu().tolist()
137
+
138
+ token_array = []
139
+ for seq_ids in token_ids:
140
+ seq_array = []
141
+ for id in seq_ids:
142
+ token = self._convert_id_to_token(id)
143
+ seq_array.append(token)
144
+ token_array.append(seq_array)
145
+
146
+ return token_array
147
+
148
+ def _convert_id_to_token(self, index):
149
+ """Converts an index (integer) in a token (str) using the vocab."""
150
+ return self.ids_to_tokens.get(index, self.unk_token)
151
+
152
+ def convert_tokens_to_string(self, tokens):
153
+ """ Converts a sequence of tokens (string) in a single string. """
154
+ out_string = " ".join(tokens).replace(" ##", "").strip()
155
+ return out_string
156
+
157
+ def build_inputs_with_special_tokens(
158
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
159
+ ) -> List[int]:
160
+ """
161
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
162
+ by concatenating and adding special tokens.
163
+ A BERT sequence has the following format:
164
+ - single sequence: ``[CLS] X [SEP]``
165
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
166
+ Args:
167
+ token_ids_0 (:obj:`List[int]`):
168
+ List of IDs to which the special tokens will be added
169
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
170
+ Optional second list of IDs for sequence pairs.
171
+ Returns:
172
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
173
+ """
174
+ if token_ids_1 is None:
175
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
176
+ cls = [self.cls_token_id]
177
+ sep = [self.sep_token_id]
178
+ return cls + token_ids_0 + sep + token_ids_1 + sep
179
+
180
+ def get_special_tokens_mask(
181
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
182
+ ) -> List[int]:
183
+ """
184
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
185
+ special tokens using the tokenizer ``prepare_for_model`` method.
186
+ Args:
187
+ token_ids_0 (:obj:`List[int]`):
188
+ List of ids.
189
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
190
+ Optional second list of IDs for sequence pairs.
191
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
192
+ Set to True if the token list is already formatted with special tokens for the model
193
+ Returns:
194
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
195
+ """
196
+
197
+ if already_has_special_tokens:
198
+ if token_ids_1 is not None:
199
+ raise ValueError(
200
+ "You should not supply a second sequence if the provided sequence of "
201
+ "ids is already formated with special tokens for the model."
202
+ )
203
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
204
+
205
+ if token_ids_1 is not None:
206
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
207
+ return [1] + ([0] * len(token_ids_0)) + [1]
208
+
209
+ def create_token_type_ids_from_sequences(
210
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
211
+ ) -> List[int]:
212
+ """
213
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
214
+ A BERT sequence pair mask has the following format:
215
+ ::
216
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
217
+ | first sequence | second sequence |
218
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
219
+ Args:
220
+ token_ids_0 (:obj:`List[int]`):
221
+ List of ids.
222
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
223
+ Optional second list of IDs for sequence pairs.
224
+ Returns:
225
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
226
+ sequence(s).
227
+ """
228
+ sep = [self.sep_token_id]
229
+ cls = [self.cls_token_id]
230
+ if token_ids_1 is None:
231
+ return len(cls + token_ids_0 + sep) * [0]
232
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
233
+
234
+ def save_vocabulary(self, vocab_path):
235
+ """
236
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
237
+ Args:
238
+ vocab_path (:obj:`str`):
239
+ The directory in which to save the vocabulary.
240
+ Returns:
241
+ :obj:`Tuple(str)`: Paths to the files saved.
242
+ """
243
+ index = 0
244
+ vocab_file = vocab_path
245
+ with open(vocab_file, "w", encoding="utf-8") as writer:
246
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
247
+ if index != token_index:
248
+ index = token_index
249
+ writer.write(token + "\n")
250
+ index += 1
251
+ return (vocab_file,)
252
+
253
+ class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer):
254
+ r"""
255
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
256
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
257
+ should refer to the superclass for more information regarding methods.
258
+ Args:
259
+ vocab_file (:obj:`string`):
260
+ File containing the vocabulary.
261
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
262
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
263
+ token instead.
264
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
265
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
266
+ for sequence classification or for a text and a question for question answering.
267
+ It is also used as the last token of a sequence built with special tokens.
268
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
269
+ The token used for padding, for example when batching sequences of different lengths.
270
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
271
+ The classifier token which is used when doing sequence classification (classification of the whole
272
+ sequence instead of per-token classification). It is the first token of the sequence when built with
273
+ special tokens.
274
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
275
+ The token used for masking values. This is the token used when training this model with masked language
276
+ modeling. This is the token which the model will try to predict.
277
+ """
278
+
279
+ def __init__(
280
+ self,
281
+ vocab_file,
282
+ unk_token="[UNK]",
283
+ sep_token="[SEP]",
284
+ pad_token="[PAD]",
285
+ cls_token="[CLS]",
286
+ mask_token="[MASK]",
287
+ **kwargs
288
+ ):
289
+ super().__init__(
290
+ unk_token=unk_token,
291
+ sep_token=sep_token,
292
+ pad_token=pad_token,
293
+ cls_token=cls_token,
294
+ mask_token=mask_token,
295
+ **kwargs,
296
+ )
297
+
298
+ if not os.path.isfile(vocab_file):
299
+ raise ValueError(
300
+ "Can't find a vocabulary file at path '{}'.".format(vocab_file)
301
+ )
302
+ self.vocab = load_vocab(vocab_file)
303
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
304
+ self.tokenizer = Atomwise_Tokenizer()
305
+
306
+ @property
307
+ def vocab_size(self):
308
+ return len(self.vocab)
309
+
310
+ def get_vocab(self):
311
+ return dict(self.vocab, **self.added_tokens_encoder)
312
+
313
+
314
+ def _tokenize(self, text):
315
+ return self.tokenizer.tokenize(text)
316
+
317
+ def _convert_token_to_id(self, token):
318
+ """ Converts a token (str) in an id using the vocab. """
319
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
320
+
321
+ def _convert_id_to_token(self, index):
322
+ """Converts an index (integer) in a token (str) using the vocab."""
323
+ return self.ids_to_tokens.get(index, self.unk_token)
324
+
325
+ def convert_tokens_to_string(self, tokens):
326
+ """ Converts a sequence of tokens (string) in a single string. """
327
+ out_string = " ".join(tokens).replace(" ##", "").strip()
328
+ return out_string
329
+
330
+ def build_inputs_with_special_tokens(
331
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
332
+ ) -> List[int]:
333
+ """
334
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
335
+ by concatenating and adding special tokens.
336
+ A BERT sequence has the following format:
337
+ - single sequence: ``[CLS] X [SEP]``
338
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
339
+ Args:
340
+ token_ids_0 (:obj:`List[int]`):
341
+ List of IDs to which the special tokens will be added
342
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
343
+ Optional second list of IDs for sequence pairs.
344
+ Returns:
345
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
346
+ """
347
+ if token_ids_1 is None:
348
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
349
+ cls = [self.cls_token_id]
350
+ sep = [self.sep_token_id]
351
+ return cls + token_ids_0 + sep + token_ids_1 + sep
352
+
353
+ def get_special_tokens_mask(
354
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
355
+ ) -> List[int]:
356
+ """
357
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
358
+ special tokens using the tokenizer ``prepare_for_model`` method.
359
+ Args:
360
+ token_ids_0 (:obj:`List[int]`):
361
+ List of ids.
362
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
363
+ Optional second list of IDs for sequence pairs.
364
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
365
+ Set to True if the token list is already formatted with special tokens for the model
366
+ Returns:
367
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
368
+ """
369
+
370
+ if already_has_special_tokens:
371
+ if token_ids_1 is not None:
372
+ raise ValueError(
373
+ "You should not supply a second sequence if the provided sequence of "
374
+ "ids is already formated with special tokens for the model."
375
+ )
376
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
377
+
378
+ if token_ids_1 is not None:
379
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
380
+ return [1] + ([0] * len(token_ids_0)) + [1]
381
+
382
+ def create_token_type_ids_from_sequences(
383
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
384
+ ) -> List[int]:
385
+ """
386
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
387
+ A BERT sequence pair mask has the following format:
388
+ ::
389
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
390
+ | first sequence | second sequence |
391
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
392
+ Args:
393
+ token_ids_0 (:obj:`List[int]`):
394
+ List of ids.
395
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
396
+ Optional second list of IDs for sequence pairs.
397
+ Returns:
398
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
399
+ sequence(s).
400
+ """
401
+ sep = [self.sep_token_id]
402
+ cls = [self.cls_token_id]
403
+ if token_ids_1 is None:
404
+ return len(cls + token_ids_0 + sep) * [0]
405
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
406
+
407
+ def save_vocabulary(self, vocab_path):
408
+ """
409
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
410
+ Args:
411
+ vocab_path (:obj:`str`):
412
+ The directory in which to save the vocabulary.
413
+ Returns:
414
+ :obj:`Tuple(str)`: Paths to the files saved.
415
+ """
416
+ index = 0
417
+ vocab_file = vocab_path
418
+ with open(vocab_file, "w", encoding="utf-8") as writer:
419
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
420
+ if index != token_index:
421
+ index = token_index
422
+ writer.write(token + "\n")
423
+ index += 1
424
+ return (vocab_file,)
tokenizer/new_splits.txt ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ c 1
2
+ c 2
3
+ c 3
4
+ c 4
5
+ c 5
6
+ c 6
7
+ c 7
8
+ c 8
9
+ c 9
10
+ ( c1
11
+ ( c2
12
+ c1 )
13
+ c2 )
14
+ n 1
15
+ n 2
16
+ n 3
17
+ n 4
18
+ n 5
19
+ n 6
20
+ n 7
21
+ n 8
22
+ n 9
23
+ ( n1
24
+ ( n2
25
+ n1 )
26
+ n2 )
27
+ O 1
28
+ O 2
29
+ O 3
30
+ O 4
31
+ O 5
32
+ O 6
33
+ O 7
34
+ O 8
35
+ O 9
36
+ ( O1
37
+ ( O2
38
+ O2 )
39
+ O2 )
40
+ = O
41
+ = C
42
+ = c
43
+ = N
44
+ = n
45
+ =C C
46
+ =C N
47
+ =C c
48
+ =c c
49
+ =N C
50
+ =N c
51
+ =n C
52
+ =n c
53
+ # N
54
+ # C
55
+ #N C
56
+ #C C
57
+ #C N
58
+ #N N
59
+ ( C
60
+ C )
61
+ ( O
62
+ O )
63
+ ( N
64
+ N )
65
+ Br c
66
+ ( =O
67
+ (=O )
68
+ C (=O)
69
+ C =O
70
+ C =N
71
+ C #N
72
+ C #C
73
+ C C
74
+ CC C
75
+ CC N
76
+ CC O
77
+ CC S
78
+ CC c
79
+ CC n
80
+ C N
81
+ CN C
82
+ CN c
83
+ C O
84
+ CO C
85
+ CO N
86
+ CO c
87
+ C S
88
+ CS C
89
+ CS S
90
+ CS c
91
+ C c
92
+ Cl c
93
+ C n
94
+ F c
95
+ N C
96
+ NC C
97
+ NC c
98
+ N N
99
+ N O
100
+ N c
101
+ N n
102
+ O C
103
+ OC C
104
+ OC O
105
+ OC c
106
+ O N
107
+ O O
108
+ O c
109
+ S C
110
+ SC C
111
+ SC c
112
+ S S
113
+ S c
114
+ c c
115
+ cc c
116
+ cc n
117
+ cc o
118
+ cc s
119
+ cc cc
120
+ c n
121
+ cn c
122
+ cn n
123
+ c o
124
+ co c
125
+ c s
126
+ cs c
127
+ cs n
128
+ n c
129
+ nc c
130
+ nc n
131
+ nc o
132
+ nc s
133
+ n n
134
+ nn c
135
+ nn n
136
+ n o
137
+ no c
138
+ no n
139
+ n s
140
+ ns c
141
+ ns n
142
+ o c
143
+ oc c
144
+ o n
145
+ s c
146
+ sc c
147
+ sc n
148
+ s n
149
+ N P
150
+ P N
151
+ C P
152
+ P C
153
+ N S
154
+ S N
155
+ C S
156
+ S C
157
+ S P
158
+ P S
159
+ C I
tokenizer/new_vocab.txt ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [PAD]
2
+ [UNK]
3
+ [CLS]
4
+ [SEP]
5
+ [MASK]
6
+ #
7
+ %
8
+ (
9
+ )
10
+ +
11
+ -
12
+ /
13
+ 0
14
+ 1
15
+ 2
16
+ 3
17
+ 4
18
+ 5
19
+ 6
20
+ 7
21
+ 8
22
+ 9
23
+ =
24
+ @
25
+ A
26
+ B
27
+ Br
28
+ Brc
29
+ C
30
+ CC
31
+ CCC
32
+ CCN
33
+ CCO
34
+ CCS
35
+ CCc
36
+ CCn
37
+ CN
38
+ CNC
39
+ CNc
40
+ CO
41
+ COC
42
+ CON
43
+ COc
44
+ CS
45
+ CSC
46
+ CSS
47
+ CSc
48
+ Cc
49
+ Cl
50
+ Clc
51
+ Cn
52
+ F
53
+ Fc
54
+ H
55
+ I
56
+ K
57
+ L
58
+ M
59
+ N
60
+ NC
61
+ NCC
62
+ NCc
63
+ NN
64
+ NO
65
+ Nc
66
+ Nn
67
+ O
68
+ OC
69
+ OCC
70
+ OCO
71
+ OCc
72
+ ON
73
+ OO
74
+ Oc
75
+ P
76
+ R
77
+ S
78
+ SC
79
+ SCC
80
+ SCc
81
+ SS
82
+ Sc
83
+ T
84
+ X
85
+ Z
86
+ [
87
+ \\
88
+ (/
89
+ ]
90
+ a
91
+ b
92
+ c
93
+ cc
94
+ ccc
95
+ cccc
96
+ ccn
97
+ cco
98
+ ccs
99
+ cn
100
+ cnc
101
+ cnn
102
+ co
103
+ coc
104
+ cs
105
+ csc
106
+ csn
107
+ e
108
+ g
109
+ i
110
+ l
111
+ n
112
+ nc
113
+ ncc
114
+ ncn
115
+ nco
116
+ ncs
117
+ nn
118
+ nnc
119
+ nnn
120
+ no
121
+ noc
122
+ non
123
+ ns
124
+ nsc
125
+ nsn
126
+ o
127
+ oc
128
+ occ
129
+ on
130
+ p
131
+ r
132
+ s
133
+ sc
134
+ scc
135
+ scn
136
+ sn
137
+ t
138
+ c1
139
+ c2
140
+ c3
141
+ c4
142
+ c5
143
+ c6
144
+ c7
145
+ c8
146
+ c9
147
+ n1
148
+ n2
149
+ n3
150
+ n4
151
+ n5
152
+ n6
153
+ n7
154
+ n8
155
+ n9
156
+ O1
157
+ O2
158
+ O3
159
+ O4
160
+ O5
161
+ O6
162
+ O7
163
+ O8
164
+ O9
165
+ (c1
166
+ (c2
167
+ c1)
168
+ c2)
169
+ (n1
170
+ (n2
171
+ n1)
172
+ n2)
173
+ (O1
174
+ (O2
175
+ O2)
176
+ =O
177
+ =C
178
+ =c
179
+ =N
180
+ =n
181
+ =CC
182
+ =CN
183
+ =Cc
184
+ =cc
185
+ =NC
186
+ =Nc
187
+ =nC
188
+ =nc
189
+ #C
190
+ #CC
191
+ #CN
192
+ #N
193
+ #NC
194
+ #NN
195
+ (C
196
+ C)
197
+ (O
198
+ O)
199
+ (N
200
+ N)
201
+ NP
202
+ PN
203
+ CP
204
+ PC
205
+ NS
206
+ SN
207
+ SP
208
+ PS
209
+ C(=O)
210
+ (/Br)
211
+ (/C#N)
212
+ (/C)
213
+ (/C=N)
214
+ (/C=O)
215
+ (/CBr)
216
+ (/CC)
217
+ (/CCC)
218
+ (/CCF)
219
+ (/CCN)
220
+ (/CCO)
221
+ (/CCl)
222
+ (/CI)
223
+ (/CN)
224
+ (/CO)
225
+ (/CS)
226
+ (/Cl)
227
+ (/F)
228
+ (/I)
229
+ (/N)
230
+ (/NC)
231
+ (/NCC)
232
+ (/NO)
233
+ (/O)
234
+ (/OC)
235
+ (/OCC)
236
+ (/S)
237
+ (/SC)
238
+ (=C)
239
+ (=C/C)
240
+ (=C/F)
241
+ (=C/I)
242
+ (=C/N)
243
+ (=C/O)
244
+ (=CBr)
245
+ (=CC)
246
+ (=CCF)
247
+ (=CCN)
248
+ (=CCO)
249
+ (=CCl)
250
+ (=CF)
251
+ (=CI)
252
+ (=CN)
253
+ (=CO)
254
+ (=C\\C)
255
+ (=C\\F)
256
+ (=C\\I)
257
+ (=C\\N)
258
+ (=C\\O)
259
+ (=N)
260
+ (=N/C)
261
+ (=N/N)
262
+ (=N/O)
263
+ (=NBr)
264
+ (=NC)
265
+ (=NCC)
266
+ (=NCl)
267
+ (=NN)
268
+ (=NO)
269
+ (=NOC)
270
+ (=N\\C)
271
+ (=N\\N)
272
+ (=N\\O)
273
+ (=O)
274
+ (=S)
275
+ (B)
276
+ (Br)
277
+ (C#C)
278
+ (C#CC)
279
+ (C#CI)
280
+ (C#CO)
281
+ (C#N)
282
+ (C#SN)
283
+ (C)
284
+ (C=C)
285
+ (C=CF)
286
+ (C=CI)
287
+ (C=N)
288
+ (C=NN)
289
+ (C=NO)
290
+ (C=O)
291
+ (C=S)
292
+ (CBr)
293
+ (CC#C)
294
+ (CC#N)
295
+ (CC)
296
+ (CC=C)
297
+ (CC=O)
298
+ (CCBr)
299
+ (CCC)
300
+ (CCCC)
301
+ (CCCF)
302
+ (CCCI)
303
+ (CCCN)
304
+ (CCCO)
305
+ (CCCS)
306
+ (CCCl)
307
+ (CCF)
308
+ (CCI)
309
+ (CCN)
310
+ (CCNC)
311
+ (CCNN)
312
+ (CCNO)
313
+ (CCO)
314
+ (CCOC)
315
+ (CCON)
316
+ (CCS)
317
+ (CCSC)
318
+ (CCl)
319
+ (CF)
320
+ (CI)
321
+ (CN)
322
+ (CN=O)
323
+ (CNC)
324
+ (CNCC)
325
+ (CNCO)
326
+ (CNN)
327
+ (CNNC)
328
+ (CNO)
329
+ (CNOC)
330
+ (CO)
331
+ (COC)
332
+ (COCC)
333
+ (COCI)
334
+ (COCN)
335
+ (COCO)
336
+ (COF)
337
+ (CON)
338
+ (COO)
339
+ (CS)
340
+ (CSC)
341
+ (CSCC)
342
+ (CSCF)
343
+ (CSO)
344
+ (Cl)
345
+ (F)
346
+ (I)
347
+ (N)
348
+ (N=N)
349
+ (N=NO)
350
+ (N=O)
351
+ (N=S)
352
+ (NBr)
353
+ (NC#N)
354
+ (NC)
355
+ (NC=N)
356
+ (NC=O)
357
+ (NC=S)
358
+ (NCBr)
359
+ (NCC)
360
+ (NCCC)
361
+ (NCCF)
362
+ (NCCN)
363
+ (NCCO)
364
+ (NCCS)
365
+ (NCCl)
366
+ (NCNC)
367
+ (NCO)
368
+ (NCS)
369
+ (NCl)
370
+ (NN)
371
+ (NN=O)
372
+ (NNC)
373
+ (NO)
374
+ (NOC)
375
+ (O)
376
+ (OC#N)
377
+ (OC)
378
+ (OC=C)
379
+ (OC=O)
380
+ (OC=S)
381
+ (OCBr)
382
+ (OCC)
383
+ (OCCC)
384
+ (OCCF)
385
+ (OCCI)
386
+ (OCCN)
387
+ (OCCO)
388
+ (OCCS)
389
+ (OCCl)
390
+ (OCF)
391
+ (OCI)
392
+ (OCO)
393
+ (OCOC)
394
+ (OCON)
395
+ (OCSC)
396
+ (OCl)
397
+ (OI)
398
+ (ON)
399
+ (OO)
400
+ (OOC)
401
+ (OOCC)
402
+ (OOSN)
403
+ (OSC)
404
+ (P)
405
+ (S)
406
+ (SC#N)
407
+ (SC)
408
+ (SCC)
409
+ (SCCC)
410
+ (SCCF)
411
+ (SCCN)
412
+ (SCCO)
413
+ (SCCS)
414
+ (SCCl)
415
+ (SCF)
416
+ (SCN)
417
+ (SCOC)
418
+ (SCSC)
419
+ (SCl)
420
+ (SI)
421
+ (SN)
422
+ (SN=O)
423
+ (SO)
424
+ (SOC)
425
+ (SOOO)
426
+ (SS)
427
+ (SSC)
428
+ (SSCC)
429
+ ([At])
430
+ ([O-])
431
+ ([O])
432
+ ([S-])
433
+ (\\Br)
434
+ (\\C#N)
435
+ (\\C)
436
+ (\\C=N)
437
+ (\\C=O)
438
+ (\\CBr)
439
+ (\\CC)
440
+ (\\CCC)
441
+ (\\CCO)
442
+ (\\CCl)
443
+ (\\CF)
444
+ (\\CN)
445
+ (\\CNC)
446
+ (\\CO)
447
+ (\\COC)
448
+ (\\Cl)
449
+ (\\F)
450
+ (\\I)
451
+ (\\N)
452
+ (\\NC)
453
+ (\\NCC)
454
+ (\\NN)
455
+ (\\NO)
456
+ (\\NOC)
457
+ (\\O)
458
+ (\\OC)
459
+ (\\OCC)
460
+ (\\ON)
461
+ (\\S)
462
+ (\\SC)
463
+ (\\SCC)
464
+ [Ag+]
465
+ [Ag-4]
466
+ [Ag]
467
+ [Al-3]
468
+ [Al]
469
+ [As+]
470
+ [AsH3]
471
+ [AsH]
472
+ [As]
473
+ [At]
474
+ [B-]
475
+ [B@-]
476
+ [B@@-]
477
+ [BH-]
478
+ [BH2-]
479
+ [BH3-]
480
+ [B]
481
+ [Ba]
482
+ [Br+2]
483
+ [BrH]
484
+ [Br]
485
+ [C+]
486
+ [C-]
487
+ [C@@H]
488
+ [C@@]
489
+ [C@H]
490
+ [C@]
491
+ [CH-]
492
+ [CH2]
493
+ [CH3]
494
+ [CH]
495
+ [C]
496
+ [CaH2]
497
+ [Ca]
498
+ [Cl+2]
499
+ [Cl+3]
500
+ [Cl+]
501
+ [Cs]
502
+ [FH]
503
+ [F]
504
+ [H]
505
+ [He]
506
+ [I+2]
507
+ [I+3]
508
+ [I+]
509
+ [IH]
510
+ [I]
511
+ [K]
512
+ [Kr]
513
+ [Li+]
514
+ [LiH]
515
+ [MgH2]
516
+ [Mg]
517
+ [N+]
518
+ [N-]
519
+ [N@+]
520
+ [N@@+]
521
+ [N@@]
522
+ [N@]
523
+ [NH+]
524
+ [NH-]
525
+ [NH2+]
526
+ [NH3]
527
+ [NH]
528
+ [N]
529
+ [Na]
530
+ [O+]
531
+ [O-]
532
+ [OH+]
533
+ [OH2]
534
+ [OH]
535
+ [O]
536
+ [P+]
537
+ [P@+]
538
+ [P@@+]
539
+ [P@@]
540
+ [P@]
541
+ [PH2]
542
+ [PH]
543
+ [P]
544
+ [Ra]
545
+ [Rb]
546
+ [S+]
547
+ [S-]
548
+ [S@+]
549
+ [S@@+]
550
+ [S@@]
551
+ [S@]
552
+ [SH+]
553
+ [SH2]
554
+ [SH]
555
+ [S]
556
+ [Se+]
557
+ [Se-2]
558
+ [SeH2]
559
+ [SeH]
560
+ [Se]
561
+ [Si@]
562
+ [SiH2]
563
+ [SiH]
564
+ [Si]
565
+ [SrH2]
566
+ [TeH]
567
+ [Te]
568
+ [Xe]
569
+ [Zn+2]
570
+ [Zn-2]
571
+ [Zn]
572
+ [b-]
573
+ [c+]
574
+ [c-]
575
+ [cH-]
576
+ [cH]
577
+ [c]
578
+ [n+]
579
+ [n-]
580
+ [nH]
581
+ [n]
582
+ [o+]
583
+ [s+]
584
+ [se+]
585
+ [se]
586
+ [te+]
587
+ [te]
training_data/half_life_smiles.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d90293170442bc81af2cf9f64656c40bf884733947ca52b2f9308f42220680c3
3
+ size 174323
training_data/hemo-negative.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f83aad41f160deb6401bc0801bddc931488da6e1785749e6f72de6d0f154a37f
3
+ size 109451
training_data/hemo-positive.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96cb24d5a7617f7e211cd48d2b0b424a46affa95716b96058058902068068d27
3
+ size 27840
training_data/nf-negative.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e506e52e101308dd3882ca6bd45833a6e0837f9f240aa85d575c2a41e305b854
3
+ size 21845190
training_data/nf-positive.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78caae183fe840b145275d9486a3f94a963989deb9d55a57995653bf1d497bf2
3
+ size 41326
training_data/sol-negative.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3b6d380024e0483e15e3e219a7cbf23f4d178d823287cef24bc1bd918a817b6
3
+ size 15469064
training_data/sol-positive.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46169267fd0d37d8a063a4e9fc1cdd9b701a9211b1f16515e3d569fcf2d4d859
3
+ size 14056264