File size: 4,135 Bytes
cf28cda
c5f86e6
7011182
e269365
 
 
 
 
 
 
25aa730
 
 
 
 
 
56c0430
25aa730
e269365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1e6f72
 
 
 
 
e269365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1e6f72
e269365
 
 
 
 
a2897bd
e269365
 
 
 
f1e6f72
4c5ed91
e269365
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import streamlit as st
import numpy as np
import cv2
from tensorflow.keras.models import load_model
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
from PIL import Image
import io

# Streamlit app header
st.title("Cervical Cancer Cell Classification - CNN Model")

st.sidebar.info("Feel free to select other models from the pages above 🙂")

st.sidebar.write("""
### Instructions 🔧
Make sure you have predicted Cytoplasm and Nuclei Mask from U-Net-Model before proceeding. Please upload Cell Image, Cytoplasm Mask, and Nuclei Mask for accurate and precise classification of Cervical Cancel Cell Type
""")

# Function to process uploaded images
def process_uploaded_image(uploaded_file):
    # Convert uploaded file to an image
    image = Image.open(uploaded_file)
    image = image.convert('L')  # Convert to grayscale
    image = np.array(image)
    return image

# Upload images through Streamlit
st.subheader("Upload Images")
uploaded_cell_image = st.file_uploader("Upload Cell Image", type=['png', 'jpg', 'jpeg'])
uploaded_cytoplasm_mask = st.file_uploader("Upload Cytoplasm Mask", type=['png', 'jpg', 'jpeg'])
uploaded_nuclei_mask = st.file_uploader("Upload Nuclei Mask", type=['png', 'jpg', 'jpeg'])

# Check if all images are uploaded
if uploaded_cell_image and uploaded_cytoplasm_mask and uploaded_nuclei_mask:
    # Display the uploaded images
    st.subheader("Uploaded Images")
    col1, col2, col3 = st.columns(3)
    with col1:
        st.image(uploaded_cell_image, caption="Cell Image", use_column_width=True)
    with col2:
        st.image(uploaded_cytoplasm_mask, caption="Cytoplasm Mask", use_column_width=True)
    with col3:
        st.image(uploaded_nuclei_mask, caption="Nuclei Mask", use_column_width=True)
    
    # Convert the uploaded files to numpy arrays
    cell_image = process_uploaded_image(uploaded_cell_image)
    cytoplasm_mask = process_uploaded_image(uploaded_cytoplasm_mask)
    nuclei_mask = process_uploaded_image(uploaded_nuclei_mask)

    # Load your pre-trained CNN model
    # Replace the path with your actual saved model
    cnn_model = load_model('pages/cnn_model.h5')

    # Label encoder (should be the one used during training)
    # If saved, load it or recreate it from class folders
    class_folders = [
        "Dyskeratotic",
        "Koilocytotic",
        "Metaplastic",
        "Parabasal",
        "Superficial-Intermediate"
    ]
    
    label_encoder = LabelEncoder()
    label_encoder.fit(class_folders)

    # Apply masks to the cell image
    cytoplasm_region = cell_image * (cytoplasm_mask / 255)
    nuclei_region = cell_image * (nuclei_mask / 255)

    # Concatenate cytoplasm and nuclei regions as a 2-channel input
    concatenated_image = np.dstack((cytoplasm_region, nuclei_region))
    
    # Resize to match the input shape of the CNN model (128x128)
    concatenated_image_resized = cv2.resize(concatenated_image, (128, 128))
    concatenated_image_resized = concatenated_image_resized.reshape(1, 128, 128, 2)  # Add batch dimension
    
    # Make a prediction
    predicted_probs = cnn_model.predict(concatenated_image_resized)
    predicted_class = np.argmax(predicted_probs, axis=1)[0]

    # Convert predicted class back to label
    predicted_label = label_encoder.inverse_transform([predicted_class])[0]

    # Display the prediction result
    st.subheader("Prediction Results")
    st.write(f"**Predicted Cell Type:** {predicted_label}")

    # Plot the uploaded image with masks
    def plot_uploaded_image(cell_image, cytoplasm_mask, nuclei_mask, predicted_label):
        color_coded_image = np.zeros((cell_image.shape[0], cell_image.shape[1], 3), dtype=np.uint8)
        color_coded_image[:, :, 1] = cytoplasm_mask  # Green for cytoplasm
        color_coded_image[:, :, 0] = nuclei_mask     # Red for nuclei

        plt.imshow(color_coded_image)
        plt.title(f"Predicted: {predicted_label}")
        plt.axis('off')
        st.pyplot(plt)

    # Plot the results
    plot_uploaded_image(cell_image, cytoplasm_mask, nuclei_mask, predicted_label)
else:
    st.write("Please upload all required images.")