| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import DataLoader, Dataset |
| from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score |
|
|
| from tqdm import tqdm |
| from datetime import datetime |
| import pandas as pd |
| import numpy as np |
| import pickle |
| import os |
|
|
| |
| path = "/home/a03-sgoel/MDpLM" |
|
|
| hyperparams = { |
| "batch_size": 1, |
| "learning_rate": 4e-5, |
| "num_epochs": 5, |
| "max_length": 2000, |
| "train_data": path + "/benchmarks/DeepLoc/cell_localization_train_val.csv.csv", |
| "test_data" : path + "/benchmarks/DeepLoc/cell_localization_test.csv", |
| "val_data": "", |
| "embeddings_pkl": "", |
| } |
|
|
| |
| class LocalizationDataset(Dataset): |
| def __init__(self, csv_file, embeddings_pkl, max_length=2000): |
| self.data = pd.read_csv(csv_file) |
| self.max_length = max_length |
|
|
| |
| with open(embeddings_pkl, 'rb') as f: |
| self.embeddings_dict = pickle.load(f) |
| self.data['embedding'] = self.data['Sequence'].map(self.embeddings_dict) |
|
|
| |
| assert len(self.data) == len(self.data['embedding']), "CSV data and embeddings length mismatch" |
|
|
| |
| self.data['label'] = self.data.iloc[:, 1:9].value.tolist() |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| embeddings = torch.tensor(self.data['embedding'][idx], dtype=torch.float) |
| labels = torch.tensor(self.data['label'][idx], dtype=torch.long) |
|
|
| return embeddings, labels |
|
|
| |
| class LocalizationPredictor(nn.Module): |
| def __init__(self, input_dim, num_classes): |
| super(LocalizationPredictor, self).__init__() |
| self.classifier = nn.Linear(input_dim, num_classes) |
|
|
| def forward(self, embeddings): |
| avg_embedding = torch.mean(embeddings, dim=0) |
| logits = self.classifier(avg_embedding) |
| return logits |
|
|
| |
| def train(model, dataloader, optimizer, criterion, device): |
| model.train() |
| total_loss = 0 |
| for embeddings, labels in tqdm(dataloader): |
| embeddings, labels = embeddings.to(device), labels.to(device) |
| optimizer.zero_grad() |
| outputs = model(embeddings) |
| loss = criterion(outputs, labels) |
| loss.backward() |
| optimizer.step() |
| total_loss += loss.item() |
| return total_loss / len(dataloader) |
|
|
| |
| def evaluate(model, dataloader, device): |
| model.eval() |
| preds, true_labels = [], [] |
| with torch.no_grad(): |
| for embeddings, labels in tqdm(dataloader): |
| embeddings, labels = embeddings.to(device), labels.to(device) |
| outputs = model(embeddings) |
| preds.append(outputs.cpu().numpy()) |
| true_labels.append(labels.cpu().numpy()) |
| return preds, true_labels |
|
|
| |
| def calculate_metrics(preds, labels, threshold=0.5): |
| flat_binary_preds, flat_labels = [], [] |
|
|
| for pred, label in zip(preds, labels): |
| flat_binary_preds.extend((pred > threshold).astype(int).flatten()) |
| flat_labels.extend(label.flatten()) |
|
|
| flat_binary_preds = np.array(flat_binary_preds) |
| flat_labels = np.array(flat_labels) |
|
|
| accuracy = accuracy_score(flat_labels, flat_binary_preds) |
| precision = precision_score(flat_labels, flat_binary_preds, average='macro') |
| recall = recall_score(flat_labels, flat_binary_preds, average='macro') |
| f1 = f1_score(flat_labels, flat_binary_preds, average='macro') |
|
|
| return accuracy, precision, recall, f1 |
|
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| train_dataset = LocalizationDataset(hyperparams["train_data"], hyperparams["embeddings_pkl"], max_length=hyperparams["max_length"]) |
| test_dataset = LocalizationDataset(hyperparams["test_data"], hyperparams["embeddings_pkl"], max_length=hyperparams["max_length"]) |
|
|
| train_dataloader = DataLoader(train_dataset, batch_size=hyperparams["batch_size"], shuffle=True) |
| test_dataloader = DataLoader(test_dataset, batch_size=hyperparams["batch_size"], shuffle=False) |
|
|
| model = LocalizationPredictor(input_dim=1280, num_classes=8).to(device) |
| optimizer = optim.Adam(model.parameters(), lr=hyperparams["learning_rate"]) |
| criterion = nn.CrossEntropyLoss() |
|
|
| |
| for epoch in range(hyperparams["num_epochs"]): |
| train_loss = train(model, train_dataloader, optimizer, criterion, device) |
| print(f"EPOCH {epoch+1}/{hyperparams['num_epochs']}") |
| print(f"TRAIN LOSS: {train_loss:.4f}") |
| print("\n") |
|
|
| |
| print("Test set") |
| test_preds, test_labels = evaluate(model, test_dataloader, device) |
| test_metrics = calculate_metrics(test_preds, test_labels) |
| print("TEST METRICS:") |
| print(f"Accuracy: {test_metrics[0]:.4f}") |
| print(f"Precision: {test_metrics[1]:.4f}") |
| print(f"Recall: {test_metrics[2]:.4f}") |
| print(f"F1 Score: {test_metrics[3]:.4f}") |