itsanmolgupta's picture
Update app.py
d277480 verified
raw
history blame
2.47 kB
import streamlit as st
import numpy as np
import tensorflow as tf
from PIL import Image
import cv2 as cv
import io
# Load trained model
model = tf.keras.models.load_model('model.keras')
# Define the image preprocessing function
def preprocess_image(image):
# Convert to numpy array
image_array = np.array(image)
# Apply Gaussian Blur
image_array = cv.GaussianBlur(image_array, (9, 9), 0)
# Apply CLAHE
clahe = cv.createCLAHE(clipLimit=3, tileGridSize=(10, 10))
clahe_image = clahe.apply(image_array)
# Convert CLAHE image to RGB
clahe_image = cv.cvtColor(clahe_image, cv.COLOR_GRAY2RGB)
# Normalize image to [0, 1]
clahe_image = (clahe_image - clahe_image.min()) / (clahe_image.max() - clahe_image.min())
# Resize the image to 224x224
image_resized = cv.resize(clahe_image, (224, 224))
# Add batch dimension
image_array = np.expand_dims(image_resized, axis=0).astype(np.float32)
return image_array
# Define the class labels
class_labels = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Effusion',
'Emphysema', 'Fibrosis', 'Infiltration', 'Mass',
'Nodule', 'Pleural_Thickening', 'Pneumothorax']
# Streamlit app
st.title("Chest X-ray Classification")
# Upload image
uploaded_file = st.file_uploader("Upload a Chest X-ray image...", type=["jpg", "jpeg", "png"])
# Create two columns
col1, col2 = st.columns(2)
if uploaded_file is not None:
# Read and display the image
image = Image.open(uploaded_file)
with col1:
st.image(image, caption='Uploaded Image', use_column_width=True)
# Preprocess the image
preprocessed_image = preprocess_image(image)
# Make predictions
predictions = model.predict(preprocessed_image)[0]
# Get top 3 predictions with probability greater than 0.5
top_predictions = [(label, prob) for label, prob in zip(class_labels, predictions) if prob > 0.5]
top_predictions = sorted(top_predictions, key=lambda x: x[1], reverse=True)[:3]
with col2:
# Display results
if not top_predictions:
st.write("No diseases found with probability greater than 50%.")
else:
st.write("Predicted Disease(s):")
for label, prob in top_predictions:
st.write(f"{label}: {prob*100:.2f}%")
percentage = int(prob * 100)
st.progress(percentage)