NRRD_Viewer / app.py
vincentgao95's picture
Added all 3 views, along with correct rotations for viewing
06e3013 verified
raw
history blame
2.24 kB
import gradio as gr
import numpy as np
import nrrd
import matplotlib.pyplot as plt
import io
from PIL import Image
# Set Matplotlib to use the 'Agg' backend
plt.switch_backend('Agg')
def load_nrrd(file_path):
data, _ = nrrd.read(file_path.name)
return data
def visualize_slice(file_path, slice_index, view_axis):
data = load_nrrd(file_path)
# Ensure the slice index is within the range of available slices
slice_index = min(max(0, slice_index), data.shape[view_axis] - 1)
if view_axis == 0: # Axial
slice_image = np.rot90(data[slice_index, :, :], k=1)
elif view_axis == 1: # Coronal
slice_image = np.rot90(data[:, slice_index, :], k=1)
elif view_axis == 2: # Sagittal
slice_image = np.flipud(np.rot90(data[:, :, slice_index], k=1))
# Plot the slice
fig, ax = plt.subplots()
ax.imshow(slice_image, cmap='gray')
plt.axis('off')
# Convert matplotlib figure to PIL Image
buf = io.BytesIO()
fig.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)
pil_img = Image.open(buf)
return pil_img
def update_slider(file_path, view_axis):
data = load_nrrd(file_path)
middle_slice = data.shape[view_axis] // 2
return gr.update(maximum=data.shape[view_axis] - 1, value=middle_slice)
with gr.Blocks() as app:
gr.Markdown("## NRRD Slice Visualizer")
gr.Markdown("Upload an NRRD file and use the slider to select and visualize slices.")
file_input = gr.File(label="Upload NRRD File")
view_axis_selector = gr.Radio(choices=[0, 1, 2], value=2, label="View Axis", info="0: Axial, 1: Coronal, 2: Sagittal")
slider = gr.Slider(minimum=0, maximum=1, step=1, value=0, label="Slice Selector")
image_output = gr.Image(type="pil", label="Selected Slice")
file_input.change(fn=update_slider, inputs=[file_input, view_axis_selector], outputs=slider)
file_input.change(fn=visualize_slice, inputs=[file_input, slider, view_axis_selector], outputs=image_output)
view_axis_selector.change(fn=update_slider, inputs=[file_input, view_axis_selector], outputs=slider)
slider.change(fn=visualize_slice, inputs=[file_input, slider, view_axis_selector], outputs=image_output)
app.launch()