File size: 4,172 Bytes
0ebdb2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8407bd
0ebdb2c
e8407bd
0ebdb2c
e8407bd
 
 
 
 
 
 
0ebdb2c
e8407bd
 
 
 
 
 
0ebdb2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8407bd
 
 
562a24e
 
 
e8407bd
 
 
 
 
 
 
 
 
 
0ebdb2c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch

from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from transformers import MaskFormerForInstanceSegmentation, MaskFormerImageProcessor
from PIL import Image
from pathlib import Path

# can upload from Huggingface Space "elliemci/maskformer_tumor_segmentation"
model = MaskFormerForInstanceSegmentation.from_pretrained("elliemci/maskformer_tumor_segmentation")
image_processor = MaskFormerImageProcessor.from_pretrained("elliemci/maskformer_tumor_segmentation")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define a custom dataset class to handle images
class ImageDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')

        if self.transform:
            image = self.transform(image)
        return image

def segment(image_files):
    """Takes a list of UploadedFile objects and returns a list of segmented images."""

    dataset = ImageDataset(image_files, transform=transforms.ToTensor())
    dataloader = DataLoader(dataset, batch_size=len(image_files), shuffle=False)  # Batch size is the number of images

    # process a batch
    with torch.no_grad():
        for batch in dataloader:  # Only one iteration since batch_size = len(image_files)
            pixel_values = batch.to(device, dtype=torch.float32)
            outputs = model(pixel_values=pixel_values)

            # Post-processing
            original_images = outputs.get("org_images", batch)
            target_sizes = [(image.shape[-2], image.shape[-1]) for image in original_images]
            predicted_masks = image_processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)

            return predicted_masks  # Return the list of segmented images

# components for Gradion interface
def update_gallery(images):
    print(f"Images received: {images}")

    gallery_data = []
    if images:
        segmented_images = segment(images)  # Process images

        for i, image_path in enumerate(images):
            try:
                image = Image.open(image_path).convert("RGB")  # Load original image
                
                segmented_mask = segmented_images[i].to(dtype=torch.float32, device="cpu")

                segmented_image_pil = transforms.ToPILImage()(segmented_mask)  # Convert to PIL Image
                
                gallery_data.extend([(image, "Original Image"), (segmented_image_pil, "Segmented Image")])
            except Exception as e:
                print(f"Error processing image {i}: {e}")
                gallery_data.extend([(image, "Original Image"), (image, f"Error: {e}")])

    return gallery_data

import gradio as gr

with gr.Blocks() as demo:
  gr.Markdown("<h1 style='text-align: center;'>MRI Brain Tumor Segmentation App</h1>")

  with gr.Column():
    with gr.Column():
      image_files = gr.Files(label="Upload MRI files",
                                     file_count="multiple",
                                     type="filepath")
      with gr.Row():
        gallery = gr.Gallery(label="Brain Images and Tumor Segmentation")

        image_files.change(
             fn=update_gallery,
             inputs=[image_files],
             outputs=[gallery])
        
      with gr.Column():
        example_image = gr.Image(type="filepath", label="MRI Image", visible=False)
        examples = gr.Examples(examples=["Te-me_0194.jpg", "Te-me_0111.jpg", 
                                         "Te-me_0295.jpg", "Te-me_0228.jpg",
                                         "Te-me_0218.jpg", "Te-me_0164.jpg"], 
                               inputs=[example_image])

      with gr.Column(scale=0):

        example_button = gr.Button("Process Example Image", variant="secondary")
        example_button.click(
            fn=lambda img: update_gallery([img]) if img else [],
            inputs=[example_image],  
            outputs=[gallery]
        )

demo.launch()