suhanii23 commited on
Commit
4369d2b
·
verified ·
1 Parent(s): b17a76a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -14
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import gradio as gr
2
  import numpy as np
3
  import cv2
 
 
4
  from tensorflow.keras.models import load_model
5
  from PIL import Image
6
 
@@ -16,9 +18,22 @@ CLASS_DESCRIPTIONS = {
16
  'Proliferative': 'Proliferative diabetic retinopathy. Advanced stage with new abnormal blood vessel growth.'
17
  }
18
 
 
 
 
 
 
 
 
 
 
 
 
19
  # Load model
20
- model = load_model('diabetic_retinopathy_full_model.h5')
 
21
 
 
22
  def crop_image_from_gray(img, tol=7):
23
  """Ben Graham's preprocessing: crop black borders"""
24
  if img.ndim == 2:
@@ -27,13 +42,13 @@ def crop_image_from_gray(img, tol=7):
27
  elif img.ndim == 3:
28
  gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
29
  mask = gray_img > tol
30
- check_shape = img[:,:,0][np.ix_(mask.any(1), mask.any(0))].shape[0]
31
  if check_shape == 0:
32
  return img
33
  else:
34
- img1 = img[:,:,0][np.ix_(mask.any(1), mask.any(0))]
35
- img2 = img[:,:,1][np.ix_(mask.any(1), mask.any(0))]
36
- img3 = img[:,:,2][np.ix_(mask.any(1), mask.any(0))]
37
  img = np.stack([img1, img2, img3], axis=-1)
38
  return img
39
 
@@ -45,11 +60,12 @@ def preprocess_image(image, sigmaX=10):
45
  image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
46
  image = crop_image_from_gray(image)
47
  image = cv2.resize(image, (IMAGE_WIDTH, IMAGE_HEIGHT))
48
- image = cv2.addWeighted(image, 4, cv2.GaussianBlur(image, (0,0), sigmaX), -4, 128)
49
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
50
 
51
  return image
52
 
 
53
  def predict_dr(image):
54
  """Main prediction function"""
55
  try:
@@ -60,11 +76,7 @@ def predict_dr(image):
60
  binary_pred = (predictions > 0.5).astype(int)
61
  final_class = binary_pred.sum(axis=1)[0] - 1
62
 
63
- confidences = {
64
- CLASSES[i]: float(predictions[0][i])
65
- for i in range(len(CLASSES))
66
- }
67
-
68
  result_class = CLASSES[final_class]
69
  description = CLASS_DESCRIPTIONS[result_class]
70
 
@@ -77,9 +89,10 @@ def predict_dr(image):
77
  except Exception as e:
78
  return None, f"Error: {str(e)}", {}
79
 
 
80
  with gr.Blocks(title="Diabetic Retinopathy Detector") as demo:
81
  gr.Markdown("""
82
- # 🏥 Diabetic Retinopathy Detection
83
  Upload a retinal fundus image to detect diabetic retinopathy severity.
84
 
85
  **Classes:** Normal → Mild → Moderate → Severe → Proliferative
@@ -102,9 +115,9 @@ with gr.Blocks(title="Diabetic Retinopathy Detector") as demo:
102
  )
103
 
104
  gr.Markdown("""
105
- ### ⚠️ Medical Disclaimer
106
  This tool is for educational purposes only. Always consult a qualified healthcare provider.
107
  """)
108
 
109
  if __name__ == "__main__":
110
- demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
3
  import cv2
4
+ import os
5
+ import requests
6
  from tensorflow.keras.models import load_model
7
  from PIL import Image
8
 
 
18
  'Proliferative': 'Proliferative diabetic retinopathy. Advanced stage with new abnormal blood vessel growth.'
19
  }
20
 
21
+ # --- 🔽 Ensure model is available locally ---
22
+ MODEL_PATH = "diabetic_retinopathy_full_model.h5"
23
+ MODEL_URL = "https://github.com/suhanii-23/retinopathy-detector/releases/download/v1.0-model/diabetic_retinopathy_full_model.h5"
24
+
25
+ if not os.path.exists(MODEL_PATH):
26
+ print("Downloading model from GitHub release...")
27
+ r = requests.get(MODEL_URL, allow_redirects=True)
28
+ with open(MODEL_PATH, "wb") as f:
29
+ f.write(r.content)
30
+ print("✅ Model downloaded successfully!")
31
+
32
  # Load model
33
+ model = load_model(MODEL_PATH)
34
+ print("✅ Model loaded successfully!")
35
 
36
+ # --- Image Preprocessing ---
37
  def crop_image_from_gray(img, tol=7):
38
  """Ben Graham's preprocessing: crop black borders"""
39
  if img.ndim == 2:
 
42
  elif img.ndim == 3:
43
  gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
44
  mask = gray_img > tol
45
+ check_shape = img[:, :, 0][np.ix_(mask.any(1), mask.any(0))].shape[0]
46
  if check_shape == 0:
47
  return img
48
  else:
49
+ img1 = img[:, :, 0][np.ix_(mask.any(1), mask.any(0))]
50
+ img2 = img[:, :, 1][np.ix_(mask.any(1), mask.any(0))]
51
+ img3 = img[:, :, 2][np.ix_(mask.any(1), mask.any(0))]
52
  img = np.stack([img1, img2, img3], axis=-1)
53
  return img
54
 
 
60
  image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
61
  image = crop_image_from_gray(image)
62
  image = cv2.resize(image, (IMAGE_WIDTH, IMAGE_HEIGHT))
63
+ image = cv2.addWeighted(image, 4, cv2.GaussianBlur(image, (0, 0), sigmaX), -4, 128)
64
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
65
 
66
  return image
67
 
68
+ # --- Prediction ---
69
  def predict_dr(image):
70
  """Main prediction function"""
71
  try:
 
76
  binary_pred = (predictions > 0.5).astype(int)
77
  final_class = binary_pred.sum(axis=1)[0] - 1
78
 
79
+ confidences = {CLASSES[i]: float(predictions[0][i]) for i in range(len(CLASSES))}
 
 
 
 
80
  result_class = CLASSES[final_class]
81
  description = CLASS_DESCRIPTIONS[result_class]
82
 
 
89
  except Exception as e:
90
  return None, f"Error: {str(e)}", {}
91
 
92
+ # --- Gradio UI ---
93
  with gr.Blocks(title="Diabetic Retinopathy Detector") as demo:
94
  gr.Markdown("""
95
+ # 🏥 Diabetic Retinopathy Detection
96
  Upload a retinal fundus image to detect diabetic retinopathy severity.
97
 
98
  **Classes:** Normal → Mild → Moderate → Severe → Proliferative
 
115
  )
116
 
117
  gr.Markdown("""
118
+ ### ⚠️ Medical Disclaimer
119
  This tool is for educational purposes only. Always consult a qualified healthcare provider.
120
  """)
121
 
122
  if __name__ == "__main__":
123
+ demo.launch()