Spaces:
Running
Running
| """ | |
| Add additional grasp decoder for Segment Anything model. | |
| The structure should follow the grasp decoder structure in GraspDETR. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from transformers.models.detr.configuration_detr import DetrConfig | |
| from transformers.models.detr.modeling_detr import DetrHungarianMatcher, DetrLoss, DetrSegmentationOutput, DetrDecoder, sigmoid_focal_loss, dice_loss | |
| from typing import Any, Dict, List, Tuple | |
| from transformers.models.detr.modeling_detr import generalized_box_iou | |
| from transformers.image_transforms import center_to_corners_format | |
| from scipy.optimize import linear_sum_assignment | |
| def modify_matcher_forward(self): | |
| def matcher_forward(outputs, targets): | |
| batch_size, num_queries = outputs["logits"].shape[:2] | |
| # We flatten to compute the cost matrices in a batch | |
| out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] | |
| out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] | |
| # Also concat the target labels and boxes | |
| target_ids = torch.cat([v["class_labels"] for v in targets]) | |
| target_bbox = torch.cat([v["boxes"] for v in targets]) | |
| # Compute the classification cost. Contrary to the loss, we don't use the NLL, | |
| # but approximate it in 1 - proba[target class]. | |
| # The 1 is a constant that doesn't change the matching, it can be ommitted. | |
| class_cost = -out_prob[:, target_ids] | |
| # Compute the L1 cost between boxes | |
| bbox_cost = torch.cdist(out_bbox, target_bbox, p=1) | |
| # Compute the giou cost between boxes | |
| giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox[:, :4]), center_to_corners_format(target_bbox[:, :4])) | |
| # Final cost matrix | |
| cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost | |
| cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu() | |
| sizes = [len(v["boxes"]) for v in targets] | |
| indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))] | |
| return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] | |
| return matcher_forward | |
| def modify_grasp_loss_forward(self): | |
| def modified_loss_labels(outputs, targets, indices, num_boxes): | |
| """ | |
| Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim | |
| [nb_target_boxes] | |
| """ | |
| num_classes = 1 # model v9 always use class agnostic grasp | |
| if "logits" not in outputs: | |
| raise KeyError("No logits were found in the outputs") | |
| source_logits = outputs["logits"] | |
| idx = self._get_source_permutation_idx(indices) | |
| target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)]) | |
| target_classes = torch.full( | |
| source_logits.shape[:2], num_classes, dtype=torch.int64, device=source_logits.device | |
| ) | |
| target_classes[idx] = target_classes_o | |
| loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes) | |
| losses = {"loss_ce": loss_ce} | |
| return losses | |
| def modified_loss_boxes(outputs, targets, indices, num_boxes, ignore_wh=False): | |
| if "pred_boxes" not in outputs: | |
| raise KeyError("No predicted boxes found in outputs") | |
| idx = self._get_source_permutation_idx(indices) | |
| source_boxes = outputs["pred_boxes"][idx] | |
| target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
| if not ignore_wh: | |
| loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none") | |
| else: | |
| source_xytheta = source_boxes[:, [0, 1, 4]] | |
| target_xytheta = target_boxes[:, [0, 1, 4]] | |
| loss_bbox = nn.functional.l1_loss(source_xytheta, target_xytheta, reduction="none") * 5 / 3 | |
| losses = {} | |
| losses["loss_bbox"] = loss_bbox.sum() / num_boxes | |
| if not ignore_wh: | |
| loss_giou = 1 - torch.diag( | |
| generalized_box_iou(center_to_corners_format(source_boxes[:, :4]), center_to_corners_format(target_boxes[:, :4])) | |
| ) | |
| else: | |
| source_boxes[:, -2:] = target_boxes[:, -2:].clone() | |
| source_corners = center_to_corners_format(source_boxes[:, :4]) | |
| target_corners = center_to_corners_format(target_boxes[:, :4]) | |
| loss_giou = 1 - torch.diag(generalized_box_iou(source_corners, target_corners)) | |
| losses["loss_giou"] = loss_giou.sum() / num_boxes | |
| return losses | |
| def modified_forward(outputs, targets, ignore_wh=False): | |
| """ | |
| This performs the loss computation. | |
| Args: | |
| outputs (`dict`, *optional*): | |
| Dictionary of tensors, see the output specification of the model for the format. | |
| targets (`List[dict]`, *optional*): | |
| List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the | |
| losses applied, see each loss' doc. | |
| """ | |
| outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"} | |
| # Retrieve the matching between the outputs of the last layer and the targets | |
| indices = self.matcher(outputs_without_aux, targets) | |
| # Compute the average number of target boxes across all nodes, for normalization purposes | |
| num_boxes = sum(len(t["class_labels"]) for t in targets) | |
| num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) | |
| # (Niels): comment out function below, distributed training to be added | |
| # if is_dist_avail_and_initialized(): | |
| # torch.distributed.all_reduce(num_boxes) | |
| # (Niels) in original implementation, num_boxes is divided by get_world_size() | |
| num_boxes = torch.clamp(num_boxes, min=1).item() | |
| # Compute all the requested losses | |
| losses = {} | |
| losses.update(self.loss_labels(outputs, targets, indices, num_boxes)) | |
| losses.update(self.loss_boxes(outputs, targets, indices, num_boxes, ignore_wh)) | |
| return losses | |
| return modified_loss_labels, modified_loss_boxes, modified_forward | |
| def modify_forward(self): | |
| """ | |
| Modify the following methods to make SAM perform grasp detection after segmentation: | |
| 1. Add a parallel decoder for grasping detection: 1(+1) classes, 5 values to regress (bbox & rotation) | |
| Returns: | |
| Modified model | |
| """ | |
| # 1. We instantiate a new module in self.base_model, as another decoder | |
| self.grasp_decoder_config = DetrConfig() | |
| self.grasp_decoder = DetrDecoder(self.grasp_decoder_config).to(self.device) | |
| self.grasp_query_position_embeddings = nn.Embedding(20, 256).to(self.device) | |
| # 2. Base model forward method is not directly used, no modification needs to be done | |
| # self.detr.model.forward = modify_base_model_forward(self.detr.model) | |
| # 3. Add additional classification head & bbox regression head for grasp_decoder output | |
| self.grasp_predictor = torch.nn.Sequential( | |
| torch.nn.Linear(256, 256), | |
| torch.nn.Linear(256, 256), | |
| torch.nn.Linear(256, 5) | |
| ).to(self.device) | |
| self.grasp_label_classifier = torch.nn.Linear(256, 2).to(self.device) | |
| # 4. Add positional embedding | |
| # name it as grasp_img_pos_embed to avoid name conflict | |
| class ImagePosEmbed(nn.Module): | |
| def __init__(self, img_size=64, hidden_dim=256): | |
| super().__init__() | |
| self.pos_embed = nn.Parameter( | |
| torch.randn(1, img_size, img_size, hidden_dim) | |
| ) | |
| def forward(self, x): | |
| return x + self.pos_embed | |
| self.grasp_img_pos_embed = ImagePosEmbed().to(self.device) | |
| def modified_forward( | |
| batched_input: List[Dict[str, Any]], | |
| multimask_output: bool, | |
| ): | |
| input_images = torch.stack([x["image"] for x in batched_input], dim=0) | |
| image_embeddings = self.image_encoder(input_images) | |
| batch_size = len(batched_input) | |
| outputs = [] | |
| srcs = [] | |
| for image_record, curr_embedding in zip(batched_input, image_embeddings): | |
| if "point_coords" in image_record: | |
| points = (image_record["point_coords"], image_record["point_labels"]) | |
| else: | |
| points = None | |
| sparse_embeddings, dense_embeddings = self.prompt_encoder( | |
| points=points, | |
| boxes=image_record.get("boxes", None), | |
| masks=image_record.get("mask_inputs", None), | |
| ) | |
| low_res_masks, iou_predictions, src = self.mask_decoder( | |
| image_embeddings=curr_embedding.unsqueeze(0), | |
| image_pe=self.prompt_encoder.get_dense_pe(), | |
| sparse_prompt_embeddings=sparse_embeddings, | |
| dense_prompt_embeddings=dense_embeddings, | |
| multimask_output=multimask_output, | |
| ) | |
| outputs.append( | |
| { | |
| "iou_predictions": iou_predictions, | |
| "low_res_logits": low_res_masks, | |
| } | |
| ) | |
| srcs.append(src[0]) | |
| srcs = torch.stack(srcs, dim=0) | |
| # forward grasp decoder here | |
| # 1. Get encoder hidden states | |
| grasp_encoder_hidden_states = self.grasp_img_pos_embed(srcs.permute(0, 2, 3, 1)) | |
| # 2. Get query embeddings | |
| grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device)) | |
| # repeat to batchsize | |
| grasp_query_pe = grasp_query_pe.repeat(len(batched_input), 1, 1) | |
| pixel_masks = torch.cat([batched_input[i]['pixel_mask'] for i in range(len(batched_input))], dim=0) | |
| downsampled_pixel_masks = nn.functional.interpolate(pixel_masks.unsqueeze(1).float(), size=(64, 64), mode='nearest').squeeze(1).bool() | |
| downsampled_pixel_masks = downsampled_pixel_masks.view(batch_size, 64*64).contiguous() | |
| grasp_encoder_hidden_states = grasp_encoder_hidden_states.view(batch_size, 64*64, 256).contiguous() | |
| grasp_decoder_outputs = self.grasp_decoder( | |
| inputs_embeds=torch.zeros_like(grasp_query_pe), | |
| attention_mask=None, | |
| position_embeddings=torch.zeros_like(grasp_encoder_hidden_states), | |
| query_position_embeddings=grasp_query_pe, | |
| encoder_hidden_states=grasp_encoder_hidden_states, | |
| encoder_attention_mask=downsampled_pixel_masks, | |
| output_attentions=False, | |
| output_hidden_states=False, | |
| return_dict=True, | |
| ) | |
| grasp_sequence_output = grasp_decoder_outputs[0] | |
| grasp_logits = self.grasp_label_classifier(grasp_sequence_output) | |
| pred_grasps = self.grasp_predictor(grasp_sequence_output).sigmoid() | |
| # 3. Calculate loss | |
| loss, loss_dict = 0, {} | |
| if "grasp_labels" in batched_input[0]: | |
| config = self.grasp_decoder_config | |
| grasp_labels = [{ | |
| "class_labels": torch.zeros([len(x["grasp_labels"])], dtype=torch.long).to(self.device), | |
| "boxes": x["grasp_labels"], | |
| } for x in batched_input] | |
| # First: create the matcher | |
| matcher = DetrHungarianMatcher( | |
| class_cost=config.class_cost, bbox_cost=config.bbox_cost, giou_cost=config.giou_cost | |
| ) | |
| matcher.forward = modify_matcher_forward(matcher) | |
| # Second: create the criterion | |
| losses = ["labels", "boxes"] | |
| criterion = DetrLoss( | |
| matcher=matcher, | |
| num_classes=config.num_labels, | |
| eos_coef=config.eos_coefficient, | |
| losses=losses, | |
| ) | |
| criterion.loss_labels, criterion.loss_boxes, criterion.forward = modify_grasp_loss_forward(criterion) | |
| criterion.to(self.device) | |
| # Third: compute the losses, based on outputs and labels | |
| outputs_loss = {} | |
| outputs_loss["logits"] = grasp_logits | |
| outputs_loss["pred_boxes"] = pred_grasps | |
| grasp_loss_dict = criterion(outputs_loss, grasp_labels, ignore_wh=batched_input[0].get("ignore_wh", False)) | |
| # Fourth: compute total loss, as a weighted sum of the various losses | |
| weight_dict = {"loss_ce": 1, "loss_bbox": config.bbox_loss_coefficient} | |
| weight_dict["loss_giou"] = config.giou_loss_coefficient | |
| if config.auxiliary_loss: | |
| aux_weight_dict = {} | |
| for i in range(config.decoder_layers - 1): | |
| aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) | |
| weight_dict.update(aux_weight_dict) | |
| grasp_loss = sum(grasp_loss_dict[k] * weight_dict[k] for k in grasp_loss_dict.keys() if k in weight_dict) | |
| # merge grasp branch loss into variable loss & loss_dict | |
| loss += grasp_loss | |
| loss_dict.update(grasp_loss_dict) | |
| pred_masks = self.postprocess_masks( | |
| torch.cat([x['low_res_logits'] for x in outputs], dim=0), | |
| input_size=image_record["image"].shape[-2:], | |
| original_size=(1024, 1024), | |
| ) | |
| if 'masks' in batched_input[0]: | |
| # 4. Calculate segmentation loss | |
| sf_loss = sigmoid_focal_loss(pred_masks.flatten(1), | |
| torch.stack([x['masks'] for x in batched_input], dim=0).unsqueeze(1).type(torch.float32).flatten(1), len(batched_input)) | |
| d_loss = dice_loss(pred_masks.flatten(1), | |
| torch.stack([x['masks'] for x in batched_input], dim=0).unsqueeze(1).type(torch.float32).flatten(1), len(batched_input)) | |
| loss += sf_loss + d_loss | |
| loss_dict["sf_loss"] = sf_loss | |
| loss_dict["d_loss"] = d_loss | |
| return DetrSegmentationOutput( | |
| loss=loss, | |
| loss_dict=loss_dict, | |
| logits=grasp_logits, | |
| pred_boxes=pred_grasps, | |
| pred_masks=pred_masks, | |
| ) | |
| return modified_forward | |
| def add_inference_method(self): | |
| def infer( | |
| batched_input: List[Dict[str, Any]], | |
| multimask_output: bool, | |
| ): | |
| input_images = torch.stack([x["image"] for x in batched_input], dim=0) | |
| image_embeddings = self.image_encoder(input_images) | |
| outputs = [] | |
| srcs = [] | |
| curr_embedding = image_embeddings[0] | |
| image_record = batched_input[0] | |
| if "point_coords" in image_record: | |
| points = (image_record["point_coords"], image_record["point_labels"]) | |
| else: | |
| points = None | |
| sparse_embeddings, dense_embeddings = self.prompt_encoder( | |
| points=points, | |
| boxes=image_record.get("boxes", None), | |
| masks=image_record.get("mask_inputs", None), | |
| ) | |
| low_res_masks, iou_predictions, src = self.mask_decoder( | |
| image_embeddings=curr_embedding.unsqueeze(0), | |
| image_pe=self.prompt_encoder.get_dense_pe(), | |
| sparse_prompt_embeddings=sparse_embeddings, | |
| dense_prompt_embeddings=dense_embeddings, | |
| multimask_output=multimask_output, | |
| ) | |
| outputs.append( | |
| { | |
| "iou_predictions": iou_predictions, | |
| "low_res_logits": low_res_masks, | |
| } | |
| ) | |
| srcs.append(src[0]) | |
| n_queries = iou_predictions.size(0) | |
| batch_size = n_queries | |
| # forward grasp decoder here | |
| # 1. Get encoder hidden states | |
| grasp_encoder_hidden_states = self.grasp_img_pos_embed(src.permute(0, 2, 3, 1)) | |
| # 2. Get query embeddings | |
| grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device)) | |
| # repeat to batchsize | |
| grasp_query_pe = grasp_query_pe.repeat(n_queries, 1, 1) | |
| pixel_masks = torch.cat([batched_input[i]['pixel_mask'] for i in range(len(batched_input))], dim=0) | |
| pixel_masks = pixel_masks.repeat(n_queries, 1, 1) | |
| downsampled_pixel_masks = nn.functional.interpolate(pixel_masks.unsqueeze(1).float(), size=(64, 64), | |
| mode='nearest').squeeze(1).bool() | |
| downsampled_pixel_masks = downsampled_pixel_masks.view(batch_size, 64 * 64).contiguous() | |
| grasp_encoder_hidden_states = grasp_encoder_hidden_states.view(batch_size, 64 * 64, 256).contiguous() | |
| grasp_decoder_outputs = self.grasp_decoder( | |
| inputs_embeds=torch.zeros_like(grasp_query_pe), | |
| attention_mask=None, | |
| position_embeddings=torch.zeros_like(grasp_encoder_hidden_states), | |
| query_position_embeddings=grasp_query_pe, | |
| encoder_hidden_states=grasp_encoder_hidden_states, | |
| encoder_attention_mask=downsampled_pixel_masks, | |
| output_attentions=False, | |
| output_hidden_states=False, | |
| return_dict=True, | |
| ) | |
| grasp_sequence_output = grasp_decoder_outputs[0] | |
| grasp_logits = self.grasp_label_classifier(grasp_sequence_output) | |
| pred_grasps = self.grasp_predictor(grasp_sequence_output).sigmoid() | |
| pred_masks = self.postprocess_masks( | |
| torch.cat([x['low_res_logits'] for x in outputs], dim=0), | |
| input_size=image_record["image"].shape[-2:], | |
| original_size=(1024, 1024), | |
| ) | |
| return DetrSegmentationOutput( | |
| loss=0, | |
| loss_dict={}, | |
| logits=grasp_logits, | |
| pred_boxes=pred_grasps, | |
| pred_masks=pred_masks, | |
| ) | |
| return infer |