File size: 5,092 Bytes
7f0865b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c34df6b
7f0865b
 
 
 
 
 
 
c34df6b
7f0865b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c34df6b
7f0865b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import streamlit as st
import torch
import segmentation_models_pytorch as smp
import numpy as np
import cv2
from PIL import Image
import os
import matplotlib.pyplot as plt
from io import BytesIO

# Set locale to Thai
# (There is no direct i18n module in Streamlit for localization, manual translation is required)
# st.set_locale("th")

# Define the model architecture
model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights=None,  # We will load our own weights
    in_channels=1,
    classes=1,
)

# Load the model weights
model_path = 'unet_model_statedict_resnet34andimagenet_best.pth'
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

# Set the device to GPU if available, otherwise CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

model.eval()
st.sidebar.image("smte_logo.png")

# Function to preprocess the uploaded image
def preprocess_image(image):
    image = image.convert("L")  # Convert to grayscale
    image = np.array(image)
    image = cv2.resize(image, (256, 256))  # Resize to 256x256
    image = image / 255.0  # Normalize to [0,1]
    image = image[np.newaxis, np.newaxis, :, :]  # Add batch and channel dimensions
    image = torch.tensor(image, dtype=torch.float32)
    return image.to(device)

# Function to apply the model and get the segmentation mask
def get_segmentation_mask(image, threshold):
    with torch.no_grad():
        output = model(image)
        output = output.squeeze().cpu().numpy()
        output = (output > threshold).astype(np.uint8) * 255
    return output

# Function to list sample images in a directory
def list_sample_images(directory):
    valid_extensions = ['jpg', 'jpeg', 'png']
    return [f for f in os.listdir(directory) if any(f.lower().endswith(ext) for ext in valid_extensions)]

# Streamlit app
st.title("Brainstroke segmentation from CT-SCAN")

# Sidebar controls
st.sidebar.title("ตัวควบคุม")

# Sample image selection in sidebar
sample_images_directory = 'sample'  # Directory containing sample images
sample_images = list_sample_images(sample_images_directory)
selected_sample_image = st.sidebar.selectbox("เลือกรูปภาพตัวอย่าง", ["ไม่มี"] + sample_images)

# Image upload in sidebar
uploaded_file = st.sidebar.file_uploader("หรืออัปโหลดรูปภาพ MRI สมอง", type=["jpg", "jpeg", "png"])


# Main content area
if uploaded_file is not None:
    original_image = Image.open(uploaded_file)
    # st.image(original_image, caption='รูปภาพที่อัปโหลด', use_column_width=True)
elif selected_sample_image != "ไม่มี":
    image_path = os.path.join(sample_images_directory, selected_sample_image)
    original_image = Image.open(image_path)
    # st.image(original_image, caption=f'รูปภาพตัวอย่างที่เลือก: {selected_sample_image}', use_column_width=True)

if 'original_image' in locals():
    # Preprocess the image
    input_image = preprocess_image(original_image)
    
    # Threshold slider in sidebar
    threshold = st.sidebar.slider('ความมั่นใจของโมเดล', 0.0, 1.0, 0.5, 0.01)



    # Display original and predicted images side by side
    col1, col2 = st.columns(2)
    with col1:
        st.image(original_image, caption='รูปภาพต้นฉบับ', use_column_width=True)
    with col2:
                # Get segmentation mask
        mask = get_segmentation_mask(input_image, threshold)
        
        # Convert the original image to grayscale and resize it
        brain_image = original_image.convert("L")
        brain_image = np.array(brain_image)
        brain_image = cv2.resize(brain_image, (256, 256))
        # Colormap selector and alpha slider in sidebar
        colormap = st.sidebar.selectbox("เลือกสีการแยกส่วน", ["Blues", "viridis", "plasma", "inferno", "magma", "cividis"])
        alpha = st.sidebar.slider("ค่าความโปร่งใสของการแยกส่วน", 0.0, 1.0, 0.7, 0.01)
        
        # Create a plot of the brain image and overlay the mask
        fig, ax = plt.subplots()
        ax.imshow(brain_image, cmap='gray')
        ax.imshow(mask, cmap=colormap, alpha=alpha)
        ax.axis('off')  # Hide axes

        # Save the figure to a BytesIO object
        buf = BytesIO()
        plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
        buf.seek(0)

        # Convert the BytesIO object to an image and display it
        overlay_image = Image.open(buf)
        st.image(overlay_image, caption='ภาพซ้อนกันของการแยกส่วน', use_column_width=True)
        # Download button for the segmented image
        st.sidebar.download_button(
            label="ดาวน์โหลดผลวินิฉัย",
            data=buf,
            file_name='segmented_image.png',
            mime='image/png',
            
        )