Create U-Net-Model.py
Browse files- pages/U-Net-Model.py +118 -0
pages/U-Net-Model.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import tensorflow as tf
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from tensorflow.keras.preprocessing.image import img_to_array
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import io
|
| 9 |
+
import zipfile
|
| 10 |
+
|
| 11 |
+
# Load multiple trained U-Net models
|
| 12 |
+
model_1 = tf.keras.models.load_model("pages/dyskera_unet_model.keras")
|
| 13 |
+
model_2 = tf.keras.models.load_model("pages/koilo_unet_model.keras")
|
| 14 |
+
model_3 = tf.keras.models.load_model("pages/meta_unet_model.keras")
|
| 15 |
+
model_4 = tf.keras.models.load_model("pages/para_unet_model.keras")
|
| 16 |
+
model_5 = tf.keras.models.load_model("pages/SI_unet_model.keras")
|
| 17 |
+
# Add more models as needed
|
| 18 |
+
|
| 19 |
+
# Function to preprocess the uploaded image
|
| 20 |
+
def preprocess_image(uploaded_image, target_size=(128, 128)):
|
| 21 |
+
image = Image.open(uploaded_image).convert('L') # Convert to grayscale
|
| 22 |
+
image = np.array(image) # Convert to numpy array
|
| 23 |
+
image = cv2.resize(image, target_size) # Resize the image
|
| 24 |
+
image = image.astype('float32') / 255.0 # Normalize
|
| 25 |
+
image = np.expand_dims(image, axis=-1) # Add channel dimension
|
| 26 |
+
image = np.expand_dims(image, axis=0) # Add batch dimension
|
| 27 |
+
return image
|
| 28 |
+
|
| 29 |
+
# Function to predict masks using the selected model
|
| 30 |
+
def predict_masks(image, model):
|
| 31 |
+
predictions = model.predict(image)
|
| 32 |
+
return predictions[0]
|
| 33 |
+
|
| 34 |
+
# Function to save masks as downloadable images
|
| 35 |
+
def save_image(mask, filename):
|
| 36 |
+
mask = (mask * 255).astype(np.uint8)
|
| 37 |
+
im = Image.fromarray(mask)
|
| 38 |
+
buf = io.BytesIO()
|
| 39 |
+
im.save(buf, format="PNG")
|
| 40 |
+
buf.seek(0)
|
| 41 |
+
return buf
|
| 42 |
+
|
| 43 |
+
# Streamlit Interface
|
| 44 |
+
st.title("U-Net Cell Segmentation with Multiple Models")
|
| 45 |
+
|
| 46 |
+
st.sidebar.info("Feel free to select other models from the pages above 🙂")
|
| 47 |
+
|
| 48 |
+
st.sidebar.write("""
|
| 49 |
+
### Instructions 🔧
|
| 50 |
+
Please upload Cell Image to get Predicted Cytoplasm Mask and Nuclei Mask. After getting and downloading Predicted Cytoplasm and Nuclei Mask image, head over to CNN or SVM model to get accurate prediction of Cancer Cell Type.
|
| 51 |
+
""")
|
| 52 |
+
|
| 53 |
+
# Upload an image
|
| 54 |
+
uploaded_image = st.file_uploader("Choose an image...", type=["bmp", "png", "jpg", "jpeg"])
|
| 55 |
+
|
| 56 |
+
# Model selection dropdown
|
| 57 |
+
model_option = st.selectbox(
|
| 58 |
+
"Select a Model:",
|
| 59 |
+
("U-Net Dyskeratotic", "U-Net Koilocyctotic", "U-Net Metaplastic", "U-Net Parabasal", "U-Net Superficial-Intermediate")
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Choose the corresponding model based on selection
|
| 63 |
+
if model_option == "U-Net Dyskeratotic":
|
| 64 |
+
selected_model = model_1
|
| 65 |
+
elif model_option == "U-Net Koilocyctotic":
|
| 66 |
+
selected_model = model_2
|
| 67 |
+
elif model_option == "U-Net Metaplastic":
|
| 68 |
+
selected_model = model_3
|
| 69 |
+
elif model_option == "U-Net Parabasal":
|
| 70 |
+
selected_model = model_4
|
| 71 |
+
elif model_option == "U-Net Superficial-Intermediate":
|
| 72 |
+
selected_model = model_5
|
| 73 |
+
|
| 74 |
+
if uploaded_image is not None:
|
| 75 |
+
st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
|
| 76 |
+
|
| 77 |
+
# Preprocess the image
|
| 78 |
+
image = preprocess_image(uploaded_image)
|
| 79 |
+
|
| 80 |
+
# Predict the masks using the selected model
|
| 81 |
+
predictions = predict_masks(image, selected_model)
|
| 82 |
+
|
| 83 |
+
# Display results using Matplotlib
|
| 84 |
+
fig, ax = plt.subplots(1, 3, figsize=(12, 8))
|
| 85 |
+
ax[0].imshow(image.squeeze(), cmap='gray')
|
| 86 |
+
ax[0].set_title('Original Image')
|
| 87 |
+
ax[1].imshow(predictions[:, :, 0], cmap='gray')
|
| 88 |
+
ax[1].set_title('Predicted Cytoplasm Mask')
|
| 89 |
+
ax[2].imshow(predictions[:, :, 1], cmap='gray')
|
| 90 |
+
ax[2].set_title('Predicted Nuclei Mask')
|
| 91 |
+
st.pyplot(fig)
|
| 92 |
+
|
| 93 |
+
# Convert the predicted masks to images for display
|
| 94 |
+
cytoplasm_mask_img = (predictions[:, :, 0] * 255).astype(np.uint8)
|
| 95 |
+
nuclei_mask_img = (predictions[:, :, 1] * 255).astype(np.uint8)
|
| 96 |
+
|
| 97 |
+
st.image(cytoplasm_mask_img, caption="Predicted Cytoplasm Mask", use_column_width=True)
|
| 98 |
+
st.image(nuclei_mask_img, caption="Predicted Nuclei Mask", use_column_width=True)
|
| 99 |
+
|
| 100 |
+
# Save the predicted masks for download
|
| 101 |
+
cytoplasm_mask_download = save_image(predictions[:, :, 0], "predicted_cytoplasm_mask.png")
|
| 102 |
+
nuclei_mask_download = save_image(predictions[:, :, 1], "predicted_nuclei_mask.png")
|
| 103 |
+
|
| 104 |
+
# Combine the two mask files into a zip
|
| 105 |
+
zip_buffer = io.BytesIO()
|
| 106 |
+
with zipfile.ZipFile(zip_buffer, "w") as zip_file:
|
| 107 |
+
zip_file.writestr("predicted_cytoplasm_mask.png", cytoplasm_mask_download.getvalue())
|
| 108 |
+
zip_file.writestr("predicted_nuclei_mask.png", nuclei_mask_download.getvalue())
|
| 109 |
+
|
| 110 |
+
zip_buffer.seek(0)
|
| 111 |
+
|
| 112 |
+
# Provide a download button for the zipped masks
|
| 113 |
+
st.download_button(
|
| 114 |
+
label="Download Predicted Cytoplasm and Nuclei Masks",
|
| 115 |
+
data=zip_buffer,
|
| 116 |
+
file_name="predicted_masks.zip",
|
| 117 |
+
mime="application/zip"
|
| 118 |
+
)
|