Mishab commited on
Commit
cf85da5
·
verified ·
1 Parent(s): 024e71a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -38
app.py CHANGED
@@ -5,21 +5,18 @@ import numpy as np
5
  import io
6
  import pandas as pd
7
  from lime import lime_image
 
 
 
 
8
 
9
  # Load the model
10
  def load_model():
11
  model = tf.keras.models.load_model("custom_model_final.h5", compile=False)
12
- # Explicitly compile the model with the required loss function
13
- loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
14
- model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
15
  return model
16
 
17
- # Load the labels
18
- def load_labels():
19
- with open("labels.txt", "r") as file:
20
- class_names = [line.strip() for line in file.readlines()]
21
- return class_names
22
-
23
  # Preprocess image
24
  def preprocess_image(image):
25
  image = image.resize((256, 256))
@@ -28,12 +25,18 @@ def preprocess_image(image):
28
  data = np.expand_dims(normalized_image_array, axis=0)
29
  return data
30
 
31
- # Classify image
32
- def classify_image(model, image, class_names):
33
- prediction = model.predict(image)
34
- predicted_class = class_names[np.argmax(prediction[0])]
35
- confidence_score = round(100 * np.max(prediction[0]), 2)
36
- return predicted_class, confidence_score
 
 
 
 
 
 
37
 
38
  # Explain image
39
  def explain_image(image, model):
@@ -56,15 +59,14 @@ def main():
56
  st.sidebar.title("Upload Image")
57
  uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
58
 
59
-
60
  if uploaded_file is not None:
61
- # Load model and labels
62
  model = load_model()
63
- class_names = load_labels()
64
 
65
  # Display uploaded image
66
  image = Image.open(io.BytesIO(uploaded_file.read()))
67
  st.image(image, caption="Uploaded Image", use_column_width=True)
 
68
 
69
  # Predict button
70
  predict_button = st.sidebar.button("Predict", key="predict_button")
@@ -74,26 +76,26 @@ def main():
74
  </style>""", unsafe_allow_html=True
75
  )
76
  if predict_button:
77
- # Preprocess image
78
- processed_image = preprocess_image(image)
79
-
80
- # Classify image
81
- predicted_class, confidence_score = classify_image(model, processed_image, class_names)
82
-
83
- # Explain image classification
84
- explanation_image = explain_image(processed_image, model)
85
-
86
- # Display explanation image
87
- st.image(explanation_image, caption="Explanation Image", use_column_width=True)
88
 
89
- # Display prediction
90
- st.subheader("Prediction:")
91
- # Create a table for prediction results
92
- prediction_table = pd.DataFrame({
93
- "Predicted Class": [predicted_class],
94
- "Confidence": [f"{confidence_score}%"]
95
- })
96
- st.table(prediction_table)
 
 
 
 
 
 
97
 
98
  if __name__ == "__main__":
99
- main()
 
5
  import io
6
  import pandas as pd
7
  from lime import lime_image
8
+ import time
9
+
10
+ # Define your image size
11
+ IMG_SIZE = 256
12
 
13
  # Load the model
14
  def load_model():
15
  model = tf.keras.models.load_model("custom_model_final.h5", compile=False)
16
+ # Compile the model if necessary
17
+ # model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
 
18
  return model
19
 
 
 
 
 
 
 
20
  # Preprocess image
21
  def preprocess_image(image):
22
  image = image.resize((256, 256))
 
25
  data = np.expand_dims(normalized_image_array, axis=0)
26
  return data
27
 
28
+ # Define the predict function
29
+ def predict(model, img):
30
+ img = img.resize((IMG_SIZE, IMG_SIZE)) # Resize the image
31
+ img_array = tf.keras.preprocessing.image.img_to_array(img)
32
+ img_array = tf.expand_dims(img_array, 0)
33
+
34
+ predictions = model.predict(img_array)
35
+
36
+ class_labels = ["normal", "cataract", "retina disease", "glaucoma"]
37
+ predicted_class = class_labels[np.argmax(predictions[0])]
38
+ confidence = round(100 * (np.max(predictions[0])), 2)
39
+ return predicted_class, confidence
40
 
41
  # Explain image
42
  def explain_image(image, model):
 
59
  st.sidebar.title("Upload Image")
60
  uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
61
 
 
62
  if uploaded_file is not None:
63
+ # Load model
64
  model = load_model()
 
65
 
66
  # Display uploaded image
67
  image = Image.open(io.BytesIO(uploaded_file.read()))
68
  st.image(image, caption="Uploaded Image", use_column_width=True)
69
+ processed_image = preprocess_image(image)
70
 
71
  # Predict button
72
  predict_button = st.sidebar.button("Predict", key="predict_button")
 
76
  </style>""", unsafe_allow_html=True
77
  )
78
  if predict_button:
79
+ # Display processing message with spinner
80
+ with st.spinner("Please wait...Processing the image and predicting..."):
81
+
82
+ # Classify image
83
+ predicted_class, confidence_score = predict(model, image)
 
 
 
 
 
 
84
 
85
+ # Explain image classification
86
+ explanation_image = explain_image(processed_image, model)
87
+
88
+ # Display explanation image
89
+ st.image(explanation_image, caption="Explanation Image", use_column_width=True)
90
+
91
+ # Display prediction
92
+ st.subheader("Prediction:")
93
+ # Create a table for prediction results
94
+ prediction_table = pd.DataFrame({
95
+ "Predicted Class": [predicted_class],
96
+ "Confidence": [f"{confidence_score}%"]
97
+ })
98
+ st.table(prediction_table)
99
 
100
  if __name__ == "__main__":
101
+ main()