Rodiyah commited on
Commit
6a60023
Β·
verified Β·
1 Parent(s): 59466c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -70
app.py CHANGED
@@ -11,19 +11,15 @@ from pytorch_grad_cam.utils.image import show_cam_on_image
11
  import os
12
  import datetime
13
 
14
- # -----------------------
15
  # Setup
16
- # -----------------------
17
  device = torch.device("cpu")
18
  save_dir = "/home/user/app/saved_predictions"
 
 
 
19
  os.makedirs(save_dir, exist_ok=True)
20
 
21
- # Placeholder image for invalid / non-fundus uploads
22
- invalid_img = Image.new("RGB", (224, 224), color=(200, 200, 200))
23
-
24
- # -----------------------
25
  # Load model
26
- # -----------------------
27
  model = models.resnet50(weights=None)
28
  model.fc = torch.nn.Linear(model.fc.in_features, 2)
29
  model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
@@ -42,56 +38,8 @@ transform = transforms.Compose([
42
  [0.229, 0.224, 0.225])
43
  ])
44
 
45
- # -----------------------
46
- # Helper: soft fundus check
47
- # -----------------------
48
- def looks_like_fundus(image):
49
- """
50
- Very lightweight heuristic to guess if an image looks like a retinal fundus scan.
51
- NOT a medical-grade classifier – only used to suppress predictions on
52
- obviously wrong images.
53
- """
54
- img = np.array(image.convert("L").resize((224, 224)))
55
-
56
- # Central square vs border
57
- center = img[40:184, 40:184]
58
-
59
- border_mask = np.ones_like(img, dtype=bool)
60
- border_mask[40:184, 40:184] = False
61
- border_pixels = img[border_mask]
62
-
63
- if border_pixels.size == 0:
64
- return True
65
-
66
- center_mean = center.mean()
67
- border_mean = border_pixels.mean()
68
-
69
- border_dark_ratio = np.mean(border_pixels < 40)
70
- center_bright_ratio = np.mean(center > 80)
71
-
72
- cond_contrast = center_mean - border_mean > 15
73
- cond_border_dark = border_dark_ratio > 0.3
74
- cond_center_bright = center_bright_ratio > 0.25
75
-
76
- return cond_contrast and cond_border_dark and cond_center_bright
77
-
78
-
79
- # -----------------------
80
  # Predict and save
81
- # -----------------------
82
  def predict_retinopathy(image):
83
- is_fundus = looks_like_fundus(image)
84
-
85
- # If it's clearly not a fundus image: DO NOT give a DR/NoDR decision
86
- if not is_fundus:
87
- return (
88
- invalid_img,
89
- "⚠️ Image not recognised as a retinal fundus scan.\n\n"
90
- "No Diabetic Retinopathy assessment has been generated. "
91
- "Please upload a valid ophthalmic retinal fundus image."
92
- )
93
-
94
- # Normal pipeline for likely fundus images
95
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
96
  img = image.convert("RGB").resize((224, 224))
97
  img_tensor = transform(img).unsqueeze(0).to(device)
@@ -107,10 +55,7 @@ def predict_retinopathy(image):
107
  # Grad-CAM
108
  rgb_img_np = np.array(img).astype(np.float32) / 255.0
109
  rgb_img_np = np.ascontiguousarray(rgb_img_np)
110
- grayscale_cam = cam(
111
- input_tensor=img_tensor,
112
- targets=[ClassifierOutputTarget(pred)]
113
- )[0]
114
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
115
  cam_pil = Image.fromarray(cam_image)
116
 
@@ -120,15 +65,12 @@ def predict_retinopathy(image):
120
 
121
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
122
 
123
-
124
- # -----------------------
125
  # Gradio app
126
- # -----------------------
127
- demo = gr.Interface(
128
  fn=predict_retinopathy,
129
  inputs=gr.Image(type="pil", label="Upload Retinal Image"),
130
  outputs=[
131
- gr.Image(type="pil", label="Grad-CAM / Status"),
132
  gr.Text(label="Prediction")
133
  ],
134
  title="OpthaDetect – AI Retinal Screening",
@@ -138,10 +80,6 @@ demo = gr.Interface(
138
  ),
139
  article=(
140
  "βš•οΈ **OpthaDetect** is an AI-powered ophthalmic decision-support tool. "
141
- "It highlights retinal risk regions using Grad-CAM for better clinical interpretability. "
142
- "Outputs should always be reviewed alongside clinical judgement."
143
  )
144
- )
145
-
146
- if __name__ == "__main__":
147
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
11
  import os
12
  import datetime
13
 
 
14
  # Setup
 
15
  device = torch.device("cpu")
16
  save_dir = "/home/user/app/saved_predictions"
17
+ if not os.path.exists(save_dir):
18
+ os.makedirs(save_dir)
19
+ print("πŸ“ Folder created:", save_dir)
20
  os.makedirs(save_dir, exist_ok=True)
21
 
 
 
 
 
22
  # Load model
 
23
  model = models.resnet50(weights=None)
24
  model.fc = torch.nn.Linear(model.fc.in_features, 2)
25
  model.load_state_dict(torch.load("resnet50_dr_classifier.pth", map_location=device))
 
38
  [0.229, 0.224, 0.225])
39
  ])
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # Predict and save
 
42
  def predict_retinopathy(image):
 
 
 
 
 
 
 
 
 
 
 
 
43
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
44
  img = image.convert("RGB").resize((224, 224))
45
  img_tensor = transform(img).unsqueeze(0).to(device)
 
55
  # Grad-CAM
56
  rgb_img_np = np.array(img).astype(np.float32) / 255.0
57
  rgb_img_np = np.ascontiguousarray(rgb_img_np)
58
+ grayscale_cam = cam(input_tensor=img_tensor, targets=[ClassifierOutputTarget(pred)])[0]
 
 
 
59
  cam_image = show_cam_on_image(rgb_img_np, grayscale_cam, use_rgb=True)
60
  cam_pil = Image.fromarray(cam_image)
61
 
 
65
 
66
  return cam_pil, f"{label} (Confidence: {confidence:.2f})"
67
 
 
 
68
  # Gradio app
69
+ gr.Interface(
 
70
  fn=predict_retinopathy,
71
  inputs=gr.Image(type="pil", label="Upload Retinal Image"),
72
  outputs=[
73
+ gr.Image(type="pil", label="Grad-CAM Heatmap"),
74
  gr.Text(label="Prediction")
75
  ],
76
  title="OpthaDetect – AI Retinal Screening",
 
80
  ),
81
  article=(
82
  "βš•οΈ **OpthaDetect** is an AI-powered ophthalmic decision-support tool. "
83
+ "It highlights retinal risk regions using Grad-CAM for better clinical interpretability."
 
84
  )
85
+ ).launch()