|
|
|
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import random |
|
|
from dataclasses import dataclass |
|
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import requests |
|
|
import torch |
|
|
from PIL import Image |
|
|
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline |
|
|
|
|
|
|
|
|
def create_palette(): |
|
|
|
|
|
palette = [ |
|
|
0, |
|
|
0, |
|
|
0, |
|
|
255, |
|
|
0, |
|
|
0, |
|
|
0, |
|
|
255, |
|
|
0, |
|
|
0, |
|
|
0, |
|
|
255, |
|
|
255, |
|
|
255, |
|
|
0, |
|
|
255, |
|
|
0, |
|
|
255, |
|
|
0, |
|
|
255, |
|
|
255, |
|
|
128, |
|
|
0, |
|
|
0, |
|
|
0, |
|
|
128, |
|
|
0, |
|
|
0, |
|
|
0, |
|
|
128, |
|
|
128, |
|
|
128, |
|
|
0, |
|
|
128, |
|
|
0, |
|
|
128, |
|
|
0, |
|
|
128, |
|
|
128, |
|
|
64, |
|
|
0, |
|
|
0, |
|
|
0, |
|
|
64, |
|
|
0, |
|
|
0, |
|
|
0, |
|
|
64, |
|
|
64, |
|
|
64, |
|
|
0, |
|
|
64, |
|
|
0, |
|
|
64, |
|
|
0, |
|
|
64, |
|
|
64, |
|
|
192, |
|
|
192, |
|
|
192, |
|
|
128, |
|
|
128, |
|
|
128, |
|
|
255, |
|
|
165, |
|
|
0, |
|
|
75, |
|
|
0, |
|
|
130, |
|
|
238, |
|
|
130, |
|
|
238, |
|
|
] |
|
|
|
|
|
palette.extend([0] * (768 - len(palette))) |
|
|
return palette |
|
|
|
|
|
|
|
|
PALETTE = create_palette() |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class BoundingBox: |
|
|
xmin: int |
|
|
ymin: int |
|
|
xmax: int |
|
|
ymax: int |
|
|
|
|
|
@property |
|
|
def xyxy(self) -> List[float]: |
|
|
return [self.xmin, self.ymin, self.xmax, self.ymax] |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DetectionResult: |
|
|
score: Optional[float] = None |
|
|
label: Optional[str] = None |
|
|
box: Optional[BoundingBox] = None |
|
|
mask: Optional[np.array] = None |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, detection_dict: Dict) -> "DetectionResult": |
|
|
return cls( |
|
|
score=detection_dict["score"], |
|
|
label=detection_dict["label"], |
|
|
box=BoundingBox( |
|
|
xmin=detection_dict["box"]["xmin"], |
|
|
ymin=detection_dict["box"]["ymin"], |
|
|
xmax=detection_dict["box"]["xmax"], |
|
|
ymax=detection_dict["box"]["ymax"], |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def mask_to_polygon(mask: np.ndarray) -> List[List[int]]: |
|
|
|
|
|
contours, _ = cv2.findContours( |
|
|
mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE |
|
|
) |
|
|
|
|
|
|
|
|
largest_contour = max(contours, key=cv2.contourArea) |
|
|
|
|
|
|
|
|
polygon = largest_contour.reshape(-1, 2).tolist() |
|
|
|
|
|
return polygon |
|
|
|
|
|
|
|
|
def polygon_to_mask( |
|
|
polygon: List[Tuple[int, int]], image_shape: Tuple[int, int] |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Convert a polygon to a segmentation mask. |
|
|
|
|
|
Args: |
|
|
- polygon (list): List of (x, y) coordinates representing the vertices of the polygon. |
|
|
- image_shape (tuple): Shape of the image (height, width) for the mask. |
|
|
|
|
|
Returns: |
|
|
- np.ndarray: Segmentation mask with the polygon filled. |
|
|
""" |
|
|
|
|
|
mask = np.zeros(image_shape, dtype=np.uint8) |
|
|
|
|
|
|
|
|
pts = np.array(polygon, dtype=np.int32) |
|
|
|
|
|
|
|
|
cv2.fillPoly(mask, [pts], color=(255,)) |
|
|
|
|
|
return mask |
|
|
|
|
|
|
|
|
def load_image(image_str: str) -> Image.Image: |
|
|
if image_str.startswith("http"): |
|
|
image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB") |
|
|
else: |
|
|
image = Image.open(image_str).convert("RGB") |
|
|
|
|
|
return image |
|
|
|
|
|
|
|
|
def get_boxes(results: DetectionResult) -> List[List[List[float]]]: |
|
|
boxes = [] |
|
|
for result in results: |
|
|
xyxy = result.box.xyxy |
|
|
boxes.append(xyxy) |
|
|
|
|
|
return [boxes] |
|
|
|
|
|
|
|
|
def refine_masks( |
|
|
masks: torch.BoolTensor, polygon_refinement: bool = False |
|
|
) -> List[np.ndarray]: |
|
|
masks = masks.cpu().float() |
|
|
masks = masks.permute(0, 2, 3, 1) |
|
|
masks = masks.mean(axis=-1) |
|
|
masks = (masks > 0).int() |
|
|
masks = masks.numpy().astype(np.uint8) |
|
|
masks = list(masks) |
|
|
|
|
|
if polygon_refinement: |
|
|
for idx, mask in enumerate(masks): |
|
|
shape = mask.shape |
|
|
polygon = mask_to_polygon(mask) |
|
|
mask = polygon_to_mask(polygon, shape) |
|
|
masks[idx] = mask |
|
|
|
|
|
return masks |
|
|
|
|
|
|
|
|
|
|
|
def generate_colored_segmentation(label_image): |
|
|
|
|
|
label_image_pil = Image.fromarray(label_image.astype(np.uint8), mode="P") |
|
|
|
|
|
|
|
|
palette = create_palette() |
|
|
label_image_pil.putpalette(palette) |
|
|
|
|
|
return label_image_pil |
|
|
|
|
|
|
|
|
def plot_segmentation(image, detections): |
|
|
seg_map = np.zeros(image.size[::-1], dtype=np.uint8) |
|
|
for i, detection in enumerate(detections): |
|
|
mask = detection.mask |
|
|
seg_map[mask > 0] = i + 1 |
|
|
seg_map_pil = generate_colored_segmentation(seg_map) |
|
|
return seg_map_pil |
|
|
|
|
|
|
|
|
|
|
|
def prepare_model( |
|
|
device: str = "cuda", |
|
|
detector_id: Optional[str] = None, |
|
|
segmenter_id: Optional[str] = None, |
|
|
): |
|
|
detector_id = ( |
|
|
detector_id if detector_id is not None else "IDEA-Research/grounding-dino-tiny" |
|
|
) |
|
|
object_detector = pipeline( |
|
|
model=detector_id, task="zero-shot-object-detection", device=device |
|
|
) |
|
|
|
|
|
segmenter_id = segmenter_id if segmenter_id is not None else "facebook/sam-vit-base" |
|
|
processor = AutoProcessor.from_pretrained(segmenter_id) |
|
|
segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device) |
|
|
|
|
|
return object_detector, processor, segmentator |
|
|
|
|
|
|
|
|
def detect( |
|
|
object_detector: Any, |
|
|
image: Image.Image, |
|
|
labels: List[str], |
|
|
threshold: float = 0.3, |
|
|
) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion. |
|
|
""" |
|
|
labels = [label if label.endswith(".") else label + "." for label in labels] |
|
|
|
|
|
results = object_detector(image, candidate_labels=labels, threshold=threshold) |
|
|
results = [DetectionResult.from_dict(result) for result in results] |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def segment( |
|
|
processor: Any, |
|
|
segmentator: Any, |
|
|
image: Image.Image, |
|
|
boxes: Optional[List[List[List[float]]]] = None, |
|
|
detection_results: Optional[List[Dict[str, Any]]] = None, |
|
|
polygon_refinement: bool = False, |
|
|
) -> List[DetectionResult]: |
|
|
""" |
|
|
Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes. |
|
|
""" |
|
|
if detection_results is None and boxes is None: |
|
|
raise ValueError( |
|
|
"Either detection_results or detection_boxes must be provided." |
|
|
) |
|
|
|
|
|
if boxes is None: |
|
|
boxes = get_boxes(detection_results) |
|
|
|
|
|
inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to( |
|
|
segmentator.device, segmentator.dtype |
|
|
) |
|
|
|
|
|
outputs = segmentator(**inputs) |
|
|
masks = processor.post_process_masks( |
|
|
masks=outputs.pred_masks, |
|
|
original_sizes=inputs.original_sizes, |
|
|
reshaped_input_sizes=inputs.reshaped_input_sizes, |
|
|
)[0] |
|
|
|
|
|
masks = refine_masks(masks, polygon_refinement) |
|
|
|
|
|
if detection_results is None: |
|
|
detection_results = [DetectionResult() for _ in masks] |
|
|
|
|
|
for detection_result, mask in zip(detection_results, masks): |
|
|
detection_result.mask = mask |
|
|
|
|
|
return detection_results |
|
|
|
|
|
|
|
|
def grounded_segmentation( |
|
|
object_detector, |
|
|
processor, |
|
|
segmentator, |
|
|
image: Union[Image.Image, str], |
|
|
labels: Union[str, List[str]], |
|
|
threshold: float = 0.3, |
|
|
polygon_refinement: bool = False, |
|
|
) -> Tuple[np.ndarray, List[DetectionResult], Image.Image]: |
|
|
if isinstance(image, str): |
|
|
image = load_image(image) |
|
|
if isinstance(labels, str): |
|
|
labels = labels.split(",") |
|
|
|
|
|
detections = detect(object_detector, image, labels, threshold) |
|
|
detections = segment(processor, segmentator, image, detections, polygon_refinement) |
|
|
|
|
|
seg_map_pil = plot_segmentation(image, detections) |
|
|
|
|
|
return np.array(image), detections, seg_map_pil |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--image", type=str, required=True) |
|
|
parser.add_argument("--labels", type=str, nargs="+", required=True) |
|
|
parser.add_argument("--output", type=str, default="./", help="Output directory") |
|
|
parser.add_argument("--threshold", type=float, default=0.3) |
|
|
parser.add_argument( |
|
|
"--detector_id", type=str, default="IDEA-Research/grounding-dino-base" |
|
|
) |
|
|
parser.add_argument("--segmenter_id", type=str, default="facebook/sam-vit-base") |
|
|
args = parser.parse_args() |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
object_detector, processor, segmentator = prepare_model( |
|
|
device=device, detector_id=args.detector_id, segmenter_id=args.segmenter_id |
|
|
) |
|
|
|
|
|
image_array, detections, seg_map_pil = grounded_segmentation( |
|
|
object_detector, |
|
|
processor, |
|
|
segmentator, |
|
|
image=args.image, |
|
|
labels=args.labels, |
|
|
threshold=args.threshold, |
|
|
polygon_refinement=True, |
|
|
) |
|
|
|
|
|
os.makedirs(args.output, exist_ok=True) |
|
|
seg_map_pil.save(os.path.join(args.output, "segmentation.png")) |
|
|
|