Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| import nibabel as nib | |
| import tensorflow as tf | |
| import streamlit as st | |
| import plotly.graph_objects as go | |
| from skimage import measure | |
| import requests | |
| import tempfile | |
| # Load the pre-trained model from Hugging Face | |
| model_url = "https://huggingface.co/chrisaldikaraharja/TumorSegmentationU-Net/resolve/main/unet_model_full.h5" | |
| model_path = tf.keras.utils.get_file("unet_model_full.h5", model_url) | |
| model = tf.keras.models.load_model(model_path) | |
| # Compile the model manually | |
| model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), | |
| loss='binary_crossentropy', | |
| metrics=[tf.keras.metrics.MeanIoU(num_classes=2)]) | |
| st.title("MRI Tumor Segmentation and Analysis") | |
| # Step 1: Upload MRI image | |
| uploaded_file = st.file_uploader("Upload an MRI image", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| # Read the image | |
| test_img = cv2.imdecode(np.frombuffer(uploaded_file.read(), np.uint8), cv2.IMREAD_COLOR) | |
| test_img = cv2.cvtColor(test_img, cv2.COLOR_BGR2RGB) # Convert to RGB | |
| st.image(test_img, caption='Uploaded MRI Image', use_column_width=True) | |
| # Button to analyze the image | |
| if st.button("Analyze"): | |
| # Ensure the input image is in the correct shape | |
| test_img_resized = cv2.resize(test_img, (256, 256)) # Resize to match model input | |
| test_img_input = np.expand_dims(test_img_resized, axis=0) # shape (1, 256, 256, 3) | |
| # Step 2: Use the loaded model to make a prediction | |
| test_pred1 = model.predict(test_img_input) | |
| test_prediction1 = np.argmax(test_pred1, axis=3)[0, :, :] # shape (256, 256) | |
| # Step 3: Get tumor coordinates | |
| tumor_coordinates = np.argwhere(test_prediction1 > 0) | |
| # Calculate centroid | |
| if len(tumor_coordinates) > 0: | |
| centroid = tumor_coordinates.mean(axis=0) | |
| else: | |
| centroid = None | |
| # Calculate bounding box | |
| if len(tumor_coordinates) > 0: | |
| min_x, min_y = tumor_coordinates.min(axis=0) | |
| max_x, max_y = tumor_coordinates.max(axis=0) | |
| else: | |
| min_x = min_y = max_x = max_y = 0 | |
| # Step 4: Ellipsoid approximation for tumor size estimation | |
| if len(tumor_coordinates) > 0: | |
| a = (max_x - min_x) / 2 # semi-major axis | |
| b = (max_y - min_y) / 2 # semi-minor axis | |
| c = 1.0 # assume a depth value (cm) | |
| tumor_volume = (4/3) * np.pi * a * b * c | |
| tumor_volume_cm3 = tumor_volume / 1000 # convert from mm³ to cm³ | |
| else: | |
| tumor_volume_cm3 = 0 | |
| # Step 5: Create a heatmap from the prediction | |
| heatmap = cv2.applyColorMap(np.uint8(test_prediction1 * 255), cv2.COLORMAP_JET) | |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) # Convert to RGB for proper coloring | |
| overlay_img = cv2.addWeighted(test_img_resized, 0.5, heatmap, 0.5, 0) | |
| # Step 6: Plotting the results | |
| fig, (ax1, ax2, ax3) = plt.subplots(figsize=(15, 10), ncols=3) | |
| # Original MRI image | |
| ax1.set_title('Original MRI Image') | |
| ax1.axis('off') | |
| ax1.imshow(test_img_resized) | |
| # MRI with predicted tumor region highlighted | |
| ax2.set_title('MRI with Predicted Tumor Region Highlighted') | |
| ax2.axis('off') | |
| ax2.imshow(overlay_img) | |
| # Predicted tumor region | |
| ax3.set_title('Predicted Tumor Region') | |
| ax3.axis('off') | |
| pred = ax3.imshow(test_prediction1, cmap='magma') | |
| # Add colorbar | |
| fig.colorbar(pred, ax=ax3, fraction=0.046, pad=0.04) | |
| st.pyplot(fig) | |
| # Display results | |
| if centroid is not None: | |
| st.write(f'Tumor centroid coordinates (x, y): ({centroid[0]:.2f}, {centroid[1]:.2f})') | |
| else: | |
| st.write('No tumor detected.') | |
| st.write(f'Estimated tumor volume: {tumor_volume_cm3:.2f} cm³') | |
| # Step 7: Create 3D interactive visualization with Plotly | |
| st.subheader("3D Interactive Visualization of Tumor") | |
| # Function to download a file and return the file path | |
| def download_file(url): | |
| response = requests.get(url) | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".nii") | |
| temp_file.write(response.content) | |
| temp_file.close() | |
| return temp_file.name | |
| # URLs for the NIfTI files on Hugging Face | |
| brain_url = "https://huggingface.co/chrisaldikaraharja/Brain3D/resolve/main/BraTS20_Training_001_flair.nii" | |
| seg_url = "https://huggingface.co/chrisaldikaraharja/Brain3D/resolve/main/BraTS20_Training_001_seg.nii" | |
| # Download the brain and segmentation images | |
| brain_path = download_file(brain_url) | |
| seg_path = download_file(seg_url) | |
| # Load the brain and segmentation images from the temporary files | |
| im = nib.load(brain_path).get_fdata() | |
| seg = nib.load(seg_path).get_fdata() | |
| # Compute the isosurfaces for the brain image | |
| verts, faces, normals, values = measure.marching_cubes(im, 0) | |
| x, y, z = verts.T | |
| i, j, k = faces.T | |
| # Create the first mesh (brain) | |
| mesh1 = go.Mesh3d(x=x, y=y, z=z, color='pink', opacity=0.5, i=i, j=j, k=k) | |
| # Compute the isosurfaces for the segmentation | |
| verts, faces, normals, values = measure.marching_cubes(seg, 2) | |
| x, y, z = verts.T | |
| i, j, k = faces.T | |
| # Create the second mesh (tumor) with brown color | |
| mesh2 = go.Mesh3d(x=x, y=y, z=z, color='brown', opacity=0.5, i=i, j=j, k=k) | |
| # Create the figure with both meshes | |
| bfig = go.Figure(data=[mesh1, mesh2]) | |
| # Set the layout for better viewing, setting the background to light blue and axis labels text to black | |
| bfig.update_layout( | |
| scene=dict( | |
| xaxis_title='X', | |
| yaxis_title='Y', | |
| zaxis_title='Z', | |
| xaxis=dict( | |
| backgroundcolor="lightblue", # Set background for X-axis plane | |
| titlefont=dict(color='black'), # X-axis label text color | |
| tickfont=dict(color='black') # X-axis ticks text color | |
| ), | |
| yaxis=dict( | |
| backgroundcolor="lightblue", # Set background for Y-axis plane | |
| titlefont=dict(color='black'), # Y-axis label text color | |
| tickfont=dict(color='black') # Y-axis ticks text color | |
| ), | |
| zaxis=dict( | |
| backgroundcolor="lightblue", # Set background for Z-axis plane | |
| titlefont=dict(color='black'), # Z-axis label text color | |
| tickfont=dict(color='black') # Z-axis ticks text color | |
| ), | |
| ), | |
| paper_bgcolor="white", # Set the outer background to white | |
| plot_bgcolor="white" # Set the inner plot background to white | |
| ) | |
| # Render the Plotly figure in Streamlit | |
| st.plotly_chart(bfig, use_container_width=True) |