|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
from torch import nn
|
|
|
from torch.nn import functional as F
|
|
|
from typing import Any, Dict, List, Tuple
|
|
|
from .image_encoder import ImageEncoderViT
|
|
|
from .mask_decoder import MaskDecoder
|
|
|
from .prompt_encoder import PromptEncoder
|
|
|
|
|
|
|
|
|
class Sam(nn.Module):
|
|
|
mask_threshold: float = 0.0
|
|
|
image_format: str = "RGB"
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
image_encoder: ImageEncoderViT,
|
|
|
prompt_encoder: PromptEncoder,
|
|
|
mask_decoder: MaskDecoder,
|
|
|
pixel_mean: List[float] = [123.675, 116.28, 103.53],
|
|
|
pixel_std: List[float] = [58.395, 57.12, 57.375],
|
|
|
) -> None:
|
|
|
"""
|
|
|
SAM predicts object masks from an image and input prompts.
|
|
|
|
|
|
Arguments:
|
|
|
image_encoder (ImageEncoderViT): The backbone used to encode the
|
|
|
image into image embeddings that allow for efficient mask prediction.
|
|
|
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
|
|
mask_decoder (MaskDecoder): Predicts masks from the image embeddings
|
|
|
and encoded prompts.
|
|
|
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
|
|
|
pixel_std (list(float)): Std values for normalizing pixels in the input image.
|
|
|
"""
|
|
|
super().__init__()
|
|
|
self.image_encoder = image_encoder
|
|
|
self.prompt_encoder = prompt_encoder
|
|
|
self.mask_decoder = mask_decoder
|
|
|
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
|
|
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
|
|
|
|
|
@property
|
|
|
def device(self) -> Any:
|
|
|
return self.pixel_mean.device
|
|
|
|
|
|
def forward(self, batched_input: Dict[str, Any], multimask_output: bool) -> List[Dict[str, torch.Tensor]]:
|
|
|
|
|
|
input_images = batched_input.get("image")
|
|
|
image_embeddings = self.image_encoder(input_images)
|
|
|
|
|
|
if "point_coords" in batched_input and batched_input["point_coords"] != None:
|
|
|
points = (batched_input["point_coords"], batched_input["point_labels"])
|
|
|
else:
|
|
|
points = None
|
|
|
|
|
|
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
|
|
points=points,
|
|
|
boxes=batched_input.get("boxes", None),
|
|
|
masks=batched_input.get("mask_inputs", None),
|
|
|
)
|
|
|
|
|
|
low_res_masks, iou_predictions = self.mask_decoder(
|
|
|
image_embeddings=image_embeddings,
|
|
|
image_pe=self.prompt_encoder.get_dense_pe(),
|
|
|
sparse_prompt_embeddings=sparse_embeddings,
|
|
|
dense_prompt_embeddings=dense_embeddings,
|
|
|
multimask_output=multimask_output,
|
|
|
)
|
|
|
|
|
|
masks = self.postprocess_masks(
|
|
|
low_res_masks,
|
|
|
input_size=batched_input["image"].shape[-2:],
|
|
|
original_size=batched_input["original_size"],
|
|
|
)
|
|
|
|
|
|
outputs = {
|
|
|
"masks": masks,
|
|
|
"iou_predictions": iou_predictions,
|
|
|
"low_res_logits": low_res_masks,
|
|
|
}
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
def postprocess_masks(self,masks: torch.Tensor, input_size: Tuple[int, ...],original_size: Tuple[int, ...],) -> torch.Tensor:
|
|
|
masks = F.interpolate(
|
|
|
masks,
|
|
|
(self.image_encoder.img_size, self.image_encoder.img_size), mode="bilinear", align_corners=False,)
|
|
|
|
|
|
masks = masks[..., : input_size[0], : input_size[1]]
|
|
|
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
|
|
|
return masks
|
|
|
|
|
|
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
"""Normalize pixel values and pad to a square input."""
|
|
|
|
|
|
x = (x - self.pixel_mean) / self.pixel_std
|
|
|
|
|
|
h, w = x.shape[-2:]
|
|
|
padh = self.image_encoder.img_size - h
|
|
|
padw = self.image_encoder.img_size - w
|
|
|
x = F.pad(x, (0, padw, 0, padh))
|
|
|
return x
|
|
|
|