mma666's picture
Files added to the repo
5fcd7bf
import torch
import torch.nn as nn
from torchvision import transforms as T
import os
import gradio as gr
#################################
# Define problem parameters
#################################
class config:
img_size = 224
pn_mean = [0.4752, 0.4752, 0.4752] # Pneumonia dataset mean
pn_std = [0.2234, 0.2234, 0.2234] # Pneumonia dataset std
class_names = ["Normal", "Pneumonia"]
device = torch.device('cpu')
print(f"device: {device}")
#######################################
# Define image transformation pipeline
#######################################
class Gray2RGB:
def __call__(self, image):
if image.shape[0] == 3:
return image
else:
return image.repeat(3, 1, 1) # Repeat the single channel across 3 channels to convert to RGB
test_transform_custom = T.Compose([
T.Resize(size=(config.img_size, config.img_size)),
T.ToTensor(),
Gray2RGB(),
T.Normalize(config.pn_mean, config.pn_std),
])
#################################
# Define model architecture
#################################
class ConvolutionalNetwork(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 8, 3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(8),
nn.MaxPool2d(2,2))
self.conv2 = nn.Sequential(
nn.Conv2d(8, 16, 3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(16),
nn.MaxPool2d(2,2))
self.conv3 = nn.Sequential(
nn.Conv2d(16, 32, 3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(32),
nn.MaxPool2d(2,2))
self.conv4 = nn.Sequential(
nn.Conv2d(32, 64, 3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(64),
nn.MaxPool2d(2,2))
self.conv5 = nn.Sequential(
nn.Conv2d(64, 128, 3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(128),
nn.MaxPool2d(2,2))
self.fc = nn.Sequential(
nn.Linear(128*7*7, 512),
nn.ReLU(inplace=True),
nn.BatchNorm1d(512),
nn.Dropout(0.5),
nn.Linear(512, 2))
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = x.view(x.shape[0], -1)
x = self.fc(x)
return x
cnn_model = ConvolutionalNetwork()
cnn_model.to(device)
status = cnn_model.load_state_dict(torch.load('pneumonia_cnn_model.pt', map_location=device, weights_only=True))
print(f"Status: {status}")
#################################
# Define the prediction fucntion
#################################
def predict(image):
"""Transforms and performs a prediction on an image and returns the prediction dictionary."""
image = test_transform_custom(image).unsqueeze(0)
cnn_model.eval()
with torch.no_grad():
pred_probs = torch.softmax(cnn_model(image), dim=1)
# Create a prediction probability dictionary for each prediction class
pred_dict = {config.class_names[i]: float(pred_probs[0][i]) for i in range(len(config.class_names))}
# Return the prediction dictionary
return pred_dict
##########################
# Create the Gradio demo
##########################
title = "Pneumonia Detection"
description = """This is a pneumonia detection model that uses a custom convolutional neural network to predict whether an image contains pneumonia or not. \
GitHub project can be accessed [here](https://github.com/mma666/Pneumonia-Detection-Computer-Vision).
"""
# Create examples list from "examples/" directory
example_list = [["examples/" + example] for example in os.listdir("examples")]
# Create the Gradio demo
demo = gr.Interface(fn=predict,
inputs=[gr.Image(label="Upload image", type="pil", height=320, width=320)],
outputs=[gr.Label(num_top_classes=2, label="Predictions")],
examples=example_list,
title=title,
description=description,
cache_examples=False)
demo.launch()