Prototype / pages /U-Net-Model.py
Dokkone's picture
Update pages/U-Net-Model.py
0e28d84 verified
import streamlit as st
import numpy as np
import cv2
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import img_to_array
from PIL import Image
import io
import zipfile
# Load multiple trained U-Net models
model_1 = tf.keras.models.load_model("pages/dyskera_unet_model.keras")
model_2 = tf.keras.models.load_model("pages/koilo_unet_model.keras")
model_3 = tf.keras.models.load_model("pages/meta_unet_model.keras")
model_4 = tf.keras.models.load_model("pages/para_unet_model.keras")
model_5 = tf.keras.models.load_model("pages/SI_unet_model.keras")
# Add more models as needed
# Function to preprocess the uploaded image
def preprocess_image(uploaded_image, target_size=(128, 128)):
image = Image.open(uploaded_image).convert('L') # Convert to grayscale
image = np.array(image) # Convert to numpy array
image = cv2.resize(image, target_size) # Resize the image
image = image.astype('float32') / 255.0 # Normalize
image = np.expand_dims(image, axis=-1) # Add channel dimension
image = np.expand_dims(image, axis=0) # Add batch dimension
return image
# Function to predict masks using the selected model
def predict_masks(image, model):
predictions = model.predict(image)
return predictions[0]
# Function to save masks as downloadable images
def save_image(mask, filename):
mask = (mask * 255).astype(np.uint8)
im = Image.fromarray(mask)
buf = io.BytesIO()
im.save(buf, format="PNG")
buf.seek(0)
return buf
# Streamlit Interface
st.title("U-Net Cell Segmentation with Multiple Models")
st.sidebar.info("Feel free to select other models from the pages above 🙂")
st.sidebar.write("""
### Instructions 🔧
Please upload Cell Image and choose appropriate U-net Model from Drop-down menu to get Predicted Cytoplasm Mask and Nuclei Mask. After getting and downloading Predicted Cytoplasm and Nuclei Mask image, head over to CNN or SVM model to get accurate prediction of Cancer Cell Type.
""")
# Upload an image
uploaded_image = st.file_uploader("Choose an image...", type=["bmp", "png", "jpg", "jpeg"])
# Model selection dropdown
model_option = st.selectbox(
"Select a Model:",
("U-Net Dyskeratotic", "U-Net Koilocyctotic", "U-Net Metaplastic", "U-Net Parabasal", "U-Net Superficial-Intermediate")
)
# Choose the corresponding model based on selection
if model_option == "U-Net Dyskeratotic":
selected_model = model_1
elif model_option == "U-Net Koilocyctotic":
selected_model = model_2
elif model_option == "U-Net Metaplastic":
selected_model = model_3
elif model_option == "U-Net Parabasal":
selected_model = model_4
elif model_option == "U-Net Superficial-Intermediate":
selected_model = model_5
if uploaded_image is not None:
st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
# Preprocess the image
image = preprocess_image(uploaded_image)
# Predict the masks using the selected model
predictions = predict_masks(image, selected_model)
# Display results using Matplotlib
fig, ax = plt.subplots(1, 3, figsize=(12, 8))
ax[0].imshow(image.squeeze(), cmap='gray')
ax[0].set_title('Original Image')
ax[1].imshow(predictions[:, :, 0], cmap='gray')
ax[1].set_title('Predicted Cytoplasm Mask')
ax[2].imshow(predictions[:, :, 1], cmap='gray')
ax[2].set_title('Predicted Nuclei Mask')
st.pyplot(fig)
# Convert the predicted masks to images for display
cytoplasm_mask_img = (predictions[:, :, 0] * 255).astype(np.uint8)
nuclei_mask_img = (predictions[:, :, 1] * 255).astype(np.uint8)
st.image(cytoplasm_mask_img, caption="Predicted Cytoplasm Mask", use_column_width=True)
st.image(nuclei_mask_img, caption="Predicted Nuclei Mask", use_column_width=True)
# Save the predicted masks for download
cytoplasm_mask_download = save_image(predictions[:, :, 0], "predicted_cytoplasm_mask.png")
nuclei_mask_download = save_image(predictions[:, :, 1], "predicted_nuclei_mask.png")
# Combine the two mask files into a zip
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w") as zip_file:
zip_file.writestr("predicted_cytoplasm_mask.png", cytoplasm_mask_download.getvalue())
zip_file.writestr("predicted_nuclei_mask.png", nuclei_mask_download.getvalue())
zip_buffer.seek(0)
# Provide a download button for the zipped masks
st.download_button(
label="Download Predicted Cytoplasm and Nuclei Masks",
data=zip_buffer,
file_name="predicted_masks.zip",
mime="application/zip"
)