arthurpendragon commited on
Commit
1ee2406
·
verified ·
1 Parent(s): 2a74a99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -42
app.py CHANGED
@@ -1,98 +1,104 @@
1
  import streamlit as st
2
  from PIL import Image, ImageOps
3
  import numpy as np
4
- import pandas as pd
5
  import matplotlib.pyplot as plt
6
- from keras import tensorflow as tf
 
7
 
8
- # Load the TensorFlow model
9
  model = tf.keras.models.load_model('gastrointestinal_model.h5', compile=False)
10
 
11
  # Load class names
12
- class_names = ['Normal', 'Ulcerative Colitis', 'Polyp', 'Esophagitis']
 
13
 
14
  # Function to create plot
15
- def create_plot(prediction, class_names):
16
- df = pd.DataFrame(prediction, index=class_names, columns=['Confidence'])
17
- df = df.sort_values(by='Confidence', ascending=False)
18
- plt.figure(figsize=(8, 5))
19
- plt.bar(df.index, df['Confidence'], color='blue')
20
- plt.xlabel('Class')
21
- plt.ylabel('Confidence Score')
22
- plt.title('Classification Confidence Scores')
23
- plt.xticks(rotation=45)
24
- plt.tight_layout()
25
- return plt
26
 
27
  # Function to predict gastrointestinal conditions
28
  def predict_gastrointestinal(img):
 
 
29
  size = (224, 224)
30
  image_PIL = ImageOps.fit(img, size, Image.LANCZOS)
31
  image_array = np.asarray(image_PIL)
32
  normalized_image_array = (image_array.astype(np.float32) / 127.0) - 1
33
- data = np.expand_dims(normalized_image_array, axis=0)
 
 
 
 
 
34
 
35
- prediction = model.predict(data)[0]
36
- class_index = np.argmax(prediction)
37
- predicted_class = class_names[class_index]
38
- confidence_scores = prediction * 100
 
 
 
39
 
40
- # Create plot
41
- plot = create_plot(confidence_scores, class_names)
 
 
42
 
43
- return predicted_class, plot
44
 
45
  # Streamlit app
46
  st.title("Gastrointestinal Classification Web App")
47
 
48
- uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "jpeg", "png"])
49
 
50
  if uploaded_file is not None:
51
  image = Image.open(uploaded_file)
52
  st.image(image, caption='Uploaded Image', use_column_width=True)
53
  st.write("Classifying...")
54
 
55
- predicted_class, plot = predict_gastrointestinal(image)
56
- st.write(f"Prediction: {predicted_class}")
57
  st.pyplot(plot)
58
 
59
  # Sample images
60
  st.markdown("### Sample Images")
61
  if st.button('Normal Sample'):
62
- image = Image.open('normal_sample.jpg')
63
  st.image(image, caption='Normal Sample Image', use_column_width=True)
64
  st.write("Classifying...")
65
- predicted_class, plot = predict_gastrointestinal(image)
66
- st.write(f"Prediction: {predicted_class}")
67
  st.pyplot(plot)
68
 
69
  if st.button('Ulcerative Colitis Sample'):
70
  image = Image.open('ulcerative_colitis_sample.jpg')
71
  st.image(image, caption='Ulcerative Colitis Sample Image', use_column_width=True)
72
  st.write("Classifying...")
73
- predicted_class, plot = predict_gastrointestinal(image)
74
- st.write(f"Prediction: {predicted_class}")
75
  st.pyplot(plot)
76
 
77
  if st.button('Polyp Sample'):
78
  image = Image.open('polyp_sample.jpg')
79
  st.image(image, caption='Polyp Sample Image', use_column_width=True)
80
  st.write("Classifying...")
81
- predicted_class, plot = predict_gastrointestinal(image)
82
- st.write(f"Prediction: {predicted_class}")
83
  st.pyplot(plot)
84
 
85
  if st.button('Esophagitis Sample'):
86
  image = Image.open('esophagitis_sample.jpg')
87
  st.image(image, caption='Esophagitis Sample Image', use_column_width=True)
88
  st.write("Classifying...")
89
- predicted_class, plot = predict_gastrointestinal(image)
90
- st.write(f"Prediction: {predicted_class}")
91
  st.pyplot(plot)
92
-
93
- # Educational content
94
- st.markdown("### Learn More About Gastrointestinal Conditions")
95
- st.markdown("""
96
- - [Gastrointestinal Disorders Overview](https://www.mayoclinic.org/diseases-conditions/gastrointestinal-disorders/symptoms-causes/syc-20375441)
97
- - [Preventing Gastrointestinal Conditions](https://www.niddk.nih.gov/health-information/digestive-diseases)
98
- """)
 
1
  import streamlit as st
2
  from PIL import Image, ImageOps
3
  import numpy as np
4
+ import seaborn as sns
5
  import matplotlib.pyplot as plt
6
+ import pandas as pd
7
+ import tensorflow as tf # Use TensorFlow's Keras API
8
 
9
+ # Load the TensorFlow Keras model
10
  model = tf.keras.models.load_model('gastrointestinal_model.h5', compile=False)
11
 
12
  # Load class names
13
+ with open('labels.txt', 'r') as f:
14
+ class_names = f.readlines()
15
 
16
  # Function to create plot
17
+ def create_plot(data):
18
+ sns.set_theme(style="whitegrid")
19
+ f, ax = plt.subplots(figsize=(5, 5))
20
+ sns.set_color_codes("pastel")
21
+ sns.barplot(x="Total", y="Labels", data=data, label="Total", color="b")
22
+ sns.set_color_codes("muted")
23
+ sns.barplot(x="Confidence Score", y="Labels", data=data, label="Confidence Score", color="b")
24
+ ax.legend(ncol=2, loc="lower right", frameon=True)
25
+ sns.despine(left=True, bottom=True)
26
+ return f
 
27
 
28
  # Function to predict gastrointestinal conditions
29
  def predict_gastrointestinal(img):
30
+ np.set_printoptions(suppress=True)
31
+ data = np.ndarray(shape=(1, 224, 224, 3), dtype=np.float32)
32
  size = (224, 224)
33
  image_PIL = ImageOps.fit(img, size, Image.LANCZOS)
34
  image_array = np.asarray(image_PIL)
35
  normalized_image_array = (image_array.astype(np.float32) / 127.0) - 1
36
+ data[0] = normalized_image_array
37
+ prediction = model.predict(data)
38
+ index = np.argmax(prediction)
39
+ class_name = class_names[index].strip()
40
+ confidence_score = prediction[0][index]
41
+ other_class = [name for i, name in enumerate(class_names) if i != index][0].strip()
42
 
43
+ result = {
44
+ "Labels": [class_name, other_class],
45
+ "Confidence Score": [confidence_score * 100, (1 - confidence_score) * 100],
46
+ "Total": 100
47
+ }
48
+ data_for_plot = pd.DataFrame.from_dict(result)
49
+ plot = create_plot(data_for_plot)
50
 
51
+ if class_name == "Normal":
52
+ prediction_text = "The image is classified as Normal."
53
+ else:
54
+ prediction_text = f"The image shows signs of {class_name}."
55
 
56
+ return prediction_text, plot
57
 
58
  # Streamlit app
59
  st.title("Gastrointestinal Classification Web App")
60
 
61
+ uploaded_file = st.file_uploader("Upload a gastrointestinal image...", type=["jpg", "jpeg", "png"])
62
 
63
  if uploaded_file is not None:
64
  image = Image.open(uploaded_file)
65
  st.image(image, caption='Uploaded Image', use_column_width=True)
66
  st.write("Classifying...")
67
 
68
+ prediction, plot = predict_gastrointestinal(image)
69
+ st.write(prediction)
70
  st.pyplot(plot)
71
 
72
  # Sample images
73
  st.markdown("### Sample Images")
74
  if st.button('Normal Sample'):
75
+ image = Image.open('normal_Sample.jpg')
76
  st.image(image, caption='Normal Sample Image', use_column_width=True)
77
  st.write("Classifying...")
78
+ prediction, plot = predict_gastrointestinal(image)
79
+ st.write(prediction)
80
  st.pyplot(plot)
81
 
82
  if st.button('Ulcerative Colitis Sample'):
83
  image = Image.open('ulcerative_colitis_sample.jpg')
84
  st.image(image, caption='Ulcerative Colitis Sample Image', use_column_width=True)
85
  st.write("Classifying...")
86
+ prediction, plot = predict_gastrointestinal(image)
87
+ st.write(prediction)
88
  st.pyplot(plot)
89
 
90
  if st.button('Polyp Sample'):
91
  image = Image.open('polyp_sample.jpg')
92
  st.image(image, caption='Polyp Sample Image', use_column_width=True)
93
  st.write("Classifying...")
94
+ prediction, plot = predict_gastrointestinal(image)
95
+ st.write(prediction)
96
  st.pyplot(plot)
97
 
98
  if st.button('Esophagitis Sample'):
99
  image = Image.open('esophagitis_sample.jpg')
100
  st.image(image, caption='Esophagitis Sample Image', use_column_width=True)
101
  st.write("Classifying...")
102
+ prediction, plot = predict_gastrointestinal(image)
103
+ st.write(prediction)
104
  st.pyplot(plot)