# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch from torch import nn from torch.nn import functional as F from typing import Any, Dict, List, Tuple from .image_encoder import ImageEncoderViT from .mask_decoder import MaskDecoder from .prompt_encoder import PromptEncoder class Sam(nn.Module): mask_threshold: float = 0.0 image_format: str = "RGB" def __init__( self, image_encoder: ImageEncoderViT, prompt_encoder: PromptEncoder, mask_decoder: MaskDecoder, pixel_mean: List[float] = [123.675, 116.28, 103.53], pixel_std: List[float] = [58.395, 57.12, 57.375], ) -> None: """ SAM predicts object masks from an image and input prompts. Arguments: image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for efficient mask prediction. prompt_encoder (PromptEncoder): Encodes various types of input prompts. mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts. pixel_mean (list(float)): Mean values for normalizing pixels in the input image. pixel_std (list(float)): Std values for normalizing pixels in the input image. """ super().__init__() self.image_encoder = image_encoder self.prompt_encoder = prompt_encoder self.mask_decoder = mask_decoder self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) @property def device(self) -> Any: return self.pixel_mean.device def forward(self, batched_input: Dict[str, Any], multimask_output: bool) -> List[Dict[str, torch.Tensor]]: input_images = batched_input.get("image") image_embeddings = self.image_encoder(input_images) if "point_coords" in batched_input and batched_input["point_coords"] != None: points = (batched_input["point_coords"], batched_input["point_labels"]) else: points = None sparse_embeddings, dense_embeddings = self.prompt_encoder( points=points, boxes=batched_input.get("boxes", None), masks=batched_input.get("mask_inputs", None), ) # sparse_embeddings:[2, 3, 256], dense_embeddings:[2, 256, 64, 64] low_res_masks, iou_predictions = self.mask_decoder( image_embeddings=image_embeddings, image_pe=self.prompt_encoder.get_dense_pe(), # 1x(256)x(64)x(64) sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, ) masks = self.postprocess_masks( low_res_masks, input_size=batched_input["image"].shape[-2:], original_size=batched_input["original_size"], ) outputs = { "masks": masks, "iou_predictions": iou_predictions, "low_res_logits": low_res_masks, } return outputs def postprocess_masks(self,masks: torch.Tensor, input_size: Tuple[int, ...],original_size: Tuple[int, ...],) -> torch.Tensor: masks = F.interpolate( masks, (self.image_encoder.img_size, self.image_encoder.img_size), mode="bilinear", align_corners=False,) #[1,1024,1024] masks = masks[..., : input_size[0], : input_size[1]] masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) return masks def preprocess(self, x: torch.Tensor) -> torch.Tensor: """Normalize pixel values and pad to a square input.""" # Normalize colors x = (x - self.pixel_mean) / self.pixel_std # Pad h, w = x.shape[-2:] padh = self.image_encoder.img_size - h padw = self.image_encoder.img_size - w x = F.pad(x, (0, padw, 0, padh)) return x