| import sys |
|
|
| import modules.config |
| import numpy as np |
| import torch |
| from extras.GroundingDINO.util.inference import default_groundingdino |
| from extras.sam.predictor import SamPredictor |
| from rembg import remove, new_session |
| from segment_anything import sam_model_registry |
| from segment_anything.utils.amg import remove_small_regions |
|
|
|
|
| class SAMOptions: |
| def __init__(self, |
| |
| dino_prompt: str = '', |
| dino_box_threshold=0.3, |
| dino_text_threshold=0.25, |
| dino_erode_or_dilate=0, |
| dino_debug=False, |
| |
| |
| max_detections=2, |
| model_type='vit_b' |
| ): |
| self.dino_prompt = dino_prompt |
| self.dino_box_threshold = dino_box_threshold |
| self.dino_text_threshold = dino_text_threshold |
| self.dino_erode_or_dilate = dino_erode_or_dilate |
| self.dino_debug = dino_debug |
| self.max_detections = max_detections |
| self.model_type = model_type |
|
|
|
|
| def optimize_masks(masks: torch.Tensor) -> torch.Tensor: |
| """ |
| removes small disconnected regions and holes |
| """ |
| fine_masks = [] |
| for mask in masks.to('cpu').numpy(): |
| fine_masks.append(remove_small_regions(mask[0], 400, mode="holes")[0]) |
| masks = np.stack(fine_masks, axis=0)[:, np.newaxis] |
| return torch.from_numpy(masks) |
|
|
|
|
| def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=None, |
| sam_options: SAMOptions | None = SAMOptions) -> tuple[np.ndarray | None, int | None, int | None, int | None]: |
| dino_detection_count = 0 |
| sam_detection_count = 0 |
| sam_detection_on_mask_count = 0 |
|
|
| if image is None: |
| return None, dino_detection_count, sam_detection_count, sam_detection_on_mask_count |
|
|
| if extras is None: |
| extras = {} |
|
|
| if 'image' in image: |
| image = image['image'] |
|
|
| if mask_model != 'sam' or sam_options is None: |
| result = remove( |
| image, |
| session=new_session(mask_model, **extras), |
| only_mask=True, |
| **extras |
| ) |
|
|
| return result, dino_detection_count, sam_detection_count, sam_detection_on_mask_count |
|
|
| detections, boxes, logits, phrases = default_groundingdino( |
| image=image, |
| caption=sam_options.dino_prompt, |
| box_threshold=sam_options.dino_box_threshold, |
| text_threshold=sam_options.dino_text_threshold |
| ) |
|
|
| H, W = image.shape[0], image.shape[1] |
| boxes = boxes * torch.Tensor([W, H, W, H]) |
| boxes[:, :2] = boxes[:, :2] - boxes[:, 2:] / 2 |
| boxes[:, 2:] = boxes[:, 2:] + boxes[:, :2] |
|
|
| sam_checkpoint = modules.config.download_sam_model(sam_options.model_type) |
| sam = sam_model_registry[sam_options.model_type](checkpoint=sam_checkpoint) |
|
|
| sam_predictor = SamPredictor(sam) |
| final_mask_tensor = torch.zeros((image.shape[0], image.shape[1])) |
| dino_detection_count = boxes.size(0) |
|
|
| if dino_detection_count > 0: |
| sam_predictor.set_image(image) |
|
|
| if sam_options.dino_erode_or_dilate != 0: |
| for index in range(boxes.size(0)): |
| assert boxes.size(1) == 4 |
| boxes[index][0] -= sam_options.dino_erode_or_dilate |
| boxes[index][1] -= sam_options.dino_erode_or_dilate |
| boxes[index][2] += sam_options.dino_erode_or_dilate |
| boxes[index][3] += sam_options.dino_erode_or_dilate |
|
|
| if sam_options.dino_debug: |
| from PIL import ImageDraw, Image |
| debug_dino_image = Image.new("RGB", (image.shape[1], image.shape[0]), color="black") |
| draw = ImageDraw.Draw(debug_dino_image) |
| for box in boxes.numpy(): |
| draw.rectangle(box.tolist(), fill="white") |
| return np.array(debug_dino_image), dino_detection_count, sam_detection_count, sam_detection_on_mask_count |
|
|
| transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image.shape[:2]) |
| masks, _, _ = sam_predictor.predict_torch( |
| point_coords=None, |
| point_labels=None, |
| boxes=transformed_boxes, |
| multimask_output=False, |
| ) |
|
|
| masks = optimize_masks(masks) |
| sam_detection_count = len(masks) |
| if sam_options.max_detections == 0: |
| sam_options.max_detections = sys.maxsize |
| sam_objects = min(len(logits), sam_options.max_detections) |
| for obj_ind in range(sam_objects): |
| mask_tensor = masks[obj_ind][0] |
| final_mask_tensor += mask_tensor |
| sam_detection_on_mask_count += 1 |
|
|
| final_mask_tensor = (final_mask_tensor > 0).to('cpu').numpy() |
| mask_image = np.dstack((final_mask_tensor, final_mask_tensor, final_mask_tensor)) * 255 |
| mask_image = np.array(mask_image, dtype=np.uint8) |
| return mask_image, dino_detection_count, sam_detection_count, sam_detection_on_mask_count |
|
|