import streamlit as st import numpy as np import cv2 import tensorflow as tf import matplotlib.pyplot as plt from tensorflow.keras.preprocessing.image import img_to_array from PIL import Image import io import zipfile # Load multiple trained U-Net models model_1 = tf.keras.models.load_model("pages/dyskera_unet_model.keras") model_2 = tf.keras.models.load_model("pages/koilo_unet_model.keras") model_3 = tf.keras.models.load_model("pages/meta_unet_model.keras") model_4 = tf.keras.models.load_model("pages/para_unet_model.keras") model_5 = tf.keras.models.load_model("pages/SI_unet_model.keras") # Add more models as needed # Function to preprocess the uploaded image def preprocess_image(uploaded_image, target_size=(128, 128)): image = Image.open(uploaded_image).convert('L') # Convert to grayscale image = np.array(image) # Convert to numpy array image = cv2.resize(image, target_size) # Resize the image image = image.astype('float32') / 255.0 # Normalize image = np.expand_dims(image, axis=-1) # Add channel dimension image = np.expand_dims(image, axis=0) # Add batch dimension return image # Function to predict masks using the selected model def predict_masks(image, model): predictions = model.predict(image) return predictions[0] # Function to save masks as downloadable images def save_image(mask, filename): mask = (mask * 255).astype(np.uint8) im = Image.fromarray(mask) buf = io.BytesIO() im.save(buf, format="PNG") buf.seek(0) return buf # Streamlit Interface st.title("U-Net Cell Segmentation with Multiple Models") st.sidebar.info("Feel free to select other models from the pages above 🙂") st.sidebar.write(""" ### Instructions 🔧 Please upload Cell Image and choose appropriate U-net Model from Drop-down menu to get Predicted Cytoplasm Mask and Nuclei Mask. After getting and downloading Predicted Cytoplasm and Nuclei Mask image, head over to CNN or SVM model to get accurate prediction of Cancer Cell Type. """) # Upload an image uploaded_image = st.file_uploader("Choose an image...", type=["bmp", "png", "jpg", "jpeg"]) # Model selection dropdown model_option = st.selectbox( "Select a Model:", ("U-Net Dyskeratotic", "U-Net Koilocyctotic", "U-Net Metaplastic", "U-Net Parabasal", "U-Net Superficial-Intermediate") ) # Choose the corresponding model based on selection if model_option == "U-Net Dyskeratotic": selected_model = model_1 elif model_option == "U-Net Koilocyctotic": selected_model = model_2 elif model_option == "U-Net Metaplastic": selected_model = model_3 elif model_option == "U-Net Parabasal": selected_model = model_4 elif model_option == "U-Net Superficial-Intermediate": selected_model = model_5 if uploaded_image is not None: st.image(uploaded_image, caption="Uploaded Image", use_column_width=True) # Preprocess the image image = preprocess_image(uploaded_image) # Predict the masks using the selected model predictions = predict_masks(image, selected_model) # Display results using Matplotlib fig, ax = plt.subplots(1, 3, figsize=(12, 8)) ax[0].imshow(image.squeeze(), cmap='gray') ax[0].set_title('Original Image') ax[1].imshow(predictions[:, :, 0], cmap='gray') ax[1].set_title('Predicted Cytoplasm Mask') ax[2].imshow(predictions[:, :, 1], cmap='gray') ax[2].set_title('Predicted Nuclei Mask') st.pyplot(fig) # Convert the predicted masks to images for display cytoplasm_mask_img = (predictions[:, :, 0] * 255).astype(np.uint8) nuclei_mask_img = (predictions[:, :, 1] * 255).astype(np.uint8) st.image(cytoplasm_mask_img, caption="Predicted Cytoplasm Mask", use_column_width=True) st.image(nuclei_mask_img, caption="Predicted Nuclei Mask", use_column_width=True) # Save the predicted masks for download cytoplasm_mask_download = save_image(predictions[:, :, 0], "predicted_cytoplasm_mask.png") nuclei_mask_download = save_image(predictions[:, :, 1], "predicted_nuclei_mask.png") # Combine the two mask files into a zip zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, "w") as zip_file: zip_file.writestr("predicted_cytoplasm_mask.png", cytoplasm_mask_download.getvalue()) zip_file.writestr("predicted_nuclei_mask.png", nuclei_mask_download.getvalue()) zip_buffer.seek(0) # Provide a download button for the zipped masks st.download_button( label="Download Predicted Cytoplasm and Nuclei Masks", data=zip_buffer, file_name="predicted_masks.zip", mime="application/zip" )