import torch from torchvision import transforms from torch.utils.data import DataLoader, Dataset from transformers import MaskFormerForInstanceSegmentation, MaskFormerImageProcessor from PIL import Image from pathlib import Path # can upload from Huggingface Space "elliemci/maskformer_tumor_segmentation" model = MaskFormerForInstanceSegmentation.from_pretrained("elliemci/maskformer_tumor_segmentation") image_processor = MaskFormerImageProcessor.from_pretrained("elliemci/maskformer_tumor_segmentation") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Define a custom dataset class to handle images class ImageDataset(Dataset): def __init__(self, image_paths, transform=None): self.image_paths = image_paths self.transform = transform def __len__(self): return len(self.image_paths) def __getitem__(self, idx): image = Image.open(self.image_paths[idx]).convert('RGB') if self.transform: image = self.transform(image) return image def segment(image_files): """Takes a list of UploadedFile objects and returns a list of segmented images.""" dataset = ImageDataset(image_files, transform=transforms.ToTensor()) dataloader = DataLoader(dataset, batch_size=len(image_files), shuffle=False) # Batch size is the number of images # process a batch with torch.no_grad(): for batch in dataloader: # Only one iteration since batch_size = len(image_files) pixel_values = batch.to(device, dtype=torch.float32) outputs = model(pixel_values=pixel_values) # Post-processing original_images = outputs.get("org_images", batch) target_sizes = [(image.shape[-2], image.shape[-1]) for image in original_images] predicted_masks = image_processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes) return predicted_masks # Return the list of segmented images # components for Gradion interface def update_gallery(images): print(f"Images received: {images}") gallery_data = [] if images: segmented_images = segment(images) # Process images for i, image_path in enumerate(images): try: image = Image.open(image_path).convert("RGB") # Load original image segmented_mask = segmented_images[i].to(dtype=torch.float32, device="cpu") segmented_image_pil = transforms.ToPILImage()(segmented_mask) # Convert to PIL Image gallery_data.extend([(image, "Original Image"), (segmented_image_pil, "Segmented Image")]) except Exception as e: print(f"Error processing image {i}: {e}") gallery_data.extend([(image, "Original Image"), (image, f"Error: {e}")]) return gallery_data import gradio as gr with gr.Blocks() as demo: gr.Markdown("