AlienChen commited on
Commit
e2e01f3
·
verified ·
1 Parent(s): 4a4d8d9

Update classifier_code/nonfouling_wt.py

Browse files
Files changed (1) hide show
  1. classifier_code/nonfouling_wt.py +65 -85
classifier_code/nonfouling_wt.py CHANGED
@@ -1,98 +1,78 @@
1
- import sys
2
- import os
3
  import xgboost as xgb
4
  import torch
5
- import numpy as np
6
- import warnings
7
- import numpy as np
8
- from rdkit import Chem, rdBase, DataStructs
9
- from transformers import AutoTokenizer, EsmModel
10
 
11
- rdBase.DisableLog('rdApp.error')
12
- warnings.filterwarnings("ignore", category=DeprecationWarning)
13
- warnings.filterwarnings("ignore", category=UserWarning)
14
- warnings.filterwarnings("ignore", category=FutureWarning)
 
 
 
15
 
16
- class Nonfouling:
17
- def __init__(self):
18
- # change model path
19
- self.predictor = xgb.Booster(model_file='../classifier_ckpt/best_model_nonfouling.json')
20
-
21
- # Load ESM model and tokenizer
22
- self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
23
- self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
24
- self.model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- def generate_embeddings(self, sequences):
27
  """Generate ESM embeddings for protein sequences"""
28
- embeddings = []
29
-
30
- # Process sequences in batches to avoid memory issues
31
- batch_size = 8
32
- for i in range(0, len(sequences), batch_size):
33
- batch_sequences = sequences[i:i + batch_size]
34
-
35
- inputs = self.tokenizer(
36
- batch_sequences,
37
- padding=True,
38
- truncation=True,
39
- return_tensors="pt"
40
- )
41
-
42
- if torch.cuda.is_available():
43
- inputs = {k: v.cuda() for k, v in inputs.items()}
44
- self.model = self.model.cuda()
45
-
46
- # Generate embeddings
47
- with torch.no_grad():
48
- outputs = self.model(**inputs)
49
-
50
- # Get last hidden states
51
- last_hidden_states = outputs.last_hidden_state
52
-
53
- # Compute mean pooling (excluding padding tokens)
54
- attention_mask = inputs['attention_mask'].unsqueeze(-1)
55
- masked_hidden_states = last_hidden_states * attention_mask
56
- sum_hidden_states = masked_hidden_states.sum(dim=1)
57
- seq_lengths = attention_mask.sum(dim=1)
58
- batch_embeddings = sum_hidden_states / seq_lengths
59
-
60
- batch_embeddings = batch_embeddings.cpu().numpy()
61
- embeddings.append(batch_embeddings)
62
-
63
- if embeddings:
64
- return np.vstack(embeddings)
65
- else:
66
- return np.array([])
67
-
68
- def get_scores(self, input_seqs: list):
69
- scores = np.zeros(len(input_seqs))
70
- features = self.generate_embeddings(input_seqs)
71
 
72
- if len(features) == 0:
73
- return scores
74
-
75
- features = np.nan_to_num(features, nan=0.)
76
- features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max)
77
-
78
- features = xgb.DMatrix(features)
79
-
80
- scores = self.predictor.predict(features)
81
- return scores
82
 
83
- def __call__(self, input_seqs: list):
84
- scores = self.get_scores(input_seqs)
85
- return scores
 
 
 
 
86
 
 
 
 
 
 
87
  def unittest():
88
- nonfouling = Nonfouling()
89
- sequences = [
90
- "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG",
91
- "MSEGIRQAFVLAKSIWPARVARFTVDNRIRSLVKTYEAIKVDPYNPAFLEVLD"
92
- ]
93
-
94
- scores = nonfouling(input_seqs=sequences)
95
  print(scores)
96
-
 
97
  if __name__ == '__main__':
98
  unittest()
 
1
+ import numpy as np
 
2
  import xgboost as xgb
3
  import torch
4
+ from transformers import EsmModel, AutoTokenizer
5
+ import torch.nn as nn
6
+ import pdb
 
 
7
 
8
+ # ======================== MLP =========================================
9
+ # Still need mean pooling along lengths
10
+ class MaskedMeanPool(nn.Module):
11
+ def forward(self, X, M): # X: (B,L,H), M: (B,L)
12
+ Mf = M.unsqueeze(-1).float()
13
+ denom = Mf.sum(dim=1).clamp(min=1.0)
14
+ return (X * Mf).sum(dim=1) / denom # (B,H)
15
 
16
+ class MLPClassifier(nn.Module):
17
+ def __init__(self, in_dim, hidden=512, dropout=0.1):
18
+ super().__init__()
19
+ self.pool = MaskedMeanPool()
20
+ self.net = nn.Sequential(
21
+ nn.Linear(in_dim, hidden),
22
+ nn.GELU(),
23
+ nn.Dropout(dropout),
24
+ nn.Linear(hidden, 1),
25
+ )
26
+ def forward(self, X, M):
27
+ z = self.pool(X, M)
28
+ return self.net(z).squeeze(-1) # logits
29
+ # ======================== MLP =========================================
30
+
31
+
32
+ class NonfoulingModel:
33
+ def __init__(self, device):
34
+ ckpt = torch.load('../classifier_ckpt/wt_nonfouling.pt', weights_only=False, map_location=device)
35
+ best_params = ckpt["best_params"]
36
+ self.predictor = MLPClassifier(in_dim=1280, hidden=int(best_params["hidden"]), dropout=float(best_params.get("dropout", 0.1)))
37
+ self.predictor.load_state_dict(ckpt["state_dict"])
38
+ self.predictor = self.predictor.to(device)
39
+ self.predictor.eval()
40
+
41
+ self.model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device)
42
+ # self.model.eval()
43
+
44
+ self.device = device
45
 
46
+ def generate_embeddings(self, input_ids, attention_mask):
47
  """Generate ESM embeddings for protein sequences"""
48
+ with torch.no_grad():
49
+ embeddings = self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ return embeddings
 
 
 
 
 
 
 
 
 
52
 
53
+ def get_scores(self, input_ids, attention_mask):
54
+ features = self.generate_embeddings(input_ids, attention_mask)
55
+
56
+ keep = (input_ids != 0) & (input_ids != 1) & (input_ids != 2)
57
+ attention_mask[keep==False] = 0
58
+ scores = self.predictor(features, attention_mask)
59
+ return scores.detach().cpu().numpy()
60
 
61
+ def __call__(self, input_ids, attention_mask):
62
+ scores = self.get_scores(input_ids, attention_mask)
63
+ return 1.0 / (1.0 + np.exp(-scores))
64
+
65
+
66
  def unittest():
67
+ device = 'cuda:0'
68
+ nf = NonfoulingModel(device=device)
69
+ seq = ["HAIYPRH", "HAEGTFTSDVSSYLEGQAAKEFIAWLVKGR"]
70
+
71
+ tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t33_650M_UR50D')
72
+ seq_tokens = tokenizer(seq, padding=True, return_tensors='pt').to(device)
73
+ scores = nf(**seq_tokens)
74
  print(scores)
75
+
76
+
77
  if __name__ == '__main__':
78
  unittest()