| import numpy as np | |
| from transformers.models.deformable_detr.modeling_deformable_detr import DeformableDetrMLPPredictionHead | |
| import torch.nn as nn | |
| import torch | |
| def PairDetr(model, num_queries, num_classes): | |
| in_features = model.class_embed[0].in_features | |
| model.model.query_position_embeddings = nn.Embedding(num_queries, 512) | |
| class_embed = nn.Linear(in_features, num_classes) | |
| bbox_embed = DeformableDetrMLPPredictionHead( | |
| input_dim=256, hidden_dim=256, output_dim=8, num_layers=3 | |
| ) | |
| model.class_embed = nn.ModuleList([class_embed for _ in range(6)]) | |
| model.bbox_embed = nn.ModuleList([bbox_embed for _ in range(6)]) | |
| return model | |
| def inverse_sigmoid(x, eps=1e-5): | |
| x = x.clamp(min=0, max=1) | |
| x1 = x.clamp(min=eps) | |
| x2 = (1 - x).clamp(min=eps) | |
| return torch.log(x1 / x2) | |
| def forward(model, | |
| pixel_values, | |
| pixel_mask=None, | |
| decoder_attention_mask=None, | |
| encoder_outputs=None, | |
| inputs_embeds=None, | |
| decoder_inputs_embeds=None, | |
| labels=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None,) -> torch.Tensor: | |
| return_dict = return_dict if return_dict is not None else model.config.use_return_dict | |
| outputs = model.model( | |
| pixel_values, | |
| pixel_mask=pixel_mask, | |
| decoder_attention_mask=decoder_attention_mask, | |
| encoder_outputs=encoder_outputs, | |
| inputs_embeds=inputs_embeds, | |
| decoder_inputs_embeds=decoder_inputs_embeds, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| hidden_states = outputs.intermediate_hidden_states if return_dict else outputs[2] | |
| init_reference = outputs.init_reference_points if return_dict else outputs[0] | |
| inter_references = outputs.intermediate_reference_points if return_dict else outputs[3] | |
| outputs_classes = [] | |
| outputs_coords = [] | |
| cons = inverse_sigmoid(init_reference) | |
| for level in range(hidden_states.shape[1]): | |
| if level == 0: | |
| reference = init_reference | |
| else: | |
| reference = inter_references[:, level - 1] | |
| reference = inverse_sigmoid(reference) | |
| outputs_class = model.class_embed[level](hidden_states[:, level]) | |
| delta_bbox = model.bbox_embed[level](hidden_states[:, level]) | |
| if reference.shape[-1] == 4: | |
| delta_bbox[..., :4] += reference | |
| outputs_coord_logits = delta_bbox | |
| elif reference.shape[-1] == 2: | |
| delta_bbox[..., :2] += reference | |
| delta_bbox[..., 4:6] += cons | |
| outputs_coord_logits = delta_bbox | |
| else: | |
| raise ValueError(f"reference.shape[-1] should be 4 or 2, but got {reference.shape[-1]}") | |
| outputs_coord = outputs_coord_logits.sigmoid() | |
| outputs_classes.append(outputs_class) | |
| outputs_coords.append(outputs_coord) | |
| outputs_class = torch.stack(outputs_classes, dim=1) | |
| outputs_coord = torch.stack(outputs_coords, dim=1) | |
| logits = outputs_class[:, -1] | |
| pred_boxes = outputs_coord[:, -1] | |
| dict_outputs = { | |
| "logits":logits, | |
| "pred_boxes": pred_boxes, | |
| "init_reference_points": outputs.init_reference_points, | |
| } | |
| return dict_outputs | |