Spaces:
Sleeping
Sleeping
Updated to allow for nrrd uploads
Browse files
app.py
CHANGED
|
@@ -3,6 +3,7 @@ import pandas as pd
|
|
| 3 |
import numpy as np
|
| 4 |
import pydicom
|
| 5 |
import os
|
|
|
|
| 6 |
from skimage import transform
|
| 7 |
import torch
|
| 8 |
from segment_anything import sam_model_registry
|
|
@@ -12,12 +13,18 @@ import torch.nn.functional as F
|
|
| 12 |
import io
|
| 13 |
from gradio_image_prompter import ImagePrompter
|
| 14 |
|
| 15 |
-
def
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
| 19 |
else:
|
| 20 |
-
img =
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
# Convert grayscale to 3-channel RGB by replicating channels
|
| 23 |
if len(img.shape) == 2: # Grayscale image (height, width)
|
|
@@ -45,7 +52,7 @@ def medsam_inference(medsam_model, img_embed, box_1024, H, W):
|
|
| 45 |
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
|
| 46 |
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
|
| 47 |
multimask_output=False,
|
| 48 |
-
|
| 49 |
|
| 50 |
low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
|
| 51 |
|
|
@@ -59,7 +66,6 @@ def medsam_inference(medsam_model, img_embed, box_1024, H, W):
|
|
| 59 |
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
|
| 60 |
return medsam_seg
|
| 61 |
|
| 62 |
-
# Function for visualizing images with masks
|
| 63 |
def visualize(image, mask, box):
|
| 64 |
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
|
| 65 |
ax[0].imshow(image, cmap='gray')
|
|
@@ -68,30 +74,24 @@ def visualize(image, mask, box):
|
|
| 68 |
ax[1].imshow(mask, alpha=0.5, cmap="jet")
|
| 69 |
plt.tight_layout()
|
| 70 |
|
| 71 |
-
# Convert matplotlib figure to a PIL Image
|
| 72 |
buf = io.BytesIO()
|
| 73 |
fig.savefig(buf, format='png')
|
| 74 |
-
plt.close(fig)
|
| 75 |
buf.seek(0)
|
| 76 |
pil_img = Image.open(buf)
|
| 77 |
|
| 78 |
return pil_img
|
| 79 |
|
| 80 |
-
|
| 81 |
-
def process_images(img_dict):
|
| 82 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 83 |
|
| 84 |
-
# Load and preprocess
|
| 85 |
-
|
| 86 |
-
|
| 87 |
if len(points) >= 6:
|
| 88 |
x_min, y_min, x_max, y_max = points[0], points[1], points[3], points[4]
|
| 89 |
else:
|
| 90 |
raise ValueError("Insufficient data for bounding box coordinates.")
|
| 91 |
-
image, H, W = img, img.shape[0], img.shape[1] #
|
| 92 |
-
if len(image.shape) == 2:
|
| 93 |
-
image = np.repeat(image[:, :, None], 3, axis=-1)
|
| 94 |
-
H, W, _ = image.shape
|
| 95 |
|
| 96 |
image_resized = transform.resize(image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8)
|
| 97 |
image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
|
|
@@ -120,16 +120,17 @@ def process_images(img_dict):
|
|
| 120 |
|
| 121 |
# Set up Gradio interface
|
| 122 |
iface = gr.Interface(
|
| 123 |
-
fn=
|
| 124 |
inputs=[
|
| 125 |
-
|
|
|
|
| 126 |
],
|
| 127 |
outputs=[
|
| 128 |
gr.Image(type="pil", label="Processed Image")
|
| 129 |
],
|
| 130 |
-
title="ROI Selection with MEDSAM",
|
| 131 |
-
description="Upload an
|
| 132 |
)
|
| 133 |
|
| 134 |
# Launch the interface
|
| 135 |
-
iface.launch()
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import pydicom
|
| 5 |
import os
|
| 6 |
+
import nrrd
|
| 7 |
from skimage import transform
|
| 8 |
import torch
|
| 9 |
from segment_anything import sam_model_registry
|
|
|
|
| 13 |
import io
|
| 14 |
from gradio_image_prompter import ImagePrompter
|
| 15 |
|
| 16 |
+
def load_nrrd(file_path):
|
| 17 |
+
data, header = nrrd.read(file_path)
|
| 18 |
+
|
| 19 |
+
# If the data is 3D, take the middle slice
|
| 20 |
+
if len(data.shape) == 3:
|
| 21 |
+
middle_slice = data.shape[2] // 2
|
| 22 |
+
img = data[:, :, middle_slice]
|
| 23 |
else:
|
| 24 |
+
img = data
|
| 25 |
+
|
| 26 |
+
# Normalize the image to 0-255 range
|
| 27 |
+
img = ((img - img.min()) / (img.max() - img.min()) * 255).astype(np.uint8)
|
| 28 |
|
| 29 |
# Convert grayscale to 3-channel RGB by replicating channels
|
| 30 |
if len(img.shape) == 2: # Grayscale image (height, width)
|
|
|
|
| 52 |
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
|
| 53 |
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
|
| 54 |
multimask_output=False,
|
| 55 |
+
)
|
| 56 |
|
| 57 |
low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
|
| 58 |
|
|
|
|
| 66 |
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
|
| 67 |
return medsam_seg
|
| 68 |
|
|
|
|
| 69 |
def visualize(image, mask, box):
|
| 70 |
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
|
| 71 |
ax[0].imshow(image, cmap='gray')
|
|
|
|
| 74 |
ax[1].imshow(mask, alpha=0.5, cmap="jet")
|
| 75 |
plt.tight_layout()
|
| 76 |
|
|
|
|
| 77 |
buf = io.BytesIO()
|
| 78 |
fig.savefig(buf, format='png')
|
| 79 |
+
plt.close(fig)
|
| 80 |
buf.seek(0)
|
| 81 |
pil_img = Image.open(buf)
|
| 82 |
|
| 83 |
return pil_img
|
| 84 |
|
| 85 |
+
def process_nrrd(nrrd_file, points):
|
|
|
|
| 86 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 87 |
|
| 88 |
+
# Load and preprocess NRRD file
|
| 89 |
+
image, H, W = load_nrrd(nrrd_file.name)
|
| 90 |
+
|
| 91 |
if len(points) >= 6:
|
| 92 |
x_min, y_min, x_max, y_max = points[0], points[1], points[3], points[4]
|
| 93 |
else:
|
| 94 |
raise ValueError("Insufficient data for bounding box coordinates.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
image_resized = transform.resize(image, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True).astype(np.uint8)
|
| 97 |
image_resized = (image_resized - image_resized.min()) / np.clip(image_resized.max() - image_resized.min(), a_min=1e-8, a_max=None)
|
|
|
|
| 120 |
|
| 121 |
# Set up Gradio interface
|
| 122 |
iface = gr.Interface(
|
| 123 |
+
fn=process_nrrd,
|
| 124 |
inputs=[
|
| 125 |
+
gr.File(label="NRRD File"),
|
| 126 |
+
gr.JSON(label="Bounding Box Coordinates")
|
| 127 |
],
|
| 128 |
outputs=[
|
| 129 |
gr.Image(type="pil", label="Processed Image")
|
| 130 |
],
|
| 131 |
+
title="ROI Selection with MEDSAM for NRRD Files",
|
| 132 |
+
description="Upload an NRRD file and provide bounding box coordinates for processing."
|
| 133 |
)
|
| 134 |
|
| 135 |
# Launch the interface
|
| 136 |
+
iface.launch()
|