Spaces:
Running
Running
File size: 4,719 Bytes
5fcd7bf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | import torch
import torch.nn as nn
from torchvision import transforms as T
import os
import gradio as gr
#################################
# Define problem parameters
#################################
class config:
img_size = 224
pn_mean = [0.4752, 0.4752, 0.4752] # Pneumonia dataset mean
pn_std = [0.2234, 0.2234, 0.2234] # Pneumonia dataset std
class_names = ["Normal", "Pneumonia"]
device = torch.device('cpu')
print(f"device: {device}")
#######################################
# Define image transformation pipeline
#######################################
class Gray2RGB:
def __call__(self, image):
if image.shape[0] == 3:
return image
else:
return image.repeat(3, 1, 1) # Repeat the single channel across 3 channels to convert to RGB
test_transform_custom = T.Compose([
T.Resize(size=(config.img_size, config.img_size)),
T.ToTensor(),
Gray2RGB(),
T.Normalize(config.pn_mean, config.pn_std),
])
#################################
# Define model architecture
#################################
class ConvolutionalNetwork(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 8, 3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(8),
nn.MaxPool2d(2,2))
self.conv2 = nn.Sequential(
nn.Conv2d(8, 16, 3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(16),
nn.MaxPool2d(2,2))
self.conv3 = nn.Sequential(
nn.Conv2d(16, 32, 3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(32),
nn.MaxPool2d(2,2))
self.conv4 = nn.Sequential(
nn.Conv2d(32, 64, 3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(64),
nn.MaxPool2d(2,2))
self.conv5 = nn.Sequential(
nn.Conv2d(64, 128, 3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(128),
nn.MaxPool2d(2,2))
self.fc = nn.Sequential(
nn.Linear(128*7*7, 512),
nn.ReLU(inplace=True),
nn.BatchNorm1d(512),
nn.Dropout(0.5),
nn.Linear(512, 2))
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = x.view(x.shape[0], -1)
x = self.fc(x)
return x
cnn_model = ConvolutionalNetwork()
cnn_model.to(device)
status = cnn_model.load_state_dict(torch.load('pneumonia_cnn_model.pt', map_location=device, weights_only=True))
print(f"Status: {status}")
#################################
# Define the prediction fucntion
#################################
def predict(image):
"""Transforms and performs a prediction on an image and returns the prediction dictionary."""
image = test_transform_custom(image).unsqueeze(0)
cnn_model.eval()
with torch.no_grad():
pred_probs = torch.softmax(cnn_model(image), dim=1)
# Create a prediction probability dictionary for each prediction class
pred_dict = {config.class_names[i]: float(pred_probs[0][i]) for i in range(len(config.class_names))}
# Return the prediction dictionary
return pred_dict
##########################
# Create the Gradio demo
##########################
title = "Pneumonia Detection"
description = """This is a pneumonia detection model that uses a custom convolutional neural network to predict whether an image contains pneumonia or not. \
GitHub project can be accessed [here](https://github.com/mma666/Pneumonia-Detection-Computer-Vision).
"""
# Create examples list from "examples/" directory
example_list = [["examples/" + example] for example in os.listdir("examples")]
# Create the Gradio demo
demo = gr.Interface(fn=predict,
inputs=[gr.Image(label="Upload image", type="pil", height=320, width=320)],
outputs=[gr.Label(num_top_classes=2, label="Predictions")],
examples=example_list,
title=title,
description=description,
cache_examples=False)
demo.launch()
|