Clocksp commited on
Commit
234c39e
·
verified ·
1 Parent(s): 9dfe5bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -211
app.py CHANGED
@@ -7,220 +7,220 @@ import numpy as np
7
  import gradio as gr
8
  import os
9
 
10
- # class DoubleConv(nn.Module):
11
- # def __init__(self, in_channels, out_channels):
12
- # super().__init__()
13
- # self.conv_op = nn.Sequential(
14
- # nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
15
- # nn.BatchNorm2d(out_channels),
16
- # nn.ReLU(inplace=True),
17
- # nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
18
- # nn.BatchNorm2d(out_channels),
19
- # nn.ReLU(inplace=True)
20
- # )
21
- # def forward(self, x):
22
- # return self.conv_op(x)
23
-
24
- # class Downsample(nn.Module):
25
- # def __init__(self, in_channels, out_channels):
26
- # super().__init__()
27
- # self.conv = DoubleConv(in_channels, out_channels)
28
- # self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
29
- # def forward(self, x):
30
- # down = self.conv(x)
31
- # p = self.pool(down)
32
- # return down, p
33
-
34
- # class UpSample(nn.Module):
35
- # def __init__(self, in_channels, out_channels):
36
- # super().__init__()
37
- # self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
38
- # self.conv = DoubleConv(in_channels, out_channels)
39
- # def forward(self, x1, x2):
40
- # x1 = self.up(x1)
41
- # # handle spatial mismatches
42
- # diffY = x2.size()[2] - x1.size()[2]
43
- # diffX = x2.size()[3] - x1.size()[3]
44
- # x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
45
- # diffY // 2, diffY - diffY // 2])
46
- # x = torch.cat([x2, x1], dim=1)
47
- # return self.conv(x)
48
-
49
- # class UNet(nn.Module):
50
- # def __init__(self, in_channels=3, num_classes=1):
51
- # super().__init__()
52
- # self.down1 = Downsample(in_channels, 64)
53
- # self.down2 = Downsample(64, 128)
54
- # self.down3 = Downsample(128, 256)
55
- # self.down4 = Downsample(256, 512)
56
- # self.bottleneck = DoubleConv(512, 1024)
57
- # self.up1 = UpSample(1024, 512)
58
- # self.up2 = UpSample(512, 256)
59
- # self.up3 = UpSample(256, 128)
60
- # self.up4 = UpSample(128, 64)
61
- # self.out = nn.Conv2d(64, num_classes, kernel_size=1)
62
- # def forward(self, x):
63
- # d1, p1 = self.down1(x)
64
- # d2, p2 = self.down2(p1)
65
- # d3, p3 = self.down3(p2)
66
- # d4, p4 = self.down4(p3)
67
- # b = self.bottleneck(p4)
68
- # u1 = self.up1(b, d4)
69
- # u2 = self.up2(u1, d3)
70
- # u3 = self.up3(u2, d2)
71
- # u4 = self.up4(u3, d1)
72
- # return self.out(u4)
73
-
74
-
75
- # def build_efficientnet_b3(num_output=2, pretrained=False):
76
- # # torchvision efficientnet_b3; weights=None or pretrained control
77
- # model = models.efficientnet_b3(weights=None if not pretrained else models.EfficientNet_B3_Weights.IMAGENET1K_V1)
78
- # in_features = model.classifier[1].in_features
79
- # model.classifier[1] = nn.Linear(in_features, num_output)
80
- # return model
81
-
82
-
83
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
84
- # print("Using device:", device)
85
-
86
- # UNET_PATH = "models/unet.pth"
87
- # MODEL_BACT_PATH = "models/model_bacterial.pt"
88
- # MODEL_VIRAL_PATH = "models/model_viral.pt"
89
-
90
-
91
- # unet = UNet(in_channels=3, num_classes=1).to(device)
92
- # unet.load_state_dict(torch.load(UNET_PATH, map_location=device))
93
- # unet.eval()
94
-
95
-
96
- # model_bact = build_efficientnet_b3(num_output=2).to(device)
97
- # model_viral = build_efficientnet_b3(num_output=2).to(device)
98
-
99
- # model_bact.load_state_dict(torch.load(MODEL_BACT_PATH, map_location=device))
100
- # model_viral.load_state_dict(torch.load(MODEL_VIRAL_PATH, map_location=device))
101
-
102
- # model_bact.eval()
103
- # model_viral.eval()
104
-
105
-
106
- # preprocess_unet = transforms.Compose([
107
- # transforms.Resize((300, 300)),
108
- # transforms.ToTensor(),
109
- # ])
110
-
111
- # preprocess_classifier = transforms.Compose([
112
- # transforms.Resize((300, 300)),
113
- # transforms.ToTensor(),
114
- # transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
115
- # ])
116
-
117
- # def infer_mask_and_mask_image(pil_img, threshold=0.5):
118
- # """
119
- # Returns: masked_image_tensor_for_classifier (C,H,W), mask_numpy (H,W), masked_pil (PIL)
120
- # """
121
- # # Ensure RGB
122
- # if pil_img.mode != "RGB":
123
- # pil_img = pil_img.convert("RGB")
124
- # # UNet input: tensor
125
- # inp = preprocess_unet(pil_img).unsqueeze(0).to(device)
126
- # with torch.no_grad():
127
- # logits = unet(inp)
128
- # mask_prob = torch.sigmoid(logits)[0,0]
129
- # mask_np = mask_prob.cpu().numpy()
130
- # # binary mask
131
- # bin_mask = (mask_np >= threshold).astype(np.uint8)
132
- # # apply mask to original image (resized to 300x300) for classifier
133
- # img_tensor = preprocess_classifier(pil_img).to(device) # normalized
134
- # # the mask corresponds to preprocess_unet size (300,300) same as classifier
135
- # mask_tensor = torch.from_numpy(bin_mask).unsqueeze(0).to(device).float()
136
- # masked_img_tensor = img_tensor * mask_tensor
137
- # # convert masked tensor back to PIL for display (unnormalize)
138
- # img_for_display = preprocess_unet(pil_img).cpu().numpy().transpose(1,2,0)
139
- # masked_display = (img_for_display * bin_mask[...,None])
140
- # masked_display = np.clip(masked_display*255, 0, 255).astype(np.uint8)
141
- # masked_pil = Image.fromarray(masked_display)
142
- # return masked_img_tensor, mask_np, masked_pil
143
-
144
- # def classify_masked_tensor(masked_img_tensor, thresh_b=0.5, thresh_v=0.5):
145
- # """
146
- # masked_img_tensor: C,H,W on device, normalized for classifier
147
- # Returns (pb, pv, label)
148
- # pb = probability pneumonia in bacterial model
149
- # pv = probability pneumonia in viral model
150
- # """
151
- # x = masked_img_tensor.unsqueeze(0).to(device)
152
-
153
- # with torch.no_grad():
154
- # out_b = model_bact(x)
155
- # out_v = model_viral(x)
156
-
157
- # pb = torch.softmax(out_b, dim=1)[0,1].item()
158
- # pv = torch.softmax(out_v, dim=1)[0,1].item()
159
-
160
- # # ----------- DECISION LOGIC -----------
161
- # # Case 1: Both low → NORMAL
162
- # if pb < thresh_b and pv < thresh_v:
163
- # label = "NORMAL"
164
-
165
- # # Case 2: Only bacterial high → BACTERIAL
166
- # elif pb >= thresh_b and pv < thresh_v:
167
- # label = "BACTERIAL PNEUMONIA"
168
-
169
- # # Case 3: Only viral high → VIRAL
170
- # elif pv >= thresh_v and pb < thresh_b:
171
- # label = "VIRAL PNEUMONIA"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
- # # Case 4: Both high → pick the dominant type
174
- # else:
175
- # label = "BACTERIAL PNEUMONIA" if pb > pv else "VIRAL PNEUMONIA"
176
- # label += " (fallback-high-confidence-overlap)"
177
-
178
- # return pb, pv, label
179
 
 
 
 
 
180
 
 
 
181
 
182
- # def inference_pipeline(img, thresh_b=0.5, thresh_v=0.5, seg_thresh=0.5):
183
- # """
184
- # Returns: label, bacterial_prob, viral_prob, masked_image (PIL), mask_overlay (PIL)
185
- # """
186
-
187
- # pil = Image.fromarray(img.astype('uint8'), 'RGB')
188
-
189
- # masked_tensor, mask_np, masked_pil = infer_mask_and_mask_image(
190
- # pil, threshold=seg_thresh
191
- # )
192
-
193
- # pb, pv, pred_label = classify_masked_tensor(
194
- # masked_tensor,
195
- # thresh_b=thresh_b,
196
- # thresh_v=thresh_v
197
- # )
198
-
199
- # # Convert mask to PIL
200
- # mask_vis = (mask_np * 255).astype(np.uint8)
201
- # mask_pil = Image.fromarray(mask_vis).convert("L")
202
-
203
- # # Resize original for overlay
204
- # display_orig = pil.resize((300, 300))
205
-
206
- # # Create red mask overlay
207
- # red_mask = np.zeros((300, 300, 3), dtype=np.uint8)
208
- # red_mask = np.stack([mask_vis, np.zeros_like(mask_vis), np.zeros_like(mask_vis)], axis=2)
209
- # red_mask = Image.fromarray(red_mask).convert("RGBA")
210
-
211
- # alpha = (mask_np * 120).astype(np.uint8)
212
- # red_mask.putalpha(Image.fromarray(alpha))
213
-
214
- # blended = Image.alpha_composite(display_orig.convert("RGBA"), red_mask)
215
-
216
-
217
- # return (
218
- # pred_label,
219
- # float(pb),
220
- # float(pv),
221
- # masked_pil,
222
- # blended
223
- # )
224
 
225
 
226
  title = "Chest X-ray: UNet segmentation + 2 binary classifiers"
@@ -258,7 +258,7 @@ with gr.Blocks() as demo:
258
  gr.Dataset(
259
  samples=example_samples,
260
  components=[
261
- gr.Image(type="filepath", label="Image", interactive=False),
262
  gr.Markdown(label="True Label")
263
  ],
264
  headers=["Image", "Label"],
 
7
  import gradio as gr
8
  import os
9
 
10
+ class DoubleConv(nn.Module):
11
+ def __init__(self, in_channels, out_channels):
12
+ super().__init__()
13
+ self.conv_op = nn.Sequential(
14
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
15
+ nn.BatchNorm2d(out_channels),
16
+ nn.ReLU(inplace=True),
17
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
18
+ nn.BatchNorm2d(out_channels),
19
+ nn.ReLU(inplace=True)
20
+ )
21
+ def forward(self, x):
22
+ return self.conv_op(x)
23
+
24
+ class Downsample(nn.Module):
25
+ def __init__(self, in_channels, out_channels):
26
+ super().__init__()
27
+ self.conv = DoubleConv(in_channels, out_channels)
28
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
29
+ def forward(self, x):
30
+ down = self.conv(x)
31
+ p = self.pool(down)
32
+ return down, p
33
+
34
+ class UpSample(nn.Module):
35
+ def __init__(self, in_channels, out_channels):
36
+ super().__init__()
37
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
38
+ self.conv = DoubleConv(in_channels, out_channels)
39
+ def forward(self, x1, x2):
40
+ x1 = self.up(x1)
41
+ # handle spatial mismatches
42
+ diffY = x2.size()[2] - x1.size()[2]
43
+ diffX = x2.size()[3] - x1.size()[3]
44
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
45
+ diffY // 2, diffY - diffY // 2])
46
+ x = torch.cat([x2, x1], dim=1)
47
+ return self.conv(x)
48
+
49
+ class UNet(nn.Module):
50
+ def __init__(self, in_channels=3, num_classes=1):
51
+ super().__init__()
52
+ self.down1 = Downsample(in_channels, 64)
53
+ self.down2 = Downsample(64, 128)
54
+ self.down3 = Downsample(128, 256)
55
+ self.down4 = Downsample(256, 512)
56
+ self.bottleneck = DoubleConv(512, 1024)
57
+ self.up1 = UpSample(1024, 512)
58
+ self.up2 = UpSample(512, 256)
59
+ self.up3 = UpSample(256, 128)
60
+ self.up4 = UpSample(128, 64)
61
+ self.out = nn.Conv2d(64, num_classes, kernel_size=1)
62
+ def forward(self, x):
63
+ d1, p1 = self.down1(x)
64
+ d2, p2 = self.down2(p1)
65
+ d3, p3 = self.down3(p2)
66
+ d4, p4 = self.down4(p3)
67
+ b = self.bottleneck(p4)
68
+ u1 = self.up1(b, d4)
69
+ u2 = self.up2(u1, d3)
70
+ u3 = self.up3(u2, d2)
71
+ u4 = self.up4(u3, d1)
72
+ return self.out(u4)
73
+
74
+
75
+ def build_efficientnet_b3(num_output=2, pretrained=False):
76
+ # torchvision efficientnet_b3; weights=None or pretrained control
77
+ model = models.efficientnet_b3(weights=None if not pretrained else models.EfficientNet_B3_Weights.IMAGENET1K_V1)
78
+ in_features = model.classifier[1].in_features
79
+ model.classifier[1] = nn.Linear(in_features, num_output)
80
+ return model
81
+
82
+
83
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
84
+ print("Using device:", device)
85
+
86
+ UNET_PATH = "models/unet.pth"
87
+ MODEL_BACT_PATH = "models/model_bacterial.pt"
88
+ MODEL_VIRAL_PATH = "models/model_viral.pt"
89
+
90
+
91
+ unet = UNet(in_channels=3, num_classes=1).to(device)
92
+ unet.load_state_dict(torch.load(UNET_PATH, map_location=device))
93
+ unet.eval()
94
+
95
+
96
+ model_bact = build_efficientnet_b3(num_output=2).to(device)
97
+ model_viral = build_efficientnet_b3(num_output=2).to(device)
98
+
99
+ model_bact.load_state_dict(torch.load(MODEL_BACT_PATH, map_location=device))
100
+ model_viral.load_state_dict(torch.load(MODEL_VIRAL_PATH, map_location=device))
101
+
102
+ model_bact.eval()
103
+ model_viral.eval()
104
+
105
+
106
+ preprocess_unet = transforms.Compose([
107
+ transforms.Resize((300, 300)),
108
+ transforms.ToTensor(),
109
+ ])
110
+
111
+ preprocess_classifier = transforms.Compose([
112
+ transforms.Resize((300, 300)),
113
+ transforms.ToTensor(),
114
+ transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
115
+ ])
116
+
117
+ def infer_mask_and_mask_image(pil_img, threshold=0.5):
118
+ """
119
+ Returns: masked_image_tensor_for_classifier (C,H,W), mask_numpy (H,W), masked_pil (PIL)
120
+ """
121
+ # Ensure RGB
122
+ if pil_img.mode != "RGB":
123
+ pil_img = pil_img.convert("RGB")
124
+ # UNet input: tensor
125
+ inp = preprocess_unet(pil_img).unsqueeze(0).to(device)
126
+ with torch.no_grad():
127
+ logits = unet(inp)
128
+ mask_prob = torch.sigmoid(logits)[0,0]
129
+ mask_np = mask_prob.cpu().numpy()
130
+ # binary mask
131
+ bin_mask = (mask_np >= threshold).astype(np.uint8)
132
+ # apply mask to original image (resized to 300x300) for classifier
133
+ img_tensor = preprocess_classifier(pil_img).to(device) # normalized
134
+ # the mask corresponds to preprocess_unet size (300,300) same as classifier
135
+ mask_tensor = torch.from_numpy(bin_mask).unsqueeze(0).to(device).float()
136
+ masked_img_tensor = img_tensor * mask_tensor
137
+ # convert masked tensor back to PIL for display (unnormalize)
138
+ img_for_display = preprocess_unet(pil_img).cpu().numpy().transpose(1,2,0)
139
+ masked_display = (img_for_display * bin_mask[...,None])
140
+ masked_display = np.clip(masked_display*255, 0, 255).astype(np.uint8)
141
+ masked_pil = Image.fromarray(masked_display)
142
+ return masked_img_tensor, mask_np, masked_pil
143
+
144
+ def classify_masked_tensor(masked_img_tensor, thresh_b=0.5, thresh_v=0.5):
145
+ """
146
+ masked_img_tensor: C,H,W on device, normalized for classifier
147
+ Returns (pb, pv, label)
148
+ pb = probability pneumonia in bacterial model
149
+ pv = probability pneumonia in viral model
150
+ """
151
+ x = masked_img_tensor.unsqueeze(0).to(device)
152
+
153
+ with torch.no_grad():
154
+ out_b = model_bact(x)
155
+ out_v = model_viral(x)
156
+
157
+ pb = torch.softmax(out_b, dim=1)[0,1].item()
158
+ pv = torch.softmax(out_v, dim=1)[0,1].item()
159
+
160
+ # ----------- DECISION LOGIC -----------
161
+ # Case 1: Both low → NORMAL
162
+ if pb < thresh_b and pv < thresh_v:
163
+ label = "NORMAL"
164
+
165
+ # Case 2: Only bacterial high → BACTERIAL
166
+ elif pb >= thresh_b and pv < thresh_v:
167
+ label = "BACTERIAL PNEUMONIA"
168
+
169
+ # Case 3: Only viral high → VIRAL
170
+ elif pv >= thresh_v and pb < thresh_b:
171
+ label = "VIRAL PNEUMONIA"
172
+
173
+ # Case 4: Both high → pick the dominant type
174
+ else:
175
+ label = "BACTERIAL PNEUMONIA" if pb > pv else "VIRAL PNEUMONIA"
176
+ label += " (fallback-high-confidence-overlap)"
177
+
178
+ return pb, pv, label
179
+
180
+
181
+
182
+ def inference_pipeline(img, thresh_b=0.5, thresh_v=0.5, seg_thresh=0.5):
183
+ """
184
+ Returns: label, bacterial_prob, viral_prob, masked_image (PIL), mask_overlay (PIL)
185
+ """
186
+
187
+ pil = Image.fromarray(img.astype('uint8'), 'RGB')
188
+
189
+ masked_tensor, mask_np, masked_pil = infer_mask_and_mask_image(
190
+ pil, threshold=seg_thresh
191
+ )
192
+
193
+ pb, pv, pred_label = classify_masked_tensor(
194
+ masked_tensor,
195
+ thresh_b=thresh_b,
196
+ thresh_v=thresh_v
197
+ )
198
+
199
+ # Convert mask to PIL
200
+ mask_vis = (mask_np * 255).astype(np.uint8)
201
+ mask_pil = Image.fromarray(mask_vis).convert("L")
202
 
203
+ # Resize original for overlay
204
+ display_orig = pil.resize((300, 300))
 
 
 
 
205
 
206
+ # Create red mask overlay
207
+ red_mask = np.zeros((300, 300, 3), dtype=np.uint8)
208
+ red_mask = np.stack([mask_vis, np.zeros_like(mask_vis), np.zeros_like(mask_vis)], axis=2)
209
+ red_mask = Image.fromarray(red_mask).convert("RGBA")
210
 
211
+ alpha = (mask_np * 120).astype(np.uint8)
212
+ red_mask.putalpha(Image.fromarray(alpha))
213
 
214
+ blended = Image.alpha_composite(display_orig.convert("RGBA"), red_mask)
215
+
216
+
217
+ return (
218
+ pred_label,
219
+ float(pb),
220
+ float(pv),
221
+ masked_pil,
222
+ blended
223
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
 
226
  title = "Chest X-ray: UNet segmentation + 2 binary classifiers"
 
258
  gr.Dataset(
259
  samples=example_samples,
260
  components=[
261
+ gr.Image(type="filepath", label="Image"),
262
  gr.Markdown(label="True Label")
263
  ],
264
  headers=["Image", "Label"],