Update app.py
Browse files
app.py
CHANGED
|
@@ -144,29 +144,39 @@ def infer_mask_and_mask_image(pil_img, threshold=0.5):
|
|
| 144 |
def classify_masked_tensor(masked_img_tensor, thresh_b=0.5, thresh_v=0.5):
|
| 145 |
"""
|
| 146 |
masked_img_tensor: C,H,W on device, normalized for classifier
|
| 147 |
-
|
| 148 |
-
pb = probability
|
| 149 |
-
pv = probability
|
| 150 |
"""
|
| 151 |
x = masked_img_tensor.unsqueeze(0).to(device)
|
|
|
|
| 152 |
with torch.no_grad():
|
| 153 |
-
out_b = model_bact(x)
|
| 154 |
out_v = model_viral(x)
|
| 155 |
-
prob_b = torch.softmax(out_b, dim=1)[0,1].item()
|
| 156 |
-
prob_v = torch.softmax(out_v, dim=1)[0,1].item()
|
| 157 |
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
label = "NORMAL"
|
| 161 |
-
|
|
|
|
|
|
|
| 162 |
label = "BACTERIAL PNEUMONIA"
|
| 163 |
-
|
|
|
|
|
|
|
| 164 |
label = "VIRAL PNEUMONIA"
|
|
|
|
|
|
|
| 165 |
else:
|
| 166 |
-
|
| 167 |
-
label
|
| 168 |
-
|
| 169 |
-
return
|
|
|
|
| 170 |
|
| 171 |
|
| 172 |
def inference_pipeline(img, thresh_b=0.5, thresh_v=0.5, seg_thresh=0.5):
|
|
|
|
| 144 |
def classify_masked_tensor(masked_img_tensor, thresh_b=0.5, thresh_v=0.5):
|
| 145 |
"""
|
| 146 |
masked_img_tensor: C,H,W on device, normalized for classifier
|
| 147 |
+
Returns (pb, pv, label)
|
| 148 |
+
pb = probability pneumonia in bacterial model
|
| 149 |
+
pv = probability pneumonia in viral model
|
| 150 |
"""
|
| 151 |
x = masked_img_tensor.unsqueeze(0).to(device)
|
| 152 |
+
|
| 153 |
with torch.no_grad():
|
| 154 |
+
out_b = model_bact(x)
|
| 155 |
out_v = model_viral(x)
|
|
|
|
|
|
|
| 156 |
|
| 157 |
+
pb = torch.softmax(out_b, dim=1)[0,1].item()
|
| 158 |
+
pv = torch.softmax(out_v, dim=1)[0,1].item()
|
| 159 |
+
|
| 160 |
+
# ----------- DECISION LOGIC -----------
|
| 161 |
+
# Case 1: Both low → NORMAL
|
| 162 |
+
if pb < thresh_b and pv < thresh_v:
|
| 163 |
label = "NORMAL"
|
| 164 |
+
|
| 165 |
+
# Case 2: Only bacterial high → BACTERIAL
|
| 166 |
+
elif pb >= thresh_b and pv < thresh_v:
|
| 167 |
label = "BACTERIAL PNEUMONIA"
|
| 168 |
+
|
| 169 |
+
# Case 3: Only viral high → VIRAL
|
| 170 |
+
elif pv >= thresh_v and pb < thresh_b:
|
| 171 |
label = "VIRAL PNEUMONIA"
|
| 172 |
+
|
| 173 |
+
# Case 4: Both high → pick the dominant type
|
| 174 |
else:
|
| 175 |
+
label = "BACTERIAL PNEUMONIA" if pb > pv else "VIRAL PNEUMONIA"
|
| 176 |
+
label += " (fallback-high-confidence-overlap)"
|
| 177 |
+
|
| 178 |
+
return pb, pv, label
|
| 179 |
+
|
| 180 |
|
| 181 |
|
| 182 |
def inference_pipeline(img, thresh_b=0.5, thresh_v=0.5, seg_thresh=0.5):
|