Clocksp commited on
Commit
dffb6c9
·
verified ·
1 Parent(s): b83fefe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -0
app.py CHANGED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms, models
5
+ from PIL import Image, ImageOps
6
+ import numpy as np
7
+ import gradio as gr
8
+ import os
9
+
10
+ class DoubleConv(nn.Module):
11
+ def __init__(self, in_channels, out_channels):
12
+ super().__init__()
13
+ self.conv_op = nn.Sequential(
14
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
15
+ nn.BatchNorm2d(out_channels),
16
+ nn.ReLU(inplace=True),
17
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
18
+ nn.BatchNorm2d(out_channels),
19
+ nn.ReLU(inplace=True)
20
+ )
21
+ def forward(self, x):
22
+ return self.conv_op(x)
23
+
24
+ class Downsample(nn.Module):
25
+ def __init__(self, in_channels, out_channels):
26
+ super().__init__()
27
+ self.conv = DoubleConv(in_channels, out_channels)
28
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
29
+ def forward(self, x):
30
+ down = self.conv(x)
31
+ p = self.pool(down)
32
+ return down, p
33
+
34
+ class UpSample(nn.Module):
35
+ def __init__(self, in_channels, out_channels):
36
+ super().__init__()
37
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
38
+ self.conv = DoubleConv(in_channels, out_channels)
39
+ def forward(self, x1, x2):
40
+ x1 = self.up(x1)
41
+ # handle spatial mismatches
42
+ diffY = x2.size()[2] - x1.size()[2]
43
+ diffX = x2.size()[3] - x1.size()[3]
44
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
45
+ diffY // 2, diffY - diffY // 2])
46
+ x = torch.cat([x2, x1], dim=1)
47
+ return self.conv(x)
48
+
49
+ class UNet(nn.Module):
50
+ def __init__(self, in_channels=3, num_classes=1):
51
+ super().__init__()
52
+ self.down1 = Downsample(in_channels, 64)
53
+ self.down2 = Downsample(64, 128)
54
+ self.down3 = Downsample(128, 256)
55
+ self.down4 = Downsample(256, 512)
56
+ self.bottleneck = DoubleConv(512, 1024)
57
+ self.up1 = UpSample(1024, 512)
58
+ self.up2 = UpSample(512, 256)
59
+ self.up3 = UpSample(256, 128)
60
+ self.up4 = UpSample(128, 64)
61
+ self.out = nn.Conv2d(64, num_classes, kernel_size=1)
62
+ def forward(self, x):
63
+ d1, p1 = self.down1(x)
64
+ d2, p2 = self.down2(p1)
65
+ d3, p3 = self.down3(p2)
66
+ d4, p4 = self.down4(p3)
67
+ b = self.bottleneck(p4)
68
+ u1 = self.up1(b, d4)
69
+ u2 = self.up2(u1, d3)
70
+ u3 = self.up3(u2, d2)
71
+ u4 = self.up4(u3, d1)
72
+ return self.out(u4)
73
+
74
+
75
+ def build_efficientnet_b3(num_output=2, pretrained=False):
76
+ # torchvision efficientnet_b3; weights=None or pretrained control
77
+ model = models.efficientnet_b3(weights=None if not pretrained else models.EfficientNet_B3_Weights.IMAGENET1K_V1)
78
+ in_features = model.classifier[1].in_features
79
+ model.classifier[1] = nn.Linear(in_features, num_output)
80
+ return model
81
+
82
+
83
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
84
+ print("Using device:", device)
85
+
86
+ UNET_PATH = "models/unet.pth"
87
+ MODEL_BACT_PATH = "models/models_bacterial.pt"
88
+ MODEL_VIRAL_PATH = "models/models_viral.pt"
89
+
90
+
91
+ unet = UNet(in_channels=3, num_classes=1).to(device)
92
+ unet.load_state_dict(torch.load(UNET_PATH, map_location=device))
93
+ unet.eval()
94
+
95
+
96
+ model_bact = build_efficientnet_b3(num_output=2).to(device)
97
+ model_viral = build_efficientnet_b3(num_output=2).to(device)
98
+
99
+ model_bact.load_state_dict(torch.load(MODEL_BACT_PATH, map_location=device))
100
+ model_viral.load_state_dict(torch.load(MODEL_VIRAL_PATH, map_location=device))
101
+
102
+ model_bact.eval()
103
+ model_viral.eval()
104
+
105
+
106
+ preprocess_unet = transforms.Compose([
107
+ transforms.Resize((300, 300)),
108
+ transforms.ToTensor(),
109
+ ])
110
+
111
+ preprocess_classifier = transforms.Compose([
112
+ transforms.Resize((300, 300)),
113
+ transforms.ToTensor(),
114
+ transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
115
+ ])
116
+
117
+ def infer_mask_and_mask_image(pil_img, threshold=0.5):
118
+ """
119
+ Returns: masked_image_tensor_for_classifier (C,H,W), mask_numpy (H,W), masked_pil (PIL)
120
+ """
121
+ # Ensure RGB
122
+ if pil_img.mode != "RGB":
123
+ pil_img = pil_img.convert("RGB")
124
+ # UNet input: tensor
125
+ inp = preprocess_unet(pil_img).unsqueeze(0).to(device)
126
+ with torch.no_grad():
127
+ logits = unet(inp)
128
+ mask_prob = torch.sigmoid(logits)[0,0]
129
+ mask_np = mask_prob.cpu().numpy()
130
+ # binary mask
131
+ bin_mask = (mask_np >= threshold).astype(np.uint8)
132
+ # apply mask to original image (resized to 300x300) for classifier
133
+ img_tensor = preprocess_classifier(pil_img).to(device) # normalized
134
+ # the mask corresponds to preprocess_unet size (300,300) same as classifier
135
+ mask_tensor = torch.from_numpy(bin_mask).unsqueeze(0).to(device).float()
136
+ masked_img_tensor = img_tensor * mask_tensor
137
+ # convert masked tensor back to PIL for display (unnormalize)
138
+ img_for_display = preprocess_unet(pil_img).cpu().numpy().transpose(1,2,0)
139
+ masked_display = (img_for_display * bin_mask[...,None])
140
+ masked_display = np.clip(masked_display*255, 0, 255).astype(np.uint8)
141
+ masked_pil = Image.fromarray(masked_display)
142
+ return masked_img_tensor, mask_np, masked_pil
143
+
144
+ def classify_masked_tensor(masked_img_tensor, thresh_b=0.5, thresh_v=0.5):
145
+ """
146
+ masked_img_tensor: C,H,W on device, normalized for classifier
147
+ returns (pb, pv, label)
148
+ pb = probability of pneumonia class from model_bact (index 1)
149
+ pv = probability of pneumonia class from model_viral (index 1)
150
+ """
151
+ x = masked_img_tensor.unsqueeze(0).to(device)
152
+ with torch.no_grad():
153
+ out_b = model_bact(x)
154
+ out_v = model_viral(x)
155
+ prob_b = torch.softmax(out_b, dim=1)[0,1].item()
156
+ prob_v = torch.softmax(out_v, dim=1)[0,1].item()
157
+
158
+ # Decision logic: thresholds + fallback to higher prob when both triggered
159
+ if prob_b < thresh_b and prob_v < thresh_v:
160
+ label = "NORMAL"
161
+ elif prob_b >= thresh_b and prob_v < thresh_v:
162
+ label = "BACTERIAL PNEUMONIA"
163
+ elif prob_v >= thresh_v and prob_b < thresh_b:
164
+ label = "VIRAL PNEUMONIA"
165
+ else:
166
+ # both triggered -> pick the stronger probability (fallback)
167
+ label = "BACTERIAL PNEUMONIA" if prob_b > prob_v else "VIRAL PNEUMONIA"
168
+ label = label + " (fallback)"
169
+ return prob_b, prob_v, label
170
+
171
+
172
+ def inference_pipeline(img, thresh_b=0.5, thresh_v=0.5, seg_thresh=0.5):
173
+ """
174
+ Returns: label, bacterial_prob, viral_prob, masked_image (PIL), mask (PIL)
175
+ """
176
+ pil = Image.fromarray(img.astype('uint8'), 'RGB')
177
+ masked_tensor, mask_np, masked_pil = infer_mask_and_mask_image(pil, threshold=seg_thresh)
178
+ pb, pv, label = classify_masked_tensor(masked_tensor, thresh_b=thresh_b, thresh_v=thresh_v)
179
+ mask_vis = (mask_np * 255).astype(np.uint8)
180
+ mask_pil = Image.fromarray(mask_vis).convert("L")
181
+ display_orig = pil.resize((300,300))
182
+ overlay = Image.new("RGBA", display_orig.size)
183
+ overlay.paste(display_orig.convert("RGBA"))
184
+ # red mask with alpha
185
+ red_mask = Image.fromarray(np.zeros((300,300,3), dtype=np.uint8))
186
+ red_mask = Image.fromarray(np.stack([mask_vis, np.zeros_like(mask_vis), np.zeros_like(mask_vis)], axis=2))
187
+ red_mask = red_mask.convert("RGBA")
188
+ # apply alpha where mask is 1
189
+ alpha = (mask_np * 120).astype(np.uint8)
190
+ red_mask.putalpha(Image.fromarray(alpha))
191
+ blended = Image.alpha_composite(display_orig.convert("RGBA"), red_mask)
192
+ # return values
193
+ return {
194
+ "Prediction": label,
195
+ "Bacterial Probability": float(pb),
196
+ "Viral Probability": float(pv),
197
+ "Masked Image": masked_pil,
198
+ "Segmentation Overlay": blended
199
+ }
200
+
201
+ title = "Chest X-ray: UNet segmentation + 2 binary classifiers"
202
+ desc = "Pipeline: UNet -> mask lungs -> two binary classifiers (Normal vs Bacterial, Normal vs Viral). " \
203
+ "If both classifiers fire, the stronger probability is chosen (fallback). Thresholds adjustable."
204
+
205
+ iface = gr.Interface(
206
+ fn=inference_pipeline,
207
+ inputs=[
208
+ gr.Image(type="numpy", label="Upload chest X-ray (RGB or grayscale)"),
209
+ gr.Slider(minimum=0.1, maximum=0.9, step=0.01, value=0.5, label="Bacterial threshold (thresh_b)"),
210
+ gr.Slider(minimum=0.1, maximum=0.9, step=0.01, value=0.5, label="Viral threshold (thresh_v)"),
211
+ gr.Slider(minimum=0.1, maximum=0.9, step=0.01, value=0.5, label="Segmentation mask threshold (seg_thresh)")
212
+ ],
213
+ outputs=[
214
+ gr.Label(num_top_classes=1, label="Prediction"),
215
+ gr.Number(label="Bacterial Probability"),
216
+ gr.Number(label="Viral Probability"),
217
+ gr.Image(type="pil", label="Masked Image (input × mask)"),
218
+ gr.Image(type="pil", label="Segmentation Overlay (red mask)")
219
+ ],
220
+ title=title,
221
+ description=desc,
222
+ allow_flagging="never"
223
+ )
224
+
225
+ if __name__ == "__main__":
226
+ iface.launch()