Spaces:
Sleeping
Sleeping
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from transformers import AutoImageProcessor, MobileViTForSemanticSegmentation | |
| import gradio as gr | |
| # --------------------------- | |
| # Load model & processor | |
| # --------------------------- | |
| model_checkpoint = "apple/deeplabv3-mobilevit-small" | |
| image_processor = AutoImageProcessor.from_pretrained(model_checkpoint) | |
| model = MobileViTForSemanticSegmentation.from_pretrained(model_checkpoint).eval() | |
| palette = np.array( | |
| [ | |
| [ 0, 0, 0], [192, 0, 0], [ 0, 192, 0], [192, 192, 0], | |
| [ 0, 0, 192], [192, 0, 192], [ 0, 192, 192], [192, 192, 192], | |
| [128, 0, 0], [255, 0, 0], [128, 192, 0], [255, 192, 0], | |
| [128, 0, 192], [255, 0, 192], [128, 192, 192], [255, 192, 192], | |
| [ 0, 128, 0], [192, 128, 0], [ 0, 255, 0], [192, 255, 0], | |
| [ 0, 128, 192] | |
| ], | |
| dtype=np.uint8, | |
| ) | |
| labels = [ | |
| "background","aeroplane","bicycle","bird","boat","bottle","bus","car","cat","chair", | |
| "cow","diningtable","dog","horse","motorbike","person","pottedplant","sheep", | |
| "sofa","train","tvmonitor", | |
| ] | |
| # --------------------------- | |
| # Prediction Function | |
| # --------------------------- | |
| def predict(image): | |
| if image is None: | |
| return None, None | |
| with torch.no_grad(): | |
| inputs = image_processor(image, return_tensors="pt") | |
| outputs = model(**inputs) | |
| # Re-normalize back to uint8 | |
| resized = ( | |
| inputs["pixel_values"] | |
| .numpy() | |
| .squeeze() | |
| .transpose(1, 2, 0)[..., ::-1] * 255 | |
| ).astype(np.uint8) | |
| # Class map | |
| classes = outputs.logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8) | |
| # Vectorized lookup table coloring | |
| colored = palette[classes] | |
| # Resize segmentation to match resized input | |
| colored_img = Image.fromarray(colored).resize( | |
| (resized.shape[1], resized.shape[0]), | |
| resample=Image.Resampling.NEAREST | |
| ) | |
| # Binary mask for overlay | |
| mask = (classes != 0).astype(np.uint8) * 255 | |
| mask_img = Image.fromarray(mask).resize( | |
| (resized.shape[1], resized.shape[0]), | |
| resample=Image.Resampling.NEAREST | |
| ).convert("RGB") | |
| resized_img = Image.fromarray(resized) | |
| highlighted = Image.blend(resized_img, mask_img, 0.4) | |
| return colored_img, highlighted | |
| # --------------------------- | |
| # Labels HTML | |
| # --------------------------- | |
| inverted = {0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20} | |
| labels_html = " ".join( | |
| f"<span style='background-color: rgb{tuple(palette[i])}; " | |
| f"color: {'white' if i in inverted else 'black'}; padding: 2px 4px;'>" | |
| f"{labels[i]}</span>" | |
| for i in range(len(labels)) | |
| ) | |
| description = f""" | |
| Semantic Segmentation with MobileViT + DeepLabV3 | |
| Model trained on Pascal VOC.<br><br> | |
| Classes:<br>{labels_html} | |
| """ | |
| article = """ | |
| <p>Sources:</p> | |
| <ul> | |
| <li><a href="https://arxiv.org/abs/2110.02178">MobileViT Paper</a></li> | |
| <li><a href="https://github.com/apple/ml-cvnets">Apple ML-CVnets</a></li> | |
| <li>Example images from <a href="https://huggingface.co/datasets/mishig/sample_images">Sample Images Dataset</a></li> | |
| </ul> | |
| """ | |
| # --------------------------- | |
| # Gradio App (Blocks) | |
| # --------------------------- | |
| with gr.Blocks(title="Semantic Segmentation with MobileViT") as demo: | |
| gr.Markdown("# Semantic Segmentation with MobileViT & DeepLabV3") | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| input_img = gr.Image(label="Upload Image", type="pil") | |
| output_mask = gr.Image(label="Segmentation Mask") | |
| output_overlay = gr.Image(label="Highlighted Overlay") | |
| run_btn = gr.Button("Run") | |
| run_btn.click( | |
| predict, | |
| inputs=input_img, | |
| outputs=[output_mask, output_overlay] | |
| ) | |
| gr.Markdown(article) | |
| gr.Examples( | |
| examples=[ | |
| ["cat-3.jpg"], | |
| ["construction-site.jpg"], | |
| ["dog-cat.jpg"], | |
| ["football-match.jpg"], | |
| ], | |
| inputs=input_img | |
| ) | |
| demo.launch() | |