Clocksp commited on
Commit
9dfe5bd
·
verified ·
1 Parent(s): 57bf351

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -211
app.py CHANGED
@@ -7,220 +7,220 @@ import numpy as np
7
  import gradio as gr
8
  import os
9
 
10
- class DoubleConv(nn.Module):
11
- def __init__(self, in_channels, out_channels):
12
- super().__init__()
13
- self.conv_op = nn.Sequential(
14
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
15
- nn.BatchNorm2d(out_channels),
16
- nn.ReLU(inplace=True),
17
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
18
- nn.BatchNorm2d(out_channels),
19
- nn.ReLU(inplace=True)
20
- )
21
- def forward(self, x):
22
- return self.conv_op(x)
23
-
24
- class Downsample(nn.Module):
25
- def __init__(self, in_channels, out_channels):
26
- super().__init__()
27
- self.conv = DoubleConv(in_channels, out_channels)
28
- self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
29
- def forward(self, x):
30
- down = self.conv(x)
31
- p = self.pool(down)
32
- return down, p
33
-
34
- class UpSample(nn.Module):
35
- def __init__(self, in_channels, out_channels):
36
- super().__init__()
37
- self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
38
- self.conv = DoubleConv(in_channels, out_channels)
39
- def forward(self, x1, x2):
40
- x1 = self.up(x1)
41
- # handle spatial mismatches
42
- diffY = x2.size()[2] - x1.size()[2]
43
- diffX = x2.size()[3] - x1.size()[3]
44
- x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
45
- diffY // 2, diffY - diffY // 2])
46
- x = torch.cat([x2, x1], dim=1)
47
- return self.conv(x)
48
-
49
- class UNet(nn.Module):
50
- def __init__(self, in_channels=3, num_classes=1):
51
- super().__init__()
52
- self.down1 = Downsample(in_channels, 64)
53
- self.down2 = Downsample(64, 128)
54
- self.down3 = Downsample(128, 256)
55
- self.down4 = Downsample(256, 512)
56
- self.bottleneck = DoubleConv(512, 1024)
57
- self.up1 = UpSample(1024, 512)
58
- self.up2 = UpSample(512, 256)
59
- self.up3 = UpSample(256, 128)
60
- self.up4 = UpSample(128, 64)
61
- self.out = nn.Conv2d(64, num_classes, kernel_size=1)
62
- def forward(self, x):
63
- d1, p1 = self.down1(x)
64
- d2, p2 = self.down2(p1)
65
- d3, p3 = self.down3(p2)
66
- d4, p4 = self.down4(p3)
67
- b = self.bottleneck(p4)
68
- u1 = self.up1(b, d4)
69
- u2 = self.up2(u1, d3)
70
- u3 = self.up3(u2, d2)
71
- u4 = self.up4(u3, d1)
72
- return self.out(u4)
73
-
74
-
75
- def build_efficientnet_b3(num_output=2, pretrained=False):
76
- # torchvision efficientnet_b3; weights=None or pretrained control
77
- model = models.efficientnet_b3(weights=None if not pretrained else models.EfficientNet_B3_Weights.IMAGENET1K_V1)
78
- in_features = model.classifier[1].in_features
79
- model.classifier[1] = nn.Linear(in_features, num_output)
80
- return model
81
-
82
-
83
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
84
- print("Using device:", device)
85
-
86
- UNET_PATH = "models/unet.pth"
87
- MODEL_BACT_PATH = "models/model_bacterial.pt"
88
- MODEL_VIRAL_PATH = "models/model_viral.pt"
89
-
90
-
91
- unet = UNet(in_channels=3, num_classes=1).to(device)
92
- unet.load_state_dict(torch.load(UNET_PATH, map_location=device))
93
- unet.eval()
94
-
95
-
96
- model_bact = build_efficientnet_b3(num_output=2).to(device)
97
- model_viral = build_efficientnet_b3(num_output=2).to(device)
98
-
99
- model_bact.load_state_dict(torch.load(MODEL_BACT_PATH, map_location=device))
100
- model_viral.load_state_dict(torch.load(MODEL_VIRAL_PATH, map_location=device))
101
-
102
- model_bact.eval()
103
- model_viral.eval()
104
-
105
-
106
- preprocess_unet = transforms.Compose([
107
- transforms.Resize((300, 300)),
108
- transforms.ToTensor(),
109
- ])
110
-
111
- preprocess_classifier = transforms.Compose([
112
- transforms.Resize((300, 300)),
113
- transforms.ToTensor(),
114
- transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
115
- ])
116
-
117
- def infer_mask_and_mask_image(pil_img, threshold=0.5):
118
- """
119
- Returns: masked_image_tensor_for_classifier (C,H,W), mask_numpy (H,W), masked_pil (PIL)
120
- """
121
- # Ensure RGB
122
- if pil_img.mode != "RGB":
123
- pil_img = pil_img.convert("RGB")
124
- # UNet input: tensor
125
- inp = preprocess_unet(pil_img).unsqueeze(0).to(device)
126
- with torch.no_grad():
127
- logits = unet(inp)
128
- mask_prob = torch.sigmoid(logits)[0,0]
129
- mask_np = mask_prob.cpu().numpy()
130
- # binary mask
131
- bin_mask = (mask_np >= threshold).astype(np.uint8)
132
- # apply mask to original image (resized to 300x300) for classifier
133
- img_tensor = preprocess_classifier(pil_img).to(device) # normalized
134
- # the mask corresponds to preprocess_unet size (300,300) same as classifier
135
- mask_tensor = torch.from_numpy(bin_mask).unsqueeze(0).to(device).float()
136
- masked_img_tensor = img_tensor * mask_tensor
137
- # convert masked tensor back to PIL for display (unnormalize)
138
- img_for_display = preprocess_unet(pil_img).cpu().numpy().transpose(1,2,0)
139
- masked_display = (img_for_display * bin_mask[...,None])
140
- masked_display = np.clip(masked_display*255, 0, 255).astype(np.uint8)
141
- masked_pil = Image.fromarray(masked_display)
142
- return masked_img_tensor, mask_np, masked_pil
143
-
144
- def classify_masked_tensor(masked_img_tensor, thresh_b=0.5, thresh_v=0.5):
145
- """
146
- masked_img_tensor: C,H,W on device, normalized for classifier
147
- Returns (pb, pv, label)
148
- pb = probability pneumonia in bacterial model
149
- pv = probability pneumonia in viral model
150
- """
151
- x = masked_img_tensor.unsqueeze(0).to(device)
152
-
153
- with torch.no_grad():
154
- out_b = model_bact(x)
155
- out_v = model_viral(x)
156
-
157
- pb = torch.softmax(out_b, dim=1)[0,1].item()
158
- pv = torch.softmax(out_v, dim=1)[0,1].item()
159
-
160
- # ----------- DECISION LOGIC -----------
161
- # Case 1: Both low → NORMAL
162
- if pb < thresh_b and pv < thresh_v:
163
- label = "NORMAL"
164
-
165
- # Case 2: Only bacterial high → BACTERIAL
166
- elif pb >= thresh_b and pv < thresh_v:
167
- label = "BACTERIAL PNEUMONIA"
168
-
169
- # Case 3: Only viral high → VIRAL
170
- elif pv >= thresh_v and pb < thresh_b:
171
- label = "VIRAL PNEUMONIA"
172
-
173
- # Case 4: Both high → pick the dominant type
174
- else:
175
- label = "BACTERIAL PNEUMONIA" if pb > pv else "VIRAL PNEUMONIA"
176
- label += " (fallback-high-confidence-overlap)"
177
-
178
- return pb, pv, label
179
-
180
-
181
-
182
- def inference_pipeline(img, thresh_b=0.5, thresh_v=0.5, seg_thresh=0.5):
183
- """
184
- Returns: label, bacterial_prob, viral_prob, masked_image (PIL), mask_overlay (PIL)
185
- """
186
-
187
- pil = Image.fromarray(img.astype('uint8'), 'RGB')
188
-
189
- masked_tensor, mask_np, masked_pil = infer_mask_and_mask_image(
190
- pil, threshold=seg_thresh
191
- )
192
-
193
- pb, pv, pred_label = classify_masked_tensor(
194
- masked_tensor,
195
- thresh_b=thresh_b,
196
- thresh_v=thresh_v
197
- )
198
-
199
- # Convert mask to PIL
200
- mask_vis = (mask_np * 255).astype(np.uint8)
201
- mask_pil = Image.fromarray(mask_vis).convert("L")
202
 
203
- # Resize original for overlay
204
- display_orig = pil.resize((300, 300))
 
 
 
 
205
 
206
- # Create red mask overlay
207
- red_mask = np.zeros((300, 300, 3), dtype=np.uint8)
208
- red_mask = np.stack([mask_vis, np.zeros_like(mask_vis), np.zeros_like(mask_vis)], axis=2)
209
- red_mask = Image.fromarray(red_mask).convert("RGBA")
210
 
211
- alpha = (mask_np * 120).astype(np.uint8)
212
- red_mask.putalpha(Image.fromarray(alpha))
213
 
214
- blended = Image.alpha_composite(display_orig.convert("RGBA"), red_mask)
215
-
216
-
217
- return (
218
- pred_label,
219
- float(pb),
220
- float(pv),
221
- masked_pil,
222
- blended
223
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
 
226
  title = "Chest X-ray: UNet segmentation + 2 binary classifiers"
@@ -248,7 +248,7 @@ with gr.Blocks() as demo:
248
  title=title,
249
  description=desc,
250
  allow_flagging="never"
251
- ).render()
252
 
253
  example_samples = [
254
  ["images/NORMAL.jpeg", "NORMAL"],
 
7
  import gradio as gr
8
  import os
9
 
10
+ # class DoubleConv(nn.Module):
11
+ # def __init__(self, in_channels, out_channels):
12
+ # super().__init__()
13
+ # self.conv_op = nn.Sequential(
14
+ # nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
15
+ # nn.BatchNorm2d(out_channels),
16
+ # nn.ReLU(inplace=True),
17
+ # nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
18
+ # nn.BatchNorm2d(out_channels),
19
+ # nn.ReLU(inplace=True)
20
+ # )
21
+ # def forward(self, x):
22
+ # return self.conv_op(x)
23
+
24
+ # class Downsample(nn.Module):
25
+ # def __init__(self, in_channels, out_channels):
26
+ # super().__init__()
27
+ # self.conv = DoubleConv(in_channels, out_channels)
28
+ # self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
29
+ # def forward(self, x):
30
+ # down = self.conv(x)
31
+ # p = self.pool(down)
32
+ # return down, p
33
+
34
+ # class UpSample(nn.Module):
35
+ # def __init__(self, in_channels, out_channels):
36
+ # super().__init__()
37
+ # self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
38
+ # self.conv = DoubleConv(in_channels, out_channels)
39
+ # def forward(self, x1, x2):
40
+ # x1 = self.up(x1)
41
+ # # handle spatial mismatches
42
+ # diffY = x2.size()[2] - x1.size()[2]
43
+ # diffX = x2.size()[3] - x1.size()[3]
44
+ # x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
45
+ # diffY // 2, diffY - diffY // 2])
46
+ # x = torch.cat([x2, x1], dim=1)
47
+ # return self.conv(x)
48
+
49
+ # class UNet(nn.Module):
50
+ # def __init__(self, in_channels=3, num_classes=1):
51
+ # super().__init__()
52
+ # self.down1 = Downsample(in_channels, 64)
53
+ # self.down2 = Downsample(64, 128)
54
+ # self.down3 = Downsample(128, 256)
55
+ # self.down4 = Downsample(256, 512)
56
+ # self.bottleneck = DoubleConv(512, 1024)
57
+ # self.up1 = UpSample(1024, 512)
58
+ # self.up2 = UpSample(512, 256)
59
+ # self.up3 = UpSample(256, 128)
60
+ # self.up4 = UpSample(128, 64)
61
+ # self.out = nn.Conv2d(64, num_classes, kernel_size=1)
62
+ # def forward(self, x):
63
+ # d1, p1 = self.down1(x)
64
+ # d2, p2 = self.down2(p1)
65
+ # d3, p3 = self.down3(p2)
66
+ # d4, p4 = self.down4(p3)
67
+ # b = self.bottleneck(p4)
68
+ # u1 = self.up1(b, d4)
69
+ # u2 = self.up2(u1, d3)
70
+ # u3 = self.up3(u2, d2)
71
+ # u4 = self.up4(u3, d1)
72
+ # return self.out(u4)
73
+
74
+
75
+ # def build_efficientnet_b3(num_output=2, pretrained=False):
76
+ # # torchvision efficientnet_b3; weights=None or pretrained control
77
+ # model = models.efficientnet_b3(weights=None if not pretrained else models.EfficientNet_B3_Weights.IMAGENET1K_V1)
78
+ # in_features = model.classifier[1].in_features
79
+ # model.classifier[1] = nn.Linear(in_features, num_output)
80
+ # return model
81
+
82
+
83
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
84
+ # print("Using device:", device)
85
+
86
+ # UNET_PATH = "models/unet.pth"
87
+ # MODEL_BACT_PATH = "models/model_bacterial.pt"
88
+ # MODEL_VIRAL_PATH = "models/model_viral.pt"
89
+
90
+
91
+ # unet = UNet(in_channels=3, num_classes=1).to(device)
92
+ # unet.load_state_dict(torch.load(UNET_PATH, map_location=device))
93
+ # unet.eval()
94
+
95
+
96
+ # model_bact = build_efficientnet_b3(num_output=2).to(device)
97
+ # model_viral = build_efficientnet_b3(num_output=2).to(device)
98
+
99
+ # model_bact.load_state_dict(torch.load(MODEL_BACT_PATH, map_location=device))
100
+ # model_viral.load_state_dict(torch.load(MODEL_VIRAL_PATH, map_location=device))
101
+
102
+ # model_bact.eval()
103
+ # model_viral.eval()
104
+
105
+
106
+ # preprocess_unet = transforms.Compose([
107
+ # transforms.Resize((300, 300)),
108
+ # transforms.ToTensor(),
109
+ # ])
110
+
111
+ # preprocess_classifier = transforms.Compose([
112
+ # transforms.Resize((300, 300)),
113
+ # transforms.ToTensor(),
114
+ # transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
115
+ # ])
116
+
117
+ # def infer_mask_and_mask_image(pil_img, threshold=0.5):
118
+ # """
119
+ # Returns: masked_image_tensor_for_classifier (C,H,W), mask_numpy (H,W), masked_pil (PIL)
120
+ # """
121
+ # # Ensure RGB
122
+ # if pil_img.mode != "RGB":
123
+ # pil_img = pil_img.convert("RGB")
124
+ # # UNet input: tensor
125
+ # inp = preprocess_unet(pil_img).unsqueeze(0).to(device)
126
+ # with torch.no_grad():
127
+ # logits = unet(inp)
128
+ # mask_prob = torch.sigmoid(logits)[0,0]
129
+ # mask_np = mask_prob.cpu().numpy()
130
+ # # binary mask
131
+ # bin_mask = (mask_np >= threshold).astype(np.uint8)
132
+ # # apply mask to original image (resized to 300x300) for classifier
133
+ # img_tensor = preprocess_classifier(pil_img).to(device) # normalized
134
+ # # the mask corresponds to preprocess_unet size (300,300) same as classifier
135
+ # mask_tensor = torch.from_numpy(bin_mask).unsqueeze(0).to(device).float()
136
+ # masked_img_tensor = img_tensor * mask_tensor
137
+ # # convert masked tensor back to PIL for display (unnormalize)
138
+ # img_for_display = preprocess_unet(pil_img).cpu().numpy().transpose(1,2,0)
139
+ # masked_display = (img_for_display * bin_mask[...,None])
140
+ # masked_display = np.clip(masked_display*255, 0, 255).astype(np.uint8)
141
+ # masked_pil = Image.fromarray(masked_display)
142
+ # return masked_img_tensor, mask_np, masked_pil
143
+
144
+ # def classify_masked_tensor(masked_img_tensor, thresh_b=0.5, thresh_v=0.5):
145
+ # """
146
+ # masked_img_tensor: C,H,W on device, normalized for classifier
147
+ # Returns (pb, pv, label)
148
+ # pb = probability pneumonia in bacterial model
149
+ # pv = probability pneumonia in viral model
150
+ # """
151
+ # x = masked_img_tensor.unsqueeze(0).to(device)
152
+
153
+ # with torch.no_grad():
154
+ # out_b = model_bact(x)
155
+ # out_v = model_viral(x)
156
+
157
+ # pb = torch.softmax(out_b, dim=1)[0,1].item()
158
+ # pv = torch.softmax(out_v, dim=1)[0,1].item()
159
+
160
+ # # ----------- DECISION LOGIC -----------
161
+ # # Case 1: Both low → NORMAL
162
+ # if pb < thresh_b and pv < thresh_v:
163
+ # label = "NORMAL"
164
+
165
+ # # Case 2: Only bacterial high → BACTERIAL
166
+ # elif pb >= thresh_b and pv < thresh_v:
167
+ # label = "BACTERIAL PNEUMONIA"
168
+
169
+ # # Case 3: Only viral high → VIRAL
170
+ # elif pv >= thresh_v and pb < thresh_b:
171
+ # label = "VIRAL PNEUMONIA"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
+ # # Case 4: Both high → pick the dominant type
174
+ # else:
175
+ # label = "BACTERIAL PNEUMONIA" if pb > pv else "VIRAL PNEUMONIA"
176
+ # label += " (fallback-high-confidence-overlap)"
177
+
178
+ # return pb, pv, label
179
 
 
 
 
 
180
 
 
 
181
 
182
+ # def inference_pipeline(img, thresh_b=0.5, thresh_v=0.5, seg_thresh=0.5):
183
+ # """
184
+ # Returns: label, bacterial_prob, viral_prob, masked_image (PIL), mask_overlay (PIL)
185
+ # """
186
+
187
+ # pil = Image.fromarray(img.astype('uint8'), 'RGB')
188
+
189
+ # masked_tensor, mask_np, masked_pil = infer_mask_and_mask_image(
190
+ # pil, threshold=seg_thresh
191
+ # )
192
+
193
+ # pb, pv, pred_label = classify_masked_tensor(
194
+ # masked_tensor,
195
+ # thresh_b=thresh_b,
196
+ # thresh_v=thresh_v
197
+ # )
198
+
199
+ # # Convert mask to PIL
200
+ # mask_vis = (mask_np * 255).astype(np.uint8)
201
+ # mask_pil = Image.fromarray(mask_vis).convert("L")
202
+
203
+ # # Resize original for overlay
204
+ # display_orig = pil.resize((300, 300))
205
+
206
+ # # Create red mask overlay
207
+ # red_mask = np.zeros((300, 300, 3), dtype=np.uint8)
208
+ # red_mask = np.stack([mask_vis, np.zeros_like(mask_vis), np.zeros_like(mask_vis)], axis=2)
209
+ # red_mask = Image.fromarray(red_mask).convert("RGBA")
210
+
211
+ # alpha = (mask_np * 120).astype(np.uint8)
212
+ # red_mask.putalpha(Image.fromarray(alpha))
213
+
214
+ # blended = Image.alpha_composite(display_orig.convert("RGBA"), red_mask)
215
+
216
+
217
+ # return (
218
+ # pred_label,
219
+ # float(pb),
220
+ # float(pv),
221
+ # masked_pil,
222
+ # blended
223
+ # )
224
 
225
 
226
  title = "Chest X-ray: UNet segmentation + 2 binary classifiers"
 
248
  title=title,
249
  description=desc,
250
  allow_flagging="never"
251
+ )
252
 
253
  example_samples = [
254
  ["images/NORMAL.jpeg", "NORMAL"],