import torch from torchvision import transforms from torch.utils.data import DataLoader, Dataset from transformers import MaskFormerForInstanceSegmentation, MaskFormerImageProcessor from PIL import Image from pathlib import Path # can upload from Huggingface Space "elliemci/maskformer_tumor_segmentation" model = MaskFormerForInstanceSegmentation.from_pretrained("elliemci/maskformer_tumor_segmentation") image_processor = MaskFormerImageProcessor.from_pretrained("elliemci/maskformer_tumor_segmentation") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Define a custom dataset class to handle images class ImageDataset(Dataset): def __init__(self, image_paths, transform=None): self.image_paths = image_paths self.transform = transform def __len__(self): return len(self.image_paths) def __getitem__(self, idx): image = Image.open(self.image_paths[idx]).convert('RGB') if self.transform: image = self.transform(image) return image def segment(image_files): """Takes a list of UploadedFile objects and returns a list of segmented images.""" dataset = ImageDataset(image_files, transform=transforms.ToTensor()) dataloader = DataLoader(dataset, batch_size=len(image_files), shuffle=False) # Batch size is the number of images # process a batch with torch.no_grad(): for batch in dataloader: # Only one iteration since batch_size = len(image_files) pixel_values = batch.to(device, dtype=torch.float32) outputs = model(pixel_values=pixel_values) # Post-processing original_images = outputs.get("org_images", batch) target_sizes = [(image.shape[-2], image.shape[-1]) for image in original_images] predicted_masks = image_processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes) return predicted_masks # Return the list of segmented images # components for Gradion interface def update_gallery(images): print(f"Images received: {images}") gallery_data = [] if images: segmented_images = segment(images) # Process images for i, image_path in enumerate(images): try: image = Image.open(image_path).convert("RGB") # Load original image segmented_mask = segmented_images[i].to(dtype=torch.float32, device="cpu") segmented_image_pil = transforms.ToPILImage()(segmented_mask) # Convert to PIL Image gallery_data.extend([(image, "Original Image"), (segmented_image_pil, "Segmented Image")]) except Exception as e: print(f"Error processing image {i}: {e}") gallery_data.extend([(image, "Original Image"), (image, f"Error: {e}")]) return gallery_data import gradio as gr with gr.Blocks() as demo: gr.Markdown("

MRI Brain Tumor Segmentation App

") with gr.Column(): with gr.Column(): image_files = gr.Files(label="Upload MRI files", file_count="multiple", type="filepath") with gr.Row(): gallery = gr.Gallery(label="Brain Images and Tumor Segmentation") image_files.change( fn=update_gallery, inputs=[image_files], outputs=[gallery]) with gr.Column(): example_image = gr.Image(type="filepath", label="MRI Image", visible=False) examples = gr.Examples(examples=["Te-me_0194.jpg", "Te-me_0111.jpg", "Te-me_0295.jpg", "Te-me_0228.jpg", "Te-me_0218.jpg", "Te-me_0164.jpg"], inputs=[example_image]) with gr.Column(scale=0): example_button = gr.Button("Process Example Image", variant="secondary") example_button.click( fn=lambda img: update_gallery([img]) if img else [], inputs=[example_image], outputs=[gallery] ) demo.launch()