Spaces:
Runtime error
Runtime error
| import os | |
| import urllib | |
| from functools import lru_cache | |
| from random import randint | |
| from typing import Any, Callable, Dict, List, Tuple | |
| import clip | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import PIL | |
| import torch | |
| from segment_anything import SamAutomaticMaskGenerator, sam_model_registry | |
| CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), ".cache", "SAM") | |
| CHECKPOINT_NAME = "sam_vit_h_4b8939.pth" | |
| CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" | |
| MODEL_TYPE = "default" | |
| MAX_WIDTH = MAX_HEIGHT = 1024 | |
| TOP_K_OBJ = 100 | |
| THRESHOLD = 0.85 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_mask_generator() -> SamAutomaticMaskGenerator: | |
| if not os.path.exists(CHECKPOINT_PATH): | |
| os.makedirs(CHECKPOINT_PATH) | |
| checkpoint = os.path.join(CHECKPOINT_PATH, CHECKPOINT_NAME) | |
| if not os.path.exists(checkpoint): | |
| urllib.request.urlretrieve(CHECKPOINT_URL, checkpoint) | |
| sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint).to(device) | |
| mask_generator = SamAutomaticMaskGenerator(sam) | |
| return mask_generator | |
| def load_clip( | |
| name: str = "ViT-B/32", | |
| ) -> Tuple[torch.nn.Module, Callable[[PIL.Image.Image], torch.Tensor]]: | |
| model, preprocess = clip.load(name, device=device) | |
| return model.to(device), preprocess | |
| def adjust_image_size(image: np.ndarray) -> np.ndarray: | |
| height, width = image.shape[:2] | |
| if height > width: | |
| if height > MAX_HEIGHT: | |
| height, width = MAX_HEIGHT, int(MAX_HEIGHT / height * width) | |
| else: | |
| if width > MAX_WIDTH: | |
| height, width = int(MAX_WIDTH / width * height), MAX_WIDTH | |
| image = cv2.resize(image, (width, height)) | |
| return image | |
| def get_score(crop: PIL.Image.Image, texts: List[str]) -> torch.Tensor: | |
| model, preprocess = load_clip() | |
| preprocessed = preprocess(crop).unsqueeze(0).to(device) | |
| tokens = clip.tokenize(texts).to(device) | |
| logits_per_image, _ = model(preprocessed, tokens) | |
| similarity = logits_per_image.softmax(-1).cpu() | |
| return similarity[0, 0] | |
| def crop_image(image: np.ndarray, mask: Dict[str, Any]) -> PIL.Image.Image: | |
| x, y, w, h = mask["bbox"] | |
| masked = image * np.expand_dims(mask["segmentation"], -1) | |
| crop = masked[y : y + h, x : x + w] | |
| if h > w: | |
| top, bottom, left, right = 0, 0, (h - w) // 2, (h - w) // 2 | |
| else: | |
| top, bottom, left, right = (w - h) // 2, (w - h) // 2, 0, 0 | |
| # padding | |
| crop = cv2.copyMakeBorder( | |
| crop, | |
| top, | |
| bottom, | |
| left, | |
| right, | |
| cv2.BORDER_CONSTANT, | |
| value=(0, 0, 0), | |
| ) | |
| crop = PIL.Image.fromarray(crop) | |
| return crop | |
| def get_texts(query: str) -> List[str]: | |
| return [f"a picture of {query}", "a picture of background"] | |
| def filter_masks( | |
| image: np.ndarray, | |
| masks: List[Dict[str, Any]], | |
| predicted_iou_threshold: float, | |
| stability_score_threshold: float, | |
| query: str, | |
| clip_threshold: float, | |
| ) -> List[Dict[str, Any]]: | |
| filtered_masks: List[Dict[str, Any]] = [] | |
| for mask in sorted(masks, key=lambda mask: mask["area"])[-TOP_K_OBJ:]: | |
| if ( | |
| mask["predicted_iou"] < predicted_iou_threshold | |
| or mask["stability_score"] < stability_score_threshold | |
| or image.shape[:2] != mask["segmentation"].shape[:2] | |
| or query | |
| and get_score(crop_image(image, mask), get_texts(query)) < clip_threshold | |
| ): | |
| continue | |
| filtered_masks.append(mask) | |
| return filtered_masks | |
| def draw_masks( | |
| image: np.ndarray, masks: List[np.ndarray], alpha: float = 0.7 | |
| ) -> np.ndarray: | |
| for mask in masks: | |
| color = [randint(127, 255) for _ in range(3)] | |
| # draw mask overlay | |
| colored_mask = np.expand_dims(mask["segmentation"], 0).repeat(3, axis=0) | |
| colored_mask = np.moveaxis(colored_mask, 0, -1) | |
| masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color) | |
| image_overlay = masked.filled() | |
| image = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0) | |
| # draw contour | |
| contours, _ = cv2.findContours( | |
| np.uint8(mask["segmentation"]), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE | |
| ) | |
| cv2.drawContours(image, contours, -1, (0, 0, 255), 2) | |
| return image | |
| def segment( | |
| predicted_iou_threshold: float, | |
| stability_score_threshold: float, | |
| clip_threshold: float, | |
| image_path: str, | |
| query: str, | |
| ) -> PIL.ImageFile.ImageFile: | |
| mask_generator = load_mask_generator() | |
| image = cv2.imread(image_path, cv2.IMREAD_COLOR) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| # reduce the size to save gpu memory | |
| image = adjust_image_size(image) | |
| masks = mask_generator.generate(image) | |
| masks = filter_masks( | |
| image, | |
| masks, | |
| predicted_iou_threshold, | |
| stability_score_threshold, | |
| query, | |
| clip_threshold, | |
| ) | |
| image = draw_masks(image, masks) | |
| image = PIL.Image.fromarray(image) | |
| return image | |
| demo = gr.Interface( | |
| fn=segment, | |
| inputs=[ | |
| gr.Slider(0, 1, value=0.9, label="predicted_iou_threshold"), | |
| gr.Slider(0, 1, value=0.8, label="stability_score_threshold"), | |
| gr.Slider(0, 1, value=0.85, label="clip_threshold"), | |
| gr.Image(type="filepath"), | |
| "text", | |
| ], | |
| outputs="image", | |
| allow_flagging="never", | |
| title="Segment Anything with CLIP", | |
| examples=[ | |
| [ | |
| 0.9, | |
| 0.8, | |
| 0.99, | |
| os.path.join(os.path.dirname(__file__), "examples/dog.jpg"), | |
| "dog", | |
| ], | |
| [ | |
| 0.9, | |
| 0.8, | |
| 0.75, | |
| os.path.join(os.path.dirname(__file__), "examples/city.jpg"), | |
| "building", | |
| ], | |
| [ | |
| 0.9, | |
| 0.8, | |
| 0.998, | |
| os.path.join(os.path.dirname(__file__), "examples/food.jpg"), | |
| "strawberry", | |
| ], | |
| [ | |
| 0.9, | |
| 0.8, | |
| 0.75, | |
| os.path.join(os.path.dirname(__file__), "examples/horse.jpg"), | |
| "horse", | |
| ], | |
| [ | |
| 0.9, | |
| 0.8, | |
| 0.99, | |
| os.path.join(os.path.dirname(__file__), "examples/bears.jpg"), | |
| "bear", | |
| ], | |
| [ | |
| 0.9, | |
| 0.8, | |
| 0.99, | |
| os.path.join(os.path.dirname(__file__), "examples/cats.jpg"), | |
| "cat", | |
| ], | |
| [ | |
| 0.9, | |
| 0.8, | |
| 0.99, | |
| os.path.join(os.path.dirname(__file__), "examples/fish.jpg"), | |
| "fish", | |
| ], | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |