sreedeepEK commited on
Commit
ac91785
·
1 Parent(s): d0d8b3d

Add application file

Browse files
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+
5
+ from src.model import resnet_model
6
+ from timeit import default_timer as timer
7
+ from typing import Tuple,Dict
8
+
9
+ class_names = ["CNV","DME","DRUSEN","NORMAL"]
10
+
11
+ resnet, resnet_transforms = resnet_model(num_classes=4)
12
+
13
+
14
+ state_dict = torch.load(f="models/model.pth", map_location=torch.device("cpu"))
15
+ resnet.load_state_dict(state_dict, strict=False)
16
+
17
+ def predict(img) -> Tuple[Dict,float]:
18
+ start_time = timer()
19
+
20
+ img = resnet_transforms(img).unsqueeze(0)
21
+ resnet.eval()
22
+ with torch.inference_mode():
23
+ pred_probs = torch.softmax(resnet(img),dim=1)
24
+
25
+ pred_label_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
26
+ pred_time = round(timer() - start_time,5)
27
+ return pred_label_and_probs, pred_time
28
+
29
+ example_paths = [["examples/" + example] for example in os.listdir("examples")]
30
+
31
+ title = "Retinal disease detection using Optical Tomography Images 👁️"
32
+ description = " This application uses Optical Coherence Tomography (OCT) images to assist in the identification of retinal conditions such as CNV, DME, DRUSEN, and NORMAL. The tool provides predictions based on the uploaded image and displays the processing time for the analysis. Please note that this tool is intended for educational and research purposes only. It is not a substitute for professional medical advice or diagnosis. For any medical concerns, please consult a healthcare professional."
33
+
34
+ gradio_interface = gr.Interface(fn=predict,
35
+ inputs=gr.Image(type="pil"),
36
+ outputs=[gr.Label(num_top_classes=4,label="Predictions"),
37
+ gr.Number(label="prediction time: ")],
38
+ title=title,
39
+ examples=example_paths,
40
+ description=description)
41
+
42
+ gradio_interface.launch()
docs/metrics.png ADDED
examples/CNV.jpg ADDED
examples/DME.jpg ADDED
examples/DRUSEN.jpg ADDED
examples/NORMAL.jpg ADDED
file.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e011dfb570e65f9d2cb87c42b51b326e0093f2d9b034d7067b90ee34f46b9e3
3
+ size 94364842
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ loguru
4
+ gradio
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ 4
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (133 Bytes). View file
 
src/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (131 Bytes). View file
 
src/__pycache__/data_setup.cpython-310.pyc ADDED
Binary file (1.78 kB). View file
 
src/__pycache__/engine.cpython-310.pyc ADDED
Binary file (2.79 kB). View file
 
src/__pycache__/helper_function.cpython-310.pyc ADDED
Binary file (6.85 kB). View file
 
src/__pycache__/logger.cpython-310.pyc ADDED
Binary file (1.03 kB). View file
 
src/__pycache__/logger.cpython-39.pyc ADDED
Binary file (1.03 kB). View file
 
src/__pycache__/model.cpython-39.pyc ADDED
Binary file (856 Bytes). View file
 
src/__pycache__/utils.cpython-310.pyc ADDED
Binary file (951 Bytes). View file
 
src/data_setup.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torchvision
3
+
4
+ from torchvision import datasets, transforms
5
+ from torch.utils.data import DataLoader
6
+
7
+ NUM_WORKERS = os.cpu_count()
8
+
9
+ def create_dataloaders(
10
+ train_dir: str,
11
+ test_dir: str,
12
+ transform: transforms.Compose,
13
+ batch_size: int,
14
+ num_workers: int=NUM_WORKERS
15
+ ):
16
+ """Creates training and testing DataLoaders.
17
+
18
+ Takes in a training directory and testing directory path and turns
19
+ them into PyTorch Datasets and then into PyTorch DataLoaders.
20
+
21
+ Args:
22
+ train_dir: Path to training directory.
23
+ test_dir: Path to testing directory.
24
+ transform: torchvision transforms to perform on training and testing data.
25
+ batch_size: Number of samples per batch in each of the DataLoaders.
26
+ num_workers: An integer for number of workers per DataLoader.
27
+
28
+ Returns:
29
+ A tuple of (train_dataloader, test_dataloader, class_names).
30
+ Where class_names is a list of the target classes.
31
+ Example usage:
32
+ train_dataloader, test_dataloader, class_names = \
33
+ = create_dataloaders(train_dir=path/to/train_dir,
34
+ test_dir=path/to/test_dir,
35
+ transform=some_transform,
36
+ batch_size=32,
37
+ num_workers=4)
38
+ """
39
+ # Use ImageFolder to create dataset(s)
40
+ train_data = datasets.ImageFolder(train_dir, transform=transform)
41
+ test_data = datasets.ImageFolder(test_dir, transform=transform)
42
+
43
+ # Get class names
44
+ class_names = train_data.classes
45
+
46
+ # Turn images into data loaders
47
+ train_dataloader = DataLoader(
48
+ train_data,
49
+ batch_size=batch_size,
50
+ shuffle=True,
51
+ num_workers=num_workers,
52
+ pin_memory=True,
53
+ )
54
+ test_dataloader = DataLoader(
55
+ test_data,
56
+ batch_size=batch_size,
57
+ shuffle=False,
58
+ num_workers=num_workers,
59
+ pin_memory=True,
60
+ )
61
+
62
+ return train_dataloader, test_dataloader, class_names
src/engine.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import mlflow
3
+ from tqdm.auto import tqdm
4
+ from src.logger import global_logger as logger
5
+ from typing import Dict, List, Tuple
6
+
7
+ def train_step(model: torch.nn.Module,
8
+ dataloader: torch.utils.data.DataLoader,
9
+ loss_fn: torch.nn.Module,
10
+ optimizer: torch.optim.Optimizer,
11
+ device: torch.device) -> Tuple[float, float]:
12
+ """Trains a PyTorch model for a single epoch."""
13
+
14
+
15
+ model.train()
16
+ train_loss, train_acc = 0, 0
17
+
18
+ for batch, (X, y) in enumerate(dataloader):
19
+ X, y = X.to(device), y.to(device)
20
+ y_pred = model(X)
21
+ loss = loss_fn(y_pred, y)
22
+ train_loss += loss.item()
23
+ optimizer.zero_grad()
24
+ loss.backward()
25
+ optimizer.step()
26
+ y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
27
+ train_acc += (y_pred_class == y).sum().item() / len(y_pred)
28
+
29
+ train_loss /= len(dataloader)
30
+ train_acc /= len(dataloader)
31
+ return train_loss, train_acc
32
+
33
+
34
+ def test_step(model: torch.nn.Module,
35
+ dataloader: torch.utils.data.DataLoader,
36
+ loss_fn: torch.nn.Module,
37
+ device: torch.device) -> Tuple[float, float]:
38
+ """Tests a PyTorch model for a single epoch."""
39
+
40
+ model.eval()
41
+ test_loss, test_acc = 0, 0
42
+
43
+ with torch.inference_mode():
44
+ for batch, (X, y) in enumerate(dataloader):
45
+ X, y = X.to(device), y.to(device)
46
+ test_pred_logits = model(X)
47
+ loss = loss_fn(test_pred_logits, y)
48
+ test_loss += loss.item()
49
+ test_pred_labels = test_pred_logits.argmax(dim=1)
50
+ test_acc += (test_pred_labels == y).sum().item() / len(test_pred_labels)
51
+
52
+ test_loss /= len(dataloader)
53
+ test_acc /= len(dataloader)
54
+ return test_loss, test_acc
55
+
56
+ def train(model: torch.nn.Module,
57
+ train_dataloader: torch.utils.data.DataLoader,
58
+ test_dataloader: torch.utils.data.DataLoader,
59
+ optimizer: torch.optim.Optimizer,
60
+ loss_fn: torch.nn.Module,
61
+ epochs: int,
62
+ device: torch.device) -> Dict[str, List[float]]:
63
+ """Trains and tests a PyTorch model."""
64
+
65
+ results = {
66
+ "train_loss": [],
67
+ "train_acc": [],
68
+ "test_loss": [],
69
+ "test_acc": []
70
+ }
71
+
72
+ for epoch in tqdm(range(epochs)):
73
+ with mlflow.start_run() as run:
74
+ mlflow.log_param("epoch", epoch)
75
+ mlflow.log_param("optimizer", optimizer.__class__.__name__)
76
+ mlflow.log_param("loss_fn", loss_fn.__class__.__name__)
77
+
78
+ train_loss, train_acc = train_step(model=model,
79
+ dataloader=train_dataloader,
80
+ loss_fn=loss_fn,
81
+ optimizer=optimizer,
82
+ device=device)
83
+ test_loss, test_acc = test_step(model=model,
84
+ dataloader=test_dataloader,
85
+ loss_fn=loss_fn,
86
+ device=device)
87
+
88
+ print(
89
+ f"Epoch: {epoch+1} | "
90
+ f"train_loss: {train_loss:.3f} | "
91
+ f"train_acc: {train_acc:.3f} | "
92
+ f"test_loss: {test_loss:.3f} | "
93
+ f"test_acc: {test_acc:.3f}"
94
+ )
95
+
96
+ results["train_loss"].append(train_loss)
97
+ results["train_acc"].append(train_acc)
98
+ results["test_loss"].append(test_loss)
99
+ results["test_acc"].append(test_acc)
100
+
101
+ mlflow.log_metric("train_loss", train_loss, step=epoch)
102
+ mlflow.log_metric("train_acc", train_acc, step=epoch)
103
+ mlflow.log_metric("test_loss", test_loss, step=epoch)
104
+ mlflow.log_metric("test_acc", test_acc, step=epoch)
105
+ mlflow.pytorch.log_model(model, "model")
106
+
107
+ return results
src/helper_function.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A series of helper functions used throughout the course.
3
+
4
+ If a function gets defined once and could be used over and over, it'll go in here.
5
+ """
6
+ import torch
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+
10
+ from torch import nn
11
+
12
+ import os
13
+ import zipfile
14
+
15
+ from pathlib import Path
16
+
17
+ import requests
18
+
19
+ # Walk through an image classification directory and find out how many files (images)
20
+ # are in each subdirectory.
21
+ import os
22
+
23
+ def walk_through_dir(dir_path):
24
+ """
25
+ Walks through dir_path returning its contents.
26
+ Args:
27
+ dir_path (str): target directory
28
+
29
+ Returns:
30
+ A print out of:
31
+ number of subdiretories in dir_path
32
+ number of images (files) in each subdirectory
33
+ name of each subdirectory
34
+ """
35
+ for dirpath, dirnames, filenames in os.walk(dir_path):
36
+ print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'.")
37
+
38
+ def plot_decision_boundary(model: torch.nn.Module, X: torch.Tensor, y: torch.Tensor):
39
+ """Plots decision boundaries of model predicting on X in comparison to y.
40
+
41
+ Source - https://madewithml.com/courses/foundations/neural-networks/ (with modifications)
42
+ """
43
+ # Put everything to CPU (works better with NumPy + Matplotlib)
44
+ model.to("cpu")
45
+ X, y = X.to("cpu"), y.to("cpu")
46
+
47
+ # Setup prediction boundaries and grid
48
+ x_min, x_max = X[:, 0].min() - 0.1, X[:, 0].max() + 0.1
49
+ y_min, y_max = X[:, 1].min() - 0.1, X[:, 1].max() + 0.1
50
+ xx, yy = np.meshgrid(np.linspace(x_min, x_max, 101), np.linspace(y_min, y_max, 101))
51
+
52
+ # Make features
53
+ X_to_pred_on = torch.from_numpy(np.column_stack((xx.ravel(), yy.ravel()))).float()
54
+
55
+ # Make predictions
56
+ model.eval()
57
+ with torch.inference_mode():
58
+ y_logits = model(X_to_pred_on)
59
+
60
+ # Test for multi-class or binary and adjust logits to prediction labels
61
+ if len(torch.unique(y)) > 2:
62
+ y_pred = torch.softmax(y_logits, dim=1).argmax(dim=1) # mutli-class
63
+ else:
64
+ y_pred = torch.round(torch.sigmoid(y_logits)) # binary
65
+
66
+ # Reshape preds and plot
67
+ y_pred = y_pred.reshape(xx.shape).detach().numpy()
68
+ plt.contourf(xx, yy, y_pred, cmap=plt.cm.RdYlBu, alpha=0.7)
69
+ plt.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.cm.RdYlBu)
70
+ plt.xlim(xx.min(), xx.max())
71
+ plt.ylim(yy.min(), yy.max())
72
+
73
+
74
+ # Plot linear data or training and test and predictions (optional)
75
+ def plot_predictions(
76
+ train_data, train_labels, test_data, test_labels, predictions=None
77
+ ):
78
+ """
79
+ Plots linear training data and test data and compares predictions.
80
+ """
81
+ plt.figure(figsize=(10, 7))
82
+
83
+ # Plot training data in blue
84
+ plt.scatter(train_data, train_labels, c="b", s=4, label="Training data")
85
+
86
+ # Plot test data in green
87
+ plt.scatter(test_data, test_labels, c="g", s=4, label="Testing data")
88
+
89
+ if predictions is not None:
90
+ # Plot the predictions in red (predictions were made on the test data)
91
+ plt.scatter(test_data, predictions, c="r", s=4, label="Predictions")
92
+
93
+ # Show the legend
94
+ plt.legend(prop={"size": 14})
95
+
96
+
97
+ # Calculate accuracy (a classification metric)
98
+ def accuracy_fn(y_true, y_pred):
99
+ """Calculates accuracy between truth labels and predictions.
100
+
101
+ Args:
102
+ y_true (torch.Tensor): Truth labels for predictions.
103
+ y_pred (torch.Tensor): Predictions to be compared to predictions.
104
+
105
+ Returns:
106
+ [torch.float]: Accuracy value between y_true and y_pred, e.g. 78.45
107
+ """
108
+ correct = torch.eq(y_true, y_pred).sum().item()
109
+ acc = (correct / len(y_pred)) * 100
110
+ return acc
111
+
112
+
113
+ def print_train_time(start, end, device=None):
114
+ """Prints difference between start and end time.
115
+
116
+ Args:
117
+ start (float): Start time of computation (preferred in timeit format).
118
+ end (float): End time of computation.
119
+ device ([type], optional): Device that compute is running on. Defaults to None.
120
+
121
+ Returns:
122
+ float: time between start and end in seconds (higher is longer).
123
+ """
124
+ total_time = end - start
125
+ print(f"\nTrain time on {device}: {total_time:.3f} seconds")
126
+ return total_time
127
+
128
+
129
+ # Plot loss curves of a model
130
+ def plot_loss_curves(results):
131
+ """Plots training curves of a results dictionary.
132
+
133
+ Args:
134
+ results (dict): dictionary containing list of values, e.g.
135
+ {"train_loss": [...],
136
+ "train_acc": [...],
137
+ "test_loss": [...],
138
+ "test_acc": [...]}
139
+ """
140
+ loss = results["train_loss"]
141
+ test_loss = results["test_loss"]
142
+
143
+ accuracy = results["train_acc"]
144
+ test_accuracy = results["test_acc"]
145
+
146
+ epochs = range(len(results["train_loss"]))
147
+
148
+ plt.figure(figsize=(15, 7))
149
+
150
+ # Plot loss
151
+ plt.subplot(1, 2, 1)
152
+ plt.plot(epochs, loss, label="train_loss")
153
+ plt.plot(epochs, test_loss, label="test_loss")
154
+ plt.title("Loss")
155
+ plt.xlabel("Epochs")
156
+ plt.legend()
157
+
158
+ # Plot accuracy
159
+ plt.subplot(1, 2, 2)
160
+ plt.plot(epochs, accuracy, label="train_accuracy")
161
+ plt.plot(epochs, test_accuracy, label="test_accuracy")
162
+ plt.title("Accuracy")
163
+ plt.xlabel("Epochs")
164
+ plt.legend()
165
+
166
+
167
+ # Pred and plot image function from notebook 04
168
+ # See creation: https://www.learnpytorch.io/04_pytorch_custom_datasets/#113-putting-custom-image-prediction-together-building-a-function
169
+ from typing import List
170
+ import torchvision
171
+
172
+
173
+ def pred_and_plot_image(
174
+ model: torch.nn.Module,
175
+ image_path: str,
176
+ class_names: List[str] = None,
177
+ transform=None,
178
+ device: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
179
+ ):
180
+ """Makes a prediction on a target image with a trained model and plots the image.
181
+
182
+ Args:
183
+ model (torch.nn.Module): trained PyTorch image classification model.
184
+ image_path (str): filepath to target image.
185
+ class_names (List[str], optional): different class names for target image. Defaults to None.
186
+ transform (_type_, optional): transform of target image. Defaults to None.
187
+ device (torch.device, optional): target device to compute on. Defaults to "cuda" if torch.cuda.is_available() else "cpu".
188
+
189
+ Returns:
190
+ Matplotlib plot of target image and model prediction as title.
191
+
192
+ Example usage:
193
+ pred_and_plot_image(model=model,
194
+ image="some_image.jpeg",
195
+ class_names=["class_1", "class_2", "class_3"],
196
+ transform=torchvision.transforms.ToTensor(),
197
+ device=device)
198
+ """
199
+
200
+ # 1. Load in image and convert the tensor values to float32
201
+ target_image = torchvision.io.read_image(str(image_path)).type(torch.float32)
202
+
203
+ # 2. Divide the image pixel values by 255 to get them between [0, 1]
204
+ target_image = target_image / 255.0
205
+
206
+ # 3. Transform if necessary
207
+ if transform:
208
+ target_image = transform(target_image)
209
+
210
+ # 4. Make sure the model is on the target device
211
+ model.to(device)
212
+
213
+ # 5. Turn on model evaluation mode and inference mode
214
+ model.eval()
215
+ with torch.inference_mode():
216
+ # Add an extra dimension to the image
217
+ target_image = target_image.unsqueeze(dim=0)
218
+
219
+ # Make a prediction on image with an extra dimension and send it to the target device
220
+ target_image_pred = model(target_image.to(device))
221
+
222
+ # 6. Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
223
+ target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
224
+
225
+ # 7. Convert prediction probabilities -> prediction labels
226
+ target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
227
+
228
+ # 8. Plot the image alongside the prediction and prediction probability
229
+ plt.imshow(
230
+ target_image.squeeze().permute(1, 2, 0)
231
+ ) # make sure it's the right size for matplotlib
232
+ if class_names:
233
+ title = f"Pred: {class_names[target_image_pred_label.cpu()]} | Prob: {target_image_pred_probs.max().cpu():.3f}"
234
+ else:
235
+ title = f"Pred: {target_image_pred_label} | Prob: {target_image_pred_probs.max().cpu():.3f}"
236
+ plt.title(title)
237
+ plt.axis(False)
238
+
239
+ def set_seeds(seed: int=42):
240
+ """Sets random sets for torch operations.
241
+
242
+ Args:
243
+ seed (int, optional): Random seed to set. Defaults to 42.
244
+ """
245
+ # Set the seed for general torch operations
246
+ torch.manual_seed(seed)
247
+ # Set the seed for CUDA torch operations (ones that happen on the GPU)
248
+ torch.cuda.manual_seed(seed)
249
+
src/logger.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from loguru import logger
2
+ import os
3
+
4
+ class LoguruLogger:
5
+ def __init__(self, log_folder="logs", log_file="training.log"):
6
+ os.makedirs(log_folder, exist_ok=True)
7
+ log_path = os.path.join(log_folder, log_file)
8
+
9
+ logger.remove() # Remove default logger
10
+ logger.add(log_path, level="INFO", format="{time} - {name} - {level} - {message}")
11
+ logger.add(lambda msg: print(msg, end=""), level="INFO", format="{time} - {name} - {level} - {message}")
12
+
13
+ def get_logger(self):
14
+ return logger
15
+
16
+ # Create a global logger instance
17
+ global_logger = LoguruLogger().get_logger()
src/model.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+ from src.logger import global_logger as logger
5
+ from torchvision.models import resnet50, ResNet50_Weights
6
+
7
+ def resnet_model(num_classes: int = 4, seed: int = 42):
8
+ # Load pretrained ResNet18 model
9
+ weights = ResNet50_Weights.DEFAULT
10
+ model = resnet50(weights=weights)
11
+
12
+ # Freeze the parameters of the pretrained model
13
+ for param in model.parameters():
14
+ param.requires_grad = False
15
+
16
+ #logger.info("Model initialized with frozen ResNet18 backbone and new fully connected layers.")
17
+
18
+ # Replace the final fully connected layer with a new one
19
+ torch.manual_seed(seed)
20
+ model.fc = nn.Sequential(
21
+ nn.Dropout(p=0.3, inplace=True),
22
+ nn.Linear(in_features=model.fc.in_features, out_features=num_classes),
23
+ )
24
+
25
+ # Define the transforms using the predefined transforms from weights
26
+ transforms = weights.transforms()
27
+
28
+ return model, transforms
29
+
30
+ # Example usage
31
+ model, transforms = resnet_model(num_classes=4)
src/train.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import src.data_setup as data_setup
4
+ import src.engine as engine
5
+ import src.utils as utils
6
+ from src.logger import global_logger as logger
7
+ from torchvision import transforms
8
+ import src.model as model_module
9
+
10
+ def main():
11
+
12
+ NUM_EPOCHS = 20
13
+ BATCH_SIZE = 32
14
+ LEARNING_RATE = 0.001
15
+
16
+ train_dir = "data\\retinal_oct\\train"
17
+ test_dir = "data\\retinal_oct\\test"
18
+
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+ # Use the transformations required by ResNet50
22
+ data_transform = transforms.Compose([
23
+ transforms.Resize((224, 224)),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
26
+ ])
27
+
28
+ train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
29
+ train_dir=train_dir,
30
+ test_dir=test_dir,
31
+ transform=data_transform,
32
+ batch_size=BATCH_SIZE
33
+ )
34
+ logger.info("Data transformed successfully.")
35
+
36
+ # Initialize the ResNet50 model
37
+ model, _ = model_module.resnet_model(num_classes=len(class_names))
38
+ model = model.to(device)
39
+
40
+ loss_fn = torch.nn.CrossEntropyLoss()
41
+ optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
42
+
43
+ engine.train(
44
+ model=model,
45
+ train_dataloader=train_dataloader,
46
+ test_dataloader=test_dataloader,
47
+ loss_fn=loss_fn,
48
+ optimizer=optimizer,
49
+ epochs=NUM_EPOCHS,
50
+ device=device
51
+ )
52
+
53
+ utils.save_model(
54
+ model=model,
55
+ target_dir="models",
56
+ model_name="model.pth"
57
+ )
58
+ logger.info("Model trained successfully.")
59
+ logger.info("Model saved to models folder.")
60
+
61
+ if __name__ == '__main__':
62
+ main()
src/utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pathlib import Path
3
+
4
+ def save_model(model: torch.nn.Module,
5
+ target_dir: str,
6
+ model_name: str):
7
+ """Saves a PyTorch model to a target directory.
8
+
9
+ Args:
10
+ model: A target PyTorch model to save.
11
+ target_dir: A directory for saving the model to.
12
+ model_name: A filename for the saved model. Should include
13
+ either ".pth" or ".pt" as the file extension.
14
+
15
+ """
16
+ # Create target directory
17
+ target_dir_path = Path(target_dir)
18
+ target_dir_path.mkdir(parents=True,
19
+ exist_ok=True)
20
+
21
+ # Create model save path
22
+ assert model_name.endswith(".pth") or model_name.endswith(".pt"), "model_name should end with '.pt' or '.pth'"
23
+ model_save_path = target_dir_path / model_name
24
+
25
+ # Save the model state_dict()
26
+ print(f"[INFO] Saving model to: {model_save_path}")
27
+ torch.save(obj=model.state_dict(),
28
+ f=model_save_path)