| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from torch.utils.data import DataLoader, Dataset |
| | from transformers import AutoModel, AutoTokenizer |
| | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score |
| |
|
| | from tqdm import tqdm |
| | from datetime import datetime |
| | import pandas as pd |
| | import numpy as np |
| | import pickle |
| | import os |
| |
|
| | |
| | path = "/workspace/sg666/MDpLM" |
| |
|
| | hyperparams = { |
| | "batch_size": 1, |
| | "learning_rate": 5e-4, |
| | "num_epochs": 5, |
| | "esm_model_path": "facebook/esm2_t33_650M_UR50D", |
| | 'mlm_model_path': path + "/benchmarks/MLM/model_ckpts/best_model_epoch", |
| | "mdlm_model_path": path + "/checkpoints/membrane_automodel/epochs30_lr3e-4_bsz16_gradclip1_beta-one0.9_beta-two0.999_bf16_all-params", |
| | "train_data": path + "/benchmarks/Supervised/Localization/true_deeploc2.0_cell-local_train-val.csv", |
| | "test_data" : path + "/benchmarks/Supervised/Localization/true_deeploc2.0_cell-local_test.csv", |
| | } |
| |
|
| | |
| | def load_models(esm_model_path, mlm_model_path, mdlm_model_path): |
| | esm_tokenizer = AutoTokenizer.from_pretrained(esm_model_path) |
| | esm_model = AutoModel.from_pretrained(esm_model_path).to(device) |
| | mlm_model = AutoModel.from_pretrained(mlm_model_path).to(device) |
| | mdlm_model = AutoModel.from_pretrained(mdlm_model_path).to(device) |
| |
|
| | return esm_tokenizer, esm_model, mlm_model, mdlm_model |
| |
|
| | def get_latents(embedding_type, tokenizer, esm_model, mlm_model, mdlm_model, sequence, device): |
| | if embedding_type == "esm": |
| | inputs = tokenizer(sequence, return_tensors='pt').to(device) |
| | with torch.no_grad(): |
| | embeddings = esm_model(**inputs).last_hidden_state.squeeze(0) |
| |
|
| | elif embedding_type == "mlm": |
| | inputs = tokenizer(sequence, return_tensors='pt')['input_ids'].to(device) |
| | with torch.no_grad(): |
| | embeddings = mlm_model(inputs).last_hidden_state.squeeze(0) |
| |
|
| | elif embedding_type == "mdlm": |
| | inputs = tokenizer(sequence, return_tensors='pt')['input_ids'].to(device) |
| | with torch.no_grad(): |
| | embeddings = mdlm_model(inputs).last_hidden_state.squeeze(0) |
| | |
| | return embeddings |
| |
|
| |
|
| | |
| | class LocalizationDataset(Dataset): |
| | def __init__(self, embedding_type, csv_file, esm_model_path, mlm_model_path, mdlm_model_path, device): |
| | self.data = pd.read_csv(csv_file) |
| | self.data = self.data[self.data['Sequence'].apply(len) < 1024].reset_index(drop=True) |
| | self.embedding_type = embedding_type |
| | self.tokenizer, self.esm_model, self.mlm_model, self.mdlm_model = load_models(esm_model_path, mlm_model_path, mdlm_model_path) |
| | self.device = device |
| |
|
| | def __len__(self): |
| | return len(self.data) |
| |
|
| | def __getitem__(self, idx): |
| | sequence = self.data.iloc[idx]['Sequence'] |
| | embeddings = get_latents(self.embedding_type, self.tokenizer, self.mlm_model, self.esm_model, self.mdlm_model, |
| | sequence, self.device) |
| |
|
| | label = 0 if self.data.iloc[idx]['Cell membrane'] == 0 else 1 |
| | labels = torch.tensor(label, dtype=torch.float32).view(1,1).squeeze(-1) |
| |
|
| | return embeddings, labels |
| |
|
| | |
| | class LocalizationPredictor(nn.Module): |
| | def __init__(self, input_dim): |
| | super(LocalizationPredictor, self).__init__() |
| | self.classifier = nn.Sequential( |
| | nn.Linear(input_dim, 640), |
| | nn.ReLU(), |
| | nn.Linear(640, 1) |
| | ) |
| |
|
| | def forward(self, embeddings): |
| | logits = self.classifier(embeddings) |
| | logits = torch.mean(logits, dim=1) |
| | probs = torch.nn.functional.softmax(logits) |
| | return probs |
| |
|
| | |
| | def train(model, dataloader, optimizer, criterion, device): |
| | model.train() |
| | total_loss = 0 |
| | for embeddings, labels in tqdm(dataloader): |
| | embeddings, labels = embeddings.to(device), labels.to(device) |
| | optimizer.zero_grad() |
| | outputs = model(embeddings) |
| | loss = criterion(outputs, labels) |
| | loss.backward() |
| | optimizer.step() |
| | total_loss += loss.item() |
| | return total_loss / len(dataloader) |
| |
|
| | |
| | def evaluate(model, dataloader, device): |
| | model.eval() |
| | preds, true_labels = [], [] |
| | with torch.no_grad(): |
| | for embeddings, labels in tqdm(dataloader): |
| | embeddings, labels = embeddings.to(device), labels.to(device) |
| | outputs = model(embeddings) |
| | preds.append(outputs.cpu().numpy()) |
| | true_labels.append(labels.cpu().numpy()) |
| | return preds, true_labels |
| |
|
| | |
| | def calculate_metrics(preds, labels, threshold=0.5): |
| | all_metrics = [] |
| | for pred, label in zip(preds, labels): |
| | pred = (pred > threshold).astype(int) |
| |
|
| | accuracy = accuracy_score(label, pred) |
| | precision = precision_score(label, pred, average='macro') |
| | recall = recall_score(label, pred, average='macro') |
| | f1_macro = f1_score(label, pred, average='macro') |
| | f1_micro = f1_score(label, pred, average='micro') |
| | |
| | all_metrics.append([accuracy, precision, recall, f1_macro, f1_micro]) |
| | |
| | avg_metrics = np.mean(all_metrics, axis=0) |
| | print(avg_metrics) |
| | return avg_metrics |
| |
|
| |
|
| |
|
| | if __name__ == "__main__": |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | for embedding_type in ['mdlm', 'esm', 'mlm']: |
| | |
| | train_dataset = LocalizationDataset(embedding_type, |
| | hyperparams['train_data'], |
| | hyperparams['esm_model_path'], |
| | hyperparams['mlm_model_path'], |
| | hyperparams['mdlm_model_path'], |
| | device) |
| | test_dataset = LocalizationDataset(embedding_type, |
| | hyperparams['test_data'], |
| | hyperparams['esm_model_path'], |
| | hyperparams['mlm_model_path'], |
| | hyperparams['mdlm_model_path'], |
| | device) |
| |
|
| | |
| | train_dataloader = DataLoader(train_dataset, batch_size=hyperparams["batch_size"], shuffle=True) |
| | test_dataloader = DataLoader(test_dataset, batch_size=hyperparams["batch_size"], shuffle=False) |
| |
|
| | |
| | input_dim=640 if embedding_type=="mdlm" else 1280 |
| | model = LocalizationPredictor(input_dim=input_dim).to(device) |
| | optimizer = optim.Adam(model.parameters(), lr=hyperparams["learning_rate"]) |
| | criterion = nn.BCELoss() |
| |
|
| | |
| | base_checkpoint_dir = f"{path}/benchmarks/Supervised/Localization/model_checkpoints/{embedding_type}" |
| | |
| | hyperparam_str = f"batch_{hyperparams['batch_size']}_lr_{hyperparams['learning_rate']}_epochs_{hyperparams['num_epochs']}" |
| | model_checkpoint_dir = os.path.join(base_checkpoint_dir, hyperparam_str) |
| | os.makedirs(model_checkpoint_dir, exist_ok=True) |
| |
|
| |
|
| | |
| | for epoch in range(hyperparams["num_epochs"]): |
| | |
| | train_loss = train(model, train_dataloader, optimizer, criterion, device) |
| | print(f"EPOCH {epoch+1}/{hyperparams['num_epochs']}") |
| | print(f"TRAIN LOSS: {train_loss:.4f}") |
| | print("\n") |
| |
|
| | |
| | checkpoint_path = os.path.join(model_checkpoint_dir, f"epoch{epoch + 1}.pth") |
| | torch.save({ |
| | 'epoch': epoch + 1, |
| | 'model_state_dict': model.state_dict(), |
| | 'optimizer_state_dict': optimizer.state_dict(), |
| | 'loss': train_loss, |
| | }, checkpoint_path) |
| | print(f"Checkpoint saved at {checkpoint_path}\n") |
| |
|
| | |
| | if epoch == 0: |
| | hyperparams_file = os.path.join(model_checkpoint_dir, "hyperparams.txt") |
| | with open(hyperparams_file, 'w') as f: |
| | for key, value in hyperparams.items(): |
| | f.write(f"{key}: {value}\n") |
| | print(f"Hyperparameters saved at {hyperparams_file}\n") |
| |
|
| | |
| | print("Test set") |
| | test_preds, test_labels = evaluate(model, test_dataloader, device) |
| | test_metrics = calculate_metrics(test_preds, test_labels) |
| | print(test_metrics) |
| | print("TEST METRICS:") |
| | print(f"Accuracy: {test_metrics[0]:.4f}") |
| | print(f"Precision: {test_metrics[1]:.4f}") |
| | print(f"Recall: {test_metrics[2]:.4f}") |
| | print(f"F1 Macro Score: {test_metrics[3]:.4f}") |
| | print(f"F1 Micro Score: {test_metrics[4]:.4f}") |
| |
|
| | |
| | test_results_file = os.path.join(model_checkpoint_dir, "test_results.txt") |
| | with open(test_results_file, 'w') as f: |
| | f.write("TEST METRICS:\n") |
| | f.write(f"Accuracy: {test_metrics[0]:.4f}\n") |
| | f.write(f"Precision: {test_metrics[1]:.4f}\n") |
| | f.write(f"Recall: {test_metrics[2]:.4f}\n") |
| | f.write(f"F1 Macro Score: {test_metrics[3]:.4f}\n") |
| | f.write(f"F1 Micro: {test_metrics[4]:.4f}\n") |
| | print(f"Test results saved at {test_results_file}\n") |