marahim20-2's picture
Updated UI
f52b2fb verified
import tensorflow as tf
from PIL import Image
import streamlit as st
import numpy as np
import io
import pandas as pd
from lime import lime_image
import time
# Define your image size
IMG_SIZE = 256
# Load the model
def load_model():
model = tf.keras.models.load_model("final_gluacoma2.h5", compile=False)
# Compile the model if necessary
# model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
return model
# Preprocess image
def preprocess_image(image):
image = image.resize((256, 256))
image_array = np.asarray(image)
normalized_image_array = (image_array.astype(np.float32) / 127.5) - 1
data = np.expand_dims(normalized_image_array, axis=0)
return data
# Define the predict function
def predict(model, img):
img = img.resize((IMG_SIZE, IMG_SIZE)) # Resize the image
img_array = tf.keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0)
predictions = model.predict(img_array)
class_labels = ["normal", "cataract", "retina disease", "glaucoma"]
predicted_class = class_labels[np.argmax(predictions[0])]
confidence = round(100 * (np.max(predictions[0])), 2)
return predicted_class, confidence
# Explain image
def explain_image(image, model):
explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(image[0], model.predict, top_labels=5, hide_color=0, num_samples=1000)
# Get explanation image
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=False)
# Convert image array to uint8 format
temp = (temp * 255).astype(np.uint8)
return temp
# Main function
def main():
st.title("DL based Glaucoma Image Classifier")
# Sidebar for file uploader
st.sidebar.title("Upload Image")
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Display uploaded image
image = Image.open(io.BytesIO(uploaded_file.read()))
# st.image(image, caption="Uploaded Image", use_column_width=True)
# Load model
model = load_model()
# Predict button
predict_button = st.sidebar.button("Predict", key="predict_button")
st.sidebar.write(
"""<style>
div[data-testid="stSidebar"][aria-expanded="true"] button {width: 100%;}
</style>""", unsafe_allow_html=True
)
if predict_button:
# Display processing message with spinner
with st.spinner(" Please wait... Processing the image and predicting..."):
# Preprocess image
processed_image = preprocess_image(image)
# Classify image
predicted_class, confidence_score = predict(model, image)
# Explain image classification
explanation_image = explain_image(processed_image, model)
# Display images side by side
col1, col2 = st.columns(2)
with col1:
st.image(image, caption="Uploaded Image", use_column_width=True)
with col2:
st.image(explanation_image, caption="Explanation Image", use_column_width=True)
# Display prediction
st.subheader("Prediction")
# Create a table for prediction results
prediction_table = pd.DataFrame({
"Predicted Class": [predicted_class],
"Confidence": [f"{confidence_score}%"]
})
st.table(prediction_table)
if __name__ == "__main__":
main()