GMED / app.py
Vinit710's picture
Update app.py
d5400ea verified
import gradio as gr
import numpy as np
import cv2
from tensorflow.keras.models import load_model
from tensorflow.keras.layers import InputLayer
# Custom object scope to handle InputLayer configuration
custom_objects = {'InputLayer': InputLayer}
# Load model
def load_ocular_model():
return load_model('odir_cnn_model.h5', custom_objects=custom_objects)
model = load_ocular_model()
# Define the labels
LABELS = ['Normal (N)', 'Diabetes (D)', 'Glaucoma (G)', 'Cataract (C)',
'Age related Macular Degeneration (A)', 'Hypertension (H)',
'Pathological Myopia (M)', 'Other diseases/abnormalities (O)']
# Preprocess the image
def preprocess_image(image, img_size=128): # Change img_size to 128
img = cv2.resize(image, (img_size, img_size))
img = img / 255.0 # Normalize
img = np.expand_dims(img, axis=0) # Add batch dimension
return img
# Predict diseases based on left and right images
def predict_diseases(left_image, right_image):
left_img = preprocess_image(left_image)
right_img = preprocess_image(right_image)
left_predictions = model.predict(left_img)
right_predictions = model.predict(right_img)
combined_predictions = (left_predictions + right_predictions) / 2
pred_labels = {label: float(pred) for label, pred in zip(LABELS, combined_predictions[0])}
return pred_labels
# Define the Gradio interface
def predict_from_images(left_image, right_image):
left_image = cv2.cvtColor(np.array(left_image), cv2.COLOR_RGB2BGR) # Convert Gradio's image to OpenCV format
right_image = cv2.cvtColor(np.array(right_image), cv2.COLOR_RGB2BGR) # Convert Gradio's image to OpenCV format
predictions = predict_diseases(left_image, right_image)
return predictions
# Gradio interface
input_images = [gr.Image(type="pil", label="Left Eye Image"), gr.Image(type="pil", label="Right Eye Image")]
output_labels = gr.Label(num_top_classes=len(LABELS), label="Ocular Disease Predictions")
gr_interface = gr.Interface(fn=predict_from_images,
inputs=input_images,
outputs=output_labels,
title="Ocular Disease Prediction",
description="Upload left and right eye images to predict ocular diseases.",
live=False)
# Launch the Gradio app
gr_interface.launch()