Spaces:
Runtime error
Runtime error
| import os | |
| import PIL | |
| from functools import lru_cache | |
| from random import randint | |
| import gradio as gr | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| from segment_anything import sam_model_registry, SamAutomaticMaskGenerator | |
| from typing import List | |
| CHECKPOINT_PATH = "sam_vit_h_4b8939.pth" | |
| MODEL_TYPE = "default" | |
| MAX_WIDTH = MAX_HEIGHT = 800 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_mask_generator(model_size: str = "large") -> SamAutomaticMaskGenerator: | |
| sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device) | |
| mask_generator = SamAutomaticMaskGenerator(sam) | |
| return mask_generator | |
| def adjust_image_size(image: np.ndarray) -> np.ndarray: | |
| height, width = image.shape[:2] | |
| if height > width: | |
| if height > MAX_HEIGHT: | |
| height, width = MAX_HEIGHT, int(MAX_HEIGHT / height * width) | |
| else: | |
| if width > MAX_WIDTH: | |
| height, width = int(MAX_WIDTH / width * height), MAX_WIDTH | |
| image = cv2.resize(image, (width, height)) | |
| print(image.shape) | |
| return image | |
| def draw_masks( | |
| image: np.ndarray, masks: List[np.ndarray], alpha: float = 0.7 | |
| ) -> np.ndarray: | |
| for mask in masks: | |
| color = [randint(127, 255) for _ in range(3)] | |
| segmentation = mask["segmentation"] | |
| # draw mask overlay | |
| colored_seg = np.expand_dims(segmentation, 0).repeat(3, axis=0) | |
| colored_seg = np.moveaxis(colored_seg, 0, -1) | |
| masked = np.ma.MaskedArray(image, mask=colored_seg, fill_value=color) | |
| image_overlay = masked.filled() | |
| image = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0) | |
| # draw contour | |
| contours, _ = cv2.findContours( | |
| np.uint8(segmentation), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE | |
| ) | |
| cv2.drawContours(image, contours, -1, (255, 0, 0), 2) | |
| return image | |
| def segment(image_path: str, query: str) -> PIL.ImageFile.ImageFile: | |
| mask_generator = load_mask_generator() | |
| # reduce the size to save gpu memory | |
| image = adjust_image_size(cv2.imread(image_path)) | |
| masks = mask_generator.generate(image) | |
| image = draw_masks(image, masks) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| image = PIL.Image.fromarray(np.uint8(image)).convert("RGB") | |
| return image | |
| demo = gr.Interface( | |
| fn=segment, | |
| inputs=[gr.Image(type="filepath"), "text"], | |
| outputs="image", | |
| allow_flagging="never", | |
| title="Segment Anything with CLIP", | |
| examples=[ | |
| [os.path.join(os.path.dirname(__file__), "examples/dog.jpg"), ""], | |
| [os.path.join(os.path.dirname(__file__), "examples/city.jpg"), ""], | |
| [os.path.join(os.path.dirname(__file__), "examples/food.jpg"), ""], | |
| [os.path.join(os.path.dirname(__file__), "examples/horse.jpg"), ""], | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |