Spaces:
Sleeping
Sleeping
| import torch | |
| from torchvision import transforms | |
| from torch.utils.data import DataLoader, Dataset | |
| from transformers import MaskFormerForInstanceSegmentation, MaskFormerImageProcessor | |
| from PIL import Image | |
| from pathlib import Path | |
| # can upload from Huggingface Space "elliemci/maskformer_tumor_segmentation" | |
| model = MaskFormerForInstanceSegmentation.from_pretrained("elliemci/maskformer_tumor_segmentation") | |
| image_processor = MaskFormerImageProcessor.from_pretrained("elliemci/maskformer_tumor_segmentation") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| # Define a custom dataset class to handle images | |
| class ImageDataset(Dataset): | |
| def __init__(self, image_paths, transform=None): | |
| self.image_paths = image_paths | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def __getitem__(self, idx): | |
| image = Image.open(self.image_paths[idx]).convert('RGB') | |
| if self.transform: | |
| image = self.transform(image) | |
| return image | |
| def segment(image_files): | |
| """Takes a list of UploadedFile objects and returns a list of segmented images.""" | |
| dataset = ImageDataset(image_files, transform=transforms.ToTensor()) | |
| dataloader = DataLoader(dataset, batch_size=len(image_files), shuffle=False) # Batch size is the number of images | |
| # process a batch | |
| with torch.no_grad(): | |
| for batch in dataloader: # Only one iteration since batch_size = len(image_files) | |
| pixel_values = batch.to(device, dtype=torch.float32) | |
| outputs = model(pixel_values=pixel_values) | |
| # Post-processing | |
| original_images = outputs.get("org_images", batch) | |
| target_sizes = [(image.shape[-2], image.shape[-1]) for image in original_images] | |
| predicted_masks = image_processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes) | |
| return predicted_masks # Return the list of segmented images | |
| # components for Gradion interface | |
| def update_gallery(images): | |
| print(f"Images received: {images}") | |
| gallery_data = [] | |
| if images: | |
| segmented_images = segment(images) # Process images | |
| for i, image_path in enumerate(images): | |
| try: | |
| image = Image.open(image_path).convert("RGB") # Load original image | |
| segmented_mask = segmented_images[i].to(dtype=torch.float32, device="cpu") | |
| segmented_image_pil = transforms.ToPILImage()(segmented_mask) # Convert to PIL Image | |
| gallery_data.extend([(image, "Original Image"), (segmented_image_pil, "Segmented Image")]) | |
| except Exception as e: | |
| print(f"Error processing image {i}: {e}") | |
| gallery_data.extend([(image, "Original Image"), (image, f"Error: {e}")]) | |
| return gallery_data | |
| import gradio as gr | |
| with gr.Blocks() as demo: | |
| gr.Markdown("<h1 style='text-align: center;'>MRI Brain Tumor Segmentation App</h1>") | |
| with gr.Column(): | |
| with gr.Column(): | |
| image_files = gr.Files(label="Upload MRI files", | |
| file_count="multiple", | |
| type="filepath") | |
| with gr.Row(): | |
| gallery = gr.Gallery(label="Brain Images and Tumor Segmentation") | |
| image_files.change( | |
| fn=update_gallery, | |
| inputs=[image_files], | |
| outputs=[gallery]) | |
| with gr.Column(): | |
| example_image = gr.Image(type="filepath", label="MRI Image", visible=False) | |
| examples = gr.Examples(examples=["Te-me_0194.jpg", "Te-me_0111.jpg", | |
| "Te-me_0295.jpg", "Te-me_0228.jpg", | |
| "Te-me_0218.jpg", "Te-me_0164.jpg"], | |
| inputs=[example_image]) | |
| with gr.Column(scale=0): | |
| example_button = gr.Button("Process Example Image", variant="secondary") | |
| example_button.click( | |
| fn=lambda img: update_gallery([img]) if img else [], | |
| inputs=[example_image], | |
| outputs=[gallery] | |
| ) | |
| demo.launch() |