Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from sam2.sam2.modeling.sam.mask_decoder import MaskDecoder | |
| from sam2.sam2.modeling.sam.prompt_encoder import PromptEncoder | |
| from sam2.sam2.modeling.sam.transformer import TwoWayTransformer | |
| class MaskProcessor(nn.Module): | |
| def __init__(self, hidden_dim, image_size, reduction, **kwargs): | |
| super().__init__() | |
| self.sam_prompt_embed_dim = hidden_dim | |
| self.reduction = reduction | |
| self.sam_image_embedding_size = image_size // self.reduction | |
| self.image_size = image_size | |
| self.prompt_encoder_sam = PromptEncoder( | |
| embed_dim=self.sam_prompt_embed_dim, | |
| image_embedding_size=( | |
| self.sam_image_embedding_size, | |
| self.sam_image_embedding_size, | |
| ), | |
| input_image_size=(self.image_size, self.image_size), | |
| mask_in_chans=16, | |
| ) | |
| self.mask_decoder = MaskDecoder( | |
| num_multimask_outputs=3, | |
| transformer=TwoWayTransformer( | |
| depth=2, | |
| embedding_dim=self.sam_prompt_embed_dim, | |
| mlp_dim=2048, | |
| num_heads=8, | |
| ), | |
| transformer_dim=self.sam_prompt_embed_dim, | |
| iou_head_depth=3, | |
| iou_head_hidden_dim=256, | |
| use_high_res_features=True, | |
| iou_prediction_use_sigmoid=True, | |
| pred_obj_scores=True, | |
| pred_obj_scores_mlp=True, | |
| use_multimask_token_for_obj_ptr=True, | |
| **({}), | |
| ) | |
| self.num_feature_levels = 3 | |
| # Spatial dim for backbone feature maps | |
| self._bb_feat_sizes = [ | |
| (256, 256), | |
| (128, 128), | |
| (64, 64), | |
| ] | |
| self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, hidden_dim)) | |
| # TODO change loading, this is ugly | |
| checkpoint = torch.hub.load_state_dict_from_url( | |
| 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt', | |
| map_location="cpu" | |
| )['model'] | |
| state_dict = {k.replace("mask_decoder.", "").replace("sam_", ""): v for k, v in | |
| checkpoint.items() if "mask_decoder" in k} | |
| self.mask_decoder.load_state_dict(state_dict) | |
| state_dict = {k.replace("prompt_encoder.", "").replace("sam_", ""): v for k, v in | |
| checkpoint.items() if "prompt_encoder" in k} | |
| self.prompt_encoder_sam.load_state_dict(state_dict) | |
| state_dict = {k: v for k, v in checkpoint.items() if "no_mem_embed" in k} | |
| self.load_state_dict(state_dict, strict=False) | |
| def _prepare_backbone_features(self, backbone_out): | |
| """Prepare and flatten visual features.""" | |
| backbone_out = backbone_out.copy() | |
| assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) | |
| assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels | |
| feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels:] | |
| vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels:] | |
| feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] | |
| # flatten NxCxHxW to HWxNxC | |
| vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] | |
| vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] | |
| return backbone_out, vision_feats, vision_pos_embeds, feat_sizes | |
| def forward_feats(self, feats: torch.Tensor): | |
| """Get the image feature on the input batch.""" | |
| # precompute projected level 0 and level 1 features in SAM decoder | |
| # to avoid running it again on every SAM click | |
| feats["backbone_fpn"][0] = self.mask_decoder.conv_s0( | |
| feats["backbone_fpn"][0] | |
| ) | |
| feats["backbone_fpn"][1] = self.mask_decoder.conv_s1( | |
| feats["backbone_fpn"][1] | |
| ) | |
| _, vision_feats, _, _ = self._prepare_backbone_features(feats) | |
| vision_feats[-1] = vision_feats[-1] + self.no_mem_embed | |
| bs = vision_feats[0].shape[1] | |
| feats = [ | |
| feat.permute(1, 2, 0).view(bs, -1, *feat_size) | |
| for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) | |
| ][::-1] | |
| return feats | |
| def forward(self, features_orig, outputs): | |
| batch_masks = [] | |
| batch_iou = [] | |
| batch_bboxes = [] | |
| for img_idx in range(len(outputs)): | |
| only_score = False | |
| # if len((outputs[img_idx]['pred_boxes'][0])) > 800: | |
| # only_score = True | |
| # batch_masks.append([]) # masks | |
| # batch_bboxes.append(outputs[img_idx]['pred_boxes'].squeeze()*self.image_size) | |
| # batch_iou.append(outputs[img_idx]['box_v']) | |
| # continue | |
| # dict with 'vision_features =(bs,c,w,h)', 'vision_pos_enc=[3lvl, bs,c,w,h]', 'backbone_fpn=[3lvl, bs,c,w,h]' | |
| # create new dict only wtih img_idx in batch | |
| features = { | |
| 'vision_features': features_orig['vision_features'][img_idx].unsqueeze(0), | |
| 'vision_pos_enc': [x[img_idx].unsqueeze(0) for x in features_orig['vision_pos_enc']], | |
| 'backbone_fpn': [x[img_idx].unsqueeze(0) for x in features_orig['backbone_fpn']], | |
| } | |
| features = self.forward_feats(features) | |
| step = 50 | |
| low_res_masks = [] | |
| iou_predictions = [] | |
| corrected_bboxes_ = [] | |
| masks_ = [] | |
| for box_i in range(step, len(outputs[img_idx]['pred_boxes'][0]) + step, step): | |
| box = outputs[img_idx]['pred_boxes'][0][(box_i - step):box_i] * self.image_size | |
| box_coords = box.reshape(-1, 2, 2) | |
| box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=box.device) | |
| box_labels = box_labels.repeat(box.size(0), 1) | |
| sparse_embeddings, dense_embeddings = self.prompt_encoder_sam( | |
| points= (box_coords, box_labels), | |
| boxes=None, | |
| masks=None, | |
| ) | |
| low_res_masks_, iou_predictions_, _, _ = self.mask_decoder( | |
| image_embeddings=features[-1], | |
| image_pe=self.prompt_encoder_sam.get_dense_pe(), | |
| sparse_prompt_embeddings=sparse_embeddings, | |
| dense_prompt_embeddings=dense_embeddings, | |
| multimask_output=True, | |
| repeat_image=True, | |
| high_res_features=features[:-1], | |
| ) | |
| low_res_masks.append(low_res_masks_) | |
| iou_predictions.append(iou_predictions_[:, 2]) | |
| # masks = F.interpolate( | |
| # low_res_masks, | |
| # (self.backbone.img_size, self.backbone.img_size), | |
| # mode="bilinear", | |
| # align_corners=False, | |
| # ) | |
| # masks = masks[..., : features.size[-1] * 16, : features.size[-1] * 16] | |
| masks = F.interpolate(low_res_masks_, (self.image_size, self.image_size), | |
| mode = "bilinear", | |
| align_corners = False) | |
| # masks = masks[..., : 1024, : 1024] | |
| masks = masks > 0 | |
| corrected_bboxes = torch.zeros((masks.shape[0], 4), dtype=torch.float) | |
| masks = masks[:, 2] | |
| # TODO SELECT BEST MASK!!!!!!!!!!!!! | |
| for index, mask in enumerate(masks): | |
| y, x = torch.where(mask != 0) | |
| if y.shape[0] > 0 and x.shape[0] > 0: | |
| corrected_bboxes[index, 0] = torch.min(x) | |
| corrected_bboxes[index, 1] = torch.min(y) | |
| corrected_bboxes[index, 2] = torch.max(x) | |
| corrected_bboxes[index, 3] = torch.max(y) | |
| masks_.append(masks)#[])# | |
| corrected_bboxes_.append(corrected_bboxes) | |
| if only_score: | |
| batch_masks.append([]) # masks | |
| batch_bboxes.append(outputs[img_idx]['pred_boxes'].squeeze()*self.image_size) | |
| batch_iou.append(torch.cat(iou_predictions).unsqueeze(0)) | |
| if len(corrected_bboxes_) > 0: | |
| # masks = (masks * torch.tensor([i for i in range(masks.shape[0])]).view(-1, 1, 1, 1).to(masks.device)).sum(dim=0) | |
| batch_masks.append(masks_) # masks | |
| # batch_masks.append([]) | |
| batch_bboxes.append(torch.cat(corrected_bboxes_)) | |
| batch_iou.append(torch.cat(iou_predictions).unsqueeze(0)) | |
| else: | |
| batch_masks.append([]) | |
| batch_bboxes.append(torch.tensor([]).to(features[0].device)) | |
| batch_iou.append(torch.tensor([]).to(features[0].device)) | |
| batch_masks = [torch.cat(masks) if len(masks)>0 else torch.zeros((1,1024,1024)) for masks in batch_masks] | |
| return batch_masks, batch_iou, batch_bboxes |