AlienChen commited on
Commit
f52b5d3
·
verified ·
1 Parent(s): e2e01f3

Update classifier_code/half_life.py

Browse files
Files changed (1) hide show
  1. classifier_code/half_life.py +205 -60
classifier_code/half_life.py CHANGED
@@ -1,65 +1,210 @@
 
 
 
1
  import numpy as np
2
  import torch
3
- import xgboost as xgb
4
- from transformers import EsmModel, EsmTokenizer
5
  import torch.nn as nn
6
- import pdb
 
7
 
8
- class PeptideCNN(nn.Module):
9
- def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate):
 
 
 
10
  super().__init__()
11
- self.conv1 = nn.Conv1d(input_dim, hidden_dims[0], kernel_size=3, padding=1)
12
- self.conv2 = nn.Conv1d(hidden_dims[0], hidden_dims[1], kernel_size=5, padding=1)
13
- self.fc = nn.Linear(hidden_dims[1], output_dim)
14
- self.dropout = nn.Dropout(dropout_rate)
15
- self.predictor = nn.Linear(output_dim, 1) # For regression/classification
16
-
17
- self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
18
- self.esm_model.eval()
19
-
20
- def forward(self, input_ids, attention_mask=None, return_features=False):
21
- with torch.no_grad():
22
- x = self.esm_model(input_ids, attention_mask).last_hidden_state
23
- # pdb.set_trace()
24
- # x shape: (B, L, input_dim)
25
- x = x.permute(0, 2, 1) # Reshape to (B, input_dim, L) for Conv1d
26
- x = nn.functional.relu(self.conv1(x))
27
- x = self.dropout(x)
28
- x = nn.functional.relu(self.conv2(x))
29
- x = self.dropout(x)
30
- x = x.permute(0, 2, 1) # Reshape back to (B, L, hidden_dims[1])
31
-
32
- # Global average pooling over the sequence dimension (L)
33
- x = x.mean(dim=1) # Shape: (B, hidden_dims[1])
34
-
35
- features = self.fc(x) # features shape: (B, output_dim)
36
- if return_features:
37
- return features
38
- return self.predictor(features) # Output shape: (B, 1)
39
-
40
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
-
42
- input_dim = 1280
43
- hidden_dims = [input_dim // 2, input_dim // 4]
44
- output_dim = input_dim // 8
45
- dropout_rate = 0.3
46
-
47
- nn_model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device)
48
- nn_model.load_state_dict(torch.load('/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_half_life.pth'))
49
- nn_model.eval()
50
-
51
- def predict(inputs):
52
- with torch.no_grad():
53
- prediction = nn_model(**inputs, return_features=False)
54
-
55
- return prediction.item()
56
-
57
- if __name__ == '__main__':
58
- sequence = 'RGLSDGFLKLKMGISGSLGC'
59
-
60
- tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
61
- inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
62
-
63
- prediction = predict(inputs)
64
- print(prediction)
65
- print(f"Predicted half life of {sequence} is {(10**prediction):.4f} h")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Optional, Union
3
+
4
  import numpy as np
5
  import torch
 
 
6
  import torch.nn as nn
7
+ from transformers import EsmModel, AutoTokenizer
8
+
9
 
10
+ # -----------------------------
11
+ # Model definition (must match training)
12
+ # -----------------------------
13
+ class TransformerRegressor(nn.Module):
14
+ def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1):
15
  super().__init__()
16
+ self.proj = nn.Linear(in_dim, d_model)
17
+ enc_layer = nn.TransformerEncoderLayer(
18
+ d_model=d_model,
19
+ nhead=nhead,
20
+ dim_feedforward=ff,
21
+ dropout=dropout,
22
+ batch_first=True,
23
+ activation="gelu",
24
+ )
25
+ self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
26
+ self.head = nn.Linear(d_model, 1)
27
+
28
+ def forward(self, X, M):
29
+ # M: True = keep token, False = padding
30
+ pad_mask = ~M
31
+ Z = self.proj(X)
32
+ Z = self.enc(Z, src_key_padding_mask=pad_mask)
33
+ Mf = M.unsqueeze(-1).float()
34
+ denom = Mf.sum(dim=1).clamp(min=1.0)
35
+ pooled = (Z * Mf).sum(dim=1) / denom
36
+ return self.head(pooled).squeeze(-1)
37
+
38
+
39
+ def build_model(model_name: str, in_dim: int, params: dict) -> nn.Module:
40
+ if model_name != "transformer":
41
+ raise ValueError(f"This inference file currently supports model_name='transformer', got: {model_name}")
42
+ return TransformerRegressor(
43
+ in_dim=in_dim,
44
+ d_model=384,
45
+ nhead=4,
46
+ layers=1,
47
+ ff=512,
48
+ dropout=0.1521676463658988,
49
+ )
50
+
51
+
52
+ def _clean_state_dict(state_dict: dict) -> dict:
53
+ cleaned = {}
54
+ for k, v in state_dict.items():
55
+ if k.startswith("module."):
56
+ k = k[len("module.") :]
57
+ if k.startswith("model."):
58
+ k = k[len("model.") :]
59
+ cleaned[k] = v
60
+ return cleaned
61
+
62
+
63
+ # -----------------------------
64
+ # Predictor
65
+ # -----------------------------
66
+ class HalflifeTransformer:
67
+
68
+ def __init__(
69
+ self,
70
+ ckpt_path: str = "/scratch/pranamlab/tong/PeptiVerse/src/halflife/FINETUNED_TRANSFORMER_DIR/final_model.pt",
71
+ esm_name: str = "facebook/esm2_t33_650M_UR50D",
72
+ device: Optional[str] = None,
73
+ model_name: str = "transformer",
74
+ ):
75
+ self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
76
+
77
+ ckpt = torch.load(ckpt_path, map_location="cpu")
78
+ if not isinstance(ckpt, dict) or "state_dict" not in ckpt:
79
+ raise ValueError(f"Checkpoint at {ckpt_path} is not the expected dict with a 'state_dict' key.")
80
+
81
+ self.best_params = ckpt.get("best_params", {})
82
+ self.in_dim = int(ckpt.get("in_dim"))
83
+ self.target_col = ckpt.get("target_col", "label") # 'log_label' or 'label'
84
+ self.model_name = model_name
85
+
86
+ # --- build + load regressor ---
87
+ self.regressor = build_model(model_name=self.model_name, in_dim=self.in_dim, params=self.best_params)
88
+ self.regressor.load_state_dict(_clean_state_dict(ckpt["state_dict"]), strict=True)
89
+ self.regressor.to(self.device)
90
+ self.regressor.eval()
91
+
92
+ # --- ESM2 embedding model ---
93
+ self.emb_model = EsmModel.from_pretrained(esm_name).to(self.device)
94
+ self.emb_model.eval()
95
+ self.tokenizer = AutoTokenizer.from_pretrained(esm_name)
96
+
97
+ # sanity: ESM2 hidden size should match training in_dim
98
+ esm_hidden = int(self.emb_model.config.hidden_size)
99
+ if esm_hidden != self.in_dim:
100
+ raise ValueError(
101
+ f"Mismatch: ESM hidden_size={esm_hidden}, but checkpoint in_dim={self.in_dim}.\n"
102
+ f"Did you train on a different embedding model/dimension than {esm_name}?"
103
+ )
104
+
105
+ @torch.no_grad()
106
+ def _embed_unpooled_batch(
107
+ self,
108
+ sequences: List[str],
109
+ max_length: int = 1024,
110
+ ):
111
+ """
112
+ Returns:
113
+ X: (B, Lmax, H) float32
114
+ M: (B, Lmax) bool, True for real residues, False for padding
115
+ """
116
+ if len(sequences) == 0:
117
+ X = torch.zeros((0, 1, self.in_dim), dtype=torch.float32, device=self.device)
118
+ M = torch.zeros((0, 1), dtype=torch.bool, device=self.device)
119
+ return X, M
120
+
121
+ toks = self.tokenizer(
122
+ sequences,
123
+ return_tensors="pt",
124
+ padding=True,
125
+ truncation=True,
126
+ max_length=max_length,
127
+ add_special_tokens=True,
128
+ )
129
+ toks = {k: v.to(self.device) for k, v in toks.items()}
130
+
131
+ out = self.emb_model(**toks)
132
+ hs = out.last_hidden_state # (B, T, H)
133
+ attn = toks["attention_mask"].bool() # (B, T)
134
+
135
+ per_seq = []
136
+ lengths = []
137
+
138
+ for i in range(hs.shape[0]):
139
+ valid_idx = torch.nonzero(attn[i], as_tuple=False).squeeze(-1)
140
+ # ESM typically has <cls> ... tokens ... <eos> among valid positions
141
+ if valid_idx.numel() <= 2:
142
+ emb = hs.new_zeros((0, hs.shape[-1]))
143
+ else:
144
+ core_idx = valid_idx[1:-1] # drop CLS and EOS
145
+ emb = hs[i, core_idx, :] # (L, H)
146
+ per_seq.append(emb)
147
+ lengths.append(int(emb.shape[0]))
148
+
149
+ Lmax = max(lengths) if lengths else 0
150
+ H = hs.shape[-1]
151
+ X = hs.new_zeros((len(sequences), Lmax, H), dtype=torch.float32)
152
+ M = torch.zeros((len(sequences), Lmax), dtype=torch.bool, device=self.device)
153
+
154
+ for i, emb in enumerate(per_seq):
155
+ L = emb.shape[0]
156
+ if L == 0:
157
+ continue
158
+ X[i, :L, :] = emb.to(torch.float32)
159
+ M[i, :L] = True
160
+
161
+ return X, M
162
+
163
+ @torch.no_grad()
164
+ def predict_raw(
165
+ self,
166
+ input_seqs: List[str],
167
+ batch_size: int = 16,
168
+ ) -> np.ndarray:
169
+ """
170
+ Returns the regressor output in the same space as training target_col:
171
+ - if trained on log_label -> returns log1p(hours)
172
+ - if trained on label -> returns hours (or whatever label scale was)
173
+ """
174
+ if len(input_seqs) == 0:
175
+ return np.array([], dtype=np.float32)
176
+
177
+ preds = []
178
+ for i in range(0, len(input_seqs), batch_size):
179
+ batch = input_seqs[i : i + batch_size]
180
+ X, M = self._embed_unpooled_batch(batch)
181
+ yhat = self.regressor(X, M) # (B,)
182
+ preds.append(yhat.detach().cpu().numpy().astype(np.float32))
183
+
184
+ return np.concatenate(preds, axis=0)
185
+
186
+ def predict_hours(self, input_seqs: List[str], batch_size: int = 16) -> np.ndarray:
187
+ """
188
+ If your model was trained on log_label, convert back to hours via expm1.
189
+ Otherwise returns raw predictions.
190
+ """
191
+ raw = self.predict_raw(input_seqs, batch_size=batch_size)
192
+ if self.target_col == "log_label":
193
+ return np.expm1(raw).astype(np.float32)
194
+ return raw.astype(np.float32)
195
+
196
+ def __call__(self, input_seqs: List[str], batch_size: int = 16) -> np.ndarray:
197
+ return self.predict_hours(input_seqs, batch_size=batch_size)
198
+
199
+
200
+ def unittest():
201
+ ckpt_path = "../classifier_ckpt/wt_halflife.pt"
202
+
203
+ halflife = HalflifeTransformer(ckpt_path=ckpt_path)
204
+ seqs = ["MWQRPSSWIEGRFPHSDAVFTDQYTRLRKQLAAKKYLQSLKQKRY"]
205
+ pred = halflife(seqs)
206
+ print("pred_hours:", pred)
207
+
208
+
209
+ if __name__ == "__main__":
210
+ unittest()