from typing import List, Tuple from torch import nn import torch from einops import rearrange from .sam.modeling import Sam from .sam.modeling.mask_decoder import MLP from .sam import sam_model_registry from .extend_sam import BaseExtendSam, BaseMaskDecoderAdapter, MaskDecoder class SemMaskDecoderAdapter(BaseMaskDecoderAdapter): def __init__(self, sam_mask_decoder: MaskDecoder, fix=False, class_num=20, init_from_sam=True): super(SemMaskDecoderAdapter, self).__init__(sam_mask_decoder, fix) self.class_num = class_num self.is_hq = self.sam_mask_decoder.is_hq self.num_mask_tokens = self.sam_mask_decoder.num_mask_tokens transformer_dim = self.sam_mask_decoder.transformer_dim self.output_hypernetworks_mlps = nn.ModuleList( [ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.class_num) ] ) if init_from_sam: target_sd = self.sam_mask_decoder.output_hypernetworks_mlps[1].state_dict() for ii in range(class_num): self.output_hypernetworks_mlps[ii].load_state_dict(target_sd) del self.sam_mask_decoder.output_hypernetworks_mlps if self.is_hq: self.hf_mlps = nn.ModuleList( [ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.class_num) ] ) if init_from_sam: target_sd = self.sam_mask_decoder.hf_mlp.state_dict() for ii in range(class_num): self.hf_mlps[ii].load_state_dict(target_sd) del self.sam_mask_decoder.hf_mlp # input cond tokens: cat[1 x iou tokens, 4 x original mask tokens, 1 x hf token] # num_mask_tokens: 4 + 1 self.hf_token_idx = self.num_mask_tokens def forward( self, image_embeddings: torch.Tensor, image_pe: torch.Tensor, sparse_prompt_embeddings: torch.Tensor, dense_prompt_embeddings: torch.Tensor, multimask_output: bool = True, hq_token_only: bool = False, interm_embeddings: torch.Tensor = None, mask_scale=1, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Predict masks given image and prompt embeddings. Arguments: image_embeddings (torch.Tensor): the embeddings from the image encoder image_pe (torch.Tensor): positional encoding with the shape of image_embeddings sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs multimask_output (bool): Whether to return multiple masks or a single mask. Returns: torch.Tensor: batched predicted masks torch.Tensor: batched predictions of mask quality """ hq_features = None # token processing if self.is_hq: vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT hq_features = self.sam_mask_decoder.embedding_encoder(image_embeddings) + self.sam_mask_decoder.compress_vit_feat(vit_features) output_tokens = [self.sam_mask_decoder.iou_token.weight, self.sam_mask_decoder.mask_tokens.weight, self.sam_mask_decoder.hf_token.weight] else: output_tokens = [self.sam_mask_decoder.iou_token.weight, self.sam_mask_decoder.mask_tokens.weight] output_tokens = torch.cat(output_tokens, dim=0) output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) # tokens: (batch size, model preserved tokens (iou*1, mask*4, hf token * 1) + user prompts, token dim) # Expand per-image data in batch direction to be per-mask. multiple user prompts for the same image are divide along batch channel src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) src = src + dense_prompt_embeddings pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) src = src.to(dtype=pos_src.dtype) tokens = tokens.to(dtype=pos_src.dtype) b, c, h, w = src.shape # Run the transformer hs, src = self.sam_mask_decoder.transformer(src, pos_src, tokens) iou_token_out = hs[:, 0, :] mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] # Decode tokens, mask tokens -> iou preds, src tokens (input image tokens) to masks # Upscale mask embeddings and predict masks using the mask tokens src = src.transpose(1, 2).view(b, c, h, w) upscaled_embedding_sam = self.sam_mask_decoder.output_upscaling(src) hyper_in_list: List[torch.Tensor] = [] hyper_hq_list: List[torch.Tensor] = [] for i in range(self.class_num): hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, mask_scale, :])) if self.is_hq: hyper_hq_list.append(self.hf_mlps[i](hs[:, self.hf_token_idx, :])) hyper_in = torch.stack(hyper_in_list, dim=1) b, c, h, w = upscaled_embedding_sam.shape masks_sam = (hyper_in @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w) iou_pred = self.sam_mask_decoder.iou_prediction_head(iou_token_out) if self.is_hq: hyper_hq = torch.stack(hyper_hq_list, dim=1) upscaled_embedding_hq = self.sam_mask_decoder.embedding_maskfeature(upscaled_embedding_sam) + hq_features masks_hq = (hyper_hq @ upscaled_embedding_hq.view(b, c, h * w)).view(b, -1, h, w) if hq_token_only: masks = masks_hq else: masks = masks_sam + masks_hq else: masks = masks_sam # Generate mask quality predictions # Prepare output return masks, iou_pred class SemMaskDecoderAdapterTokenVariant(BaseMaskDecoderAdapter): def __init__(self, sam_mask_decoder: MaskDecoder, fix=False, class_num=20, init_from_sam=True): super(SemMaskDecoderAdapterTokenVariant, self).__init__(sam_mask_decoder, fix) self.class_num = class_num self.is_hq = self.sam_mask_decoder.is_hq # self.num_mask_tokens = self.sam_mask_decoder.num_mask_tokens transformer_dim = self.sam_mask_decoder.transformer_dim self.sem_mask_tokens = nn.Embedding(class_num, transformer_dim) self.output_hypernetworks_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) if init_from_sam: target_sd = self.sam_mask_decoder.output_hypernetworks_mlps[1].state_dict() self.output_hypernetworks_mlp.load_state_dict(target_sd) target_sd = self.sam_mask_decoder.mask_tokens.state_dict() target_sd = {'weight': target_sd['weight'][[1]].repeat(class_num, 1)} self.sem_mask_tokens.load_state_dict(target_sd) pass del self.sam_mask_decoder.mask_tokens del self.sam_mask_decoder.output_hypernetworks_mlps if self.is_hq: self.hq_mask_tokens = nn.Embedding(class_num, transformer_dim) self.hf_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) if init_from_sam: target_sd = self.sam_mask_decoder.hf_mlp.state_dict() self.hf_mlp.load_state_dict(target_sd) target_sd = self.sam_mask_decoder.hf_token.state_dict() target_sd = {'weight': target_sd['weight'].repeat(class_num, 1)} self.hq_mask_tokens.load_state_dict(target_sd) del self.sam_mask_decoder.hf_mlp del self.sam_mask_decoder.hf_token # input cond tokens: cat[1 x iou tokens, 4 x original mask tokens, 1 x hf token] # num_mask_tokens: 4 + 1 # self.hf_token_idx = self.num_mask_tokens def forward( self, image_embeddings: torch.Tensor, image_pe: torch.Tensor, sparse_prompt_embeddings: torch.Tensor, dense_prompt_embeddings: torch.Tensor, multimask_output: bool = True, hq_token_only: bool = False, interm_embeddings: torch.Tensor = None, mask_scale=1, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Predict masks given image and prompt embeddings. Arguments: image_embeddings (torch.Tensor): the embeddings from the image encoder image_pe (torch.Tensor): positional encoding with the shape of image_embeddings sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs multimask_output (bool): Whether to return multiple masks or a single mask. Returns: torch.Tensor: batched predicted masks torch.Tensor: batched predictions of mask quality """ hq_features = None # token processing if self.is_hq: vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT hq_features = self.sam_mask_decoder.embedding_encoder(image_embeddings) + self.sam_mask_decoder.compress_vit_feat(vit_features) output_tokens = [self.sam_mask_decoder.iou_token.weight, self.sem_mask_tokens.weight, self.hq_mask_tokens.weight] else: output_tokens = [self.sam_mask_decoder.iou_token.weight, self.sem_mask_tokens.weight] output_tokens = torch.cat(output_tokens, dim=0) output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) # tokens: (batch size, model preserved tokens (iou*1, mask*4, hf token * 1) + user prompts, token dim) # Expand per-image data in batch direction to be per-mask. multiple user prompts for the same image are divide along batch channel src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) src = src + dense_prompt_embeddings pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) src = src.to(dtype=pos_src.dtype) tokens = tokens.to(dtype=pos_src.dtype) b, c, h, w = src.shape # Run the transformer hs, src = self.sam_mask_decoder.transformer(src, pos_src, tokens) iou_token_out = hs[:, 0, :] mask_tokens_out = hs[:, 1 : (1 + self.class_num), :] # Decode tokens, mask tokens -> iou preds, src tokens (input image tokens) to masks # Upscale mask embeddings and predict masks using the mask tokens src = src.transpose(1, 2).view(b, c, h, w) upscaled_embedding_sam = self.sam_mask_decoder.output_upscaling(src) hyper_in = self.output_hypernetworks_mlp(rearrange(mask_tokens_out, 'b c d -> (b c) d')) hyper_in = rearrange(hyper_in, '(b c) d -> b c d', b=b) # for i in range(self.class_num): # hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, mask_scale, :])) if self.is_hq: # hyper_hq_list.append(self.hf_mlps[i](hs[:, self.hf_token_idx, :])) hyper_hq = self.hf_mlp(rearrange(hs[:, 1 + self.class_num: (1 + 2 * self.class_num), :], 'b c d -> (b c) d')) hyper_hq = rearrange(hyper_hq, '(b c) d -> b c d', b=b) # hyper_in = torch.stack(hyper_in_list, dim=1) b, c, h, w = upscaled_embedding_sam.shape masks_sam = (hyper_in @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w) iou_pred = self.sam_mask_decoder.iou_prediction_head(iou_token_out) if self.is_hq: # hyper_hq = torch.stack(hyper_hq_list, dim=1) upscaled_embedding_hq = self.sam_mask_decoder.embedding_maskfeature(upscaled_embedding_sam) + hq_features masks_hq = (hyper_hq @ upscaled_embedding_hq.view(b, c, h * w)).view(b, -1, h, w) if hq_token_only: masks = masks_hq else: masks = masks_sam + masks_hq else: masks = masks_sam # Generate mask quality predictions # Prepare output return masks, iou_pred class SemanticSam(BaseExtendSam): def __init__(self, class_num, sam: Sam = None, fix_img_en=False, fix_prompt_en=False, fix_mask_de=False, model_type: str = 'h_hq', mask_decoder='mlp_variant', **kwargs): init_from_sam = sam is not None if sam is None: build_sam = sam_model_registry[model_type]['build'] sam = build_sam() super().__init__(sam=sam, fix_img_en=fix_img_en, fix_mask_de=fix_mask_de, fix_prompt_en=fix_prompt_en) sam_mask_decoder = self.mask_adapter.sam_mask_decoder del self.mask_adapter if mask_decoder == 'mlp_variant': self.mask_adapter = SemMaskDecoderAdapter(sam_mask_decoder=sam_mask_decoder, fix=fix_mask_de, class_num=class_num, init_from_sam=init_from_sam) elif mask_decoder == 'token_variant': self.mask_adapter = SemMaskDecoderAdapterTokenVariant(sam_mask_decoder=sam_mask_decoder, fix=fix_mask_de, class_num=class_num, init_from_sam=init_from_sam) else: raise Exception(f'Invalid mask decoder: {mask_decoder}')