Clocksp commited on
Commit
2160150
·
verified ·
1 Parent(s): 872bf8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -14
app.py CHANGED
@@ -144,29 +144,39 @@ def infer_mask_and_mask_image(pil_img, threshold=0.5):
144
  def classify_masked_tensor(masked_img_tensor, thresh_b=0.5, thresh_v=0.5):
145
  """
146
  masked_img_tensor: C,H,W on device, normalized for classifier
147
- returns (pb, pv, label)
148
- pb = probability of pneumonia class from model_bact (index 1)
149
- pv = probability of pneumonia class from model_viral (index 1)
150
  """
151
  x = masked_img_tensor.unsqueeze(0).to(device)
 
152
  with torch.no_grad():
153
- out_b = model_bact(x)
154
  out_v = model_viral(x)
155
- prob_b = torch.softmax(out_b, dim=1)[0,1].item()
156
- prob_v = torch.softmax(out_v, dim=1)[0,1].item()
157
 
158
- # Decision logic: thresholds + fallback to higher prob when both triggered
159
- if prob_b < thresh_b and prob_v < thresh_v:
 
 
 
 
160
  label = "NORMAL"
161
- elif prob_b >= thresh_b and prob_v < thresh_v:
 
 
162
  label = "BACTERIAL PNEUMONIA"
163
- elif prob_v >= thresh_v and prob_b < thresh_b:
 
 
164
  label = "VIRAL PNEUMONIA"
 
 
165
  else:
166
- # both triggered -> pick the stronger probability (fallback)
167
- label = "BACTERIAL PNEUMONIA" if prob_b > prob_v else "VIRAL PNEUMONIA"
168
- label = label + " (fallback)"
169
- return prob_b, prob_v, label
 
170
 
171
 
172
  def inference_pipeline(img, thresh_b=0.5, thresh_v=0.5, seg_thresh=0.5):
 
144
  def classify_masked_tensor(masked_img_tensor, thresh_b=0.5, thresh_v=0.5):
145
  """
146
  masked_img_tensor: C,H,W on device, normalized for classifier
147
+ Returns (pb, pv, label)
148
+ pb = probability pneumonia in bacterial model
149
+ pv = probability pneumonia in viral model
150
  """
151
  x = masked_img_tensor.unsqueeze(0).to(device)
152
+
153
  with torch.no_grad():
154
+ out_b = model_bact(x)
155
  out_v = model_viral(x)
 
 
156
 
157
+ pb = torch.softmax(out_b, dim=1)[0,1].item()
158
+ pv = torch.softmax(out_v, dim=1)[0,1].item()
159
+
160
+ # ----------- DECISION LOGIC -----------
161
+ # Case 1: Both low → NORMAL
162
+ if pb < thresh_b and pv < thresh_v:
163
  label = "NORMAL"
164
+
165
+ # Case 2: Only bacterial high → BACTERIAL
166
+ elif pb >= thresh_b and pv < thresh_v:
167
  label = "BACTERIAL PNEUMONIA"
168
+
169
+ # Case 3: Only viral high → VIRAL
170
+ elif pv >= thresh_v and pb < thresh_b:
171
  label = "VIRAL PNEUMONIA"
172
+
173
+ # Case 4: Both high → pick the dominant type
174
  else:
175
+ label = "BACTERIAL PNEUMONIA" if pb > pv else "VIRAL PNEUMONIA"
176
+ label += " (fallback-high-confidence-overlap)"
177
+
178
+ return pb, pv, label
179
+
180
 
181
 
182
  def inference_pipeline(img, thresh_b=0.5, thresh_v=0.5, seg_thresh=0.5):