Btsegmentation / app.py
chrisaldikaraharja's picture
Update app.py
561ae97 verified
import numpy as np
import matplotlib.pyplot as plt
import cv2
import nibabel as nib
import tensorflow as tf
import streamlit as st
import plotly.graph_objects as go
from skimage import measure
import requests
import tempfile
# Load the pre-trained model from Hugging Face
model_url = "https://huggingface.co/chrisaldikaraharja/TumorSegmentationU-Net/resolve/main/unet_model_full.h5"
model_path = tf.keras.utils.get_file("unet_model_full.h5", model_url)
model = tf.keras.models.load_model(model_path)
# Compile the model manually
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
loss='binary_crossentropy',
metrics=[tf.keras.metrics.MeanIoU(num_classes=2)])
st.title("MRI Tumor Segmentation and Analysis")
# Step 1: Upload MRI image
uploaded_file = st.file_uploader("Upload an MRI image", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Read the image
test_img = cv2.imdecode(np.frombuffer(uploaded_file.read(), np.uint8), cv2.IMREAD_COLOR)
test_img = cv2.cvtColor(test_img, cv2.COLOR_BGR2RGB) # Convert to RGB
st.image(test_img, caption='Uploaded MRI Image', use_column_width=True)
# Button to analyze the image
if st.button("Analyze"):
# Ensure the input image is in the correct shape
test_img_resized = cv2.resize(test_img, (256, 256)) # Resize to match model input
test_img_input = np.expand_dims(test_img_resized, axis=0) # shape (1, 256, 256, 3)
# Step 2: Use the loaded model to make a prediction
test_pred1 = model.predict(test_img_input)
test_prediction1 = np.argmax(test_pred1, axis=3)[0, :, :] # shape (256, 256)
# Step 3: Get tumor coordinates
tumor_coordinates = np.argwhere(test_prediction1 > 0)
# Calculate centroid
if len(tumor_coordinates) > 0:
centroid = tumor_coordinates.mean(axis=0)
else:
centroid = None
# Calculate bounding box
if len(tumor_coordinates) > 0:
min_x, min_y = tumor_coordinates.min(axis=0)
max_x, max_y = tumor_coordinates.max(axis=0)
else:
min_x = min_y = max_x = max_y = 0
# Step 4: Ellipsoid approximation for tumor size estimation
if len(tumor_coordinates) > 0:
a = (max_x - min_x) / 2 # semi-major axis
b = (max_y - min_y) / 2 # semi-minor axis
c = 1.0 # assume a depth value (cm)
tumor_volume = (4/3) * np.pi * a * b * c
tumor_volume_cm3 = tumor_volume / 1000 # convert from mm³ to cm³
else:
tumor_volume_cm3 = 0
# Step 5: Create a heatmap from the prediction
heatmap = cv2.applyColorMap(np.uint8(test_prediction1 * 255), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) # Convert to RGB for proper coloring
overlay_img = cv2.addWeighted(test_img_resized, 0.5, heatmap, 0.5, 0)
# Step 6: Plotting the results
fig, (ax1, ax2, ax3) = plt.subplots(figsize=(15, 10), ncols=3)
# Original MRI image
ax1.set_title('Original MRI Image')
ax1.axis('off')
ax1.imshow(test_img_resized)
# MRI with predicted tumor region highlighted
ax2.set_title('MRI with Predicted Tumor Region Highlighted')
ax2.axis('off')
ax2.imshow(overlay_img)
# Predicted tumor region
ax3.set_title('Predicted Tumor Region')
ax3.axis('off')
pred = ax3.imshow(test_prediction1, cmap='magma')
# Add colorbar
fig.colorbar(pred, ax=ax3, fraction=0.046, pad=0.04)
st.pyplot(fig)
# Display results
if centroid is not None:
st.write(f'Tumor centroid coordinates (x, y): ({centroid[0]:.2f}, {centroid[1]:.2f})')
else:
st.write('No tumor detected.')
st.write(f'Estimated tumor volume: {tumor_volume_cm3:.2f} cm³')
# Step 7: Create 3D interactive visualization with Plotly
st.subheader("3D Interactive Visualization of Tumor")
# Function to download a file and return the file path
def download_file(url):
response = requests.get(url)
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".nii")
temp_file.write(response.content)
temp_file.close()
return temp_file.name
# URLs for the NIfTI files on Hugging Face
brain_url = "https://huggingface.co/chrisaldikaraharja/Brain3D/resolve/main/BraTS20_Training_001_flair.nii"
seg_url = "https://huggingface.co/chrisaldikaraharja/Brain3D/resolve/main/BraTS20_Training_001_seg.nii"
# Download the brain and segmentation images
brain_path = download_file(brain_url)
seg_path = download_file(seg_url)
# Load the brain and segmentation images from the temporary files
im = nib.load(brain_path).get_fdata()
seg = nib.load(seg_path).get_fdata()
# Compute the isosurfaces for the brain image
verts, faces, normals, values = measure.marching_cubes(im, 0)
x, y, z = verts.T
i, j, k = faces.T
# Create the first mesh (brain)
mesh1 = go.Mesh3d(x=x, y=y, z=z, color='pink', opacity=0.5, i=i, j=j, k=k)
# Compute the isosurfaces for the segmentation
verts, faces, normals, values = measure.marching_cubes(seg, 2)
x, y, z = verts.T
i, j, k = faces.T
# Create the second mesh (tumor) with brown color
mesh2 = go.Mesh3d(x=x, y=y, z=z, color='brown', opacity=0.5, i=i, j=j, k=k)
# Create the figure with both meshes
bfig = go.Figure(data=[mesh1, mesh2])
# Set the layout for better viewing, setting the background to light blue and axis labels text to black
bfig.update_layout(
scene=dict(
xaxis_title='X',
yaxis_title='Y',
zaxis_title='Z',
xaxis=dict(
backgroundcolor="lightblue", # Set background for X-axis plane
titlefont=dict(color='black'), # X-axis label text color
tickfont=dict(color='black') # X-axis ticks text color
),
yaxis=dict(
backgroundcolor="lightblue", # Set background for Y-axis plane
titlefont=dict(color='black'), # Y-axis label text color
tickfont=dict(color='black') # Y-axis ticks text color
),
zaxis=dict(
backgroundcolor="lightblue", # Set background for Z-axis plane
titlefont=dict(color='black'), # Z-axis label text color
tickfont=dict(color='black') # Z-axis ticks text color
),
),
paper_bgcolor="white", # Set the outer background to white
plot_bgcolor="white" # Set the inner plot background to white
)
# Render the Plotly figure in Streamlit
st.plotly_chart(bfig, use_container_width=True)