see-through-demo / common /modules /semanticsam.py
ljsabc's picture
HF Space: see-through layer decomposition demo
e4338d2
from typing import List, Tuple
from torch import nn
import torch
from einops import rearrange
from .sam.modeling import Sam
from .sam.modeling.mask_decoder import MLP
from .sam import sam_model_registry
from .extend_sam import BaseExtendSam, BaseMaskDecoderAdapter, MaskDecoder
class SemMaskDecoderAdapter(BaseMaskDecoderAdapter):
def __init__(self, sam_mask_decoder: MaskDecoder, fix=False, class_num=20, init_from_sam=True):
super(SemMaskDecoderAdapter, self).__init__(sam_mask_decoder, fix)
self.class_num = class_num
self.is_hq = self.sam_mask_decoder.is_hq
self.num_mask_tokens = self.sam_mask_decoder.num_mask_tokens
transformer_dim = self.sam_mask_decoder.transformer_dim
self.output_hypernetworks_mlps = nn.ModuleList(
[
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
for _ in range(self.class_num)
]
)
if init_from_sam:
target_sd = self.sam_mask_decoder.output_hypernetworks_mlps[1].state_dict()
for ii in range(class_num):
self.output_hypernetworks_mlps[ii].load_state_dict(target_sd)
del self.sam_mask_decoder.output_hypernetworks_mlps
if self.is_hq:
self.hf_mlps = nn.ModuleList(
[
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
for _ in range(self.class_num)
]
)
if init_from_sam:
target_sd = self.sam_mask_decoder.hf_mlp.state_dict()
for ii in range(class_num):
self.hf_mlps[ii].load_state_dict(target_sd)
del self.sam_mask_decoder.hf_mlp
# input cond tokens: cat[1 x iou tokens, 4 x original mask tokens, 1 x hf token]
# num_mask_tokens: 4 + 1
self.hf_token_idx = self.num_mask_tokens
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool = True,
hq_token_only: bool = False,
interm_embeddings: torch.Tensor = None,
mask_scale=1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Arguments:
image_embeddings (torch.Tensor): the embeddings from the image encoder
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
multimask_output (bool): Whether to return multiple masks or a single
mask.
Returns:
torch.Tensor: batched predicted masks
torch.Tensor: batched predictions of mask quality
"""
hq_features = None
# token processing
if self.is_hq:
vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT
hq_features = self.sam_mask_decoder.embedding_encoder(image_embeddings) + self.sam_mask_decoder.compress_vit_feat(vit_features)
output_tokens = [self.sam_mask_decoder.iou_token.weight, self.sam_mask_decoder.mask_tokens.weight, self.sam_mask_decoder.hf_token.weight]
else:
output_tokens = [self.sam_mask_decoder.iou_token.weight, self.sam_mask_decoder.mask_tokens.weight]
output_tokens = torch.cat(output_tokens, dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# tokens: (batch size, model preserved tokens (iou*1, mask*4, hf token * 1) + user prompts, token dim)
# Expand per-image data in batch direction to be per-mask. multiple user prompts for the same image are divide along batch channel
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
src = src.to(dtype=pos_src.dtype)
tokens = tokens.to(dtype=pos_src.dtype)
b, c, h, w = src.shape
# Run the transformer
hs, src = self.sam_mask_decoder.transformer(src, pos_src, tokens)
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
# Decode tokens, mask tokens -> iou preds, src tokens (input image tokens) to masks
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding_sam = self.sam_mask_decoder.output_upscaling(src)
hyper_in_list: List[torch.Tensor] = []
hyper_hq_list: List[torch.Tensor] = []
for i in range(self.class_num):
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, mask_scale, :]))
if self.is_hq:
hyper_hq_list.append(self.hf_mlps[i](hs[:, self.hf_token_idx, :]))
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding_sam.shape
masks_sam = (hyper_in @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w)
iou_pred = self.sam_mask_decoder.iou_prediction_head(iou_token_out)
if self.is_hq:
hyper_hq = torch.stack(hyper_hq_list, dim=1)
upscaled_embedding_hq = self.sam_mask_decoder.embedding_maskfeature(upscaled_embedding_sam) + hq_features
masks_hq = (hyper_hq @ upscaled_embedding_hq.view(b, c, h * w)).view(b, -1, h, w)
if hq_token_only:
masks = masks_hq
else:
masks = masks_sam + masks_hq
else:
masks = masks_sam
# Generate mask quality predictions
# Prepare output
return masks, iou_pred
class SemMaskDecoderAdapterTokenVariant(BaseMaskDecoderAdapter):
def __init__(self, sam_mask_decoder: MaskDecoder, fix=False, class_num=20, init_from_sam=True):
super(SemMaskDecoderAdapterTokenVariant, self).__init__(sam_mask_decoder, fix)
self.class_num = class_num
self.is_hq = self.sam_mask_decoder.is_hq
# self.num_mask_tokens = self.sam_mask_decoder.num_mask_tokens
transformer_dim = self.sam_mask_decoder.transformer_dim
self.sem_mask_tokens = nn.Embedding(class_num, transformer_dim)
self.output_hypernetworks_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
if init_from_sam:
target_sd = self.sam_mask_decoder.output_hypernetworks_mlps[1].state_dict()
self.output_hypernetworks_mlp.load_state_dict(target_sd)
target_sd = self.sam_mask_decoder.mask_tokens.state_dict()
target_sd = {'weight': target_sd['weight'][[1]].repeat(class_num, 1)}
self.sem_mask_tokens.load_state_dict(target_sd)
pass
del self.sam_mask_decoder.mask_tokens
del self.sam_mask_decoder.output_hypernetworks_mlps
if self.is_hq:
self.hq_mask_tokens = nn.Embedding(class_num, transformer_dim)
self.hf_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
if init_from_sam:
target_sd = self.sam_mask_decoder.hf_mlp.state_dict()
self.hf_mlp.load_state_dict(target_sd)
target_sd = self.sam_mask_decoder.hf_token.state_dict()
target_sd = {'weight': target_sd['weight'].repeat(class_num, 1)}
self.hq_mask_tokens.load_state_dict(target_sd)
del self.sam_mask_decoder.hf_mlp
del self.sam_mask_decoder.hf_token
# input cond tokens: cat[1 x iou tokens, 4 x original mask tokens, 1 x hf token]
# num_mask_tokens: 4 + 1
# self.hf_token_idx = self.num_mask_tokens
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool = True,
hq_token_only: bool = False,
interm_embeddings: torch.Tensor = None,
mask_scale=1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Arguments:
image_embeddings (torch.Tensor): the embeddings from the image encoder
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
multimask_output (bool): Whether to return multiple masks or a single
mask.
Returns:
torch.Tensor: batched predicted masks
torch.Tensor: batched predictions of mask quality
"""
hq_features = None
# token processing
if self.is_hq:
vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT
hq_features = self.sam_mask_decoder.embedding_encoder(image_embeddings) + self.sam_mask_decoder.compress_vit_feat(vit_features)
output_tokens = [self.sam_mask_decoder.iou_token.weight, self.sem_mask_tokens.weight, self.hq_mask_tokens.weight]
else:
output_tokens = [self.sam_mask_decoder.iou_token.weight, self.sem_mask_tokens.weight]
output_tokens = torch.cat(output_tokens, dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# tokens: (batch size, model preserved tokens (iou*1, mask*4, hf token * 1) + user prompts, token dim)
# Expand per-image data in batch direction to be per-mask. multiple user prompts for the same image are divide along batch channel
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
src = src.to(dtype=pos_src.dtype)
tokens = tokens.to(dtype=pos_src.dtype)
b, c, h, w = src.shape
# Run the transformer
hs, src = self.sam_mask_decoder.transformer(src, pos_src, tokens)
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1 : (1 + self.class_num), :]
# Decode tokens, mask tokens -> iou preds, src tokens (input image tokens) to masks
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding_sam = self.sam_mask_decoder.output_upscaling(src)
hyper_in = self.output_hypernetworks_mlp(rearrange(mask_tokens_out, 'b c d -> (b c) d'))
hyper_in = rearrange(hyper_in, '(b c) d -> b c d', b=b)
# for i in range(self.class_num):
# hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, mask_scale, :]))
if self.is_hq:
# hyper_hq_list.append(self.hf_mlps[i](hs[:, self.hf_token_idx, :]))
hyper_hq = self.hf_mlp(rearrange(hs[:, 1 + self.class_num: (1 + 2 * self.class_num), :], 'b c d -> (b c) d'))
hyper_hq = rearrange(hyper_hq, '(b c) d -> b c d', b=b)
# hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding_sam.shape
masks_sam = (hyper_in @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w)
iou_pred = self.sam_mask_decoder.iou_prediction_head(iou_token_out)
if self.is_hq:
# hyper_hq = torch.stack(hyper_hq_list, dim=1)
upscaled_embedding_hq = self.sam_mask_decoder.embedding_maskfeature(upscaled_embedding_sam) + hq_features
masks_hq = (hyper_hq @ upscaled_embedding_hq.view(b, c, h * w)).view(b, -1, h, w)
if hq_token_only:
masks = masks_hq
else:
masks = masks_sam + masks_hq
else:
masks = masks_sam
# Generate mask quality predictions
# Prepare output
return masks, iou_pred
class SemanticSam(BaseExtendSam):
def __init__(self,
class_num,
sam: Sam = None,
fix_img_en=False,
fix_prompt_en=False,
fix_mask_de=False,
model_type: str = 'h_hq',
mask_decoder='mlp_variant',
**kwargs):
init_from_sam = sam is not None
if sam is None:
build_sam = sam_model_registry[model_type]['build']
sam = build_sam()
super().__init__(sam=sam, fix_img_en=fix_img_en, fix_mask_de=fix_mask_de, fix_prompt_en=fix_prompt_en)
sam_mask_decoder = self.mask_adapter.sam_mask_decoder
del self.mask_adapter
if mask_decoder == 'mlp_variant':
self.mask_adapter = SemMaskDecoderAdapter(sam_mask_decoder=sam_mask_decoder, fix=fix_mask_de, class_num=class_num, init_from_sam=init_from_sam)
elif mask_decoder == 'token_variant':
self.mask_adapter = SemMaskDecoderAdapterTokenVariant(sam_mask_decoder=sam_mask_decoder, fix=fix_mask_de, class_num=class_num, init_from_sam=init_from_sam)
else:
raise Exception(f'Invalid mask decoder: {mask_decoder}')