| |
| from batch_sampler import BatchSampler |
| from image_dataset import ImageDataset |
| from net import Net, ResNetModel, EfficientNetModel, EfficientNetModel_b7 |
| from train_test import train_model, test_model |
| from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay,classification_report |
| from visualise_performance_metrics import create_confusion_matrix, ROC_multiclass |
| from image_dataset_BINARY import ImageDatasetBINARY |
| from net_BINARY import Net_BINARY |
| |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torchsummary import summary |
|
|
| |
| import matplotlib.pyplot as plt |
| from matplotlib.pyplot import figure |
| import os |
| import argparse |
| import plotext |
| from datetime import datetime |
| from pathlib import Path |
| from typing import List |
| from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_curve |
| import numpy as np |
|
|
| def main(args: argparse.Namespace, activeloop: bool = True) -> None: |
| |
| |
| |
| train_dataset = ImageDataset(Path('dc1/data/X_train.npy'), Path('dc1/data/Y_train.npy')) |
| test_dataset = ImageDataset(Path('dc1/data/X_test.npy'), Path('dc1/data/Y_test.npy')) |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| model = EfficientNetModel(n_classes=6) |
| |
| |
| |
| |
| optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.1) |
| optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.1) |
| loss_function = nn.CrossEntropyLoss() |
|
|
| |
| n_epochs = args.nb_epochs |
| batch_size = args.batch_size |
|
|
| |
| |
| |
| |
| DEBUG = False |
|
|
| |
| |
| |
|
|
| |
| if torch.cuda.is_available() and not DEBUG: |
| print("@@@ CUDA device found, enabling CUDA training...") |
| device = "cuda" |
| model.to(device) |
| |
| summary(model, (1, 128, 128), device=device) |
| elif ( |
| torch.backends.mps.is_available() and not DEBUG |
| ): |
| print("@@@ Apple silicon device enabled, training with Metal backend...") |
| device = "mps" |
| model.to(device) |
| else: |
| print("@@@ No GPU boosting device found, training on CPU...") |
| device = "cpu" |
| |
| summary(model, (1, 128, 128), device=device) |
|
|
| |
| train_sampler = BatchSampler( |
| batch_size=batch_size, dataset=train_dataset, balanced=args.balanced_batches |
| ) |
| test_sampler = BatchSampler( |
| batch_size=100, dataset=test_dataset, balanced=args.balanced_batches |
| ) |
|
|
| mean_losses_train: List[torch.Tensor] = [] |
| mean_losses_test: List[torch.Tensor] = [] |
|
|
| for e in range(n_epochs): |
| if activeloop: |
| |
| losses = train_model(model, train_sampler, optimizer, loss_function, device) |
| |
| mean_loss = sum(losses) / len(losses) |
| mean_losses_train.append(mean_loss) |
| print(f"\nEpoch {e + 1} training done, loss on train set: {mean_loss}\n") |
|
|
| |
| |
| fpr = {x:[] for x in range(6)} |
| tpr = {x:[] for x in range(6)} |
| auc = {} |
| |
| |
| losses, y_pred_probs = test_model(model, test_sampler, loss_function, device, fpr, tpr, auc) |
|
|
| |
| mean_loss = sum(losses) / len(losses) |
| mean_losses_test.append(mean_loss) |
| print(f"\nEpoch {e + 1} testing done, loss on test set: {mean_loss}\n") |
|
|
| print(auc) |
|
|
| |
| plotext.clf() |
| plotext.scatter(mean_losses_train, label="train") |
| plotext.scatter(mean_losses_test, label="test") |
| plotext.title("Train and test loss") |
|
|
| plotext.xticks([i for i in range(len(mean_losses_train) + 1)]) |
|
|
| plotext.show() |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| plt.figure(figsize=(8, 6)) |
|
|
| colors = plt.cm.get_cmap('viridis', 6).colors |
| class_names = ['Class 0 (Atelactasis)','Class 1 (Effusion)', 'Class 2 (Infiltration)', 'Class 3 (No Finding)', 'Class 4 (Nodule)', 'Class 5 (Pneumonia)'] |
|
|
| for i, color in zip(range(6), colors): |
| plt.plot(fpr[i], tpr[i], color=color, lw=2, label='{} (AUC = {:.2f})'.format(class_names[i], auc[i])) |
|
|
| plt.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--') |
| plt.xlim([0.0, 1.0]) |
| plt.ylim([0.0, 1.05]) |
| plt.xlabel('False Positive Rate') |
| plt.ylabel('True Positive Rate') |
| plt.title('ROC Curves for 6 Classes') |
| plt.legend(loc="lower right") |
| plt.show() |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| now = datetime.now() |
| |
| if not Path("model_weights/").exists(): |
| os.mkdir(Path("model_weights/")) |
| if not Path("model_weights/").exists(): |
| os.mkdir(Path("model_weights/")) |
|
|
| |
| torch.save(model.state_dict(), f"model_weights/model_{now.month:02}{now.day:02}{now.hour}_{now.minute:02}.txt") |
| torch.save(model.state_dict(), f"model_weights/model_{now.month:02}{now.day:02}{now.hour}_{now.minute:02}.txt") |
|
|
| |
| figure(figsize=(9, 10), dpi=80) |
| fig, (ax1, ax2) = plt.subplots(2, sharex=True) |
|
|
| ax1.plot(range(1, 1 + n_epochs), [x.detach().cpu() for x in mean_losses_train], label="Train", color="blue") |
| ax2.plot(range(1, 1 + n_epochs), [x.detach().cpu() for x in mean_losses_test], label="Test", color="red") |
| fig.legend() |
|
|
|
|
| |
| if not Path("artifacts/").exists(): |
| os.mkdir(Path("artifacts/")) |
| if not Path("artifacts/").exists(): |
| os.mkdir(Path("artifacts/")) |
|
|
| |
| fig.savefig(Path("artifacts") / f"session_{now.month:02}{now.day:02}{now.hour}_{now.minute:02}.png") |
|
|
| |
| |
| |
| true_labels = test_dataset.get_labels() |
|
|
| |
| model.eval() |
|
|
| predicted_labels = [] |
| with torch.no_grad(): |
| for inputs, _ in test_dataset: |
| inputs = inputs.unsqueeze(0).to(device) |
|
|
| outputs = model(inputs) |
|
|
| |
| _, predicted = torch.max(outputs, 1) |
| predicted_labels.extend(predicted.cpu().numpy()) |
|
|
| |
| conf_matrix = confusion_matrix(true_labels, predicted_labels) |
|
|
| print("Confusion Matrix:") |
| print(conf_matrix) |
| |
| |
| |
| |
| |
| create_confusion_matrix(true_labels, predicted_labels) |
|
|
| |
| class_report = classification_report(true_labels, predicted_labels) |
| print("\nClassification Report:") |
| print(class_report) |
|
|
| fig.savefig(Path("artifacts") / f"session_{now.month:02}{now.day:02}{now.hour}_{now.minute:02}.png") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
|
|
| parser.add_argument( |
| "--nb_epochs", help="number of training iterations", default=1, type=int) |
| parser.add_argument("--batch_size", help="batch_size", default=25, type=int) |
| parser.add_argument( |
| "--balanced_batches", |
| help="whether to balance batches for class labels", |
| default=True, |
| type=bool, |
| ) |
| args = parser.parse_args() |
|
|
| main(args) |