Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import Tuple, Type | |
| import torch | |
| from torch import nn | |
| from .position_encoding import PositionEmbeddingRandom | |
| class PromptEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| image_embedding_size: Tuple[int, int], | |
| input_image_size: Tuple[int, int], | |
| mask_in_chans: int, | |
| activation: Type[nn.Module] = nn.GELU, | |
| ) -> None: | |
| """ | |
| Encodes prompts for input to SAM's mask decoder. | |
| Arguments: | |
| embed_dim (int): The prompts' embedding dimension | |
| image_embedding_size (tuple(int, int)): The spatial size of the | |
| image embedding, as (H, W). | |
| input_image_size (int): The padded size of the image as input | |
| to the image encoder, as (H, W). | |
| mask_in_chans (int): The number of hidden channels used for | |
| encoding input masks. | |
| activation (nn.Module): The activation to use when encoding | |
| input masks. | |
| """ | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.input_image_size = input_image_size | |
| self.image_embedding_size = image_embedding_size | |
| self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) | |
| def get_dense_pe(self) -> torch.Tensor: | |
| """ | |
| Returns the positional encoding used to encode point prompts, | |
| applied to a dense set of points the shape of the image encoding. | |
| Returns: | |
| torch.Tensor: Positional encoding with shape | |
| 1x(embed_dim)x(embedding_h)x(embedding_w) | |
| """ | |
| return self.pe_layer(self.image_embedding_size).unsqueeze(0) | |