Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms as T | |
| import os | |
| import gradio as gr | |
| ################################# | |
| # Define problem parameters | |
| ################################# | |
| class config: | |
| img_size = 224 | |
| pn_mean = [0.4752, 0.4752, 0.4752] # Pneumonia dataset mean | |
| pn_std = [0.2234, 0.2234, 0.2234] # Pneumonia dataset std | |
| class_names = ["Normal", "Pneumonia"] | |
| device = torch.device('cpu') | |
| print(f"device: {device}") | |
| ####################################### | |
| # Define image transformation pipeline | |
| ####################################### | |
| class Gray2RGB: | |
| def __call__(self, image): | |
| if image.shape[0] == 3: | |
| return image | |
| else: | |
| return image.repeat(3, 1, 1) # Repeat the single channel across 3 channels to convert to RGB | |
| test_transform_custom = T.Compose([ | |
| T.Resize(size=(config.img_size, config.img_size)), | |
| T.ToTensor(), | |
| Gray2RGB(), | |
| T.Normalize(config.pn_mean, config.pn_std), | |
| ]) | |
| ################################# | |
| # Define model architecture | |
| ################################# | |
| class ConvolutionalNetwork(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.conv1 = nn.Sequential( | |
| nn.Conv2d(3, 8, 3, stride=1, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.BatchNorm2d(8), | |
| nn.MaxPool2d(2,2)) | |
| self.conv2 = nn.Sequential( | |
| nn.Conv2d(8, 16, 3, stride=1, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.BatchNorm2d(16), | |
| nn.MaxPool2d(2,2)) | |
| self.conv3 = nn.Sequential( | |
| nn.Conv2d(16, 32, 3, stride=1, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.BatchNorm2d(32), | |
| nn.MaxPool2d(2,2)) | |
| self.conv4 = nn.Sequential( | |
| nn.Conv2d(32, 64, 3, stride=1, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.BatchNorm2d(64), | |
| nn.MaxPool2d(2,2)) | |
| self.conv5 = nn.Sequential( | |
| nn.Conv2d(64, 128, 3, stride=1, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.BatchNorm2d(128), | |
| nn.MaxPool2d(2,2)) | |
| self.fc = nn.Sequential( | |
| nn.Linear(128*7*7, 512), | |
| nn.ReLU(inplace=True), | |
| nn.BatchNorm1d(512), | |
| nn.Dropout(0.5), | |
| nn.Linear(512, 2)) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = self.conv2(x) | |
| x = self.conv3(x) | |
| x = self.conv4(x) | |
| x = self.conv5(x) | |
| x = x.view(x.shape[0], -1) | |
| x = self.fc(x) | |
| return x | |
| cnn_model = ConvolutionalNetwork() | |
| cnn_model.to(device) | |
| status = cnn_model.load_state_dict(torch.load('pneumonia_cnn_model.pt', map_location=device, weights_only=True)) | |
| print(f"Status: {status}") | |
| ################################# | |
| # Define the prediction fucntion | |
| ################################# | |
| def predict(image): | |
| """Transforms and performs a prediction on an image and returns the prediction dictionary.""" | |
| image = test_transform_custom(image).unsqueeze(0) | |
| cnn_model.eval() | |
| with torch.no_grad(): | |
| pred_probs = torch.softmax(cnn_model(image), dim=1) | |
| # Create a prediction probability dictionary for each prediction class | |
| pred_dict = {config.class_names[i]: float(pred_probs[0][i]) for i in range(len(config.class_names))} | |
| # Return the prediction dictionary | |
| return pred_dict | |
| ########################## | |
| # Create the Gradio demo | |
| ########################## | |
| title = "Pneumonia Detection" | |
| description = """This is a pneumonia detection model that uses a custom convolutional neural network to predict whether an image contains pneumonia or not. \ | |
| GitHub project can be accessed [here](https://github.com/mma666/Pneumonia-Detection-Computer-Vision). | |
| """ | |
| # Create examples list from "examples/" directory | |
| example_list = [["examples/" + example] for example in os.listdir("examples")] | |
| # Create the Gradio demo | |
| demo = gr.Interface(fn=predict, | |
| inputs=[gr.Image(label="Upload image", type="pil", height=320, width=320)], | |
| outputs=[gr.Label(num_top_classes=2, label="Predictions")], | |
| examples=example_list, | |
| title=title, | |
| description=description, | |
| cache_examples=False) | |
| demo.launch() | |