Spaces:
Runtime error
Runtime error
| from typing import Tuple, Literal | |
| import torch | |
| from mmengine import MMLogger | |
| from mmdet.registry import MODELS | |
| from mmengine.model import BaseModule | |
| from mmengine.structures import InstanceData | |
| from ext.sam import PromptEncoder | |
| from ext.meta.sam_meta import meta_dict, checkpoint_dict | |
| from utils.load_checkpoint import load_checkpoint_with_prefix | |
| class SAMPromptEncoder(BaseModule): | |
| def __init__( | |
| self, | |
| model_name: Literal['vit_h', 'vit_l', 'vit_b'] = 'vit_h', | |
| fix: bool = True, | |
| init_cfg=None, | |
| ): | |
| assert init_cfg is not None and init_cfg['type'] == 'sam_pretrain', f"{init_cfg['type']} is not supported." | |
| pretrained = init_cfg['checkpoint'] | |
| super().__init__(init_cfg=None) | |
| self.init_cfg = init_cfg | |
| self.logger = MMLogger.get_current_instance() | |
| backbone_meta = meta_dict[model_name] | |
| checkpoint_path = checkpoint_dict[pretrained] | |
| prompt_encoder = PromptEncoder( | |
| embed_dim=256, | |
| image_embedding_size=(backbone_meta['image_embedding_size'], backbone_meta['image_embedding_size']), | |
| input_image_size=(backbone_meta['image_size'], backbone_meta['image_size']), | |
| mask_in_chans=16, | |
| ) | |
| state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix='prompt_encoder') | |
| prompt_encoder.load_state_dict(state_dict, strict=True) | |
| # meta | |
| self.embed_dim = prompt_encoder.embed_dim | |
| self.input_image_size = prompt_encoder.input_image_size | |
| self.image_embedding_size = prompt_encoder.image_embedding_size | |
| self.num_point_embeddings = 4 | |
| self.mask_input_size = prompt_encoder.mask_input_size | |
| # positional encoding | |
| self.pe_layer = prompt_encoder.pe_layer | |
| # mask encoding | |
| self.mask_downscaling = prompt_encoder.mask_downscaling | |
| self.no_mask_embed = prompt_encoder.no_mask_embed | |
| # point encoding | |
| self.point_embeddings = prompt_encoder.point_embeddings | |
| self.not_a_point_embed = prompt_encoder.not_a_point_embed | |
| self.fix = fix | |
| if self.fix: | |
| self.train(mode=False) | |
| for name, param in self.named_parameters(): | |
| param.requires_grad = False | |
| def device(self): | |
| return self.no_mask_embed.weight.device | |
| def init_weights(self): | |
| self.logger.info(f"Init Config for {self.__class__.__name__}") | |
| self.logger.info(self.init_cfg) | |
| def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module: | |
| if not isinstance(mode, bool): | |
| raise ValueError("training mode is expected to be boolean") | |
| if self.fix: | |
| super().train(mode=False) | |
| else: | |
| super().train(mode=mode) | |
| return self | |
| def _embed_boxes(self, bboxes: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor: | |
| """Embeds box prompts.""" | |
| bboxes = bboxes + 0.5 # Shift to center of pixel | |
| coords = bboxes.reshape(-1, 2, 2) | |
| corner_embedding = self.pe_layer.forward_with_coords(coords, image_size) | |
| corner_embedding[:, 0, :] += self.point_embeddings[2].weight | |
| corner_embedding[:, 1, :] += self.point_embeddings[3].weight | |
| return corner_embedding | |
| def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: | |
| """Embeds mask inputs.""" | |
| mask_embedding = self.mask_downscaling(masks) | |
| return mask_embedding | |
| def get_dense_pe(self) -> torch.Tensor: | |
| return self.pe_layer(self.image_embedding_size).unsqueeze(0) | |
| def _embed_points( | |
| self, | |
| points: torch.Tensor, | |
| labels: torch.Tensor, | |
| pad: bool, | |
| ) -> torch.Tensor: | |
| """Embeds point prompts.""" | |
| points = points + 0.5 # Shift to center of pixel | |
| if pad: | |
| padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) | |
| padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) | |
| points = torch.cat([points, padding_point], dim=1) | |
| labels = torch.cat([labels, padding_label], dim=1) | |
| point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) | |
| point_embedding[labels == -1] = 0.0 | |
| point_embedding[labels == -1] += self.not_a_point_embed.weight | |
| point_embedding[labels == 0] += self.point_embeddings[0].weight | |
| point_embedding[labels == 1] += self.point_embeddings[1].weight | |
| return point_embedding | |
| def forward( | |
| self, | |
| instances: InstanceData, | |
| image_size: Tuple[int, int], | |
| with_points: bool = False, | |
| with_bboxes: bool = False, | |
| with_masks: bool = False, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| assert with_points or with_bboxes or with_masks | |
| bs = len(instances) | |
| sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self.device) | |
| if with_points: | |
| assert 'point_coords' in instances | |
| coords = instances.point_coords | |
| labels = torch.ones_like(coords)[:, :, 0] | |
| point_embeddings = self._embed_points(coords, labels, pad=not with_bboxes) | |
| sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) | |
| if with_bboxes: | |
| assert 'bboxes' in instances | |
| box_embeddings = self._embed_boxes( | |
| instances.bboxes, image_size=image_size | |
| ) | |
| sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) | |
| if with_masks: | |
| assert 'masks' in instances | |
| dense_embeddings = self._embed_masks(instances.masks.masks) | |
| else: | |
| dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( | |
| bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] | |
| ) | |
| return sparse_embeddings, dense_embeddings | |