| | import torch |
| | import numpy as np |
| | from PIL import Image |
| | import matplotlib.pyplot as plt |
| | import gradio as gr |
| | from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation |
| |
|
| | |
| | processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance") |
| | model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-small-coco-instance") |
| | model.eval() |
| |
|
| | |
| | COCO_INSTANCE_CATEGORY_NAMES = model.config.id2label if hasattr(model.config, "id2label") else [str(i) for i in range(133)] |
| |
|
| | def segment_image(image, threshold=0.5): |
| | inputs = processor(images=image, return_tensors="pt") |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| |
|
| | results = processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0] |
| |
|
| | segmentation_map = results["segmentation"].cpu().numpy() |
| | segments_info = results["segments_info"] |
| |
|
| | image_np = np.array(image).copy() |
| | overlay = image_np.copy() |
| | fig, ax = plt.subplots(1, figsize=(10, 10)) |
| | ax.imshow(image_np) |
| |
|
| | for segment in segments_info: |
| | score = segment.get("score", 1.0) |
| | if score < threshold: |
| | continue |
| |
|
| | segment_id = segment["id"] |
| | label_id = segment["label_id"] |
| | mask = segmentation_map == segment_id |
| |
|
| | |
| | color = np.random.rand(3) |
| | overlay[mask] = (overlay[mask] * 0.5 + np.array(color) * 255 * 0.5).astype(np.uint8) |
| |
|
| | |
| | y_indices, x_indices = np.where(mask) |
| | if len(x_indices) == 0 or len(y_indices) == 0: |
| | continue |
| | x1, x2 = x_indices.min(), x_indices.max() |
| | y1, y2 = y_indices.min(), y_indices.max() |
| |
|
| | label_name = COCO_INSTANCE_CATEGORY_NAMES.get(str(label_id), str(label_id)) |
| | ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, color=color, linewidth=2)) |
| | ax.text(x1, y1, f"{label_name}: {score:.2f}", |
| | bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10) |
| |
|
| | ax.imshow(overlay) |
| | ax.axis('off') |
| | output_path = "mask2former_output.png" |
| | plt.savefig(output_path, bbox_inches='tight', pad_inches=0) |
| | plt.close() |
| | return output_path |
| |
|
| |
|
| |
|
| | |
| | interface = gr.Interface( |
| | fn=segment_image, |
| | inputs=[ |
| | gr.Image(type="pil", label="Upload Image"), |
| | gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Confidence Threshold") |
| | ], |
| | outputs=gr.Image(type="filepath", label="Segmented Output"), |
| | title="Mask2Former Instance Segmentation (Transformer)", |
| | description="Upload an image to segment objects using Facebook's transformer-based Mask2Former model (Swin-Small backbone)." |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | interface.launch(debug=True,share=True) |
| |
|