| import numpy as np |
| import torch |
| import xgboost as xgb |
| from transformers import EsmModel, EsmTokenizer |
| import torch.nn as nn |
| import pdb |
|
|
| class PeptideCNN(nn.Module): |
| def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate): |
| super().__init__() |
| self.conv1 = nn.Conv1d(input_dim, hidden_dims[0], kernel_size=3, padding=1) |
| self.conv2 = nn.Conv1d(hidden_dims[0], hidden_dims[1], kernel_size=5, padding=1) |
| self.fc = nn.Linear(hidden_dims[1], output_dim) |
| self.dropout = nn.Dropout(dropout_rate) |
| self.predictor = nn.Linear(output_dim, 1) |
|
|
| self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").to(device) |
| self.esm_model.eval() |
|
|
| def forward(self, input_ids, attention_mask=None, return_features=False): |
| with torch.no_grad(): |
| x = self.esm_model(input_ids, attention_mask).last_hidden_state |
| |
| |
| x = x.permute(0, 2, 1) |
| x = nn.functional.relu(self.conv1(x)) |
| x = self.dropout(x) |
| x = nn.functional.relu(self.conv2(x)) |
| x = self.dropout(x) |
| x = x.permute(0, 2, 1) |
| |
| |
| x = x.mean(dim=1) |
| |
| features = self.fc(x) |
| if return_features: |
| return features |
| return self.predictor(features) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| input_dim = 1280 |
| hidden_dims = [input_dim // 2, input_dim // 4] |
| output_dim = input_dim // 8 |
| dropout_rate = 0.3 |
|
|
| nn_model = PeptideCNN(input_dim, hidden_dims, output_dim, dropout_rate).to(device) |
| nn_model.load_state_dict(torch.load('/scratch/pranamlab/tong/checkpoints/MOG-DFM/classifier_ckpt/best_model_half_life.pth')) |
| nn_model.eval() |
|
|
| def predict(inputs): |
| with torch.no_grad(): |
| prediction = nn_model(**inputs, return_features=False) |
|
|
| return prediction.item() |
|
|
| if __name__ == '__main__': |
| sequence = 'RGLSDGFLKLKMGISGSLGC' |
| |
| tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device) |
|
|
| prediction = predict(inputs) |
| print(prediction) |
| print(f"Predicted half life of {sequence} is {(10**prediction):.4f} h") |
|
|