import numpy as np import skimage import torch from hydra import compose from hydra.utils import instantiate from omegaconf import OmegaConf from torch import nn from torch.nn import functional as F from torchvision.ops import roi_align from torchvision.transforms import Resize from torch.cuda.amp import autocast from utils.box_ops import boxes_with_scores from .query_generator import C_base from .sam_mask import MaskProcessor class CNT(nn.Module): def __init__( self, image_size: int, num_objects: int, emb_dim: int, kernel_dim: int, reduction: int, zero_shot: bool, ): super(CNT, self).__init__() self.emb_dim = emb_dim self.num_objects = num_objects self.reduction = reduction self.kernel_dim = kernel_dim self.image_size = image_size self.zero_shot = zero_shot self.class_embed = nn.Sequential(nn.Linear(emb_dim, 1), nn.LeakyReLU()) self.bbox_embed = MLP(emb_dim, emb_dim, 4, 3) self.adapt_features = C_base( transformer_dim=self.emb_dim, num_prototype_attn_steps=3, num_image_attn_steps=2, ) from .prompt_encoder import PromptEncoder self.sam_prompt_encoder = PromptEncoder( embed_dim=self.emb_dim, image_embedding_size=( self.image_size // self.reduction, self.image_size // self.reduction, ), input_image_size=(self.image_size, self.image_size), mask_in_chans=16, ) config_name = '../configs/sam2_hiera_base_plus.yaml' cfg = compose(config_name=config_name) OmegaConf.resolve(cfg) self.backbone = instantiate(cfg.backbone, _recursive_=True) checkpoint = torch.hub.load_state_dict_from_url( 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/' + config_name.split('/')[-1].replace('.yaml', '.pt'), map_location="cpu" )['model'] state_dict = {k.replace("image_encoder.", ""): v for k, v in checkpoint.items()} self.backbone.load_state_dict(state_dict, strict=False) self.shape_or_objectness = nn.Sequential( nn.Linear(2, 64), nn.ReLU(), nn.Linear(64, emb_dim), nn.ReLU(), nn.Linear(emb_dim, 1 ** 2 * emb_dim) ) self.resize = Resize((1024, 1024)) self.sam_mask = MaskProcessor(self.emb_dim, self.image_size, reduction) self.sam_corr = True def forward(self, x, bboxes): self.num_objects = bboxes.size(1) with torch.no_grad(): feats = self.backbone(x) src = feats['vision_features'] bs, c, w, h = src.shape self.reduction = 1024 / w bboxes_roi = torch.cat([ torch.arange( bs, requires_grad=False ).to(bboxes.device).repeat_interleave(self.num_objects).reshape(-1, 1), bboxes.flatten(0, 1), ], dim=1) self.kernel_dim = 1 # # NORMAL exemplars = roi_align( src, boxes=bboxes_roi, output_size=self.kernel_dim, spatial_scale=1.0 / self.reduction, aligned=True ).permute(0, 2, 3, 1).reshape(bs, self.num_objects * self.kernel_dim ** 2, self.emb_dim) l1 = feats['backbone_fpn'][0] l2 = feats['backbone_fpn'][1] r1 = 1.0 / self.reduction * 2 * 2 exemplars_l1 = roi_align( l1, boxes=bboxes_roi, output_size=self.kernel_dim, spatial_scale=1.0 / self.reduction * 2 * 2, aligned=True ).permute(0, 2, 3, 1).reshape(bs, self.num_objects * self.kernel_dim ** 2, self.emb_dim) exemplars_l2 = roi_align( l2, boxes=bboxes_roi, output_size=self.kernel_dim, spatial_scale=1.0 / self.reduction * 2, aligned=True ).permute(0, 2, 3, 1).reshape(bs, self.num_objects * self.kernel_dim ** 2, self.emb_dim) box_hw = torch.zeros(bboxes.size(0), bboxes.size(1), 2).to(bboxes.device) box_hw[:, :, 0] = bboxes[:, :, 2] - bboxes[:, :, 0] box_hw[:, :, 1] = bboxes[:, :, 3] - bboxes[:, :, 1] # Encode shape shape = self.shape_or_objectness(box_hw).reshape( bs, -1, self.emb_dim ) prototype_embeddings = torch.cat([exemplars, shape], dim=1) prototype_embeddings_l1 = torch.cat([exemplars_l1, shape], dim=1) prototype_embeddings_l2 = torch.cat([exemplars_l2, shape], dim=1) hq_prototype_embeddings = [prototype_embeddings_l1, prototype_embeddings_l2] with autocast(enabled=False): if src.type != torch.float32: src = src.float() prototype_embeddings = prototype_embeddings.float() hq_prototype_embeddings = [hq.float() for hq in hq_prototype_embeddings] feats['backbone_fpn'] = [f.float() for f in feats['backbone_fpn']] feats['vision_pos_enc'] = [f.float() for f in feats['vision_pos_enc']] # adapt image feature with prototypes adapted_f, adapted_f_aux = self.adapt_features( image_embeddings=src, image_pe=self.sam_prompt_encoder.get_dense_pe(), prototype_embeddings=prototype_embeddings, hq_features=feats['backbone_fpn'], hq_prototypes=hq_prototype_embeddings, hq_pos=feats['vision_pos_enc'], ) # Predict class [fg, bg] and l,r,t,b bs, c, w, h = adapted_f.shape adapted_f = adapted_f.view(bs, self.emb_dim, -1).permute(0, 2, 1) centerness = self.class_embed(adapted_f).view(bs, w, h, 1).permute(0, 3, 1, 2) outputs_coord = self.bbox_embed(adapted_f).sigmoid().view(bs, w, h, 4).permute(0, 3, 1, 2) outputs, ref_points = boxes_with_scores(centerness, outputs_coord,sort=False, validate=True) # from matplotlib import pyplot as plt # plt.clf() # idx = 0 # orig_bboxes = outputs.copy() # img_ = np.array((x).cpu()[idx].permute(1, 2, 0)) # test.resize512 # img_ = img_ - np.min(img_) # img_ = img_ / np.max(img_) # plt.imshow(img_) # # bboxes_pred = orig_bboxes[idx]['pred_boxes'] # bboxes_ = ((bboxes_pred * img_.shape[0])).detach().cpu()[0] # # # calculate width and height and remove bboxes with width or height less than 3px # # bboxes_ = bboxes_[(bboxes_[:, 2] - bboxes_[:, 0]) > 15] # # bboxes_ = bboxes_[(bboxes_[:, 3] - bboxes_[:, 1]) > 15] # # for i in range(len(bboxes_)): # plt.plot([bboxes_[i][0], bboxes_[i][0], bboxes_[i][2], bboxes_[i][2], bboxes_[i][0]], # [bboxes_[i][1], bboxes_[i][3], bboxes_[i][3], bboxes_[i][1], bboxes_[i][1]], # c='orange', linewidth=0.5) # plt.savefig("gecoboxes") if self.sam_corr: # mask processing masks, ious, corrected_bboxes = self.sam_mask(feats, outputs) for i in range(len(outputs)): outputs[i]["scores"] = ious[i] outputs[i]["pred_boxes"] = corrected_bboxes[i].to(outputs[i]["pred_boxes"].device).unsqueeze(0) / \ x.shape[ -1] else: for i in range(len(outputs)): outputs[i]["scores"] = outputs[i]["box_v"] return outputs, ref_points, centerness, outputs_coord, masks class MLP(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x def build_model(args): assert args.reduction in [4, 8, 16] return CNT( image_size=args.image_size, num_objects=args.num_objects, zero_shot=args.zero_shot, emb_dim=args.emb_dim, kernel_dim=args.kernel_dim, reduction=args.reduction, )