| | import torch |
| | from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator |
| | from PIL import Image |
| | from matplotlib import pyplot as plt |
| | import numpy as np |
| | import cv2 |
| | from glob import glob |
| | import gradio as gr |
| | import os |
| |
|
| |
|
| |
|
| | def show_example(path): |
| | return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) |
| |
|
| |
|
| |
|
| | def overlay_masks_on_image(image, anns, borders=True): |
| | """ |
| | Overlays segmentation masks from 'anns' on top of 'image'. |
| | |
| | Parameters: |
| | image: np.ndarray (H, W, 3) β source RGB image |
| | anns: list of dicts β each with a 'segmentation' key containing a boolean mask |
| | borders: bool β whether to draw contours |
| | show_mask: bool β whether to show each mask separately |
| | |
| | Returns: |
| | masked_image: np.ndarray (H, W, 3) β image with overlays |
| | """ |
| | if len(anns) == 0: |
| | return image |
| |
|
| | |
| | masked_image = image.copy().astype(np.float32) / 255.0 |
| |
|
| | sorted_anns = sorted(anns, key=lambda x: x['area'], reverse=True) |
| |
|
| | for ann in sorted_anns: |
| | m = ann['segmentation'].astype(bool) |
| | color_mask = np.random.random(3) |
| | alpha = 0.5 |
| |
|
| |
|
| |
|
| | |
| | for c in range(3): |
| | masked_image[:, :, c] = np.where( |
| | m, |
| | (1 - alpha) * masked_image[:, :, c] + alpha * color_mask[c], |
| | masked_image[:, :, c] |
| | ) |
| |
|
| | if borders: |
| | contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) |
| | contours = [cv2.approxPolyDP(contour, epsilon=0.01 * cv2.arcLength(contour, True), closed=True) |
| | for contour in contours] |
| | cv2.drawContours(masked_image, contours, -1, color=(0, 0, 1), thickness=1) |
| |
|
| | return (masked_image * 255).astype(np.uint8) |
| |
|
| | def get_response(image): |
| | image = np.array(image.convert("RGB")) |
| | masks = mask_generator.generate(image) |
| | return overlay_masks_on_image(image,masks) |
| |
|
| | def download_checkpoint(): |
| | os.system('gdown 1RHSO8lHko3IK3dmABOzFDJuq7wmKVcun') |
| |
|
| |
|
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | iface = gr.Interface( |
| | cache_examples=False, |
| | fn=get_response, |
| | inputs=[gr.Image(type="pil")], |
| | examples=[[show_example('test-images/5fc8c5b53c.png')],[show_example('test-images/80719af02f.png')],[show_example('test-images/f32c7bd62b.png')]], |
| | outputs=[gr.Image(type="numpy")], |
| | title="Segmenting Microscopic images with Segment Anything", |
| | description="Segmenting Microscopic images with Meta Segment Anything") |
| |
|
| | model_path='model.pth' |
| | if not os.path.exists(model_path): |
| | print('Downloading model with weights') |
| | download_checkpoint() |
| | print('Model with weights Downloaded') |
| |
|
| | model = torch.load(model_path, map_location="cpu", weights_only=False) |
| | mask_generator = SAM2AutomaticMaskGenerator(model) |
| |
|
| | iface.launch() |