| import copy |
| import os |
| from typing import Literal, Tuple, List, Optional |
|
|
| import torch |
| from mmcv.cnn import ConvModule |
| from mmdet.structures.bbox import bbox2roi |
| from mmdet.structures.mask import mask2bbox |
| from torch import nn |
| import torch.nn.functional as F |
| from mmengine import MMLogger |
| from mmengine.model import BaseModule |
| from mmdet.registry import MODELS |
|
|
| from ext.sam import MaskDecoder |
| from ext.sam.mask_decoder import MLP as SAMMLP |
| from ext.meta.sam_meta import meta_dict, checkpoint_dict |
| from utils.load_checkpoint import load_checkpoint_with_prefix |
|
|
|
|
| @MODELS.register_module() |
| class OVSAMHead(BaseModule): |
|
|
| def __init__( |
| self, |
| model_name: Literal['vit_h', 'vit_l', 'vit_b'] = 'vit_h', |
| with_label_token: bool = False, |
| ov_classifier_name: Optional[str] = None, |
| logit: Optional[float] = None, |
| roi_extractor=None, |
| fix: bool = True, |
| init_cfg=None, |
| cur_mask=1, |
| roi_extractor_single=None, |
| load_roi_conv=None, |
| gen_box=False, |
| ): |
| assert init_cfg is not None and \ |
| init_cfg['type'] in ['sam_pretrain', 'Pretrained'], f"{init_cfg['type']} is not supported." |
| pretrained = init_cfg['checkpoint'] |
| super().__init__(init_cfg=None) |
| self.init_cfg = init_cfg |
| self.logger = MMLogger.get_current_instance() |
| if roi_extractor_single is not None: |
| self.roi_extractor_single = MODELS.build(roi_extractor_single) |
| self.roi_merge_proj = nn.Linear(768 * 2, 768) |
| else: |
| self.roi_extractor_single = None |
| self.roi_merge_proj = None |
|
|
| mask_decoder = MaskDecoder( |
| num_multimask_outputs=cur_mask - 1, |
| transformer_dim=meta_dict[model_name]['prompt_embed_dim'], |
| iou_head_depth=3, |
| iou_head_hidden_dim=256, |
| with_iou=False |
| ) |
|
|
| if self.init_cfg['type'] == 'sam_pretrain': |
| raise NotImplementedError |
|
|
| self.mask_decoder = mask_decoder |
|
|
| self.with_label_token = with_label_token |
|
|
| if self.with_label_token: |
| ov_path = os.path.join(os.path.expanduser('./models/'), f"{ov_classifier_name}.pth") |
| cls_embed = torch.load(ov_path) |
| cls_embed_norm = cls_embed.norm(p=2, dim=-1) |
| assert torch.allclose(cls_embed_norm, torch.ones_like(cls_embed_norm)) |
|
|
| _dim = cls_embed.size(2) |
| _prototypes = cls_embed.size(1) |
| back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cpu') |
| cls_embed = torch.cat([ |
| cls_embed, back_token.repeat(_prototypes, 1)[None] |
| ], dim=0) |
| self.register_buffer('cls_embed', cls_embed.permute(2, 0, 1).contiguous(), persistent=False) |
|
|
| if logit is None: |
| logit_scale = torch.tensor(4.6052, dtype=torch.float32) |
| else: |
| logit_scale = torch.tensor(logit, dtype=torch.float32) |
| self.register_buffer('logit_scale', logit_scale, persistent=False) |
|
|
| transformer_dim = self.mask_decoder.mask_tokens.weight.shape[1] |
| self.label_token = nn.Embedding(1, transformer_dim) |
| self.label_mlp = SAMMLP(transformer_dim, transformer_dim, _dim, 3) |
|
|
| self.gen_box = gen_box |
|
|
| if roi_extractor is not None: |
| self.roi = MODELS.build(roi_extractor) |
| self.roi_conv = nn.Sequential(*[ |
| ConvModule(in_channels=self.roi.out_channels, out_channels=_dim, kernel_size=1, bias=False) |
| ]) |
| else: |
| self.roi = None |
|
|
| if self.init_cfg['type'] == 'Pretrained': |
| checkpoint_path = pretrained |
| state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=self.init_cfg['prefix']) |
| self.load_state_dict(state_dict, strict=True) |
| if roi_extractor is not None and load_roi_conv is not None: |
| checkpoint_path = load_roi_conv['checkpoint'] |
| state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=load_roi_conv['prefix']) |
| self.roi_conv.load_state_dict(state_dict, strict=True) |
|
|
| self.fix = fix |
|
|
| if self.fix: |
| self.train(mode=False) |
| for name, param in self.named_parameters(): |
| param.requires_grad = False |
|
|
| def init_weights(self): |
| self.logger.info(f"Init Config for {self.__class__.__name__}") |
| self.logger.info(self.init_cfg) |
|
|
| def forward_logit(self, cls_embd): |
| cls_pred = torch.einsum('bnc,ckp->bnkp', F.normalize(cls_embd, dim=-1), self.cls_embed) |
| cls_pred = cls_pred.max(-1).values |
| cls_pred = self.logit_scale.exp() * cls_pred |
| return cls_pred |
|
|
| def predict_masks( |
| self, |
| image_embeddings: torch.Tensor, |
| image_pe: torch.Tensor, |
| sparse_prompt_embeddings: torch.Tensor, |
| dense_prompt_embeddings: torch.Tensor, |
| fpn_feats: List[torch.Tensor], |
| roi_list: Optional[List[torch.Tensor]], |
| backbone_feature: torch.Tensor, |
| backbone=None |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Predicts masks. See 'forward' for more details.""" |
| num_instances = int(sparse_prompt_embeddings.size(0)) |
| |
| output_tokens = torch.cat([ |
| self.label_token.weight, |
| self.mask_decoder.mask_tokens.weight], dim=0 |
| ) |
| output_tokens = output_tokens.unsqueeze(0).expand(num_instances, -1, -1) |
| queries = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) |
|
|
| |
| image_embeddings = image_embeddings + dense_prompt_embeddings |
| pos_img = torch.repeat_interleave(image_pe, num_instances, dim=0) |
| b, c, h, w = image_embeddings.shape |
|
|
| |
| queries, mask_feats = self.mask_decoder.transformer(image_embeddings, pos_img, queries) |
| label_query = queries[:, 0, :] |
| mask_embeds = queries[:, 1:(1 + self.mask_decoder.num_mask_tokens), :] |
|
|
| |
| mask_feats = mask_feats.transpose(1, 2).view(b, c, h, w) |
| mask_feats = self.mask_decoder.output_upscaling(mask_feats) |
| mask_queries_list: List[torch.Tensor] = [] |
| for i in range(self.mask_decoder.num_mask_tokens): |
| mask_queries_list.append(self.mask_decoder.output_hypernetworks_mlps[i](mask_embeds[:, i, :])) |
| mask_queries = torch.stack(mask_queries_list, dim=1) |
| b, c, h, w = mask_feats.shape |
| masks = (mask_queries @ mask_feats.view(b, c, h * w)).view(b, -1, h, w) |
|
|
| |
| if self.with_label_token: |
| cls_embed_list = [] |
| assert self.mask_decoder.num_mask_tokens == 1 |
| for i in range(self.mask_decoder.num_mask_tokens): |
| cls_embed_list.append(self.label_mlp(label_query)) |
| cls_embed = torch.stack(cls_embed_list, dim=1) |
| if self.gen_box: |
| bboxes = mask2bbox(masks.sigmoid()[:, 0] > 0.5) * 4 |
| roi_list = bbox2roi([bboxes]) |
| roi_feats = self.roi(fpn_feats, roi_list) |
| roi_feats = self.roi_conv(roi_feats) |
| roi_feats = roi_feats.mean(dim=-1).mean(dim=-1) |
| if self.roi_extractor_single: |
| roi_feats_clip = self.roi_extractor_single( |
| backbone.get_clip_feature(backbone_feature[-1:]), roi_list |
| ) |
| roi_feats_clip = backbone.forward_feat(roi_feats_clip) |
| roi_feats = self.roi_merge_proj(torch.cat([roi_feats, roi_feats_clip], dim=-1)) |
| roi_feats = roi_feats[:, None] + 0 * cls_embed |
| cls_pred = self.forward_logit(roi_feats) |
| else: |
| cls_pred = None |
| return masks, None, cls_pred |
|
|
| def forward( |
| self, |
| image_embeddings: torch.Tensor, |
| image_pe: torch.Tensor, |
| sparse_prompt_embeddings: torch.Tensor, |
| dense_prompt_embeddings: torch.Tensor, |
| multi_mask_output: bool, |
| data_samples=None, |
| fpn_feats=None, |
| backbone_feats=None, |
| backbone=None, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: |
| num_prompts = len(sparse_prompt_embeddings) |
| image_embeddings = torch.repeat_interleave(image_embeddings, num_prompts, dim=0) |
|
|
| masks, _, cls_pred = self.predict_masks( |
| image_embeddings=image_embeddings, |
| image_pe=image_pe, |
| sparse_prompt_embeddings=sparse_prompt_embeddings, |
| dense_prompt_embeddings=dense_prompt_embeddings, |
| fpn_feats=fpn_feats, |
| roi_list=None, |
| backbone_feature=backbone_feats, |
| backbone=backbone, |
| ) |
|
|
| |
| if multi_mask_output: |
| mask_slice = slice(1, None) |
| else: |
| mask_slice = slice(0, 1) |
| masks = masks[:, mask_slice, :, :] |
|
|
| |
| return masks, None, cls_pred |
|
|