Prototype / pages /CNN-Model.py
Dokkone's picture
Update pages/CNN-Model.py
a2897bd verified
import streamlit as st
import numpy as np
import cv2
from tensorflow.keras.models import load_model
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
from PIL import Image
import io
# Streamlit app header
st.title("Cervical Cancer Cell Classification - CNN Model")
st.sidebar.info("Feel free to select other models from the pages above 🙂")
st.sidebar.write("""
### Instructions 🔧
Make sure you have predicted Cytoplasm and Nuclei Mask from U-Net-Model before proceeding. Please upload Cell Image, Cytoplasm Mask, and Nuclei Mask for accurate and precise classification of Cervical Cancel Cell Type
""")
# Function to process uploaded images
def process_uploaded_image(uploaded_file):
# Convert uploaded file to an image
image = Image.open(uploaded_file)
image = image.convert('L') # Convert to grayscale
image = np.array(image)
return image
# Upload images through Streamlit
st.subheader("Upload Images")
uploaded_cell_image = st.file_uploader("Upload Cell Image", type=['png', 'jpg', 'jpeg'])
uploaded_cytoplasm_mask = st.file_uploader("Upload Cytoplasm Mask", type=['png', 'jpg', 'jpeg'])
uploaded_nuclei_mask = st.file_uploader("Upload Nuclei Mask", type=['png', 'jpg', 'jpeg'])
# Check if all images are uploaded
if uploaded_cell_image and uploaded_cytoplasm_mask and uploaded_nuclei_mask:
# Display the uploaded images
st.subheader("Uploaded Images")
col1, col2, col3 = st.columns(3)
with col1:
st.image(uploaded_cell_image, caption="Cell Image", use_column_width=True)
with col2:
st.image(uploaded_cytoplasm_mask, caption="Cytoplasm Mask", use_column_width=True)
with col3:
st.image(uploaded_nuclei_mask, caption="Nuclei Mask", use_column_width=True)
# Convert the uploaded files to numpy arrays
cell_image = process_uploaded_image(uploaded_cell_image)
cytoplasm_mask = process_uploaded_image(uploaded_cytoplasm_mask)
nuclei_mask = process_uploaded_image(uploaded_nuclei_mask)
# Load your pre-trained CNN model
# Replace the path with your actual saved model
cnn_model = load_model('pages/cnn_model.h5')
# Label encoder (should be the one used during training)
# If saved, load it or recreate it from class folders
class_folders = [
"Dyskeratotic",
"Koilocytotic",
"Metaplastic",
"Parabasal",
"Superficial-Intermediate"
]
label_encoder = LabelEncoder()
label_encoder.fit(class_folders)
# Apply masks to the cell image
cytoplasm_region = cell_image * (cytoplasm_mask / 255)
nuclei_region = cell_image * (nuclei_mask / 255)
# Concatenate cytoplasm and nuclei regions as a 2-channel input
concatenated_image = np.dstack((cytoplasm_region, nuclei_region))
# Resize to match the input shape of the CNN model (128x128)
concatenated_image_resized = cv2.resize(concatenated_image, (128, 128))
concatenated_image_resized = concatenated_image_resized.reshape(1, 128, 128, 2) # Add batch dimension
# Make a prediction
predicted_probs = cnn_model.predict(concatenated_image_resized)
predicted_class = np.argmax(predicted_probs, axis=1)[0]
# Convert predicted class back to label
predicted_label = label_encoder.inverse_transform([predicted_class])[0]
# Display the prediction result
st.subheader("Prediction Results")
st.write(f"**Predicted Cell Type:** {predicted_label}")
# Plot the uploaded image with masks
def plot_uploaded_image(cell_image, cytoplasm_mask, nuclei_mask, predicted_label):
color_coded_image = np.zeros((cell_image.shape[0], cell_image.shape[1], 3), dtype=np.uint8)
color_coded_image[:, :, 1] = cytoplasm_mask # Green for cytoplasm
color_coded_image[:, :, 0] = nuclei_mask # Red for nuclei
plt.imshow(color_coded_image)
plt.title(f"Predicted: {predicted_label}")
plt.axis('off')
st.pyplot(plt)
# Plot the results
plot_uploaded_image(cell_image, cytoplasm_mask, nuclei_mask, predicted_label)
else:
st.write("Please upload all required images.")