mbrq13 commited on
Commit
5cf329d
·
1 Parent(s): 3efc88d

Add pneumonia detection app with Grad-CAM

Browse files
Files changed (2) hide show
  1. app.py +87 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np, torch
3
+ import torchvision.transforms as T
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import gradio as gr
7
+ from pytorch_grad_cam import GradCAM
8
+ from pytorch_grad_cam.utils.image import show_cam_on_image
9
+ from PIL import Image
10
+
11
+
12
+ # Define CNN
13
+ class Net(nn.Module):
14
+ """Simple CNN with Batch Normalization and Dropout regularisation."""
15
+
16
+ def __init__(self) -> None:
17
+ super().__init__()
18
+ # Convolutional block 1
19
+ self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
20
+ self.bn1 = nn.BatchNorm2d(16)
21
+
22
+ # Convolutional block 2
23
+ self.pool = nn.MaxPool2d(2, 2)
24
+ self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
25
+ self.bn2 = nn.BatchNorm2d(32)
26
+
27
+ # Fully - connected head
28
+ self.fc1 = nn.Linear(32 * 56 * 56, 112)
29
+ self.dropout1 = nn.Dropout(0.5)
30
+
31
+ self.fc2 = nn.Linear(112, 84)
32
+ self.dropout2 = nn.Dropout(0.2)
33
+
34
+ self.fc3 = nn.Linear(84, 2)
35
+
36
+ def forward(self, x) -> torch.Tensor: # N,C,H,W
37
+ """Forward pass returning raw logits (no softmax)."""
38
+ c1 = self.pool(F.relu(self.bn1(self.conv1(x)))) # N,16,112,112
39
+ c2 = self.pool(F.relu(self.bn2(self.conv2(c1)))) # N,32,56,56
40
+ c2 = torch.flatten(c2, 1) # N,32*56*56
41
+ f3 = self.dropout1(F.relu(self.fc1(c2))) # N,112
42
+ f4 = self.dropout2(F.relu(self.fc2(f3))) # N,84
43
+ out = self.fc3(f4) # N,2
44
+ return out
45
+
46
+ # Load pre-trained model
47
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
+ model = Net().to(device)
49
+ model.load_state_dict(torch.load("best_model.pt", map_location=device))
50
+ model.eval()
51
+
52
+ transform = T.Compose([T.Resize((224,224)),
53
+ T.ToTensor(),
54
+ T.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
55
+
56
+ # Upload and visualize an image
57
+ def predict_gradcam(image):
58
+ # prediction
59
+ img = image.convert("RGB")
60
+ plt.imshow(image); plt.axis('off'); plt.show()
61
+ tensor = transform(img).unsqueeze(0).to(device)
62
+ with torch.no_grad():
63
+ p = torch.softmax(model(tensor), dim=1)[0,1].item()
64
+ prob= f"{p:.3f}"
65
+ label= f"{'PNEUMONIA' if p>0.5 else 'NORMAL'}"
66
+
67
+ # Grad-CAM
68
+ target_layer = model.conv2
69
+ input_tensor = transform(img).unsqueeze(0).to(device)
70
+ cam = GradCAM(model=model, target_layers=[target_layer])
71
+ grayscale_cam = cam(input_tensor=input_tensor)[0]
72
+ img_np = np.array(img.resize((224,224)), dtype=np.float32)/255.0
73
+ heatmap = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)
74
+ heatmap_pil = Image.fromarray(heatmap)
75
+
76
+ return prob, label, heatmap_pil
77
+
78
+ demo = gr.Interface(
79
+ fn=predict_gradcam,
80
+ inputs=gr.Image(type="pil", label="Upload Chest X-ray"),
81
+ outputs=[gr.Textbox(label="Probability of Pneumonia"), gr.Label(label="Prediction"), gr.Image(label="Grad-CAM")],
82
+ title="🫁 Pneumonia Detection from Chest X-rays",
83
+ description="Upload a chest X-ray to see whether it shows signs of pneumonia. The model will predict the probability and show a Grad-CAM visualization of the most important regions.",
84
+ flagging_mode="never"
85
+ )
86
+
87
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=1.9.0
2
+ torchvision>=0.10.0
3
+ pytorch-grad-cam>=1.4.0
4
+ matplotlib>=3.5.0
5
+ numpy>=1.21.0
6
+ Pillow>=8.3.0