Spaces:
Sleeping
Sleeping
Commit ·
ac91785
1
Parent(s): d0d8b3d
Add application file
Browse files- app.py +42 -0
- docs/metrics.png +0 -0
- examples/CNV.jpg +0 -0
- examples/DME.jpg +0 -0
- examples/DRUSEN.jpg +0 -0
- examples/NORMAL.jpg +0 -0
- file.ipynb +0 -0
- models/model.pth +3 -0
- requirements.txt +4 -0
- src/__init__.py +1 -0
- src/__pycache__/__init__.cpython-310.pyc +0 -0
- src/__pycache__/__init__.cpython-39.pyc +0 -0
- src/__pycache__/data_setup.cpython-310.pyc +0 -0
- src/__pycache__/engine.cpython-310.pyc +0 -0
- src/__pycache__/helper_function.cpython-310.pyc +0 -0
- src/__pycache__/logger.cpython-310.pyc +0 -0
- src/__pycache__/logger.cpython-39.pyc +0 -0
- src/__pycache__/model.cpython-39.pyc +0 -0
- src/__pycache__/utils.cpython-310.pyc +0 -0
- src/data_setup.py +62 -0
- src/engine.py +107 -0
- src/helper_function.py +249 -0
- src/logger.py +17 -0
- src/model.py +31 -0
- src/train.py +62 -0
- src/utils.py +28 -0
app.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from src.model import resnet_model
|
| 6 |
+
from timeit import default_timer as timer
|
| 7 |
+
from typing import Tuple,Dict
|
| 8 |
+
|
| 9 |
+
class_names = ["CNV","DME","DRUSEN","NORMAL"]
|
| 10 |
+
|
| 11 |
+
resnet, resnet_transforms = resnet_model(num_classes=4)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
state_dict = torch.load(f="models/model.pth", map_location=torch.device("cpu"))
|
| 15 |
+
resnet.load_state_dict(state_dict, strict=False)
|
| 16 |
+
|
| 17 |
+
def predict(img) -> Tuple[Dict,float]:
|
| 18 |
+
start_time = timer()
|
| 19 |
+
|
| 20 |
+
img = resnet_transforms(img).unsqueeze(0)
|
| 21 |
+
resnet.eval()
|
| 22 |
+
with torch.inference_mode():
|
| 23 |
+
pred_probs = torch.softmax(resnet(img),dim=1)
|
| 24 |
+
|
| 25 |
+
pred_label_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
|
| 26 |
+
pred_time = round(timer() - start_time,5)
|
| 27 |
+
return pred_label_and_probs, pred_time
|
| 28 |
+
|
| 29 |
+
example_paths = [["examples/" + example] for example in os.listdir("examples")]
|
| 30 |
+
|
| 31 |
+
title = "Retinal disease detection using Optical Tomography Images 👁️"
|
| 32 |
+
description = " This application uses Optical Coherence Tomography (OCT) images to assist in the identification of retinal conditions such as CNV, DME, DRUSEN, and NORMAL. The tool provides predictions based on the uploaded image and displays the processing time for the analysis. Please note that this tool is intended for educational and research purposes only. It is not a substitute for professional medical advice or diagnosis. For any medical concerns, please consult a healthcare professional."
|
| 33 |
+
|
| 34 |
+
gradio_interface = gr.Interface(fn=predict,
|
| 35 |
+
inputs=gr.Image(type="pil"),
|
| 36 |
+
outputs=[gr.Label(num_top_classes=4,label="Predictions"),
|
| 37 |
+
gr.Number(label="prediction time: ")],
|
| 38 |
+
title=title,
|
| 39 |
+
examples=example_paths,
|
| 40 |
+
description=description)
|
| 41 |
+
|
| 42 |
+
gradio_interface.launch()
|
docs/metrics.png
ADDED
|
examples/CNV.jpg
ADDED
|
examples/DME.jpg
ADDED
|
examples/DRUSEN.jpg
ADDED
|
examples/NORMAL.jpg
ADDED
|
file.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1e011dfb570e65f9d2cb87c42b51b326e0093f2d9b034d7067b90ee34f46b9e3
|
| 3 |
+
size 94364842
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
loguru
|
| 4 |
+
gradio
|
src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
4
|
src/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (133 Bytes). View file
|
|
|
src/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (131 Bytes). View file
|
|
|
src/__pycache__/data_setup.cpython-310.pyc
ADDED
|
Binary file (1.78 kB). View file
|
|
|
src/__pycache__/engine.cpython-310.pyc
ADDED
|
Binary file (2.79 kB). View file
|
|
|
src/__pycache__/helper_function.cpython-310.pyc
ADDED
|
Binary file (6.85 kB). View file
|
|
|
src/__pycache__/logger.cpython-310.pyc
ADDED
|
Binary file (1.03 kB). View file
|
|
|
src/__pycache__/logger.cpython-39.pyc
ADDED
|
Binary file (1.03 kB). View file
|
|
|
src/__pycache__/model.cpython-39.pyc
ADDED
|
Binary file (856 Bytes). View file
|
|
|
src/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (951 Bytes). View file
|
|
|
src/data_setup.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torchvision
|
| 3 |
+
|
| 4 |
+
from torchvision import datasets, transforms
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
|
| 7 |
+
NUM_WORKERS = os.cpu_count()
|
| 8 |
+
|
| 9 |
+
def create_dataloaders(
|
| 10 |
+
train_dir: str,
|
| 11 |
+
test_dir: str,
|
| 12 |
+
transform: transforms.Compose,
|
| 13 |
+
batch_size: int,
|
| 14 |
+
num_workers: int=NUM_WORKERS
|
| 15 |
+
):
|
| 16 |
+
"""Creates training and testing DataLoaders.
|
| 17 |
+
|
| 18 |
+
Takes in a training directory and testing directory path and turns
|
| 19 |
+
them into PyTorch Datasets and then into PyTorch DataLoaders.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
train_dir: Path to training directory.
|
| 23 |
+
test_dir: Path to testing directory.
|
| 24 |
+
transform: torchvision transforms to perform on training and testing data.
|
| 25 |
+
batch_size: Number of samples per batch in each of the DataLoaders.
|
| 26 |
+
num_workers: An integer for number of workers per DataLoader.
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
A tuple of (train_dataloader, test_dataloader, class_names).
|
| 30 |
+
Where class_names is a list of the target classes.
|
| 31 |
+
Example usage:
|
| 32 |
+
train_dataloader, test_dataloader, class_names = \
|
| 33 |
+
= create_dataloaders(train_dir=path/to/train_dir,
|
| 34 |
+
test_dir=path/to/test_dir,
|
| 35 |
+
transform=some_transform,
|
| 36 |
+
batch_size=32,
|
| 37 |
+
num_workers=4)
|
| 38 |
+
"""
|
| 39 |
+
# Use ImageFolder to create dataset(s)
|
| 40 |
+
train_data = datasets.ImageFolder(train_dir, transform=transform)
|
| 41 |
+
test_data = datasets.ImageFolder(test_dir, transform=transform)
|
| 42 |
+
|
| 43 |
+
# Get class names
|
| 44 |
+
class_names = train_data.classes
|
| 45 |
+
|
| 46 |
+
# Turn images into data loaders
|
| 47 |
+
train_dataloader = DataLoader(
|
| 48 |
+
train_data,
|
| 49 |
+
batch_size=batch_size,
|
| 50 |
+
shuffle=True,
|
| 51 |
+
num_workers=num_workers,
|
| 52 |
+
pin_memory=True,
|
| 53 |
+
)
|
| 54 |
+
test_dataloader = DataLoader(
|
| 55 |
+
test_data,
|
| 56 |
+
batch_size=batch_size,
|
| 57 |
+
shuffle=False,
|
| 58 |
+
num_workers=num_workers,
|
| 59 |
+
pin_memory=True,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
return train_dataloader, test_dataloader, class_names
|
src/engine.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import mlflow
|
| 3 |
+
from tqdm.auto import tqdm
|
| 4 |
+
from src.logger import global_logger as logger
|
| 5 |
+
from typing import Dict, List, Tuple
|
| 6 |
+
|
| 7 |
+
def train_step(model: torch.nn.Module,
|
| 8 |
+
dataloader: torch.utils.data.DataLoader,
|
| 9 |
+
loss_fn: torch.nn.Module,
|
| 10 |
+
optimizer: torch.optim.Optimizer,
|
| 11 |
+
device: torch.device) -> Tuple[float, float]:
|
| 12 |
+
"""Trains a PyTorch model for a single epoch."""
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
model.train()
|
| 16 |
+
train_loss, train_acc = 0, 0
|
| 17 |
+
|
| 18 |
+
for batch, (X, y) in enumerate(dataloader):
|
| 19 |
+
X, y = X.to(device), y.to(device)
|
| 20 |
+
y_pred = model(X)
|
| 21 |
+
loss = loss_fn(y_pred, y)
|
| 22 |
+
train_loss += loss.item()
|
| 23 |
+
optimizer.zero_grad()
|
| 24 |
+
loss.backward()
|
| 25 |
+
optimizer.step()
|
| 26 |
+
y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
|
| 27 |
+
train_acc += (y_pred_class == y).sum().item() / len(y_pred)
|
| 28 |
+
|
| 29 |
+
train_loss /= len(dataloader)
|
| 30 |
+
train_acc /= len(dataloader)
|
| 31 |
+
return train_loss, train_acc
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def test_step(model: torch.nn.Module,
|
| 35 |
+
dataloader: torch.utils.data.DataLoader,
|
| 36 |
+
loss_fn: torch.nn.Module,
|
| 37 |
+
device: torch.device) -> Tuple[float, float]:
|
| 38 |
+
"""Tests a PyTorch model for a single epoch."""
|
| 39 |
+
|
| 40 |
+
model.eval()
|
| 41 |
+
test_loss, test_acc = 0, 0
|
| 42 |
+
|
| 43 |
+
with torch.inference_mode():
|
| 44 |
+
for batch, (X, y) in enumerate(dataloader):
|
| 45 |
+
X, y = X.to(device), y.to(device)
|
| 46 |
+
test_pred_logits = model(X)
|
| 47 |
+
loss = loss_fn(test_pred_logits, y)
|
| 48 |
+
test_loss += loss.item()
|
| 49 |
+
test_pred_labels = test_pred_logits.argmax(dim=1)
|
| 50 |
+
test_acc += (test_pred_labels == y).sum().item() / len(test_pred_labels)
|
| 51 |
+
|
| 52 |
+
test_loss /= len(dataloader)
|
| 53 |
+
test_acc /= len(dataloader)
|
| 54 |
+
return test_loss, test_acc
|
| 55 |
+
|
| 56 |
+
def train(model: torch.nn.Module,
|
| 57 |
+
train_dataloader: torch.utils.data.DataLoader,
|
| 58 |
+
test_dataloader: torch.utils.data.DataLoader,
|
| 59 |
+
optimizer: torch.optim.Optimizer,
|
| 60 |
+
loss_fn: torch.nn.Module,
|
| 61 |
+
epochs: int,
|
| 62 |
+
device: torch.device) -> Dict[str, List[float]]:
|
| 63 |
+
"""Trains and tests a PyTorch model."""
|
| 64 |
+
|
| 65 |
+
results = {
|
| 66 |
+
"train_loss": [],
|
| 67 |
+
"train_acc": [],
|
| 68 |
+
"test_loss": [],
|
| 69 |
+
"test_acc": []
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
for epoch in tqdm(range(epochs)):
|
| 73 |
+
with mlflow.start_run() as run:
|
| 74 |
+
mlflow.log_param("epoch", epoch)
|
| 75 |
+
mlflow.log_param("optimizer", optimizer.__class__.__name__)
|
| 76 |
+
mlflow.log_param("loss_fn", loss_fn.__class__.__name__)
|
| 77 |
+
|
| 78 |
+
train_loss, train_acc = train_step(model=model,
|
| 79 |
+
dataloader=train_dataloader,
|
| 80 |
+
loss_fn=loss_fn,
|
| 81 |
+
optimizer=optimizer,
|
| 82 |
+
device=device)
|
| 83 |
+
test_loss, test_acc = test_step(model=model,
|
| 84 |
+
dataloader=test_dataloader,
|
| 85 |
+
loss_fn=loss_fn,
|
| 86 |
+
device=device)
|
| 87 |
+
|
| 88 |
+
print(
|
| 89 |
+
f"Epoch: {epoch+1} | "
|
| 90 |
+
f"train_loss: {train_loss:.3f} | "
|
| 91 |
+
f"train_acc: {train_acc:.3f} | "
|
| 92 |
+
f"test_loss: {test_loss:.3f} | "
|
| 93 |
+
f"test_acc: {test_acc:.3f}"
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
results["train_loss"].append(train_loss)
|
| 97 |
+
results["train_acc"].append(train_acc)
|
| 98 |
+
results["test_loss"].append(test_loss)
|
| 99 |
+
results["test_acc"].append(test_acc)
|
| 100 |
+
|
| 101 |
+
mlflow.log_metric("train_loss", train_loss, step=epoch)
|
| 102 |
+
mlflow.log_metric("train_acc", train_acc, step=epoch)
|
| 103 |
+
mlflow.log_metric("test_loss", test_loss, step=epoch)
|
| 104 |
+
mlflow.log_metric("test_acc", test_acc, step=epoch)
|
| 105 |
+
mlflow.pytorch.log_model(model, "model")
|
| 106 |
+
|
| 107 |
+
return results
|
src/helper_function.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A series of helper functions used throughout the course.
|
| 3 |
+
|
| 4 |
+
If a function gets defined once and could be used over and over, it'll go in here.
|
| 5 |
+
"""
|
| 6 |
+
import torch
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import zipfile
|
| 14 |
+
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
import requests
|
| 18 |
+
|
| 19 |
+
# Walk through an image classification directory and find out how many files (images)
|
| 20 |
+
# are in each subdirectory.
|
| 21 |
+
import os
|
| 22 |
+
|
| 23 |
+
def walk_through_dir(dir_path):
|
| 24 |
+
"""
|
| 25 |
+
Walks through dir_path returning its contents.
|
| 26 |
+
Args:
|
| 27 |
+
dir_path (str): target directory
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
A print out of:
|
| 31 |
+
number of subdiretories in dir_path
|
| 32 |
+
number of images (files) in each subdirectory
|
| 33 |
+
name of each subdirectory
|
| 34 |
+
"""
|
| 35 |
+
for dirpath, dirnames, filenames in os.walk(dir_path):
|
| 36 |
+
print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'.")
|
| 37 |
+
|
| 38 |
+
def plot_decision_boundary(model: torch.nn.Module, X: torch.Tensor, y: torch.Tensor):
|
| 39 |
+
"""Plots decision boundaries of model predicting on X in comparison to y.
|
| 40 |
+
|
| 41 |
+
Source - https://madewithml.com/courses/foundations/neural-networks/ (with modifications)
|
| 42 |
+
"""
|
| 43 |
+
# Put everything to CPU (works better with NumPy + Matplotlib)
|
| 44 |
+
model.to("cpu")
|
| 45 |
+
X, y = X.to("cpu"), y.to("cpu")
|
| 46 |
+
|
| 47 |
+
# Setup prediction boundaries and grid
|
| 48 |
+
x_min, x_max = X[:, 0].min() - 0.1, X[:, 0].max() + 0.1
|
| 49 |
+
y_min, y_max = X[:, 1].min() - 0.1, X[:, 1].max() + 0.1
|
| 50 |
+
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 101), np.linspace(y_min, y_max, 101))
|
| 51 |
+
|
| 52 |
+
# Make features
|
| 53 |
+
X_to_pred_on = torch.from_numpy(np.column_stack((xx.ravel(), yy.ravel()))).float()
|
| 54 |
+
|
| 55 |
+
# Make predictions
|
| 56 |
+
model.eval()
|
| 57 |
+
with torch.inference_mode():
|
| 58 |
+
y_logits = model(X_to_pred_on)
|
| 59 |
+
|
| 60 |
+
# Test for multi-class or binary and adjust logits to prediction labels
|
| 61 |
+
if len(torch.unique(y)) > 2:
|
| 62 |
+
y_pred = torch.softmax(y_logits, dim=1).argmax(dim=1) # mutli-class
|
| 63 |
+
else:
|
| 64 |
+
y_pred = torch.round(torch.sigmoid(y_logits)) # binary
|
| 65 |
+
|
| 66 |
+
# Reshape preds and plot
|
| 67 |
+
y_pred = y_pred.reshape(xx.shape).detach().numpy()
|
| 68 |
+
plt.contourf(xx, yy, y_pred, cmap=plt.cm.RdYlBu, alpha=0.7)
|
| 69 |
+
plt.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.cm.RdYlBu)
|
| 70 |
+
plt.xlim(xx.min(), xx.max())
|
| 71 |
+
plt.ylim(yy.min(), yy.max())
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# Plot linear data or training and test and predictions (optional)
|
| 75 |
+
def plot_predictions(
|
| 76 |
+
train_data, train_labels, test_data, test_labels, predictions=None
|
| 77 |
+
):
|
| 78 |
+
"""
|
| 79 |
+
Plots linear training data and test data and compares predictions.
|
| 80 |
+
"""
|
| 81 |
+
plt.figure(figsize=(10, 7))
|
| 82 |
+
|
| 83 |
+
# Plot training data in blue
|
| 84 |
+
plt.scatter(train_data, train_labels, c="b", s=4, label="Training data")
|
| 85 |
+
|
| 86 |
+
# Plot test data in green
|
| 87 |
+
plt.scatter(test_data, test_labels, c="g", s=4, label="Testing data")
|
| 88 |
+
|
| 89 |
+
if predictions is not None:
|
| 90 |
+
# Plot the predictions in red (predictions were made on the test data)
|
| 91 |
+
plt.scatter(test_data, predictions, c="r", s=4, label="Predictions")
|
| 92 |
+
|
| 93 |
+
# Show the legend
|
| 94 |
+
plt.legend(prop={"size": 14})
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# Calculate accuracy (a classification metric)
|
| 98 |
+
def accuracy_fn(y_true, y_pred):
|
| 99 |
+
"""Calculates accuracy between truth labels and predictions.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
y_true (torch.Tensor): Truth labels for predictions.
|
| 103 |
+
y_pred (torch.Tensor): Predictions to be compared to predictions.
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
[torch.float]: Accuracy value between y_true and y_pred, e.g. 78.45
|
| 107 |
+
"""
|
| 108 |
+
correct = torch.eq(y_true, y_pred).sum().item()
|
| 109 |
+
acc = (correct / len(y_pred)) * 100
|
| 110 |
+
return acc
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def print_train_time(start, end, device=None):
|
| 114 |
+
"""Prints difference between start and end time.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
start (float): Start time of computation (preferred in timeit format).
|
| 118 |
+
end (float): End time of computation.
|
| 119 |
+
device ([type], optional): Device that compute is running on. Defaults to None.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
float: time between start and end in seconds (higher is longer).
|
| 123 |
+
"""
|
| 124 |
+
total_time = end - start
|
| 125 |
+
print(f"\nTrain time on {device}: {total_time:.3f} seconds")
|
| 126 |
+
return total_time
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# Plot loss curves of a model
|
| 130 |
+
def plot_loss_curves(results):
|
| 131 |
+
"""Plots training curves of a results dictionary.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
results (dict): dictionary containing list of values, e.g.
|
| 135 |
+
{"train_loss": [...],
|
| 136 |
+
"train_acc": [...],
|
| 137 |
+
"test_loss": [...],
|
| 138 |
+
"test_acc": [...]}
|
| 139 |
+
"""
|
| 140 |
+
loss = results["train_loss"]
|
| 141 |
+
test_loss = results["test_loss"]
|
| 142 |
+
|
| 143 |
+
accuracy = results["train_acc"]
|
| 144 |
+
test_accuracy = results["test_acc"]
|
| 145 |
+
|
| 146 |
+
epochs = range(len(results["train_loss"]))
|
| 147 |
+
|
| 148 |
+
plt.figure(figsize=(15, 7))
|
| 149 |
+
|
| 150 |
+
# Plot loss
|
| 151 |
+
plt.subplot(1, 2, 1)
|
| 152 |
+
plt.plot(epochs, loss, label="train_loss")
|
| 153 |
+
plt.plot(epochs, test_loss, label="test_loss")
|
| 154 |
+
plt.title("Loss")
|
| 155 |
+
plt.xlabel("Epochs")
|
| 156 |
+
plt.legend()
|
| 157 |
+
|
| 158 |
+
# Plot accuracy
|
| 159 |
+
plt.subplot(1, 2, 2)
|
| 160 |
+
plt.plot(epochs, accuracy, label="train_accuracy")
|
| 161 |
+
plt.plot(epochs, test_accuracy, label="test_accuracy")
|
| 162 |
+
plt.title("Accuracy")
|
| 163 |
+
plt.xlabel("Epochs")
|
| 164 |
+
plt.legend()
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# Pred and plot image function from notebook 04
|
| 168 |
+
# See creation: https://www.learnpytorch.io/04_pytorch_custom_datasets/#113-putting-custom-image-prediction-together-building-a-function
|
| 169 |
+
from typing import List
|
| 170 |
+
import torchvision
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def pred_and_plot_image(
|
| 174 |
+
model: torch.nn.Module,
|
| 175 |
+
image_path: str,
|
| 176 |
+
class_names: List[str] = None,
|
| 177 |
+
transform=None,
|
| 178 |
+
device: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
|
| 179 |
+
):
|
| 180 |
+
"""Makes a prediction on a target image with a trained model and plots the image.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
model (torch.nn.Module): trained PyTorch image classification model.
|
| 184 |
+
image_path (str): filepath to target image.
|
| 185 |
+
class_names (List[str], optional): different class names for target image. Defaults to None.
|
| 186 |
+
transform (_type_, optional): transform of target image. Defaults to None.
|
| 187 |
+
device (torch.device, optional): target device to compute on. Defaults to "cuda" if torch.cuda.is_available() else "cpu".
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
Matplotlib plot of target image and model prediction as title.
|
| 191 |
+
|
| 192 |
+
Example usage:
|
| 193 |
+
pred_and_plot_image(model=model,
|
| 194 |
+
image="some_image.jpeg",
|
| 195 |
+
class_names=["class_1", "class_2", "class_3"],
|
| 196 |
+
transform=torchvision.transforms.ToTensor(),
|
| 197 |
+
device=device)
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
# 1. Load in image and convert the tensor values to float32
|
| 201 |
+
target_image = torchvision.io.read_image(str(image_path)).type(torch.float32)
|
| 202 |
+
|
| 203 |
+
# 2. Divide the image pixel values by 255 to get them between [0, 1]
|
| 204 |
+
target_image = target_image / 255.0
|
| 205 |
+
|
| 206 |
+
# 3. Transform if necessary
|
| 207 |
+
if transform:
|
| 208 |
+
target_image = transform(target_image)
|
| 209 |
+
|
| 210 |
+
# 4. Make sure the model is on the target device
|
| 211 |
+
model.to(device)
|
| 212 |
+
|
| 213 |
+
# 5. Turn on model evaluation mode and inference mode
|
| 214 |
+
model.eval()
|
| 215 |
+
with torch.inference_mode():
|
| 216 |
+
# Add an extra dimension to the image
|
| 217 |
+
target_image = target_image.unsqueeze(dim=0)
|
| 218 |
+
|
| 219 |
+
# Make a prediction on image with an extra dimension and send it to the target device
|
| 220 |
+
target_image_pred = model(target_image.to(device))
|
| 221 |
+
|
| 222 |
+
# 6. Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
|
| 223 |
+
target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
|
| 224 |
+
|
| 225 |
+
# 7. Convert prediction probabilities -> prediction labels
|
| 226 |
+
target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
|
| 227 |
+
|
| 228 |
+
# 8. Plot the image alongside the prediction and prediction probability
|
| 229 |
+
plt.imshow(
|
| 230 |
+
target_image.squeeze().permute(1, 2, 0)
|
| 231 |
+
) # make sure it's the right size for matplotlib
|
| 232 |
+
if class_names:
|
| 233 |
+
title = f"Pred: {class_names[target_image_pred_label.cpu()]} | Prob: {target_image_pred_probs.max().cpu():.3f}"
|
| 234 |
+
else:
|
| 235 |
+
title = f"Pred: {target_image_pred_label} | Prob: {target_image_pred_probs.max().cpu():.3f}"
|
| 236 |
+
plt.title(title)
|
| 237 |
+
plt.axis(False)
|
| 238 |
+
|
| 239 |
+
def set_seeds(seed: int=42):
|
| 240 |
+
"""Sets random sets for torch operations.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
seed (int, optional): Random seed to set. Defaults to 42.
|
| 244 |
+
"""
|
| 245 |
+
# Set the seed for general torch operations
|
| 246 |
+
torch.manual_seed(seed)
|
| 247 |
+
# Set the seed for CUDA torch operations (ones that happen on the GPU)
|
| 248 |
+
torch.cuda.manual_seed(seed)
|
| 249 |
+
|
src/logger.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from loguru import logger
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
class LoguruLogger:
|
| 5 |
+
def __init__(self, log_folder="logs", log_file="training.log"):
|
| 6 |
+
os.makedirs(log_folder, exist_ok=True)
|
| 7 |
+
log_path = os.path.join(log_folder, log_file)
|
| 8 |
+
|
| 9 |
+
logger.remove() # Remove default logger
|
| 10 |
+
logger.add(log_path, level="INFO", format="{time} - {name} - {level} - {message}")
|
| 11 |
+
logger.add(lambda msg: print(msg, end=""), level="INFO", format="{time} - {name} - {level} - {message}")
|
| 12 |
+
|
| 13 |
+
def get_logger(self):
|
| 14 |
+
return logger
|
| 15 |
+
|
| 16 |
+
# Create a global logger instance
|
| 17 |
+
global_logger = LoguruLogger().get_logger()
|
src/model.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torchvision
|
| 4 |
+
from src.logger import global_logger as logger
|
| 5 |
+
from torchvision.models import resnet50, ResNet50_Weights
|
| 6 |
+
|
| 7 |
+
def resnet_model(num_classes: int = 4, seed: int = 42):
|
| 8 |
+
# Load pretrained ResNet18 model
|
| 9 |
+
weights = ResNet50_Weights.DEFAULT
|
| 10 |
+
model = resnet50(weights=weights)
|
| 11 |
+
|
| 12 |
+
# Freeze the parameters of the pretrained model
|
| 13 |
+
for param in model.parameters():
|
| 14 |
+
param.requires_grad = False
|
| 15 |
+
|
| 16 |
+
#logger.info("Model initialized with frozen ResNet18 backbone and new fully connected layers.")
|
| 17 |
+
|
| 18 |
+
# Replace the final fully connected layer with a new one
|
| 19 |
+
torch.manual_seed(seed)
|
| 20 |
+
model.fc = nn.Sequential(
|
| 21 |
+
nn.Dropout(p=0.3, inplace=True),
|
| 22 |
+
nn.Linear(in_features=model.fc.in_features, out_features=num_classes),
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# Define the transforms using the predefined transforms from weights
|
| 26 |
+
transforms = weights.transforms()
|
| 27 |
+
|
| 28 |
+
return model, transforms
|
| 29 |
+
|
| 30 |
+
# Example usage
|
| 31 |
+
model, transforms = resnet_model(num_classes=4)
|
src/train.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import src.data_setup as data_setup
|
| 4 |
+
import src.engine as engine
|
| 5 |
+
import src.utils as utils
|
| 6 |
+
from src.logger import global_logger as logger
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
import src.model as model_module
|
| 9 |
+
|
| 10 |
+
def main():
|
| 11 |
+
|
| 12 |
+
NUM_EPOCHS = 20
|
| 13 |
+
BATCH_SIZE = 32
|
| 14 |
+
LEARNING_RATE = 0.001
|
| 15 |
+
|
| 16 |
+
train_dir = "data\\retinal_oct\\train"
|
| 17 |
+
test_dir = "data\\retinal_oct\\test"
|
| 18 |
+
|
| 19 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 20 |
+
|
| 21 |
+
# Use the transformations required by ResNet50
|
| 22 |
+
data_transform = transforms.Compose([
|
| 23 |
+
transforms.Resize((224, 224)),
|
| 24 |
+
transforms.ToTensor(),
|
| 25 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 26 |
+
])
|
| 27 |
+
|
| 28 |
+
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
|
| 29 |
+
train_dir=train_dir,
|
| 30 |
+
test_dir=test_dir,
|
| 31 |
+
transform=data_transform,
|
| 32 |
+
batch_size=BATCH_SIZE
|
| 33 |
+
)
|
| 34 |
+
logger.info("Data transformed successfully.")
|
| 35 |
+
|
| 36 |
+
# Initialize the ResNet50 model
|
| 37 |
+
model, _ = model_module.resnet_model(num_classes=len(class_names))
|
| 38 |
+
model = model.to(device)
|
| 39 |
+
|
| 40 |
+
loss_fn = torch.nn.CrossEntropyLoss()
|
| 41 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
| 42 |
+
|
| 43 |
+
engine.train(
|
| 44 |
+
model=model,
|
| 45 |
+
train_dataloader=train_dataloader,
|
| 46 |
+
test_dataloader=test_dataloader,
|
| 47 |
+
loss_fn=loss_fn,
|
| 48 |
+
optimizer=optimizer,
|
| 49 |
+
epochs=NUM_EPOCHS,
|
| 50 |
+
device=device
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
utils.save_model(
|
| 54 |
+
model=model,
|
| 55 |
+
target_dir="models",
|
| 56 |
+
model_name="model.pth"
|
| 57 |
+
)
|
| 58 |
+
logger.info("Model trained successfully.")
|
| 59 |
+
logger.info("Model saved to models folder.")
|
| 60 |
+
|
| 61 |
+
if __name__ == '__main__':
|
| 62 |
+
main()
|
src/utils.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
def save_model(model: torch.nn.Module,
|
| 5 |
+
target_dir: str,
|
| 6 |
+
model_name: str):
|
| 7 |
+
"""Saves a PyTorch model to a target directory.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
model: A target PyTorch model to save.
|
| 11 |
+
target_dir: A directory for saving the model to.
|
| 12 |
+
model_name: A filename for the saved model. Should include
|
| 13 |
+
either ".pth" or ".pt" as the file extension.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
# Create target directory
|
| 17 |
+
target_dir_path = Path(target_dir)
|
| 18 |
+
target_dir_path.mkdir(parents=True,
|
| 19 |
+
exist_ok=True)
|
| 20 |
+
|
| 21 |
+
# Create model save path
|
| 22 |
+
assert model_name.endswith(".pth") or model_name.endswith(".pt"), "model_name should end with '.pt' or '.pth'"
|
| 23 |
+
model_save_path = target_dir_path / model_name
|
| 24 |
+
|
| 25 |
+
# Save the model state_dict()
|
| 26 |
+
print(f"[INFO] Saving model to: {model_save_path}")
|
| 27 |
+
torch.save(obj=model.state_dict(),
|
| 28 |
+
f=model_save_path)
|