GECO2-demo / models /sam_mask.py
jerpelhan's picture
Initial commit
6146368
import torch
import torch.nn.functional as F
from torch import nn
from sam2.sam2.modeling.sam.mask_decoder import MaskDecoder
from sam2.sam2.modeling.sam.prompt_encoder import PromptEncoder
from sam2.sam2.modeling.sam.transformer import TwoWayTransformer
class MaskProcessor(nn.Module):
def __init__(self, hidden_dim, image_size, reduction, **kwargs):
super().__init__()
self.sam_prompt_embed_dim = hidden_dim
self.reduction = reduction
self.sam_image_embedding_size = image_size // self.reduction
self.image_size = image_size
self.prompt_encoder_sam = PromptEncoder(
embed_dim=self.sam_prompt_embed_dim,
image_embedding_size=(
self.sam_image_embedding_size,
self.sam_image_embedding_size,
),
input_image_size=(self.image_size, self.image_size),
mask_in_chans=16,
)
self.mask_decoder = MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=self.sam_prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=self.sam_prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
use_high_res_features=True,
iou_prediction_use_sigmoid=True,
pred_obj_scores=True,
pred_obj_scores_mlp=True,
use_multimask_token_for_obj_ptr=True,
**({}),
)
self.num_feature_levels = 3
# Spatial dim for backbone feature maps
self._bb_feat_sizes = [
(256, 256),
(128, 128),
(64, 64),
]
self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, hidden_dim))
# TODO change loading, this is ugly
checkpoint = torch.hub.load_state_dict_from_url(
'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt',
map_location="cpu"
)['model']
state_dict = {k.replace("mask_decoder.", "").replace("sam_", ""): v for k, v in
checkpoint.items() if "mask_decoder" in k}
self.mask_decoder.load_state_dict(state_dict)
state_dict = {k.replace("prompt_encoder.", "").replace("sam_", ""): v for k, v in
checkpoint.items() if "prompt_encoder" in k}
self.prompt_encoder_sam.load_state_dict(state_dict)
state_dict = {k: v for k, v in checkpoint.items() if "no_mem_embed" in k}
self.load_state_dict(state_dict, strict=False)
def _prepare_backbone_features(self, backbone_out):
"""Prepare and flatten visual features."""
backbone_out = backbone_out.copy()
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels:]
vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels:]
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
# flatten NxCxHxW to HWxNxC
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
def forward_feats(self, feats: torch.Tensor):
"""Get the image feature on the input batch."""
# precompute projected level 0 and level 1 features in SAM decoder
# to avoid running it again on every SAM click
feats["backbone_fpn"][0] = self.mask_decoder.conv_s0(
feats["backbone_fpn"][0]
)
feats["backbone_fpn"][1] = self.mask_decoder.conv_s1(
feats["backbone_fpn"][1]
)
_, vision_feats, _, _ = self._prepare_backbone_features(feats)
vision_feats[-1] = vision_feats[-1] + self.no_mem_embed
bs = vision_feats[0].shape[1]
feats = [
feat.permute(1, 2, 0).view(bs, -1, *feat_size)
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
][::-1]
return feats
def forward(self, features_orig, outputs):
batch_masks = []
batch_iou = []
batch_bboxes = []
for img_idx in range(len(outputs)):
only_score = False
# if len((outputs[img_idx]['pred_boxes'][0])) > 800:
# only_score = True
# batch_masks.append([]) # masks
# batch_bboxes.append(outputs[img_idx]['pred_boxes'].squeeze()*self.image_size)
# batch_iou.append(outputs[img_idx]['box_v'])
# continue
# dict with 'vision_features =(bs,c,w,h)', 'vision_pos_enc=[3lvl, bs,c,w,h]', 'backbone_fpn=[3lvl, bs,c,w,h]'
# create new dict only wtih img_idx in batch
features = {
'vision_features': features_orig['vision_features'][img_idx].unsqueeze(0),
'vision_pos_enc': [x[img_idx].unsqueeze(0) for x in features_orig['vision_pos_enc']],
'backbone_fpn': [x[img_idx].unsqueeze(0) for x in features_orig['backbone_fpn']],
}
features = self.forward_feats(features)
step = 50
low_res_masks = []
iou_predictions = []
corrected_bboxes_ = []
masks_ = []
for box_i in range(step, len(outputs[img_idx]['pred_boxes'][0]) + step, step):
box = outputs[img_idx]['pred_boxes'][0][(box_i - step):box_i] * self.image_size
box_coords = box.reshape(-1, 2, 2)
box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=box.device)
box_labels = box_labels.repeat(box.size(0), 1)
sparse_embeddings, dense_embeddings = self.prompt_encoder_sam(
points= (box_coords, box_labels),
boxes=None,
masks=None,
)
low_res_masks_, iou_predictions_, _, _ = self.mask_decoder(
image_embeddings=features[-1],
image_pe=self.prompt_encoder_sam.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=True,
repeat_image=True,
high_res_features=features[:-1],
)
low_res_masks.append(low_res_masks_)
iou_predictions.append(iou_predictions_[:, 2])
# masks = F.interpolate(
# low_res_masks,
# (self.backbone.img_size, self.backbone.img_size),
# mode="bilinear",
# align_corners=False,
# )
# masks = masks[..., : features.size[-1] * 16, : features.size[-1] * 16]
masks = F.interpolate(low_res_masks_, (self.image_size, self.image_size),
mode = "bilinear",
align_corners = False)
# masks = masks[..., : 1024, : 1024]
masks = masks > 0
corrected_bboxes = torch.zeros((masks.shape[0], 4), dtype=torch.float)
masks = masks[:, 2]
# TODO SELECT BEST MASK!!!!!!!!!!!!!
for index, mask in enumerate(masks):
y, x = torch.where(mask != 0)
if y.shape[0] > 0 and x.shape[0] > 0:
corrected_bboxes[index, 0] = torch.min(x)
corrected_bboxes[index, 1] = torch.min(y)
corrected_bboxes[index, 2] = torch.max(x)
corrected_bboxes[index, 3] = torch.max(y)
masks_.append(masks)#[])#
corrected_bboxes_.append(corrected_bboxes)
if only_score:
batch_masks.append([]) # masks
batch_bboxes.append(outputs[img_idx]['pred_boxes'].squeeze()*self.image_size)
batch_iou.append(torch.cat(iou_predictions).unsqueeze(0))
if len(corrected_bboxes_) > 0:
# masks = (masks * torch.tensor([i for i in range(masks.shape[0])]).view(-1, 1, 1, 1).to(masks.device)).sum(dim=0)
batch_masks.append(masks_) # masks
# batch_masks.append([])
batch_bboxes.append(torch.cat(corrected_bboxes_))
batch_iou.append(torch.cat(iou_predictions).unsqueeze(0))
else:
batch_masks.append([])
batch_bboxes.append(torch.tensor([]).to(features[0].device))
batch_iou.append(torch.tensor([]).to(features[0].device))
batch_masks = [torch.cat(masks) if len(masks)>0 else torch.zeros((1,1024,1024)) for masks in batch_masks]
return batch_masks, batch_iou, batch_bboxes