import tensorflow as tf from PIL import Image import streamlit as st import numpy as np import io import pandas as pd from lime import lime_image import time # Define your image size IMG_SIZE = 256 # Load the model def load_model(): model = tf.keras.models.load_model("final_gluacoma2.h5", compile=False) # Compile the model if necessary # model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) return model # Preprocess image def preprocess_image(image): image = image.resize((256, 256)) image_array = np.asarray(image) normalized_image_array = (image_array.astype(np.float32) / 127.5) - 1 data = np.expand_dims(normalized_image_array, axis=0) return data # Define the predict function def predict(model, img): img = img.resize((IMG_SIZE, IMG_SIZE)) # Resize the image img_array = tf.keras.preprocessing.image.img_to_array(img) img_array = tf.expand_dims(img_array, 0) predictions = model.predict(img_array) class_labels = ["normal", "cataract", "retina disease", "glaucoma"] predicted_class = class_labels[np.argmax(predictions[0])] confidence = round(100 * (np.max(predictions[0])), 2) return predicted_class, confidence # Explain image def explain_image(image, model): explainer = lime_image.LimeImageExplainer() explanation = explainer.explain_instance(image[0], model.predict, top_labels=5, hide_color=0, num_samples=1000) # Get explanation image temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=False) # Convert image array to uint8 format temp = (temp * 255).astype(np.uint8) return temp # Main function def main(): st.title("DL based Glaucoma Image Classifier") # Sidebar for file uploader st.sidebar.title("Upload Image") uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: # Display uploaded image image = Image.open(io.BytesIO(uploaded_file.read())) # st.image(image, caption="Uploaded Image", use_column_width=True) # Load model model = load_model() # Predict button predict_button = st.sidebar.button("Predict", key="predict_button") st.sidebar.write( """""", unsafe_allow_html=True ) if predict_button: # Display processing message with spinner with st.spinner(" Please wait... Processing the image and predicting..."): # Preprocess image processed_image = preprocess_image(image) # Classify image predicted_class, confidence_score = predict(model, image) # Explain image classification explanation_image = explain_image(processed_image, model) # Display images side by side col1, col2 = st.columns(2) with col1: st.image(image, caption="Uploaded Image", use_column_width=True) with col2: st.image(explanation_image, caption="Explanation Image", use_column_width=True) # Display prediction st.subheader("Prediction") # Create a table for prediction results prediction_table = pd.DataFrame({ "Predicted Class": [predicted_class], "Confidence": [f"{confidence_score}%"] }) st.table(prediction_table) if __name__ == "__main__": main()