Sefat33 commited on
Commit
02efb1f
Β·
verified Β·
1 Parent(s): 2d33367

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -31
app.py CHANGED
@@ -1,41 +1,40 @@
1
  import os
2
- os.environ["TF_USE_LEGACY_KERAS"] = "1"
3
- os.environ["KERAS_BACKEND"] = "tensorflow"
4
- os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
5
-
6
- import gdown
7
- import cv2
8
  import numpy as np
 
9
  import tensorflow as tf
10
  import streamlit as st
 
11
  from PIL import Image
12
- from keras.models import load_model
13
  from keras_cv_attention_models.coatnet import CoAtNet0
14
 
15
- # ------------------ Settings ------------------
 
 
 
 
 
 
 
16
  IMG_SIZE = (224, 224)
17
  CLASS_NAMES = ['Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', 'Myopia', 'Others']
18
 
19
- # ------------------ Load Model ------------------
20
- @st.cache_resource
21
- def load_coatnet_model():
22
  model_path = "model.keras"
23
-
24
  if not os.path.exists(model_path):
25
  st.info("πŸ“₯ Downloading model from Google Drive...")
26
  url = "https://drive.google.com/uc?id=1Gm2O77uWSUnajL0iFlFJtVk_UEN_wrTN"
27
  gdown.download(url, model_path, quiet=False, fuzzy=True)
28
-
29
- if os.path.getsize(model_path) < 1_000_000:
30
- raise ValueError("❌ Model file is too small. Download may have failed.")
31
-
32
  try:
33
- model = load_model(
34
  model_path,
35
  compile=False,
36
  custom_objects={
 
37
  "CoAtNet0": CoAtNet0,
38
- "Functional": tf.keras.models.Model,
39
  "gelu": tf.keras.activations.gelu
40
  }
41
  )
@@ -44,9 +43,6 @@ def load_coatnet_model():
44
  st.error(f"❌ Failed to load model: {e}")
45
  raise
46
 
47
- model = load_coatnet_model()
48
-
49
- # ------------------ Image Preprocessing ------------------
50
  def crop_circle(img):
51
  h, w = img.shape[:2]
52
  center = (w // 2, h // 2)
@@ -55,20 +51,20 @@ def crop_circle(img):
55
  dist = np.sqrt((X - center[0])**2 + (Y - center[1])**2)
56
  mask = dist <= radius
57
  if img.ndim == 3:
58
- mask = np.stack([mask] * 3, axis=-1)
59
  img[~mask] = 0
60
  return img
61
 
62
  def apply_clahe(img):
63
  lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
64
  l, a, b = cv2.split(lab)
65
- clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
66
  cl = clahe.apply(l)
67
  merged = cv2.merge((cl, a, b))
68
  return cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
69
 
70
  def sharpen_image(img, sigma=10):
71
- blur = cv2.GaussianBlur(img, (0, 0), sigma)
72
  return cv2.addWeighted(img, 4, blur, -4, 128)
73
 
74
  def resize_normalize(img):
@@ -83,19 +79,19 @@ def preprocess_image(img):
83
  img = resize_normalize(img)
84
  return img
85
 
86
- # ------------------ Streamlit UI ------------------
87
  st.set_page_config(page_title="🧠 Retina Disease Classifier", layout="centered")
88
-
89
  st.title("🧠 Retina Disease Classifier")
90
- st.markdown("Upload a retinal image to detect possible diseases using a CoAtNet model.")
 
 
91
 
92
- uploaded_file = st.file_uploader("πŸ“€ Upload Retinal Image", type=["jpg", "jpeg", "png"])
93
 
94
  if uploaded_file is not None:
95
  file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
96
  bgr_img = cv2.imdecode(file_bytes, 1)
97
  rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
98
-
99
  st.image(rgb_img, caption="Original Image", use_column_width=True)
100
 
101
  preprocessed = preprocess_image(rgb_img)
@@ -109,5 +105,5 @@ if uploaded_file is not None:
109
  st.success(f"βœ… **Prediction:** `{pred_label}`")
110
  st.info(f"πŸ” Confidence: **{confidence:.2f}%**")
111
 
112
- st.subheader("πŸ§ͺ Preprocessed Input")
113
- st.image((preprocessed * 255).astype(np.uint8), caption="Model Input", use_column_width=True)
 
1
  import os
 
 
 
 
 
 
2
  import numpy as np
3
+ import cv2
4
  import tensorflow as tf
5
  import streamlit as st
6
+ import gdown
7
  from PIL import Image
8
+ from keras.layers import BatchNormalization as KBatchNormalization
9
  from keras_cv_attention_models.coatnet import CoAtNet0
10
 
11
+ # Patch BatchNormalization to fix axis deserialization issue
12
+ class PatchedBatchNormalization(KBatchNormalization):
13
+ @classmethod
14
+ def from_config(cls, config):
15
+ if isinstance(config.get("axis"), list):
16
+ config["axis"] = config["axis"][0]
17
+ return super().from_config(config)
18
+
19
  IMG_SIZE = (224, 224)
20
  CLASS_NAMES = ['Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', 'Myopia', 'Others']
21
 
22
+ @st.cache_resource(show_spinner=True)
23
+ def load_model():
 
24
  model_path = "model.keras"
 
25
  if not os.path.exists(model_path):
26
  st.info("πŸ“₯ Downloading model from Google Drive...")
27
  url = "https://drive.google.com/uc?id=1Gm2O77uWSUnajL0iFlFJtVk_UEN_wrTN"
28
  gdown.download(url, model_path, quiet=False, fuzzy=True)
29
+ if os.path.getsize(model_path) < 1_000_000:
30
+ raise ValueError("❌ Downloaded model is too small. Download might have failed!")
 
 
31
  try:
32
+ model = tf.keras.models.load_model(
33
  model_path,
34
  compile=False,
35
  custom_objects={
36
+ "BatchNormalization": PatchedBatchNormalization,
37
  "CoAtNet0": CoAtNet0,
 
38
  "gelu": tf.keras.activations.gelu
39
  }
40
  )
 
43
  st.error(f"❌ Failed to load model: {e}")
44
  raise
45
 
 
 
 
46
  def crop_circle(img):
47
  h, w = img.shape[:2]
48
  center = (w // 2, h // 2)
 
51
  dist = np.sqrt((X - center[0])**2 + (Y - center[1])**2)
52
  mask = dist <= radius
53
  if img.ndim == 3:
54
+ mask = np.stack([mask]*3, axis=-1)
55
  img[~mask] = 0
56
  return img
57
 
58
  def apply_clahe(img):
59
  lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
60
  l, a, b = cv2.split(lab)
61
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
62
  cl = clahe.apply(l)
63
  merged = cv2.merge((cl, a, b))
64
  return cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
65
 
66
  def sharpen_image(img, sigma=10):
67
+ blur = cv2.GaussianBlur(img, (0,0), sigma)
68
  return cv2.addWeighted(img, 4, blur, -4, 128)
69
 
70
  def resize_normalize(img):
 
79
  img = resize_normalize(img)
80
  return img
81
 
82
+ # Streamlit UI
83
  st.set_page_config(page_title="🧠 Retina Disease Classifier", layout="centered")
 
84
  st.title("🧠 Retina Disease Classifier")
85
+ st.markdown("Upload a retinal image and get the predicted disease class using the CoAtNet model.")
86
+
87
+ model = load_model()
88
 
89
+ uploaded_file = st.file_uploader("πŸ“€ Upload Image", type=["jpg", "jpeg", "png"])
90
 
91
  if uploaded_file is not None:
92
  file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
93
  bgr_img = cv2.imdecode(file_bytes, 1)
94
  rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
 
95
  st.image(rgb_img, caption="Original Image", use_column_width=True)
96
 
97
  preprocessed = preprocess_image(rgb_img)
 
105
  st.success(f"βœ… **Prediction:** `{pred_label}`")
106
  st.info(f"πŸ” Confidence: **{confidence:.2f}%**")
107
 
108
+ st.subheader("πŸ§ͺ Preprocessed Input to Model")
109
+ st.image((preprocessed * 255).astype(np.uint8), caption="Preprocessed Image", use_column_width=True)