Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| from tensorflow.keras.models import load_model | |
| from tensorflow.keras.preprocessing.image import img_to_array | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| from matplotlib.colors import ListedColormap | |
| # Load the trained model | |
| def load_trained_model(): | |
| model_path = "segnet_model.keras" | |
| return load_model(model_path) | |
| def predict_segmentation(model, image, target_size=(256, 256)): | |
| image = image.resize(target_size) | |
| image_array = img_to_array(image) / 255.0 | |
| image_array = np.expand_dims(image_array, axis=0) | |
| prediction = model.predict(image_array)[0] | |
| return prediction | |
| def create_mask_plot(mask, colormap, labels): | |
| fig, ax = plt.subplots(figsize=(5, 5)) | |
| ax.imshow(mask.squeeze(), cmap=colormap, vmin=0, vmax=len(labels) - 1) | |
| ax.axis("off") | |
| # Add a legend | |
| legend_patches = [ | |
| plt.Line2D([0], [0], color=colormap.colors[i], lw=4, label=label) | |
| for i, label in enumerate(labels) | |
| ] | |
| ax.legend(handles=legend_patches, loc="upper right", bbox_to_anchor=(1.2, 1.0)) | |
| # Convert the Matplotlib figure to a PIL Image | |
| fig.canvas.draw() | |
| image = np.array(fig.canvas.renderer.buffer_rgba()) | |
| plt.close(fig) | |
| return Image.fromarray(image) | |
| # Streamlit App | |
| def main(): | |
| st.title("Flood Area Segmentation") | |
| st.write("Upload an image to predict its segmentation mask.") | |
| # Load the model (cached) | |
| model = load_trained_model() | |
| # File uploader | |
| uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "png", "jpeg"]) | |
| if uploaded_file is not None: | |
| # Load and display the uploaded image | |
| image = Image.open(uploaded_file) | |
| # Predict segmentation | |
| with st.spinner("Predicting..."): | |
| predicted_mask = predict_segmentation(model, image) | |
| # Define custom colormap and labels | |
| colormap = ListedColormap(["green", "blue"]) # Green: Non-Flooded, Blue: Flooded | |
| labels = ["Non-Flooded Area", "Flooded Area"] | |
| # Create the mask visualization | |
| mask_image = create_mask_plot(predicted_mask, colormap, labels) | |
| # Display results side by side | |
| st.subheader("Results") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.write("### Original Image") | |
| st.image(image, caption="Original Image", use_container_width=True) | |
| with col2: | |
| st.write("### Predicted Mask") | |
| st.image(mask_image, caption="Predicted Mask", use_container_width=True) | |
| if __name__ == "__main__": | |
| main() | |