Sefat33 commited on
Commit
4573cc5
·
verified ·
1 Parent(s): 06627e4

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +111 -0
  2. requirements.txt +12 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##subprocess.run(["pip", "install", "--force-reinstall", "protobuf==3.19.6"])
2
+ import os
3
+ os.environ["TF_USE_LEGACY_KERAS"] = "1"
4
+
5
+ import os
6
+ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
7
+
8
+ import os
9
+ import gdown
10
+ import cv2
11
+ import numpy as np
12
+ import tensorflow as tf
13
+ import streamlit as st
14
+ from PIL import Image
15
+
16
+ IMG_SIZE = (224, 224)
17
+ NUM_CLASSES = 8
18
+
19
+ from keras_cv_attention_models.coatnet import CoAtNet0
20
+
21
+ @st.cache_resource
22
+ def load_model():
23
+ import keras
24
+ model_path = "model.keras"
25
+ saved_model_path = "model_saved"
26
+ if not os.path.exists(model_path):
27
+ st.info("📥 Downloading model from Google Drive...")
28
+ url = "https://drive.google.com/uc?id=1Gm2O77uWSUnajL0iFlFJtVk_UEN_wrTN"
29
+ gdown.download(url, model_path, quiet=False, fuzzy=True)
30
+ if os.path.getsize(model_path) < 1_000_000:
31
+ raise ValueError("❌ Downloaded model is too small. Likely failed or incomplete download!")
32
+ if not os.path.exists(saved_model_path):
33
+ try:
34
+ model = keras.models.load_model(
35
+ model_path,
36
+ compile=False,
37
+ custom_objects={
38
+ "CoAtNet0": CoAtNet0,
39
+ "Functional": tf.keras.models.Model,
40
+ "gelu": tf.keras.activations.gelu
41
+ }
42
+ )
43
+ model.save(saved_model_path, save_format="tf")
44
+ st.success("✅ Converted `.keras` model to SavedModel format.")
45
+ except Exception as e:
46
+ st.error(f"❌ Failed to convert .keras model: {e}")
47
+ raise
48
+ return tf.keras.models.load_model(saved_model_path, compile=False)
49
+
50
+ model = load_model()
51
+
52
+ def crop_circle(img):
53
+ h, w = img.shape[:2]
54
+ center = (w // 2, h // 2)
55
+ radius = min(center[0], center[1])
56
+ Y, X = np.ogrid[:h, :w]
57
+ dist = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)
58
+ mask = dist <= radius
59
+ if img.ndim == 3:
60
+ mask = np.stack([mask] * 3, axis=-1)
61
+ img[~mask] = 0
62
+ return img
63
+
64
+ def apply_clahe(img):
65
+ lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
66
+ l, a, b = cv2.split(lab)
67
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
68
+ cl = clahe.apply(l)
69
+ merged = cv2.merge((cl, a, b))
70
+ return cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
71
+
72
+ def sharpen_image(img, sigma=10):
73
+ blur = cv2.GaussianBlur(img, (0, 0), sigma)
74
+ return cv2.addWeighted(img, 4, blur, -4, 128)
75
+
76
+ def resize_normalize(img):
77
+ img = cv2.resize(img, IMG_SIZE)
78
+ img = img / 255.0
79
+ return img
80
+
81
+ def preprocess_image(img):
82
+ img = crop_circle(img)
83
+ img = apply_clahe(img)
84
+ img = sharpen_image(img)
85
+ img = resize_normalize(img)
86
+ return img
87
+
88
+ CLASS_NAMES = ['Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', 'Myopia', 'Others']
89
+
90
+ st.set_page_config(page_title="🧠 Retina Disease Classifier", layout="centered")
91
+
92
+ st.title("🧠 Retina Disease Classifier")
93
+ st.markdown("Upload a retinal image and get predicted disease class using CoAtNet model.")
94
+
95
+ uploaded_file = st.file_uploader("📤 Upload Image", type=["jpg", "jpeg", "png"])
96
+
97
+ if uploaded_file is not None:
98
+ file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
99
+ bgr_img = cv2.imdecode(file_bytes, 1)
100
+ rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
101
+ st.image(rgb_img, caption="Original Image", use_column_width=True)
102
+ preprocessed = preprocess_image(rgb_img)
103
+ input_tensor = np.expand_dims(preprocessed, axis=0)
104
+ preds = model.predict(input_tensor)
105
+ pred_idx = np.argmax(preds)
106
+ pred_label = CLASS_NAMES[pred_idx]
107
+ confidence = np.max(preds) * 100
108
+ st.success(f"✅ **Prediction:** `{pred_label}`")
109
+ st.info(f"🔍 Confidence: **{confidence:.2f}%**")
110
+ st.subheader("🧪 Preprocessed Input")
111
+ st.image((preprocessed * 255).astype(np.uint8), caption="Model Input", use_column_width=True)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tensorflow-cpu>=2.13
2
+ keras>=3.0.0
3
+ protobuf>=3.20,<5
4
+ streamlit==1.46.1
5
+ gdown
6
+ opencv-python-headless
7
+ Pillow
8
+ scikit-image
9
+ lime
10
+ matplotlib
11
+ numpy==1.23.5
12
+ keras-cv-attention-models