fifadotjs's picture
Update app.py
c34df6b verified
import streamlit as st
import torch
import segmentation_models_pytorch as smp
import numpy as np
import cv2
from PIL import Image
import os
import matplotlib.pyplot as plt
from io import BytesIO
# Set locale to Thai
# (There is no direct i18n module in Streamlit for localization, manual translation is required)
# st.set_locale("th")
# Define the model architecture
model = smp.Unet(
encoder_name="resnet34",
encoder_weights=None, # We will load our own weights
in_channels=1,
classes=1,
)
# Load the model weights
model_path = 'unet_model_statedict_resnet34andimagenet_best.pth'
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
# Set the device to GPU if available, otherwise CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()
st.sidebar.image("smte_logo.png")
# Function to preprocess the uploaded image
def preprocess_image(image):
image = image.convert("L") # Convert to grayscale
image = np.array(image)
image = cv2.resize(image, (256, 256)) # Resize to 256x256
image = image / 255.0 # Normalize to [0,1]
image = image[np.newaxis, np.newaxis, :, :] # Add batch and channel dimensions
image = torch.tensor(image, dtype=torch.float32)
return image.to(device)
# Function to apply the model and get the segmentation mask
def get_segmentation_mask(image, threshold):
with torch.no_grad():
output = model(image)
output = output.squeeze().cpu().numpy()
output = (output > threshold).astype(np.uint8) * 255
return output
# Function to list sample images in a directory
def list_sample_images(directory):
valid_extensions = ['jpg', 'jpeg', 'png']
return [f for f in os.listdir(directory) if any(f.lower().endswith(ext) for ext in valid_extensions)]
# Streamlit app
st.title("Brainstroke segmentation from CT-SCAN")
# Sidebar controls
st.sidebar.title("ตัวควบคุม")
# Sample image selection in sidebar
sample_images_directory = 'sample' # Directory containing sample images
sample_images = list_sample_images(sample_images_directory)
selected_sample_image = st.sidebar.selectbox("เลือกรูปภาพตัวอย่าง", ["ไม่มี"] + sample_images)
# Image upload in sidebar
uploaded_file = st.sidebar.file_uploader("หรืออัปโหลดรูปภาพ MRI สมอง", type=["jpg", "jpeg", "png"])
# Main content area
if uploaded_file is not None:
original_image = Image.open(uploaded_file)
# st.image(original_image, caption='รูปภาพที่อัปโหลด', use_column_width=True)
elif selected_sample_image != "ไม่มี":
image_path = os.path.join(sample_images_directory, selected_sample_image)
original_image = Image.open(image_path)
# st.image(original_image, caption=f'รูปภาพตัวอย่างที่เลือก: {selected_sample_image}', use_column_width=True)
if 'original_image' in locals():
# Preprocess the image
input_image = preprocess_image(original_image)
# Threshold slider in sidebar
threshold = st.sidebar.slider('ความมั่นใจของโมเดล', 0.0, 1.0, 0.5, 0.01)
# Display original and predicted images side by side
col1, col2 = st.columns(2)
with col1:
st.image(original_image, caption='รูปภาพต้นฉบับ', use_column_width=True)
with col2:
# Get segmentation mask
mask = get_segmentation_mask(input_image, threshold)
# Convert the original image to grayscale and resize it
brain_image = original_image.convert("L")
brain_image = np.array(brain_image)
brain_image = cv2.resize(brain_image, (256, 256))
# Colormap selector and alpha slider in sidebar
colormap = st.sidebar.selectbox("เลือกสีการแยกส่วน", ["Blues", "viridis", "plasma", "inferno", "magma", "cividis"])
alpha = st.sidebar.slider("ค่าความโปร่งใสของการแยกส่วน", 0.0, 1.0, 0.7, 0.01)
# Create a plot of the brain image and overlay the mask
fig, ax = plt.subplots()
ax.imshow(brain_image, cmap='gray')
ax.imshow(mask, cmap=colormap, alpha=alpha)
ax.axis('off') # Hide axes
# Save the figure to a BytesIO object
buf = BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
buf.seek(0)
# Convert the BytesIO object to an image and display it
overlay_image = Image.open(buf)
st.image(overlay_image, caption='ภาพซ้อนกันของการแยกส่วน', use_column_width=True)
# Download button for the segmented image
st.sidebar.download_button(
label="ดาวน์โหลดผลวินิฉัย",
data=buf,
file_name='segmented_image.png',
mime='image/png',
)