Spaces:
Runtime error
Runtime error
| import torch.nn.functional as F | |
| from mmengine.model import BaseModel | |
| from mmdet.registry import MODELS | |
| from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig | |
| class SAMSegmentor(BaseModel): | |
| MASK_THRESHOLD = 0.5 | |
| def __init__( | |
| self, | |
| backbone: ConfigType, | |
| neck: ConfigType, | |
| prompt_encoder: ConfigType, | |
| mask_decoder: ConfigType, | |
| data_preprocessor: OptConfigType = None, | |
| fpn_neck: OptConfigType = None, | |
| init_cfg: OptMultiConfig = None, | |
| use_clip_feat: bool = False, | |
| use_head_feat: bool = False, | |
| use_gt_prompt: bool = False, | |
| use_point: bool = False, | |
| enable_backbone: bool = False, | |
| ) -> None: | |
| super().__init__(data_preprocessor=data_preprocessor, init_cfg=init_cfg) | |
| self.backbone = MODELS.build(backbone) | |
| self.neck = MODELS.build(neck) | |
| self.pe = MODELS.build(prompt_encoder) | |
| self.mask_decoder = MODELS.build(mask_decoder) | |
| if fpn_neck is not None: | |
| self.fpn_neck = MODELS.build(fpn_neck) | |
| else: | |
| self.fpn_neck = None | |
| self.use_clip_feat = use_clip_feat | |
| self.use_head_feat = use_head_feat | |
| self.use_gt_prompt = use_gt_prompt | |
| self.use_point = use_point | |
| self.enable_backbone = enable_backbone | |
| def extract_feat(self, inputs): | |
| backbone_feat = self.backbone(inputs) | |
| neck_feat = self.neck(backbone_feat) | |
| if self.fpn_neck is not None: | |
| fpn_feat = self.fpn_neck(backbone_feat) | |
| else: | |
| fpn_feat = None | |
| return dict( | |
| backbone_feat=backbone_feat, | |
| neck_feat=neck_feat, | |
| fpn_feat=fpn_feat | |
| ) | |
| def extract_masks(self, feat_cache, prompts): | |
| sparse_embed, dense_embed = self.pe( | |
| prompts, | |
| image_size=(1024, 1024), | |
| with_points='point_coords' in prompts, | |
| with_bboxes='bboxes' in prompts, | |
| ) | |
| kwargs = dict() | |
| if self.enable_backbone: | |
| kwargs['backbone_feats'] = feat_cache['backbone_feat'] | |
| kwargs['backbone'] = self.backbone | |
| kwargs['fpn_feats'] = feat_cache['fpn_feat'] | |
| low_res_masks, iou_predictions, cls_pred = self.mask_decoder( | |
| image_embeddings=feat_cache['neck_feat'], | |
| image_pe=self.pe.get_dense_pe(), | |
| sparse_prompt_embeddings=sparse_embed, | |
| dense_prompt_embeddings=dense_embed, | |
| multi_mask_output=False, | |
| **kwargs | |
| ) | |
| masks = F.interpolate( | |
| low_res_masks, | |
| scale_factor=4., | |
| mode='bilinear', | |
| align_corners=False, | |
| ) | |
| masks = masks.sigmoid() | |
| cls_pred = cls_pred.softmax(-1)[..., :-1] | |
| return masks.detach().cpu().numpy(), cls_pred.detach().cpu() | |
| def forward(self, inputs): | |
| return inputs | |