mri_segment / app.py
elliemci's picture
change example images file paths
562a24e verified
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from transformers import MaskFormerForInstanceSegmentation, MaskFormerImageProcessor
from PIL import Image
from pathlib import Path
# can upload from Huggingface Space "elliemci/maskformer_tumor_segmentation"
model = MaskFormerForInstanceSegmentation.from_pretrained("elliemci/maskformer_tumor_segmentation")
image_processor = MaskFormerImageProcessor.from_pretrained("elliemci/maskformer_tumor_segmentation")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Define a custom dataset class to handle images
class ImageDataset(Dataset):
def __init__(self, image_paths, transform=None):
self.image_paths = image_paths
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx]).convert('RGB')
if self.transform:
image = self.transform(image)
return image
def segment(image_files):
"""Takes a list of UploadedFile objects and returns a list of segmented images."""
dataset = ImageDataset(image_files, transform=transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=len(image_files), shuffle=False) # Batch size is the number of images
# process a batch
with torch.no_grad():
for batch in dataloader: # Only one iteration since batch_size = len(image_files)
pixel_values = batch.to(device, dtype=torch.float32)
outputs = model(pixel_values=pixel_values)
# Post-processing
original_images = outputs.get("org_images", batch)
target_sizes = [(image.shape[-2], image.shape[-1]) for image in original_images]
predicted_masks = image_processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)
return predicted_masks # Return the list of segmented images
# components for Gradion interface
def update_gallery(images):
print(f"Images received: {images}")
gallery_data = []
if images:
segmented_images = segment(images) # Process images
for i, image_path in enumerate(images):
try:
image = Image.open(image_path).convert("RGB") # Load original image
segmented_mask = segmented_images[i].to(dtype=torch.float32, device="cpu")
segmented_image_pil = transforms.ToPILImage()(segmented_mask) # Convert to PIL Image
gallery_data.extend([(image, "Original Image"), (segmented_image_pil, "Segmented Image")])
except Exception as e:
print(f"Error processing image {i}: {e}")
gallery_data.extend([(image, "Original Image"), (image, f"Error: {e}")])
return gallery_data
import gradio as gr
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center;'>MRI Brain Tumor Segmentation App</h1>")
with gr.Column():
with gr.Column():
image_files = gr.Files(label="Upload MRI files",
file_count="multiple",
type="filepath")
with gr.Row():
gallery = gr.Gallery(label="Brain Images and Tumor Segmentation")
image_files.change(
fn=update_gallery,
inputs=[image_files],
outputs=[gallery])
with gr.Column():
example_image = gr.Image(type="filepath", label="MRI Image", visible=False)
examples = gr.Examples(examples=["Te-me_0194.jpg", "Te-me_0111.jpg",
"Te-me_0295.jpg", "Te-me_0228.jpg",
"Te-me_0218.jpg", "Te-me_0164.jpg"],
inputs=[example_image])
with gr.Column(scale=0):
example_button = gr.Button("Process Example Image", variant="secondary")
example_button.click(
fn=lambda img: update_gallery([img]) if img else [],
inputs=[example_image],
outputs=[gallery]
)
demo.launch()