import gradio as gr import numpy as np import cv2 from tensorflow.keras.models import load_model from tensorflow.keras.layers import InputLayer # Custom object scope to handle InputLayer configuration custom_objects = {'InputLayer': InputLayer} # Load model def load_ocular_model(): return load_model('odir_cnn_model.h5', custom_objects=custom_objects) model = load_ocular_model() # Define the labels LABELS = ['Normal (N)', 'Diabetes (D)', 'Glaucoma (G)', 'Cataract (C)', 'Age related Macular Degeneration (A)', 'Hypertension (H)', 'Pathological Myopia (M)', 'Other diseases/abnormalities (O)'] # Preprocess the image def preprocess_image(image, img_size=128): # Change img_size to 128 img = cv2.resize(image, (img_size, img_size)) img = img / 255.0 # Normalize img = np.expand_dims(img, axis=0) # Add batch dimension return img # Predict diseases based on left and right images def predict_diseases(left_image, right_image): left_img = preprocess_image(left_image) right_img = preprocess_image(right_image) left_predictions = model.predict(left_img) right_predictions = model.predict(right_img) combined_predictions = (left_predictions + right_predictions) / 2 pred_labels = {label: float(pred) for label, pred in zip(LABELS, combined_predictions[0])} return pred_labels # Define the Gradio interface def predict_from_images(left_image, right_image): left_image = cv2.cvtColor(np.array(left_image), cv2.COLOR_RGB2BGR) # Convert Gradio's image to OpenCV format right_image = cv2.cvtColor(np.array(right_image), cv2.COLOR_RGB2BGR) # Convert Gradio's image to OpenCV format predictions = predict_diseases(left_image, right_image) return predictions # Gradio interface input_images = [gr.Image(type="pil", label="Left Eye Image"), gr.Image(type="pil", label="Right Eye Image")] output_labels = gr.Label(num_top_classes=len(LABELS), label="Ocular Disease Predictions") gr_interface = gr.Interface(fn=predict_from_images, inputs=input_images, outputs=output_labels, title="Ocular Disease Prediction", description="Upload left and right eye images to predict ocular diseases.", live=False) # Launch the Gradio app gr_interface.launch()