Files changed (1) hide show
  1. app.py +29 -6
app.py CHANGED
@@ -29,6 +29,11 @@ NUM_CLASSES = 2
29
  CLASS_NAMES = ["Elliptical", "Spiral"]
30
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
 
 
 
 
 
 
32
  # Image preprocessing
33
  preprocess = transforms.Compose([
34
  transforms.Resize((224, 224)),
@@ -47,7 +52,7 @@ def load_model():
47
  if os.path.exists(MODEL_PATH):
48
  try:
49
  state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
50
- model.load_state_dict(state_dict)
51
  print(f"Model loaded successfully from {MODEL_PATH}")
52
  except Exception as e:
53
  print(f"Error loading model: {e}")
@@ -140,10 +145,24 @@ def predict_galaxy(image):
140
  img_tensor = preprocess(image).unsqueeze(0).to(DEVICE)
141
  img_tensor.requires_grad = True
142
 
143
- outputs = model(img_tensor)
144
- probs = F.softmax(outputs, dim=1)
145
- pred_class = outputs.argmax(dim=1).item()
146
- confidence = probs[0][pred_class].item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  gradcam = GradCAM(model, model.layer4)
149
  cam = gradcam.generate_cam(img_tensor, pred_class)
@@ -153,7 +172,11 @@ def predict_galaxy(image):
153
  overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
154
  overlay_pil = Image.fromarray(overlay_rgb)
155
 
156
- result_text = f"Predicted Class: {CLASS_NAMES[pred_class]}\nConfidence: {confidence:.2%}"
 
 
 
 
157
 
158
  return overlay_pil, result_text
159
 
 
29
  CLASS_NAMES = ["Elliptical", "Spiral"]
30
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
 
32
+ # πŸ”΄ Calibration + OOD thresholds
33
+ TEMPERATURE = 2.5 # softens overconfidence
34
+ CONF_THRESHOLD = 0.60 # below this β†’ uncertain
35
+ ENTROPY_THRESHOLD = 0.85 # high entropy β†’ uncertain
36
+
37
  # Image preprocessing
38
  preprocess = transforms.Compose([
39
  transforms.Resize((224, 224)),
 
52
  if os.path.exists(MODEL_PATH):
53
  try:
54
  state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
55
+ model.load_state_dict((torch.load(MODEL_PATH, map_location=DEVICE))
56
  print(f"Model loaded successfully from {MODEL_PATH}")
57
  except Exception as e:
58
  print(f"Error loading model: {e}")
 
145
  img_tensor = preprocess(image).unsqueeze(0).to(DEVICE)
146
  img_tensor.requires_grad = True
147
 
148
+ # πŸ”΄ Temperature scaling
149
+ scaled_logits = logits / TEMPERATURE
150
+ probs = F.softmax(scaled_logits, dim=1)[0]
151
+
152
+ confidence, pred_class = torch.max(probs, dim=0)
153
+
154
+ # πŸ”΄ Entropy-based uncertainty
155
+ entropy = -(probs * torch.log(probs + 1e-8)).sum().item()
156
+
157
+ if confidence.item() < CONF_THRESHOLD or entropy > ENTROPY_THRESHOLD:
158
+ result_text = (
159
+ "**Prediction:** Uncertain / Not a Galaxy\n"
160
+ "**Confidence:** Low"
161
+ )
162
+ overlay_img = image.resize((224, 224))
163
+ return overlay_img, result_text
164
+
165
+
166
 
167
  gradcam = GradCAM(model, model.layer4)
168
  cam = gradcam.generate_cam(img_tensor, pred_class)
 
172
  overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
173
  overlay_pil = Image.fromarray(overlay_rgb)
174
 
175
+ # πŸ”΄ Separate lines (as requested)
176
+ result_text = (
177
+ f"**Prediction:** {CLASS_NAMES[pred_class.item()]}\n"
178
+ f"**Confidence:** {confidence.item() * 100:.2f}%"
179
+ )
180
 
181
  return overlay_pil, result_text
182