| | import io |
| |
|
| | import cv2 |
| | import gradio as gr |
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import requests |
| | import torch |
| | from PIL import Image |
| | from transformers import AutoImageProcessor, EomtForUniversalSegmentation |
| |
|
| | |
| | print("Loading model...") |
| | model_id = "tue-mps/coco_panoptic_eomt_large_640" |
| | processor = AutoImageProcessor.from_pretrained(model_id) |
| | model = EomtForUniversalSegmentation.from_pretrained(model_id) |
| |
|
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | model = model.to(device) |
| | print(f"Model loaded successfully on {device}!") |
| |
|
| |
|
| | def run_inference(image): |
| | """Run panoptic segmentation inference""" |
| | inputs = processor(images=image, return_tensors="pt") |
| |
|
| | |
| | inputs = {k: v.to(device) for k, v in inputs.items()} |
| |
|
| | with torch.inference_mode(): |
| | outputs = model(**inputs) |
| |
|
| | target_sizes = [(image.height, image.width)] |
| | preds = processor.post_process_panoptic_segmentation( |
| | outputs, target_sizes=target_sizes |
| | ) |
| |
|
| | return preds[0] |
| |
|
| |
|
| | def tensor_to_numpy(tensor): |
| | """Convert tensor to numpy array, handling CUDA tensors""" |
| | if isinstance(tensor, torch.Tensor): |
| | return tensor.cpu().numpy() |
| | return tensor |
| |
|
| |
|
| | def visualize_mask(image, segmentation_mask): |
| | """Show segmentation mask only""" |
| | fig, ax = plt.subplots(1, 1, figsize=(12, 8)) |
| |
|
| | |
| | mask_np = tensor_to_numpy(segmentation_mask) |
| |
|
| | if mask_np.max() > 0: |
| | ax.imshow(mask_np, cmap="tab20") |
| | else: |
| | |
| | ax.imshow(np.zeros_like(mask_np), cmap="gray") |
| | ax.text( |
| | 0.5, |
| | 0.5, |
| | "No segments detected", |
| | transform=ax.transAxes, |
| | ha="center", |
| | va="center", |
| | fontsize=16, |
| | color="red", |
| | weight="bold", |
| | ) |
| |
|
| | ax.axis("off") |
| |
|
| | plt.tight_layout() |
| |
|
| | |
| | buf = io.BytesIO() |
| | plt.savefig(buf, format="png", bbox_inches="tight", dpi=150) |
| | buf.seek(0) |
| | plt.close() |
| |
|
| | return Image.open(buf) |
| |
|
| |
|
| | def visualize_overlay(image, segmentation_mask): |
| | """Show segmentation overlay on original image""" |
| | fig, ax = plt.subplots(1, 1, figsize=(12, 8)) |
| |
|
| | |
| | ax.imshow(image) |
| |
|
| | |
| | mask_np = tensor_to_numpy(segmentation_mask) |
| | ax.imshow(mask_np, cmap="tab20", alpha=0.6) |
| |
|
| | ax.axis("off") |
| |
|
| | plt.tight_layout() |
| |
|
| | |
| | buf = io.BytesIO() |
| | plt.savefig(buf, format="png", bbox_inches="tight", dpi=150) |
| | buf.seek(0) |
| | plt.close() |
| |
|
| | return Image.open(buf) |
| |
|
| |
|
| | def visualize_contours(image, segmentation_mask): |
| | """Show contours of segments on original image""" |
| | fig, ax = plt.subplots(1, 1, figsize=(12, 8)) |
| |
|
| | |
| | ax.imshow(image) |
| |
|
| | |
| | mask_np = tensor_to_numpy(segmentation_mask).astype(np.uint8) |
| |
|
| | |
| | unique_segments = np.unique(mask_np) |
| |
|
| | |
| | colors = plt.cm.tab20(np.linspace(0, 1, len(unique_segments))) |
| |
|
| | |
| | for i, segment_id in enumerate(unique_segments): |
| | if segment_id == 0: |
| | continue |
| |
|
| | |
| | binary_mask = (mask_np == segment_id).astype(np.uint8) |
| |
|
| | |
| | contours, _ = cv2.findContours( |
| | binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE |
| | ) |
| |
|
| | |
| | for contour in contours: |
| | if len(contour) > 2: |
| | contour = contour.reshape(-1, 2) |
| | ax.plot( |
| | contour[:, 0], |
| | contour[:, 1], |
| | color=colors[i % len(colors)], |
| | linewidth=2, |
| | alpha=0.8, |
| | ) |
| |
|
| | ax.axis("off") |
| |
|
| | plt.tight_layout() |
| |
|
| | |
| | buf = io.BytesIO() |
| | plt.savefig(buf, format="png", bbox_inches="tight", dpi=150) |
| | buf.seek(0) |
| | plt.close() |
| |
|
| | return Image.open(buf) |
| |
|
| |
|
| | def visualize_instance_masks(image, segmentation_mask): |
| | """Show individual instance masks in a grid""" |
| | mask_np = tensor_to_numpy(segmentation_mask) |
| | unique_segments, counts = np.unique(mask_np, return_counts=True) |
| |
|
| | |
| | non_bg_indices = unique_segments != 0 |
| | if np.any(non_bg_indices): |
| | top_segments = unique_segments[non_bg_indices][ |
| | np.argsort(counts[non_bg_indices])[-9:] |
| | ] |
| | top_counts = counts[non_bg_indices][np.argsort(counts[non_bg_indices])[-9:]] |
| | else: |
| | top_segments = [] |
| | top_counts = [] |
| |
|
| | fig, axes = plt.subplots(3, 3, figsize=(15, 15)) |
| | axes = axes.flatten() |
| |
|
| | for i, (segment_id, count) in enumerate(zip(top_segments, top_counts)): |
| | binary_mask = (mask_np == segment_id).astype(float) |
| | axes[i].imshow(binary_mask, cmap="Blues") |
| | axes[i].set_title( |
| | f"Segment {segment_id}\nPixels: {count}", fontsize=10, weight="bold" |
| | ) |
| | axes[i].axis("off") |
| |
|
| | |
| | for i in range(len(top_segments), 9): |
| | axes[i].axis("off") |
| | if i == 0 and len(top_segments) == 0: |
| | axes[i].text( |
| | 0.5, |
| | 0.5, |
| | "No segments\ndetected", |
| | transform=axes[i].transAxes, |
| | ha="center", |
| | va="center", |
| | fontsize=12, |
| | color="red", |
| | weight="bold", |
| | ) |
| |
|
| | plt.tight_layout() |
| |
|
| | |
| | buf = io.BytesIO() |
| | plt.savefig(buf, format="png", bbox_inches="tight", dpi=150) |
| | buf.seek(0) |
| | plt.close() |
| |
|
| | return Image.open(buf) |
| |
|
| |
|
| | def visualize_edges(image, segmentation_mask): |
| | """Show edge detection on segmentation boundaries""" |
| | mask_np = tensor_to_numpy(segmentation_mask) |
| |
|
| | fig, ax = plt.subplots(1, 1, figsize=(12, 8)) |
| |
|
| | |
| | ax.imshow(image) |
| |
|
| | |
| | if mask_np.max() > 0: |
| | edges = cv2.Canny((mask_np * 255 / mask_np.max()).astype(np.uint8), 50, 150) |
| |
|
| | |
| | edge_overlay = np.zeros((*edges.shape, 4)) |
| | edge_overlay[edges > 0] = [1, 1, 0, 1] |
| |
|
| | |
| | ax.imshow(edge_overlay) |
| | else: |
| | |
| | ax.text( |
| | 0.5, |
| | 0.5, |
| | "No segments detected", |
| | transform=ax.transAxes, |
| | ha="center", |
| | va="center", |
| | fontsize=16, |
| | color="red", |
| | weight="bold", |
| | ) |
| |
|
| | ax.axis("off") |
| |
|
| | plt.tight_layout() |
| |
|
| | |
| | buf = io.BytesIO() |
| | plt.savefig(buf, format="png", bbox_inches="tight", dpi=150) |
| | buf.seek(0) |
| | plt.close() |
| |
|
| | return Image.open(buf) |
| |
|
| |
|
| | def visualize_segment_isolation(image, segmentation_mask): |
| | """Show the largest segment isolated from the rest""" |
| | mask_np = tensor_to_numpy(segmentation_mask) |
| | unique_segments, counts = np.unique(mask_np, return_counts=True) |
| |
|
| | |
| | non_bg_indices = unique_segments != 0 |
| | if np.any(non_bg_indices): |
| | largest_segment = unique_segments[non_bg_indices][ |
| | np.argmax(counts[non_bg_indices]) |
| | ] |
| | largest_count = counts[non_bg_indices][np.argmax(counts[non_bg_indices])] |
| | else: |
| | largest_segment = unique_segments[np.argmax(counts)] |
| | largest_count = counts[np.argmax(counts)] |
| |
|
| | fig, ax = plt.subplots(1, 1, figsize=(12, 8)) |
| |
|
| | |
| | isolated_mask = (mask_np == largest_segment).astype(float) |
| |
|
| | if isolated_mask.max() > 0: |
| | ax.imshow(isolated_mask, cmap="Reds") |
| | ax.set_title( |
| | f"Largest Segment (ID: {largest_segment}, Pixels: {largest_count})", |
| | fontsize=14, |
| | weight="bold", |
| | pad=20, |
| | ) |
| | else: |
| | ax.imshow(np.zeros_like(isolated_mask), cmap="gray") |
| | ax.text( |
| | 0.5, |
| | 0.5, |
| | "No segments detected", |
| | transform=ax.transAxes, |
| | ha="center", |
| | va="center", |
| | fontsize=16, |
| | color="red", |
| | weight="bold", |
| | ) |
| |
|
| | ax.axis("off") |
| |
|
| | plt.tight_layout() |
| |
|
| | |
| | buf = io.BytesIO() |
| | plt.savefig(buf, format="png", bbox_inches="tight", dpi=150) |
| | buf.seek(0) |
| | plt.close() |
| |
|
| | return Image.open(buf) |
| |
|
| |
|
| | def visualize_heatmap(image, segmentation_mask): |
| | """Show boundary density heatmap""" |
| | mask_np = tensor_to_numpy(segmentation_mask) |
| |
|
| | fig, ax = plt.subplots(1, 1, figsize=(12, 8)) |
| |
|
| | if mask_np.max() > 0: |
| | |
| | gradient_magnitude = np.gradient(mask_np.astype(float)) |
| | gradient_magnitude = np.sqrt( |
| | gradient_magnitude[0] ** 2 + gradient_magnitude[1] ** 2 |
| | ) |
| |
|
| | |
| | im = ax.imshow(gradient_magnitude, cmap="hot") |
| | ax.set_title("Boundary Density Heatmap", fontsize=14, weight="bold", pad=20) |
| | ax.axis("off") |
| | plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label="Gradient Magnitude") |
| | else: |
| | |
| | ax.imshow(np.zeros_like(mask_np), cmap="hot") |
| | ax.text( |
| | 0.5, |
| | 0.5, |
| | "No segments detected", |
| | transform=ax.transAxes, |
| | ha="center", |
| | va="center", |
| | fontsize=16, |
| | color="red", |
| | weight="bold", |
| | ) |
| | ax.axis("off") |
| |
|
| | plt.tight_layout() |
| |
|
| | |
| | buf = io.BytesIO() |
| | plt.savefig(buf, format="png", bbox_inches="tight", dpi=150) |
| | buf.seek(0) |
| | plt.close() |
| |
|
| | return Image.open(buf) |
| |
|
| |
|
| | def create_visualization(image, viz_type): |
| | """Create visualization based on selected type""" |
| | if image is None: |
| | return None |
| |
|
| | try: |
| | |
| | prediction = run_inference(image) |
| | segmentation_mask = prediction["segmentation"] |
| |
|
| | if viz_type == "Mask": |
| | return visualize_mask(image, segmentation_mask) |
| | elif viz_type == "Overlay": |
| | return visualize_overlay(image, segmentation_mask) |
| | elif viz_type == "Contours": |
| | return visualize_contours(image, segmentation_mask) |
| | elif viz_type == "Instance Masks": |
| | return visualize_instance_masks(image, segmentation_mask) |
| | elif viz_type == "Edge Detection": |
| | return visualize_edges(image, segmentation_mask) |
| | elif viz_type == "Segment Isolation": |
| | return visualize_segment_isolation(image, segmentation_mask) |
| | elif viz_type == "Heatmap": |
| | return visualize_heatmap(image, segmentation_mask) |
| | else: |
| | |
| | return visualize_mask(image, segmentation_mask) |
| |
|
| | except Exception as e: |
| | print(f"Error in visualization: {e}") |
| | |
| | fig, ax = plt.subplots(1, 1, figsize=(12, 8)) |
| | ax.text( |
| | 0.5, |
| | 0.5, |
| | f"Error during processing:\n{str(e)}", |
| | transform=ax.transAxes, |
| | ha="center", |
| | va="center", |
| | fontsize=12, |
| | color="red", |
| | weight="bold", |
| | ) |
| | ax.axis("off") |
| |
|
| | buf = io.BytesIO() |
| | plt.savefig(buf, format="png", bbox_inches="tight", dpi=150) |
| | buf.seek(0) |
| | plt.close() |
| |
|
| | return Image.open(buf) |
| |
|
| |
|
| | def load_sample_image(img_path): |
| | """Load a sample image from URL""" |
| | try: |
| | response = requests.get(img_path, stream=True) |
| | response.raise_for_status() |
| | return Image.open(response.raw) |
| | except Exception as e: |
| | print(f"Error loading image: {e}") |
| | return None |
| |
|
| |
|
| | |
| | def create_interface(): |
| | with gr.Blocks( |
| | title="Panoptic Segmentation Visualizer", theme=gr.themes.Soft() |
| | ) as demo: |
| | gr.Markdown(""" |
| | # 🎨 Panoptic Segmentation Visualizer |
| | |
| | Upload an image and select a visualization type to see different ways of viewing the panoptic segmentation results. |
| | The model used is `tue-mps/coco_panoptic_eomt_large_640`. |
| | """) |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | image_input = gr.Image(label="Upload Image", type="pil", height=400) |
| |
|
| | viz_type = gr.Radio( |
| | choices=[ |
| | "Mask", |
| | "Overlay", |
| | "Contours", |
| | "Instance Masks", |
| | "Edge Detection", |
| | "Segment Isolation", |
| | "Heatmap", |
| | ], |
| | label="Visualization Type", |
| | value="Mask", |
| | info="Choose how to visualize the segmentation results", |
| | ) |
| |
|
| | process_btn = gr.Button( |
| | "🚀 Process Image", variant="primary", size="lg" |
| | ) |
| |
|
| | gr.Markdown(""" |
| | ### Visualization Types: |
| | - **Mask**: Segmentation mask with color-coded segments |
| | - **Overlay**: Transparent segmentation overlay on original image |
| | - **Contours**: Segment boundaries outlined on original image |
| | - **Instance Masks**: Individual instance masks in a grid (top 9 by size) |
| | - **Edge Detection**: Segmentation boundaries highlighted in yellow |
| | - **Segment Isolation**: Shows the largest segment isolated from the rest |
| | - **Heatmap**: Boundary density visualization with color mapping |
| | """) |
| |
|
| | with gr.Column(scale=2): |
| | output_image = gr.Image( |
| | label="Segmentation Result", type="pil", height=600 |
| | ) |
| |
|
| | process_btn.click( |
| | fn=create_visualization, |
| | inputs=[image_input, viz_type], |
| | outputs=output_image, |
| | ) |
| |
|
| | |
| | gr.Markdown("### 📸 Try with sample images:") |
| |
|
| | sample_images = [ |
| | ("http://images.cocodataset.org/val2017/000000039769.jpg", "Cats on Couch"), |
| | ("http://images.cocodataset.org/val2017/000000397133.jpg", "Street Scene"), |
| | ("http://images.cocodataset.org/val2017/000000037777.jpg", "Living Room"), |
| | ( |
| | "http://images.cocodataset.org/val2017/000000174482.jpg", |
| | "Person with Laptop", |
| | ), |
| | ("http://images.cocodataset.org/val2017/000000000785.jpg", "Dining Table"), |
| | ] |
| |
|
| | def create_thumbnail_gallery(): |
| | """Create a gallery of clickable thumbnails""" |
| | gallery_images = [] |
| | for img_url, img_name in sample_images: |
| | try: |
| | img = load_sample_image(img_url) |
| | if img: |
| | |
| | img.thumbnail((200, 200), Image.Resampling.LANCZOS) |
| | gallery_images.append((img, img_name)) |
| | except Exception as e: |
| | print(f"Failed to load {img_name}: {e}") |
| | continue |
| | return gallery_images |
| |
|
| | with gr.Row(): |
| | thumbnail_gallery = gr.Gallery( |
| | value=create_thumbnail_gallery(), |
| | label="Sample Images", |
| | show_label=True, |
| | elem_id="thumbnail_gallery", |
| | columns=5, |
| | rows=1, |
| | object_fit="contain", |
| | height=200, |
| | allow_preview=False, |
| | ) |
| |
|
| | def select_from_gallery(evt: gr.SelectData): |
| | """Handle gallery selection""" |
| | selected_idx = evt.index |
| | if selected_idx < len(sample_images): |
| | img_url, _ = sample_images[selected_idx] |
| | return load_sample_image(img_url) |
| | return None |
| |
|
| | thumbnail_gallery.select(select_from_gallery, outputs=image_input) |
| |
|
| | return demo |
| |
|
| |
|
| | if __name__ == "__main__": |
| | demo = create_interface() |
| | demo.launch(share=True, server_name="0.0.0.0", server_port=7860) |
| |
|