Prototype / pages /SVM-Model.py
Dokkone's picture
Update pages/SVM-Model.py
4b1b1d1 verified
import streamlit as st
import joblib
import cv2
import numpy as np
from PIL import Image, ImageOps
import warnings
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
# Suppress warnings
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=FutureWarning)
# Streamlit app header
st.title("Cervical Cancer Cell Classification - SVM Model")
st.sidebar.info("Feel free to select other models from the pages above 🙂")
st.sidebar.write("""
### Instructions 🔧
Make sure you have predicted Cytoplasm and Nuclei Mask from U-Net-Model before proceeding. Please upload Cell Image, Cytoplasm Mask, and Nuclei Mask for accurate and precise classification of Cervical Cancel Cell Type
""")
# Function to load the pre-trained model
def load_model():
model = joblib.load('pages/svm_model.pkl') # Replace with the correct path to your model
# Label encoder (should be the one used during training)
# If saved, load it or recreate it from class folders
class_folders = [
"Dyskeratotic",
"Koilocytotic",
"Metaplastic",
"Parabasal",
"Superficial-Intermediate"
]
label_encoder = LabelEncoder()
label_encoder.fit(class_folders)
return model, label_encoder
# Load model with spinner
with st.spinner('Loading model...'):
svm_model, label_encoder = load_model()
# Function to process uploaded images
def process_uploaded_image(uploaded_file):
image = Image.open(uploaded_file)
image = ImageOps.grayscale(image) # Convert to grayscale
image = np.array(image)
return image
# Upload images through Streamlit
st.subheader("Upload Images")
cell_image_file = st.file_uploader("Upload Cell Image", type=['png', 'jpg', 'jpeg'])
cytoplasm_mask_file = st.file_uploader("Upload Cytoplasm Mask", type=['png', 'jpg', 'jpeg'])
nuclei_mask_file = st.file_uploader("Upload Nuclei Mask", type=['png', 'jpg', 'jpeg'])
# Check if all images are uploaded
if cell_image_file and cytoplasm_mask_file and nuclei_mask_file:
# Display the uploaded images
st.subheader("Uploaded Images")
col1, col2, col3 = st.columns(3)
with col1:
st.image(cell_image_file, caption="Cell Image", use_column_width=True)
with col2:
st.image(cytoplasm_mask_file, caption="Cytoplasm Mask", use_column_width=True)
with col3:
st.image(nuclei_mask_file, caption="Nuclei Mask", use_column_width=True)
# Convert the uploaded files to numpy arrays
cell_image = process_uploaded_image(cell_image_file)
cytoplasm_mask = process_uploaded_image(cytoplasm_mask_file)
nuclei_mask = process_uploaded_image(nuclei_mask_file)
# Function to make predictions
def predict_cell_type(svm_model, label_encoder, cell_image, cytoplasm_mask, nuclei_mask):
cytoplasm_region = cell_image * (cytoplasm_mask / 255)
nuclei_region = cell_image * (nuclei_mask / 255)
concatenated_image = np.dstack((cytoplasm_region, nuclei_region))
flattened_image = concatenated_image.flatten().reshape(1, -1)
predicted_class = svm_model.predict(flattened_image)
predicted_label = label_encoder.inverse_transform(predicted_class)[0]
return predicted_label # Return the predicted label
# Make prediction
predicted_label = predict_cell_type(svm_model, label_encoder, cell_image, cytoplasm_mask, nuclei_mask)
# Display prediction results
st.subheader("Prediction Results")
st.write(f"**Predicted Cell Type:** {predicted_label}")
# Function to plot the uploaded image with the prediction result
def plot_uploaded_image(cell_image, cytoplasm_mask, nuclei_mask, predicted_label):
color_coded_image = np.zeros((cell_image.shape[0], cell_image.shape[1], 3), dtype=np.uint8)
color_coded_image[:, :, 1] = cytoplasm_mask # Green for cytoplasm
color_coded_image[:, :, 0] = nuclei_mask # Red for nuclei
plt.imshow(color_coded_image)
plt.title(f"Predicted: {predicted_label}")
plt.axis('off')
st.pyplot(plt)
# Plot the results
plot_uploaded_image(cell_image, cytoplasm_mask, nuclei_mask, predicted_label)
else:
st.write("Please upload all required images.")