Rodiyah commited on
Commit
52fef39
·
verified ·
1 Parent(s): 023e61d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -20
app.py CHANGED
@@ -18,6 +18,9 @@ device = torch.device("cpu")
18
  save_dir = "/home/user/app/saved_predictions"
19
  os.makedirs(save_dir, exist_ok=True)
20
 
 
 
 
21
  # -----------------------
22
  # Load model
23
  # -----------------------
@@ -45,28 +48,26 @@ transform = transforms.Compose([
45
  def looks_like_fundus(image):
46
  """
47
  Very lightweight heuristic to guess if an image looks like a retinal fundus scan.
48
-
49
- This is NOT a medical-grade classifier – it is only used to show a warning
50
- if the image is very unlikely to be a retina.
51
  """
52
  img = np.array(image.convert("L").resize((224, 224)))
53
 
54
- # Central square (potential retina) vs border
55
  center = img[40:184, 40:184]
56
 
57
  border_mask = np.ones_like(img, dtype=bool)
58
  border_mask[40:184, 40:184] = False
59
  border_pixels = img[border_mask]
60
 
61
- # Safety fallback
62
  if border_pixels.size == 0:
63
  return True
64
 
65
  center_mean = center.mean()
66
  border_mean = border_pixels.mean()
67
 
68
- border_dark_ratio = np.mean(border_pixels < 40) # dark background
69
- center_bright_ratio = np.mean(center > 80) # bright retina
70
 
71
  cond_contrast = center_mean - border_mean > 15
72
  cond_border_dark = border_dark_ratio > 0.3
@@ -79,16 +80,18 @@ def looks_like_fundus(image):
79
  # Predict and save
80
  # -----------------------
81
  def predict_retinopathy(image):
82
- # 1. Soft validation (no blocking – just warning text)
83
- if looks_like_fundus(image):
84
- warning = ""
85
- else:
86
- warning = (
87
- "⚠️ The uploaded image may not be a retinal fundus scan. "
88
- "This system is intended for use with ophthalmic retinal images only.\n\n"
 
 
89
  )
90
 
91
- # 2. Normal pipeline
92
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
93
  img = image.convert("RGB").resize((224, 224))
94
  img_tensor = transform(img).unsqueeze(0).to(device)
@@ -115,9 +118,7 @@ def predict_retinopathy(image):
115
  filename = f"{timestamp}_{label}_{confidence:.2f}.png"
116
  cam_pil.save(os.path.join(save_dir, filename))
117
 
118
- prediction_text = f"{label} (Confidence: {confidence:.2f})"
119
-
120
- return cam_pil, warning + prediction_text
121
 
122
 
123
  # -----------------------
@@ -127,7 +128,7 @@ demo = gr.Interface(
127
  fn=predict_retinopathy,
128
  inputs=gr.Image(type="pil", label="Upload Retinal Image"),
129
  outputs=[
130
- gr.Image(type="pil", label="Grad-CAM Heatmap"),
131
  gr.Text(label="Prediction")
132
  ],
133
  title="OpthaDetect – AI Retinal Screening",
@@ -138,7 +139,7 @@ demo = gr.Interface(
138
  article=(
139
  "⚕️ **OpthaDetect** is an AI-powered ophthalmic decision-support tool. "
140
  "It highlights retinal risk regions using Grad-CAM for better clinical interpretability. "
141
- "This tool does not replace clinical judgement and should be used alongside professional assessment."
142
  )
143
  )
144
 
 
18
  save_dir = "/home/user/app/saved_predictions"
19
  os.makedirs(save_dir, exist_ok=True)
20
 
21
+ # Placeholder image for invalid / non-fundus uploads
22
+ invalid_img = Image.new("RGB", (224, 224), color=(200, 200, 200))
23
+
24
  # -----------------------
25
  # Load model
26
  # -----------------------
 
48
  def looks_like_fundus(image):
49
  """
50
  Very lightweight heuristic to guess if an image looks like a retinal fundus scan.
51
+ NOT a medical-grade classifier – only used to suppress predictions on
52
+ obviously wrong images.
 
53
  """
54
  img = np.array(image.convert("L").resize((224, 224)))
55
 
56
+ # Central square vs border
57
  center = img[40:184, 40:184]
58
 
59
  border_mask = np.ones_like(img, dtype=bool)
60
  border_mask[40:184, 40:184] = False
61
  border_pixels = img[border_mask]
62
 
 
63
  if border_pixels.size == 0:
64
  return True
65
 
66
  center_mean = center.mean()
67
  border_mean = border_pixels.mean()
68
 
69
+ border_dark_ratio = np.mean(border_pixels < 40)
70
+ center_bright_ratio = np.mean(center > 80)
71
 
72
  cond_contrast = center_mean - border_mean > 15
73
  cond_border_dark = border_dark_ratio > 0.3
 
80
  # Predict and save
81
  # -----------------------
82
  def predict_retinopathy(image):
83
+ is_fundus = looks_like_fundus(image)
84
+
85
+ # If it's clearly not a fundus image: DO NOT give a DR/NoDR decision
86
+ if not is_fundus:
87
+ return (
88
+ invalid_img,
89
+ "⚠️ Image not recognised as a retinal fundus scan.\n\n"
90
+ "No Diabetic Retinopathy assessment has been generated. "
91
+ "Please upload a valid ophthalmic retinal fundus image."
92
  )
93
 
94
+ # Normal pipeline for likely fundus images
95
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
96
  img = image.convert("RGB").resize((224, 224))
97
  img_tensor = transform(img).unsqueeze(0).to(device)
 
118
  filename = f"{timestamp}_{label}_{confidence:.2f}.png"
119
  cam_pil.save(os.path.join(save_dir, filename))
120
 
121
+ return cam_pil, f"{label} (Confidence: {confidence:.2f})"
 
 
122
 
123
 
124
  # -----------------------
 
128
  fn=predict_retinopathy,
129
  inputs=gr.Image(type="pil", label="Upload Retinal Image"),
130
  outputs=[
131
+ gr.Image(type="pil", label="Grad-CAM / Status"),
132
  gr.Text(label="Prediction")
133
  ],
134
  title="OpthaDetect – AI Retinal Screening",
 
139
  article=(
140
  "⚕️ **OpthaDetect** is an AI-powered ophthalmic decision-support tool. "
141
  "It highlights retinal risk regions using Grad-CAM for better clinical interpretability. "
142
+ "Outputs should always be reviewed alongside clinical judgement."
143
  )
144
  )
145