Vinit710 commited on
Commit
18f663d
·
verified ·
1 Parent(s): b447859

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -0
app.py CHANGED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+ from tensorflow.keras.models import load_model
5
+
6
+ # Custom object scope to handle InputLayer configuration
7
+ from tensorflow.keras.layers import InputLayer
8
+ custom_objects = {'InputLayer': InputLayer}
9
+
10
+ # Load the model
11
+ def load_ocular_model():
12
+ return load_model('odir_cnn_model.h5', custom_objects=custom_objects)
13
+
14
+ model = load_ocular_model()
15
+
16
+ # Define the labels
17
+ LABELS = ['Normal (N)', 'Diabetes (D)', 'Glaucoma (G)', 'Cataract (C)',
18
+ 'Age related Macular Degeneration (A)', 'Hypertension (H)',
19
+ 'Pathological Myopia (M)', 'Other diseases/abnormalities (O)']
20
+
21
+ # Preprocess the image
22
+ def preprocess_image(image, img_size=128):
23
+ img = cv2.resize(image, (img_size, img_size))
24
+ img = img / 255.0 # Normalize
25
+ img = np.expand_dims(img, axis=0) # Add batch dimension
26
+ return img
27
+
28
+ # Prediction function
29
+ def predict_diseases(left_image, right_image):
30
+ left_img = preprocess_image(left_image)
31
+ right_img = preprocess_image(right_image)
32
+
33
+ left_predictions = model.predict(left_img)
34
+ right_predictions = model.predict(right_img)
35
+
36
+ combined_predictions = (left_predictions + right_predictions) / 2
37
+
38
+ pred_labels = {label: float(pred) for label, pred in zip(LABELS, combined_predictions[0])}
39
+
40
+ return pred_labels
41
+
42
+ # Gradio interface
43
+ iface = gr.Interface(
44
+ fn=predict_diseases,
45
+ inputs=[gr.inputs.Image(type="numpy"), gr.inputs.Image(type="numpy")],
46
+ outputs="json",
47
+ title="Ocular Disease Prediction",
48
+ description="Upload left and right eye images to predict ocular diseases."
49
+ )
50
+
51
+ # Launch the interface
52
+ iface.launch()