|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import cv2 |
|
|
from tensorflow.keras.models import load_model |
|
|
from tensorflow.keras.layers import InputLayer |
|
|
|
|
|
|
|
|
custom_objects = {'InputLayer': InputLayer} |
|
|
|
|
|
|
|
|
def load_ocular_model(): |
|
|
return load_model('odir_cnn_model.h5', custom_objects=custom_objects) |
|
|
|
|
|
model = load_ocular_model() |
|
|
|
|
|
|
|
|
LABELS = ['Normal (N)', 'Diabetes (D)', 'Glaucoma (G)', 'Cataract (C)', |
|
|
'Age related Macular Degeneration (A)', 'Hypertension (H)', |
|
|
'Pathological Myopia (M)', 'Other diseases/abnormalities (O)'] |
|
|
|
|
|
|
|
|
def preprocess_image(image, img_size=128): |
|
|
img = cv2.resize(image, (img_size, img_size)) |
|
|
img = img / 255.0 |
|
|
img = np.expand_dims(img, axis=0) |
|
|
return img |
|
|
|
|
|
|
|
|
def predict_diseases(left_image, right_image): |
|
|
left_img = preprocess_image(left_image) |
|
|
right_img = preprocess_image(right_image) |
|
|
|
|
|
left_predictions = model.predict(left_img) |
|
|
right_predictions = model.predict(right_img) |
|
|
|
|
|
combined_predictions = (left_predictions + right_predictions) / 2 |
|
|
|
|
|
pred_labels = {label: float(pred) for label, pred in zip(LABELS, combined_predictions[0])} |
|
|
|
|
|
return pred_labels |
|
|
|
|
|
|
|
|
def predict_from_images(left_image, right_image): |
|
|
left_image = cv2.cvtColor(np.array(left_image), cv2.COLOR_RGB2BGR) |
|
|
right_image = cv2.cvtColor(np.array(right_image), cv2.COLOR_RGB2BGR) |
|
|
predictions = predict_diseases(left_image, right_image) |
|
|
|
|
|
return predictions |
|
|
|
|
|
|
|
|
input_images = [gr.Image(type="pil", label="Left Eye Image"), gr.Image(type="pil", label="Right Eye Image")] |
|
|
output_labels = gr.Label(num_top_classes=len(LABELS), label="Ocular Disease Predictions") |
|
|
|
|
|
gr_interface = gr.Interface(fn=predict_from_images, |
|
|
inputs=input_images, |
|
|
outputs=output_labels, |
|
|
title="Ocular Disease Prediction", |
|
|
description="Upload left and right eye images to predict ocular diseases.", |
|
|
live=False) |
|
|
|
|
|
|
|
|
gr_interface.launch() |
|
|
|