| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from torch.utils.data import DataLoader, Dataset |
| | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score |
| |
|
| | from tqdm import tqdm |
| | from datetime import datetime |
| | import pandas as pd |
| | import numpy as np |
| | import pickle |
| | import os |
| |
|
| | |
| | path = "/home/a03-sgoel/MDpLM" |
| |
|
| | hyperparams = { |
| | "batch_size": 1, |
| | "learning_rate": 4e-5, |
| | "num_epochs": 5, |
| | "max_length": 2000, |
| | "train_data": path + "/benchmarks/DeepLoc/cell_localization_train_val.csv.csv", |
| | "test_data" : path + "/benchmarks/DeepLoc/cell_localization_test.csv", |
| | "val_data": "", |
| | "embeddings_pkl": "", |
| | } |
| |
|
| | |
| | class LocalizationDataset(Dataset): |
| | def __init__(self, csv_file, embeddings_pkl, max_length=2000): |
| | self.data = pd.read_csv(csv_file) |
| | self.max_length = max_length |
| |
|
| | |
| | with open(embeddings_pkl, 'rb') as f: |
| | self.embeddings_dict = pickle.load(f) |
| | self.data['embedding'] = self.data['Sequence'].map(self.embeddings_dict) |
| |
|
| | |
| | assert len(self.data) == len(self.data['embedding']), "CSV data and embeddings length mismatch" |
| |
|
| | |
| | self.data['label'] = self.data.iloc[:, 1:9].value.tolist() |
| |
|
| | def __len__(self): |
| | return len(self.data) |
| |
|
| | def __getitem__(self, idx): |
| | embeddings = torch.tensor(self.data['embedding'][idx], dtype=torch.float) |
| | labels = torch.tensor(self.data['label'][idx], dtype=torch.long) |
| |
|
| | return embeddings, labels |
| |
|
| | |
| | class LocalizationPredictor(nn.Module): |
| | def __init__(self, input_dim, num_classes): |
| | super(LocalizationPredictor, self).__init__() |
| | self.classifier = nn.Linear(input_dim, num_classes) |
| |
|
| | def forward(self, embeddings): |
| | avg_embedding = torch.mean(embeddings, dim=0) |
| | logits = self.classifier(avg_embedding) |
| | return logits |
| |
|
| | |
| | def train(model, dataloader, optimizer, criterion, device): |
| | model.train() |
| | total_loss = 0 |
| | for embeddings, labels in tqdm(dataloader): |
| | embeddings, labels = embeddings.to(device), labels.to(device) |
| | optimizer.zero_grad() |
| | outputs = model(embeddings) |
| | loss = criterion(outputs, labels) |
| | loss.backward() |
| | optimizer.step() |
| | total_loss += loss.item() |
| | return total_loss / len(dataloader) |
| |
|
| | |
| | def evaluate(model, dataloader, device): |
| | model.eval() |
| | preds, true_labels = [], [] |
| | with torch.no_grad(): |
| | for embeddings, labels in tqdm(dataloader): |
| | embeddings, labels = embeddings.to(device), labels.to(device) |
| | outputs = model(embeddings) |
| | preds.append(outputs.cpu().numpy()) |
| | true_labels.append(labels.cpu().numpy()) |
| | return preds, true_labels |
| |
|
| | |
| | def calculate_metrics(preds, labels, threshold=0.5): |
| | flat_binary_preds, flat_labels = [], [] |
| |
|
| | for pred, label in zip(preds, labels): |
| | flat_binary_preds.extend((pred > threshold).astype(int).flatten()) |
| | flat_labels.extend(label.flatten()) |
| |
|
| | flat_binary_preds = np.array(flat_binary_preds) |
| | flat_labels = np.array(flat_labels) |
| |
|
| | accuracy = accuracy_score(flat_labels, flat_binary_preds) |
| | precision = precision_score(flat_labels, flat_binary_preds, average='macro') |
| | recall = recall_score(flat_labels, flat_binary_preds, average='macro') |
| | f1 = f1_score(flat_labels, flat_binary_preds, average='macro') |
| |
|
| | return accuracy, precision, recall, f1 |
| |
|
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | train_dataset = LocalizationDataset(hyperparams["train_data"], hyperparams["embeddings_pkl"], max_length=hyperparams["max_length"]) |
| | test_dataset = LocalizationDataset(hyperparams["test_data"], hyperparams["embeddings_pkl"], max_length=hyperparams["max_length"]) |
| |
|
| | train_dataloader = DataLoader(train_dataset, batch_size=hyperparams["batch_size"], shuffle=True) |
| | test_dataloader = DataLoader(test_dataset, batch_size=hyperparams["batch_size"], shuffle=False) |
| |
|
| | model = LocalizationPredictor(input_dim=1280, num_classes=8).to(device) |
| | optimizer = optim.Adam(model.parameters(), lr=hyperparams["learning_rate"]) |
| | criterion = nn.CrossEntropyLoss() |
| |
|
| | |
| | for epoch in range(hyperparams["num_epochs"]): |
| | train_loss = train(model, train_dataloader, optimizer, criterion, device) |
| | print(f"EPOCH {epoch+1}/{hyperparams['num_epochs']}") |
| | print(f"TRAIN LOSS: {train_loss:.4f}") |
| | print("\n") |
| |
|
| | |
| | print("Test set") |
| | test_preds, test_labels = evaluate(model, test_dataloader, device) |
| | test_metrics = calculate_metrics(test_preds, test_labels) |
| | print("TEST METRICS:") |
| | print(f"Accuracy: {test_metrics[0]:.4f}") |
| | print(f"Precision: {test_metrics[1]:.4f}") |
| | print(f"Recall: {test_metrics[2]:.4f}") |
| | print(f"F1 Score: {test_metrics[3]:.4f}") |