Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import numpy as np | |
| from PIL import Image | |
| from tensorflow.keras.models import load_model | |
| import requests | |
| from io import BytesIO | |
| # Load the pre-trained chest X-ray classification model | |
| try: | |
| model = load_model('chest_xray_model.h5') | |
| except Exception as e: | |
| st.error("Error loading model") | |
| st.stop() | |
| # Define the class names for prediction output | |
| class_names = ['Normal', 'Pneumonia'] | |
| # Prediction class for encapsulating the prediction logic | |
| class Prediction: | |
| def __init__(self, model): | |
| self.model = model | |
| def classify_image(self, image): | |
| try: | |
| image = Image.fromarray(image).convert('RGB') | |
| image = image.resize((512, 512)) | |
| image_array = np.array(image).astype(np.float32) / 255.0 | |
| image_array = np.expand_dims(image_array, axis=0) | |
| predictions = self.model.predict(image_array)[0] | |
| predicted_class_idx = np.argmax(predictions) | |
| predicted_class = class_names[predicted_class_idx] | |
| predicted_confidence = predictions[predicted_class_idx] * 100 | |
| return predicted_class, predicted_confidence, predictions | |
| except Exception as e: | |
| st.error("Error during classification") | |
| return None, None, None | |
| # Initialize the Prediction class | |
| predictor = Prediction(model) | |
| # Streamlit app layout | |
| st.title("📊 Chest X-Ray Classification") | |
| st.markdown(""" | |
| Upload one or more chest X-ray images or provide an image URL, and the model will classify each as either **Normal** or **Pneumonia**. | |
| """) | |
| # Input option selection | |
| input_option = st.radio("Choose how to upload the image(s):", ("Upload Image(s)", "Image URL")) | |
| # Initialize images list | |
| images = [] | |
| # Patient name input | |
| patient_name = st.text_input("### Patient Name") | |
| if input_option == "Upload Image(s)": | |
| uploaded_images = st.file_uploader("### Step 1: Upload Your Chest X-Ray Image(s)", type=["jpg", "jpeg", "png"], | |
| accept_multiple_files=True) | |
| if uploaded_images: | |
| for uploaded_image in uploaded_images: | |
| try: | |
| image = np.array(Image.open(uploaded_image)) | |
| images.append((image, uploaded_image.name)) | |
| except Exception as e: | |
| st.error("Error loading image") | |
| elif input_option == "Image URL": | |
| image_url = st.text_input("### Step 1: Enter the Image URL") | |
| if image_url: | |
| try: | |
| response = requests.get(image_url) | |
| if response.status_code == 200: | |
| images.append((np.array(Image.open(BytesIO(response.content))), image_url)) | |
| st.markdown(f"[Image URL]({image_url})") | |
| else: | |
| st.error("Error fetching image from URL: Unable to retrieve the image.") | |
| except Exception as e: | |
| st.error("Error fetching image from URL") | |
| # Store classification results | |
| results = [] | |
| if images: | |
| submit_button = st.button("Submit", key="submit") | |
| if submit_button: | |
| st.write("### Step 2: Review the Uploaded Image(s) and Results") | |
| for idx, (image, image_name) in enumerate(images): | |
| patient_display_name = f"[{patient_name}](#{image_name})" | |
| st.write(f"#### Patient: {patient_display_name}") | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.image(image, caption=image_name, use_column_width=True, clamp=True) | |
| with col2: | |
| st.subheader("Prediction Results") | |
| with st.spinner("Processing..."): | |
| predicted_class, predicted_confidence, predictions = predictor.classify_image(image) | |
| if predicted_class is not None: | |
| st.markdown(f""" | |
| <div style="border: 2px solid #2196F3; padding: 10px; border-radius: 5px;"> | |
| <p style="font-size: 16px; font-weight: bold;">Predicted Class: {predicted_class}</p> | |
| <p style="font-size: 16px;">Confidence: {predicted_confidence:.2f}%</p> | |
| <p style="font-size: 16px; font-weight: bold;">Class Confidence Levels:</p> | |
| <ul style="list-style-type: none; padding: 0;"> | |
| <li style="color: #4CAF50;">Normal: {predictions[0] * 100:.1f}%</li> | |
| <li style="color: #F44336;">Pneumonia: {predictions[1] * 100:.1f}%</li> | |
| </ul> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| results.append((idx + 1, patient_name, predicted_class, predicted_confidence, predictions)) | |
| st.markdown("---") | |
| # Button to manually start a new session | |
| if st.button("Start New Session"): | |
| images.clear() | |
| st.experimental_rerun() # Rerun the app to refresh the state | |
| st.write("### Additional Information") | |
| st.markdown(""" | |
| - This model is trained to differentiate between normal and pneumonia-affected chest X-rays. | |
| - Confidence levels are displayed as a percentage for each class. | |
| """) | |