Spaces:
Sleeping
Sleeping
| from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation | |
| from collections import defaultdict | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| from matplotlib import cm | |
| import torch | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| import gradio as gr | |
| processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") | |
| model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic") | |
| def replace(text): | |
| # image = Image.open(text).convert("RGB") | |
| inputs = processor(text, return_tensors="pt") | |
| outputs = model(**inputs) | |
| prediction = processor.post_process_panoptic_segmentation(outputs, target_sizes=[text.size[::-1]])[0] | |
| return draw_panoptic_segmentation(**prediction) | |
| def draw_panoptic_segmentation(segmentation, segments_info): | |
| # get the used color map | |
| viridis = cm.get_cmap('viridis', torch.max(segmentation)) | |
| fig, ax = plt.subplots() | |
| ax.imshow(segmentation) | |
| instances_counter = defaultdict(int) | |
| handles = [] | |
| # for each segment, draw its legend | |
| # for segment in segments_info: | |
| # segment_id = segment['id'] | |
| # segment_label_id = segment['label_id'] | |
| # segment_label = model.config.id2label[segment_label_id] | |
| # label = f"{segment_label}-{instances_counter[segment_label_id]}" | |
| # instances_counter[segment_label_id] += 1 | |
| # color = viridis(segment_id) | |
| # handles.append(mpatches.Patch(color=color, label=label)) | |
| # ax.legend(handles=handles) | |
| for segment in segments_info: | |
| segment_id = segment['id'] | |
| color = viridis(segment_id) | |
| # Save the figure to a buffer and convert it to a PIL image | |
| buf = BytesIO() | |
| plt.savefig(buf, format='png') | |
| buf.seek(0) | |
| plt.close(fig) # Close the figure to free memory | |
| pil_image = Image.open(buf) | |
| return pil_image | |
| # Set up the Gradio interface with updated syntax | |
| interface = gr.Interface( | |
| fn=replace, # The function to execute | |
| inputs=gr.Image(type="pil"), # Input type as PIL image | |
| outputs="image", # Output type as an image | |
| title="Image Segmentation with Mask Overlay", # Title for the Gradio app | |
| description="Upload an image to see the segmentation mask applied." # Description for the app | |
| ) | |
| # Launch the Gradio app | |
| interface.launch(debug=True) |