Dokkone commited on
Commit
c3e587d
·
verified ·
1 Parent(s): 23eb173

Create U-Net-Model.py

Browse files
Files changed (1) hide show
  1. pages/U-Net-Model.py +118 -0
pages/U-Net-Model.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import cv2
4
+ import tensorflow as tf
5
+ import matplotlib.pyplot as plt
6
+ from tensorflow.keras.preprocessing.image import img_to_array
7
+ from PIL import Image
8
+ import io
9
+ import zipfile
10
+
11
+ # Load multiple trained U-Net models
12
+ model_1 = tf.keras.models.load_model("pages/dyskera_unet_model.keras")
13
+ model_2 = tf.keras.models.load_model("pages/koilo_unet_model.keras")
14
+ model_3 = tf.keras.models.load_model("pages/meta_unet_model.keras")
15
+ model_4 = tf.keras.models.load_model("pages/para_unet_model.keras")
16
+ model_5 = tf.keras.models.load_model("pages/SI_unet_model.keras")
17
+ # Add more models as needed
18
+
19
+ # Function to preprocess the uploaded image
20
+ def preprocess_image(uploaded_image, target_size=(128, 128)):
21
+ image = Image.open(uploaded_image).convert('L') # Convert to grayscale
22
+ image = np.array(image) # Convert to numpy array
23
+ image = cv2.resize(image, target_size) # Resize the image
24
+ image = image.astype('float32') / 255.0 # Normalize
25
+ image = np.expand_dims(image, axis=-1) # Add channel dimension
26
+ image = np.expand_dims(image, axis=0) # Add batch dimension
27
+ return image
28
+
29
+ # Function to predict masks using the selected model
30
+ def predict_masks(image, model):
31
+ predictions = model.predict(image)
32
+ return predictions[0]
33
+
34
+ # Function to save masks as downloadable images
35
+ def save_image(mask, filename):
36
+ mask = (mask * 255).astype(np.uint8)
37
+ im = Image.fromarray(mask)
38
+ buf = io.BytesIO()
39
+ im.save(buf, format="PNG")
40
+ buf.seek(0)
41
+ return buf
42
+
43
+ # Streamlit Interface
44
+ st.title("U-Net Cell Segmentation with Multiple Models")
45
+
46
+ st.sidebar.info("Feel free to select other models from the pages above 🙂")
47
+
48
+ st.sidebar.write("""
49
+ ### Instructions 🔧
50
+ Please upload Cell Image to get Predicted Cytoplasm Mask and Nuclei Mask. After getting and downloading Predicted Cytoplasm and Nuclei Mask image, head over to CNN or SVM model to get accurate prediction of Cancer Cell Type.
51
+ """)
52
+
53
+ # Upload an image
54
+ uploaded_image = st.file_uploader("Choose an image...", type=["bmp", "png", "jpg", "jpeg"])
55
+
56
+ # Model selection dropdown
57
+ model_option = st.selectbox(
58
+ "Select a Model:",
59
+ ("U-Net Dyskeratotic", "U-Net Koilocyctotic", "U-Net Metaplastic", "U-Net Parabasal", "U-Net Superficial-Intermediate")
60
+ )
61
+
62
+ # Choose the corresponding model based on selection
63
+ if model_option == "U-Net Dyskeratotic":
64
+ selected_model = model_1
65
+ elif model_option == "U-Net Koilocyctotic":
66
+ selected_model = model_2
67
+ elif model_option == "U-Net Metaplastic":
68
+ selected_model = model_3
69
+ elif model_option == "U-Net Parabasal":
70
+ selected_model = model_4
71
+ elif model_option == "U-Net Superficial-Intermediate":
72
+ selected_model = model_5
73
+
74
+ if uploaded_image is not None:
75
+ st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
76
+
77
+ # Preprocess the image
78
+ image = preprocess_image(uploaded_image)
79
+
80
+ # Predict the masks using the selected model
81
+ predictions = predict_masks(image, selected_model)
82
+
83
+ # Display results using Matplotlib
84
+ fig, ax = plt.subplots(1, 3, figsize=(12, 8))
85
+ ax[0].imshow(image.squeeze(), cmap='gray')
86
+ ax[0].set_title('Original Image')
87
+ ax[1].imshow(predictions[:, :, 0], cmap='gray')
88
+ ax[1].set_title('Predicted Cytoplasm Mask')
89
+ ax[2].imshow(predictions[:, :, 1], cmap='gray')
90
+ ax[2].set_title('Predicted Nuclei Mask')
91
+ st.pyplot(fig)
92
+
93
+ # Convert the predicted masks to images for display
94
+ cytoplasm_mask_img = (predictions[:, :, 0] * 255).astype(np.uint8)
95
+ nuclei_mask_img = (predictions[:, :, 1] * 255).astype(np.uint8)
96
+
97
+ st.image(cytoplasm_mask_img, caption="Predicted Cytoplasm Mask", use_column_width=True)
98
+ st.image(nuclei_mask_img, caption="Predicted Nuclei Mask", use_column_width=True)
99
+
100
+ # Save the predicted masks for download
101
+ cytoplasm_mask_download = save_image(predictions[:, :, 0], "predicted_cytoplasm_mask.png")
102
+ nuclei_mask_download = save_image(predictions[:, :, 1], "predicted_nuclei_mask.png")
103
+
104
+ # Combine the two mask files into a zip
105
+ zip_buffer = io.BytesIO()
106
+ with zipfile.ZipFile(zip_buffer, "w") as zip_file:
107
+ zip_file.writestr("predicted_cytoplasm_mask.png", cytoplasm_mask_download.getvalue())
108
+ zip_file.writestr("predicted_nuclei_mask.png", nuclei_mask_download.getvalue())
109
+
110
+ zip_buffer.seek(0)
111
+
112
+ # Provide a download button for the zipped masks
113
+ st.download_button(
114
+ label="Download Predicted Cytoplasm and Nuclei Masks",
115
+ data=zip_buffer,
116
+ file_name="predicted_masks.zip",
117
+ mime="application/zip"
118
+ )