| |
| |
| |
| |
| |
| |
| |
| |
| import os |
| import copy |
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| from typing import List |
|
|
| from util.keypoint_ops import keypoint_xyzxyz_to_xyxyzz |
| from util.misc import NestedTensor, nested_tensor_from_tensor_list,inverse_sigmoid |
|
|
| from .utils import MLP |
| from .backbone import build_backbone |
| from ..registry import MODULE_BUILD_FUNCS |
| from .mask_generate import prepare_for_mask, post_process |
| from .deformable_transformer import build_deformable_transformer |
|
|
|
|
| class UniPose(nn.Module): |
| """ This is the Cross-Attention Detector module that performs object detection """ |
|
|
| def __init__(self, backbone, transformer, num_classes, num_queries, |
| aux_loss=False, iter_update=False, |
| query_dim=2, |
| random_refpoints_xy=False, |
| fix_refpoints_hw=-1, |
| num_feature_levels=1, |
| nheads=8, |
| |
| two_stage_type='no', |
| two_stage_add_query_num=0, |
| dec_pred_class_embed_share=True, |
| dec_pred_bbox_embed_share=True, |
| two_stage_class_embed_share=True, |
| two_stage_bbox_embed_share=True, |
| decoder_sa_type='sa', |
| num_patterns=0, |
| dn_number=100, |
| dn_box_noise_scale=0.4, |
| dn_label_noise_ratio=0.5, |
| dn_labelbook_size=100, |
| use_label_enc=True, |
| |
| text_encoder_type='bert-base-uncased', |
| |
| binary_query_selection=False, |
| use_cdn=True, |
| sub_sentence_present=True, |
| num_body_points=68, |
| num_box_decoder_layers=2, |
| ): |
| """ Initializes the model. |
| Parameters: |
| backbone: torch module of the backbone to be used. See backbone.py |
| transformer: torch module of the transformer architecture. See transformer.py |
| num_classes: number of object classes |
| num_queries: number of object queries, ie detection slot. This is the maximal number of objects |
| Conditional DETR can detect in a single image. For COCO, we recommend 100 queries. |
| aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. |
| |
| fix_refpoints_hw: -1(default): learn w and h for each box seperately |
| >0 : given fixed number |
| -2 : learn a shared w and h |
| """ |
| super().__init__() |
| self.num_queries = num_queries |
| self.transformer = transformer |
| self.num_classes = num_classes |
| self.hidden_dim = hidden_dim = transformer.d_model |
| self.num_feature_levels = num_feature_levels |
| self.nheads = nheads |
| self.use_label_enc = use_label_enc |
| if use_label_enc: |
| self.label_enc = nn.Embedding(dn_labelbook_size + 1, hidden_dim) |
| else: |
| raise NotImplementedError |
| self.label_enc = None |
| self.max_text_len = 256 |
| self.binary_query_selection = binary_query_selection |
| self.sub_sentence_present = sub_sentence_present |
|
|
| |
| self.query_dim = query_dim |
| assert query_dim == 4 |
| self.random_refpoints_xy = random_refpoints_xy |
| self.fix_refpoints_hw = fix_refpoints_hw |
|
|
| |
| self.num_patterns = num_patterns |
| self.dn_number = dn_number |
| self.dn_box_noise_scale = dn_box_noise_scale |
| self.dn_label_noise_ratio = dn_label_noise_ratio |
| self.dn_labelbook_size = dn_labelbook_size |
| self.use_cdn = use_cdn |
|
|
|
|
| self.projection = MLP(512, hidden_dim, hidden_dim, 3) |
|
|
| self.projection_kpt = MLP(512, hidden_dim, hidden_dim, 3) |
|
|
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| |
| |
| |
| |
| |
|
|
| self.pos_proj = nn.Linear(hidden_dim, 768) |
| self.padding = nn.Embedding(1, 768) |
|
|
| |
| if num_feature_levels > 1: |
| num_backbone_outs = len(backbone.num_channels) |
| input_proj_list = [] |
| for _ in range(num_backbone_outs): |
| in_channels = backbone.num_channels[_] |
| input_proj_list.append(nn.Sequential( |
| nn.Conv2d(in_channels, hidden_dim, kernel_size=1), |
| nn.GroupNorm(32, hidden_dim), |
| )) |
| for _ in range(num_feature_levels - num_backbone_outs): |
| input_proj_list.append(nn.Sequential( |
| nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1), |
| nn.GroupNorm(32, hidden_dim), |
| )) |
| in_channels = hidden_dim |
| self.input_proj = nn.ModuleList(input_proj_list) |
| else: |
| assert two_stage_type == 'no', "two_stage_type should be no if num_feature_levels=1 !!!" |
| self.input_proj = nn.ModuleList([ |
| nn.Sequential( |
| nn.Conv2d(backbone.num_channels[-1], hidden_dim, kernel_size=1), |
| nn.GroupNorm(32, hidden_dim), |
| )]) |
|
|
| self.backbone = backbone |
| self.aux_loss = aux_loss |
| self.box_pred_damping = box_pred_damping = None |
|
|
| self.iter_update = iter_update |
| assert iter_update, "Why not iter_update?" |
|
|
| |
| self.dec_pred_class_embed_share = dec_pred_class_embed_share |
| self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share |
| |
| _class_embed = ContrastiveAssign() |
|
|
|
|
|
|
| _bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) |
| nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0) |
| nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0) |
|
|
| _pose_embed = MLP(hidden_dim, hidden_dim, 2, 3) |
| _pose_hw_embed = MLP(hidden_dim, hidden_dim, 2, 3) |
| nn.init.constant_(_pose_embed.layers[-1].weight.data, 0) |
| nn.init.constant_(_pose_embed.layers[-1].bias.data, 0) |
|
|
| if dec_pred_bbox_embed_share: |
| box_embed_layerlist = [_bbox_embed for i in range(transformer.num_decoder_layers)] |
| else: |
| box_embed_layerlist = [copy.deepcopy(_bbox_embed) for i in range(transformer.num_decoder_layers)] |
| if dec_pred_class_embed_share: |
| class_embed_layerlist = [_class_embed for i in range(transformer.num_decoder_layers)] |
| else: |
| class_embed_layerlist = [copy.deepcopy(_class_embed) for i in range(transformer.num_decoder_layers)] |
|
|
|
|
| if dec_pred_bbox_embed_share: |
|
|
| pose_embed_layerlist = [_pose_embed for i in |
| range(transformer.num_decoder_layers - num_box_decoder_layers + 1)] |
| else: |
| pose_embed_layerlist = [copy.deepcopy(_pose_embed) for i in |
| range(transformer.num_decoder_layers - num_box_decoder_layers + 1)] |
|
|
| pose_hw_embed_layerlist = [_pose_hw_embed for i in |
| range(transformer.num_decoder_layers - num_box_decoder_layers)] |
|
|
|
|
| self.num_box_decoder_layers = num_box_decoder_layers |
| self.bbox_embed = nn.ModuleList(box_embed_layerlist) |
| self.class_embed = nn.ModuleList(class_embed_layerlist) |
| self.num_body_points = num_body_points |
| self.pose_embed = nn.ModuleList(pose_embed_layerlist) |
| self.pose_hw_embed = nn.ModuleList(pose_hw_embed_layerlist) |
|
|
| self.transformer.decoder.bbox_embed = self.bbox_embed |
| self.transformer.decoder.class_embed = self.class_embed |
|
|
| self.transformer.decoder.pose_embed = self.pose_embed |
| self.transformer.decoder.pose_hw_embed = self.pose_hw_embed |
|
|
| self.transformer.decoder.num_body_points = num_body_points |
|
|
|
|
| |
| self.two_stage_type = two_stage_type |
| self.two_stage_add_query_num = two_stage_add_query_num |
| assert two_stage_type in ['no', 'standard'], "unknown param {} of two_stage_type".format(two_stage_type) |
| if two_stage_type != 'no': |
| if two_stage_bbox_embed_share: |
| assert dec_pred_class_embed_share and dec_pred_bbox_embed_share |
| self.transformer.enc_out_bbox_embed = _bbox_embed |
| else: |
| self.transformer.enc_out_bbox_embed = copy.deepcopy(_bbox_embed) |
|
|
| if two_stage_class_embed_share: |
| assert dec_pred_class_embed_share and dec_pred_bbox_embed_share |
| self.transformer.enc_out_class_embed = _class_embed |
| else: |
| self.transformer.enc_out_class_embed = copy.deepcopy(_class_embed) |
|
|
| self.refpoint_embed = None |
| if self.two_stage_add_query_num > 0: |
| self.init_ref_points(two_stage_add_query_num) |
|
|
| self.decoder_sa_type = decoder_sa_type |
| assert decoder_sa_type in ['sa', 'ca_label', 'ca_content'] |
| |
| if decoder_sa_type == 'ca_label': |
| self.label_embedding = nn.Embedding(num_classes, hidden_dim) |
| for layer in self.transformer.decoder.layers: |
| layer.label_embedding = self.label_embedding |
| else: |
| for layer in self.transformer.decoder.layers: |
| layer.label_embedding = None |
| self.label_embedding = None |
|
|
| self._reset_parameters() |
|
|
| def open_set_transfer_init(self): |
| for name, param in self.named_parameters(): |
| if 'fusion_layers' in name: |
| continue |
| if 'ca_text' in name: |
| continue |
| if 'catext_norm' in name: |
| continue |
| if 'catext_dropout' in name: |
| continue |
| if "text_layers" in name: |
| continue |
| if 'bert' in name: |
| continue |
| if 'bbox_embed' in name: |
| continue |
| if 'label_enc.weight' in name: |
| continue |
| if 'feat_map' in name: |
| continue |
| if 'enc_output' in name: |
| continue |
|
|
| param.requires_grad_(False) |
|
|
| |
|
|
| def _reset_parameters(self): |
| |
| for proj in self.input_proj: |
| nn.init.xavier_uniform_(proj[0].weight, gain=1) |
| nn.init.constant_(proj[0].bias, 0) |
|
|
| def init_ref_points(self, use_num_queries): |
| self.refpoint_embed = nn.Embedding(use_num_queries, self.query_dim) |
|
|
| if self.random_refpoints_xy: |
| |
| self.refpoint_embed.weight.data[:, :2].uniform_(0, 1) |
| self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(self.refpoint_embed.weight.data[:, :2]) |
| self.refpoint_embed.weight.data[:, :2].requires_grad = False |
|
|
| if self.fix_refpoints_hw > 0: |
| print("fix_refpoints_hw: {}".format(self.fix_refpoints_hw)) |
| assert self.random_refpoints_xy |
| self.refpoint_embed.weight.data[:, 2:] = self.fix_refpoints_hw |
| self.refpoint_embed.weight.data[:, 2:] = inverse_sigmoid(self.refpoint_embed.weight.data[:, 2:]) |
| self.refpoint_embed.weight.data[:, 2:].requires_grad = False |
| elif int(self.fix_refpoints_hw) == -1: |
| pass |
| elif int(self.fix_refpoints_hw) == -2: |
| print('learn a shared h and w') |
| assert self.random_refpoints_xy |
| self.refpoint_embed = nn.Embedding(use_num_queries, 2) |
| self.refpoint_embed.weight.data[:, :2].uniform_(0, 1) |
| self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(self.refpoint_embed.weight.data[:, :2]) |
| self.refpoint_embed.weight.data[:, :2].requires_grad = False |
| self.hw_embed = nn.Embedding(1, 1) |
| else: |
| raise NotImplementedError('Unknown fix_refpoints_hw {}'.format(self.fix_refpoints_hw)) |
|
|
| def forward(self, samples: NestedTensor, targets: List = None, **kw): |
| """ The forward expects a NestedTensor, which consists of: |
| - samples.tensor: batched images, of shape [batch_size x 3 x H x W] |
| - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels |
| |
| It returns a dict with the following elements: |
| - "pred_logits": the classification logits (including no-object) for all queries. |
| Shape= [batch_size x num_queries x num_classes] |
| - "pred_boxes": The normalized boxes coordinates for all queries, represented as |
| (center_x, center_y, width, height). These values are normalized in [0, 1], |
| relative to the size of each individual image (disregarding possible padding). |
| See PostProcess for information on how to retrieve the unnormalized bounding box. |
| - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of |
| dictionnaries containing the two above keys for each decoder layer. |
| """ |
|
|
| captions = [t['instance_text_prompt'] for t in targets] |
| bs=len(captions) |
| tensor_list = [tgt["object_embeddings_text"] for tgt in targets] |
| max_size = 350 |
| padded_tensors = [torch.cat([tensor, torch.zeros(max_size - tensor.size(0), tensor.size(1),device=tensor.device)]) if tensor.size(0) < max_size else tensor for tensor in tensor_list] |
| object_embeddings_text = torch.stack(padded_tensors) |
|
|
| kpts_embeddings_text = torch.stack([tgt["kpts_embeddings_text"] for tgt in targets])[:, :self.num_body_points] |
| encoded_text=self.projection(object_embeddings_text) |
| kpt_embeddings_specific=self.projection_kpt(kpts_embeddings_text) |
|
|
|
|
| kpt_vis = torch.stack([tgt["kpt_vis_text"] for tgt in targets])[:, :self.num_body_points] |
| kpt_mask = torch.cat((torch.ones_like(kpt_vis, device=kpt_vis.device)[..., 0].unsqueeze(-1), kpt_vis), dim=-1) |
|
|
|
|
| num_classes = encoded_text.shape[1] |
| text_self_attention_masks = torch.eye(num_classes).unsqueeze(0).expand(bs, -1, -1).bool().to(samples.device) |
| text_token_mask = torch.zeros(samples.shape[0],num_classes).to(samples.device)>0 |
| for i in range(bs): |
| text_token_mask[i,:len(captions[i])]=True |
|
|
| position_ids = torch.zeros(samples.shape[0], num_classes).to(samples.device) |
|
|
| for i in range(bs): |
| position_ids[i,:len(captions[i])]= 1 |
|
|
|
|
| text_dict = { |
| 'encoded_text': encoded_text, |
| 'text_token_mask': text_token_mask, |
| 'position_ids': position_ids, |
| 'text_self_attention_masks': text_self_attention_masks |
| } |
|
|
|
|
| |
|
|
| if isinstance(samples, (list, torch.Tensor)): |
| samples = nested_tensor_from_tensor_list(samples) |
| features, poss = self.backbone(samples) |
| if os.environ.get("SHILONG_AMP_INFNAN_DEBUG") == '1': |
| import ipdb; |
| ipdb.set_trace() |
|
|
|
|
| srcs = [] |
| masks = [] |
| for l, feat in enumerate(features): |
| src, mask = feat.decompose() |
| srcs.append(self.input_proj[l](src)) |
| masks.append(mask) |
| assert mask is not None |
|
|
| if self.num_feature_levels > len(srcs): |
| _len_srcs = len(srcs) |
| for l in range(_len_srcs, self.num_feature_levels): |
| if l == _len_srcs: |
| src = self.input_proj[l](features[-1].tensors) |
| else: |
| src = self.input_proj[l](srcs[-1]) |
| m = samples.mask |
| mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] |
| pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) |
| srcs.append(src) |
| masks.append(mask) |
| poss.append(pos_l) |
|
|
| if self.label_enc is not None: |
| label_enc = self.label_enc |
| else: |
| raise NotImplementedError |
| label_enc = encoded_text |
| if self.dn_number > 0 or targets is not None: |
| input_query_label, input_query_bbox, attn_mask, attn_mask2, dn_meta = \ |
| prepare_for_mask(kpt_mask=kpt_mask) |
| else: |
| assert targets is None |
| input_query_bbox = input_query_label = attn_mask = attn_mask2 = dn_meta = None |
|
|
|
|
| hs, reference, hs_enc, ref_enc, init_box_proposal = self.transformer(srcs, masks, input_query_bbox, poss, |
| input_query_label, attn_mask, attn_mask2, |
| text_dict, dn_meta,targets,kpt_embeddings_specific) |
|
|
| |
| if self.label_enc is not None: |
| hs[0] += self.label_enc.weight[0, 0] * 0.0 |
|
|
| hs[0] += self.pos_proj.weight[0, 0] * 0.0 |
| hs[0] += self.pos_proj.bias[0] * 0.0 |
| hs[0] += self.padding.weight[0, 0] * 0.0 |
|
|
| num_group = 50 |
| effective_dn_number = dn_meta['pad_size'] if self.training else 0 |
| outputs_coord_list = [] |
| outputs_class = [] |
|
|
|
|
| for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_cls_embed, layer_hs) in enumerate( |
| zip(reference[:-1], self.bbox_embed, self.class_embed, hs)): |
|
|
|
|
| if dec_lid < self.num_box_decoder_layers: |
| layer_delta_unsig = layer_bbox_embed(layer_hs) |
| layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(layer_ref_sig) |
| layer_outputs_unsig = layer_outputs_unsig.sigmoid() |
| layer_cls = layer_cls_embed(layer_hs, text_dict) |
| outputs_coord_list.append(layer_outputs_unsig) |
| outputs_class.append(layer_cls) |
|
|
|
|
| else: |
|
|
| layer_hs_bbox_dn = layer_hs[:, :effective_dn_number, :] |
| layer_hs_bbox_norm = layer_hs[:, effective_dn_number:, :][:, 0::(self.num_body_points + 1), :] |
| bs = layer_ref_sig.shape[0] |
| reference_before_sigmoid_bbox_dn = layer_ref_sig[:, :effective_dn_number, :] |
| reference_before_sigmoid_bbox_norm = layer_ref_sig[:, effective_dn_number:, :][:, |
| 0::(self.num_body_points + 1), :] |
| layer_delta_unsig_dn = layer_bbox_embed(layer_hs_bbox_dn) |
| layer_delta_unsig_norm = layer_bbox_embed(layer_hs_bbox_norm) |
| layer_outputs_unsig_dn = layer_delta_unsig_dn + inverse_sigmoid(reference_before_sigmoid_bbox_dn) |
| layer_outputs_unsig_dn = layer_outputs_unsig_dn.sigmoid() |
| layer_outputs_unsig_norm = layer_delta_unsig_norm + inverse_sigmoid(reference_before_sigmoid_bbox_norm) |
| layer_outputs_unsig_norm = layer_outputs_unsig_norm.sigmoid() |
| layer_outputs_unsig = torch.cat((layer_outputs_unsig_dn, layer_outputs_unsig_norm), dim=1) |
| layer_cls_dn = layer_cls_embed(layer_hs_bbox_dn, text_dict) |
| layer_cls_norm = layer_cls_embed(layer_hs_bbox_norm, text_dict) |
| layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1) |
| outputs_class.append(layer_cls) |
| outputs_coord_list.append(layer_outputs_unsig) |
|
|
| |
| outputs_keypoints_list = [] |
| outputs_keypoints_hw = [] |
| kpt_index = [x for x in range(num_group * (self.num_body_points + 1)) if x % (self.num_body_points + 1) != 0] |
| for dec_lid, (layer_ref_sig, layer_hs) in enumerate(zip(reference[:-1], hs)): |
| if dec_lid < self.num_box_decoder_layers: |
| assert isinstance(layer_hs, torch.Tensor) |
| bs = layer_hs.shape[0] |
| layer_res = layer_hs.new_zeros((bs, self.num_queries, self.num_body_points * 3)) |
| outputs_keypoints_list.append(layer_res) |
| else: |
| bs = layer_ref_sig.shape[0] |
| layer_hs_kpt = layer_hs[:, effective_dn_number:, :].index_select(1, torch.tensor(kpt_index, |
| device=layer_hs.device)) |
| delta_xy_unsig = self.pose_embed[dec_lid - self.num_box_decoder_layers](layer_hs_kpt) |
| layer_ref_sig_kpt = layer_ref_sig[:, effective_dn_number:, :].index_select(1, torch.tensor(kpt_index, |
| device=layer_hs.device)) |
| layer_outputs_unsig_keypoints = delta_xy_unsig + inverse_sigmoid(layer_ref_sig_kpt[..., :2]) |
| vis_xy_unsig = torch.ones_like(layer_outputs_unsig_keypoints, |
| device=layer_outputs_unsig_keypoints.device) |
| xyv = torch.cat((layer_outputs_unsig_keypoints, vis_xy_unsig[:, :, 0].unsqueeze(-1)), dim=-1) |
| xyv = xyv.sigmoid() |
| layer_res = xyv.reshape((bs, num_group, self.num_body_points, 3)).flatten(2, 3) |
| layer_hw = layer_ref_sig_kpt[..., 2:].reshape(bs, num_group, self.num_body_points, 2).flatten(2, 3) |
| layer_res = keypoint_xyzxyz_to_xyxyzz(layer_res) |
| outputs_keypoints_list.append(layer_res) |
| outputs_keypoints_hw.append(layer_hw) |
|
|
|
|
| if self.dn_number > 0 and dn_meta is not None: |
| outputs_class, outputs_coord_list = \ |
| post_process(outputs_class, outputs_coord_list, |
| dn_meta, self.aux_loss, self._set_aux_loss) |
| out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord_list[-1], |
| 'pred_keypoints': outputs_keypoints_list[-1]} |
|
|
| return out |
|
|
|
|
| @MODULE_BUILD_FUNCS.registe_with_name(module_name='UniPose') |
| def build_unipose(args): |
|
|
| num_classes = args.num_classes |
| device = torch.device(args.device) |
|
|
| backbone = build_backbone(args) |
|
|
| transformer = build_deformable_transformer(args) |
|
|
| try: |
| match_unstable_error = args.match_unstable_error |
| dn_labelbook_size = args.dn_labelbook_size |
| except: |
| match_unstable_error = True |
| dn_labelbook_size = num_classes |
|
|
| try: |
| dec_pred_class_embed_share = args.dec_pred_class_embed_share |
| except: |
| dec_pred_class_embed_share = True |
| try: |
| dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share |
| except: |
| dec_pred_bbox_embed_share = True |
|
|
| binary_query_selection = False |
| try: |
| binary_query_selection = args.binary_query_selection |
| except: |
| binary_query_selection = False |
|
|
| use_cdn = True |
| try: |
| use_cdn = args.use_cdn |
| except: |
| use_cdn = True |
|
|
| sub_sentence_present = True |
| try: |
| sub_sentence_present = args.sub_sentence_present |
| except: |
| sub_sentence_present = True |
| |
|
|
| model = UniPose( |
| backbone, |
| transformer, |
| num_classes=num_classes, |
| num_queries=args.num_queries, |
| aux_loss=True, |
| iter_update=True, |
| query_dim=4, |
| random_refpoints_xy=args.random_refpoints_xy, |
| fix_refpoints_hw=args.fix_refpoints_hw, |
| num_feature_levels=args.num_feature_levels, |
| nheads=args.nheads, |
| dec_pred_class_embed_share=dec_pred_class_embed_share, |
| dec_pred_bbox_embed_share=dec_pred_bbox_embed_share, |
| |
| two_stage_type=args.two_stage_type, |
| |
| two_stage_bbox_embed_share=args.two_stage_bbox_embed_share, |
| two_stage_class_embed_share=args.two_stage_class_embed_share, |
| decoder_sa_type=args.decoder_sa_type, |
| num_patterns=args.num_patterns, |
| dn_number=args.dn_number if args.use_dn else 0, |
| dn_box_noise_scale=args.dn_box_noise_scale, |
| dn_label_noise_ratio=args.dn_label_noise_ratio, |
| dn_labelbook_size=dn_labelbook_size, |
| use_label_enc=args.use_label_enc, |
|
|
| text_encoder_type=args.text_encoder_type, |
|
|
| binary_query_selection=binary_query_selection, |
| use_cdn=use_cdn, |
| sub_sentence_present=sub_sentence_present |
| ) |
|
|
| return model |
|
|
|
|
| class ContrastiveAssign(nn.Module): |
| def __init__(self, project=False, cal_bias=None, max_text_len=256): |
| """ |
| :param x: query |
| :param y: text embed |
| :param proj: |
| :return: |
| """ |
| super().__init__() |
| self.project = project |
| self.cal_bias = cal_bias |
| self.max_text_len = max_text_len |
|
|
| def forward(self, x, text_dict): |
| """_summary_ |
| |
| Args: |
| x (_type_): _description_ |
| text_dict (_type_): _description_ |
| { |
| 'encoded_text': encoded_text, # bs, 195, d_model |
| 'text_token_mask': text_token_mask, # bs, 195 |
| # True for used tokens. False for padding tokens |
| } |
| Returns: |
| _type_: _description_ |
| """ |
| assert isinstance(text_dict, dict) |
|
|
| y = text_dict['encoded_text'] |
|
|
|
|
| max_text_len = y.shape[1] |
|
|
|
|
|
|
| text_token_mask = text_dict['text_token_mask'] |
|
|
| if self.cal_bias is not None: |
| raise NotImplementedError |
| return x @ y.transpose(-1, -2) + self.cal_bias.weight.repeat(x.shape[0], x.shape[1], 1) |
| res = x @ y.transpose(-1, -2) |
| res.masked_fill_(~text_token_mask[:, None, :], float('-inf')) |
|
|
| |
| new_res = torch.full((*res.shape[:-1], max_text_len), float('-inf'), device=res.device) |
| new_res[..., :res.shape[-1]] = res |
|
|
| return new_res |
|
|