Rodiyah commited on
Commit
94fd547
·
verified ·
1 Parent(s): d7b38d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -6
app.py CHANGED
@@ -11,15 +11,16 @@ from pytorch_grad_cam.utils.image import show_cam_on_image
11
  import os
12
  import datetime
13
 
 
14
  # Setup
 
15
  device = torch.device("cpu")
16
  save_dir = "/home/user/app/saved_predictions"
17
- if not os.path.exists(save_dir):
18
- os.makedirs(save_dir)
19
- print("📁 Folder created:", save_dir)
20
  os.makedirs(save_dir, exist_ok=True)
21
 
 
22
  # Load model
 
23
  model = models.resnet50(weights=None)
24
  model.fc = torch.nn.Linear(model.fc.in_features, 2)
25
  model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
@@ -38,8 +39,48 @@ transform = transforms.Compose([
38
  [0.229, 0.224, 0.225])
39
  ])
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # Predict and save
 
42
  def predict_retinopathy(image):
 
 
 
 
 
 
 
 
43
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
44
  img = image.convert("RGB").resize((224, 224))
45
  img_tensor = transform(img).unsqueeze(0).to(device)
@@ -55,7 +96,10 @@ def predict_retinopathy(image):
55
  # Grad-CAM
56
  rgb_img_np = np.array(img).astype(np.float32) / 255.0
57
  rgb_img_np = np.ascontiguousarray(rgb_img_np)
58
- grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
 
 
 
59
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
60
  cam_pil = Image.fromarray(cam_image)
61
 
@@ -65,8 +109,11 @@ def predict_retinopathy(image):
65
 
66
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
67
 
 
 
68
  # Gradio app
69
- gr.Interface(
 
70
  fn=predict_retinopathy,
71
  inputs=gr.Image(type="pil", label="Upload Retinal Image"),
72
  outputs=[
@@ -82,4 +129,7 @@ gr.Interface(
82
  "⚕️ **OpthaDetect** is an AI-powered ophthalmic decision-support tool. "
83
  "It highlights retinal risk regions using Grad-CAM for better clinical interpretability."
84
  )
85
- ).launch()
 
 
 
 
11
  import os
12
  import datetime
13
 
14
+ # -----------------------
15
  # Setup
16
+ # -----------------------
17
  device = torch.device("cpu")
18
  save_dir = "/home/user/app/saved_predictions"
 
 
 
19
  os.makedirs(save_dir, exist_ok=True)
20
 
21
+ # -----------------------
22
  # Load model
23
+ # -----------------------
24
  model = models.resnet50(weights=None)
25
  model.fc = torch.nn.Linear(model.fc.in_features, 2)
26
  model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
 
39
  [0.229, 0.224, 0.225])
40
  ])
41
 
42
+ # -----------------------
43
+ # Helper: basic fundus check
44
+ # -----------------------
45
+ def looks_like_fundus(image):
46
+ """
47
+ Heuristic to check if an image is likely a retinal fundus scan.
48
+ Fundus images typically have a brighter central region (retina)
49
+ and a darker outer border (background).
50
+ """
51
+ img = np.array(image.convert("L").resize((224, 224)))
52
+
53
+ # Central crop
54
+ center = img[40:184, 40:184]
55
+ center_mean = center.mean()
56
+
57
+ # Border = everything outside the central crop
58
+ border = img.copy()
59
+ border[40:184, 40:184] = 0
60
+ border_pixels = border[border > 0]
61
+
62
+ # If no real border pixels, don't block it
63
+ if border_pixels.size == 0:
64
+ return True
65
+
66
+ border_mean = border_pixels.mean()
67
+
68
+ # Fundus: center clearly brighter than border
69
+ return center_mean > border_mean + 8
70
+
71
+
72
+ # -----------------------
73
  # Predict and save
74
+ # -----------------------
75
  def predict_retinopathy(image):
76
+ # 1. Block non-retina images BEFORE model / Grad-CAM
77
+ if not looks_like_fundus(image):
78
+ raise gr.Error(
79
+ "The uploaded image does not appear to be a retinal fundus scan. "
80
+ "Please upload a valid ophthalmic retinal image for Diabetic Retinopathy assessment."
81
+ )
82
+
83
+ # 2. Normal pipeline for valid retinal images
84
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
85
  img = image.convert("RGB").resize((224, 224))
86
  img_tensor = transform(img).unsqueeze(0).to(device)
 
96
  # Grad-CAM
97
  rgb_img_np = np.array(img).astype(np.float32) / 255.0
98
  rgb_img_np = np.ascontiguousarray(rgb_img_np)
99
+ grayscale_cam = cam(
100
+ input_tensor=img_tensor,
101
+ targets=[ClassifierOutputTarget(pred)]
102
+ )[0]
103
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
104
  cam_pil = Image.fromarray(cam_image)
105
 
 
109
 
110
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
111
 
112
+
113
+ # -----------------------
114
  # Gradio app
115
+ # -----------------------
116
+ demo = gr.Interface(
117
  fn=predict_retinopathy,
118
  inputs=gr.Image(type="pil", label="Upload Retinal Image"),
119
  outputs=[
 
129
  "⚕️ **OpthaDetect** is an AI-powered ophthalmic decision-support tool. "
130
  "It highlights retinal risk regions using Grad-CAM for better clinical interpretability."
131
  )
132
+ )
133
+
134
+ if __name__ == "__main__":
135
+ demo.launch()