Spaces:
Sleeping
Sleeping
Update models/grasp_mods.py
Browse files- models/grasp_mods.py +64 -13
models/grasp_mods.py
CHANGED
|
@@ -68,25 +68,65 @@ def modify_grasp_loss_forward(self):
|
|
| 68 |
|
| 69 |
return losses
|
| 70 |
|
| 71 |
-
def modified_loss_boxes(outputs, targets, indices, num_boxes):
|
| 72 |
|
| 73 |
if "pred_boxes" not in outputs:
|
| 74 |
raise KeyError("No predicted boxes found in outputs")
|
| 75 |
idx = self._get_source_permutation_idx(indices)
|
| 76 |
source_boxes = outputs["pred_boxes"][idx]
|
| 77 |
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
losses = {}
|
| 82 |
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
| 88 |
return losses
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
def modify_forward(self):
|
| 92 |
"""
|
|
@@ -127,7 +167,7 @@ def modify_forward(self):
|
|
| 127 |
):
|
| 128 |
input_images = torch.stack([x["image"] for x in batched_input], dim=0)
|
| 129 |
image_embeddings = self.image_encoder(input_images)
|
| 130 |
-
|
| 131 |
outputs = []
|
| 132 |
srcs = []
|
| 133 |
for image_record, curr_embedding in zip(batched_input, image_embeddings):
|
|
@@ -162,13 +202,17 @@ def modify_forward(self):
|
|
| 162 |
grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device))
|
| 163 |
# repeat to batchsize
|
| 164 |
grasp_query_pe = grasp_query_pe.repeat(len(batched_input), 1, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
grasp_decoder_outputs = self.grasp_decoder(
|
| 166 |
inputs_embeds=torch.zeros_like(grasp_query_pe),
|
| 167 |
attention_mask=None,
|
| 168 |
position_embeddings=torch.zeros_like(grasp_encoder_hidden_states),
|
| 169 |
query_position_embeddings=grasp_query_pe,
|
| 170 |
encoder_hidden_states=grasp_encoder_hidden_states,
|
| 171 |
-
encoder_attention_mask=
|
| 172 |
output_attentions=False,
|
| 173 |
output_hidden_states=False,
|
| 174 |
return_dict=True,
|
|
@@ -198,14 +242,14 @@ def modify_forward(self):
|
|
| 198 |
eos_coef=config.eos_coefficient,
|
| 199 |
losses=losses,
|
| 200 |
)
|
| 201 |
-
criterion.loss_labels, criterion.loss_boxes = modify_grasp_loss_forward(criterion)
|
| 202 |
criterion.to(self.device)
|
| 203 |
# Third: compute the losses, based on outputs and labels
|
| 204 |
outputs_loss = {}
|
| 205 |
outputs_loss["logits"] = grasp_logits
|
| 206 |
outputs_loss["pred_boxes"] = pred_grasps
|
| 207 |
|
| 208 |
-
grasp_loss_dict = criterion(outputs_loss, grasp_labels)
|
| 209 |
# Fourth: compute total loss, as a weighted sum of the various losses
|
| 210 |
weight_dict = {"loss_ce": 1, "loss_bbox": config.bbox_loss_coefficient}
|
| 211 |
weight_dict["loss_giou"] = config.giou_loss_coefficient
|
|
@@ -282,6 +326,8 @@ def add_inference_method(self):
|
|
| 282 |
|
| 283 |
n_queries = iou_predictions.size(0)
|
| 284 |
|
|
|
|
|
|
|
| 285 |
# forward grasp decoder here
|
| 286 |
# 1. Get encoder hidden states
|
| 287 |
grasp_encoder_hidden_states = self.grasp_img_pos_embed(src.permute(0, 2, 3, 1))
|
|
@@ -289,13 +335,18 @@ def add_inference_method(self):
|
|
| 289 |
grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device))
|
| 290 |
# repeat to batchsize
|
| 291 |
grasp_query_pe = grasp_query_pe.repeat(n_queries, 1, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
grasp_decoder_outputs = self.grasp_decoder(
|
| 293 |
inputs_embeds=torch.zeros_like(grasp_query_pe),
|
| 294 |
attention_mask=None,
|
| 295 |
position_embeddings=torch.zeros_like(grasp_encoder_hidden_states),
|
| 296 |
query_position_embeddings=grasp_query_pe,
|
| 297 |
encoder_hidden_states=grasp_encoder_hidden_states,
|
| 298 |
-
encoder_attention_mask=
|
| 299 |
output_attentions=False,
|
| 300 |
output_hidden_states=False,
|
| 301 |
return_dict=True,
|
|
|
|
| 68 |
|
| 69 |
return losses
|
| 70 |
|
| 71 |
+
def modified_loss_boxes(outputs, targets, indices, num_boxes, ignore_wh=False):
|
| 72 |
|
| 73 |
if "pred_boxes" not in outputs:
|
| 74 |
raise KeyError("No predicted boxes found in outputs")
|
| 75 |
idx = self._get_source_permutation_idx(indices)
|
| 76 |
source_boxes = outputs["pred_boxes"][idx]
|
| 77 |
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
| 78 |
+
if not ignore_wh:
|
| 79 |
+
loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
|
| 80 |
+
else:
|
| 81 |
+
source_xytheta = source_boxes[:, [0, 1, 4]]
|
| 82 |
+
target_xytheta = target_boxes[:, [0, 1, 4]]
|
| 83 |
+
loss_bbox = nn.functional.l1_loss(source_xytheta, target_xytheta, reduction="none") * 5 / 3
|
| 84 |
|
| 85 |
losses = {}
|
| 86 |
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
| 87 |
+
if not ignore_wh:
|
| 88 |
+
loss_giou = 1 - torch.diag(
|
| 89 |
+
generalized_box_iou(center_to_corners_format(source_boxes[:, :4]), center_to_corners_format(target_boxes[:, :4]))
|
| 90 |
+
)
|
| 91 |
+
else:
|
| 92 |
+
source_boxes[:, -2:] = target_boxes[:, -2:].clone()
|
| 93 |
+
source_corners = center_to_corners_format(source_boxes[:, :4])
|
| 94 |
+
target_corners = center_to_corners_format(target_boxes[:, :4])
|
| 95 |
+
loss_giou = 1 - torch.diag(generalized_box_iou(source_corners, target_corners))
|
| 96 |
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
| 97 |
return losses
|
| 98 |
+
def modified_forward(outputs, targets, ignore_wh=False):
|
| 99 |
+
"""
|
| 100 |
+
This performs the loss computation.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
outputs (`dict`, *optional*):
|
| 104 |
+
Dictionary of tensors, see the output specification of the model for the format.
|
| 105 |
+
targets (`List[dict]`, *optional*):
|
| 106 |
+
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
|
| 107 |
+
losses applied, see each loss' doc.
|
| 108 |
+
"""
|
| 109 |
+
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
|
| 110 |
+
|
| 111 |
+
# Retrieve the matching between the outputs of the last layer and the targets
|
| 112 |
+
indices = self.matcher(outputs_without_aux, targets)
|
| 113 |
+
|
| 114 |
+
# Compute the average number of target boxes across all nodes, for normalization purposes
|
| 115 |
+
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
| 116 |
+
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
| 117 |
+
# (Niels): comment out function below, distributed training to be added
|
| 118 |
+
# if is_dist_avail_and_initialized():
|
| 119 |
+
# torch.distributed.all_reduce(num_boxes)
|
| 120 |
+
# (Niels) in original implementation, num_boxes is divided by get_world_size()
|
| 121 |
+
num_boxes = torch.clamp(num_boxes, min=1).item()
|
| 122 |
+
|
| 123 |
+
# Compute all the requested losses
|
| 124 |
+
losses = {}
|
| 125 |
+
losses.update(self.loss_labels(outputs, targets, indices, num_boxes))
|
| 126 |
+
losses.update(self.loss_boxes(outputs, targets, indices, num_boxes, ignore_wh))
|
| 127 |
+
|
| 128 |
+
return losses
|
| 129 |
+
return modified_loss_labels, modified_loss_boxes, modified_forward
|
| 130 |
|
| 131 |
def modify_forward(self):
|
| 132 |
"""
|
|
|
|
| 167 |
):
|
| 168 |
input_images = torch.stack([x["image"] for x in batched_input], dim=0)
|
| 169 |
image_embeddings = self.image_encoder(input_images)
|
| 170 |
+
batch_size = len(batched_input)
|
| 171 |
outputs = []
|
| 172 |
srcs = []
|
| 173 |
for image_record, curr_embedding in zip(batched_input, image_embeddings):
|
|
|
|
| 202 |
grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device))
|
| 203 |
# repeat to batchsize
|
| 204 |
grasp_query_pe = grasp_query_pe.repeat(len(batched_input), 1, 1)
|
| 205 |
+
pixel_masks = torch.cat([batched_input[i]['pixel_mask'] for i in range(len(batched_input))], dim=0)
|
| 206 |
+
downsampled_pixel_masks = nn.functional.interpolate(pixel_masks.unsqueeze(1).float(), size=(64, 64), mode='nearest').squeeze(1).bool()
|
| 207 |
+
downsampled_pixel_masks = downsampled_pixel_masks.view(batch_size, 64*64).contiguous()
|
| 208 |
+
grasp_encoder_hidden_states = grasp_encoder_hidden_states.view(batch_size, 64*64, 256).contiguous()
|
| 209 |
grasp_decoder_outputs = self.grasp_decoder(
|
| 210 |
inputs_embeds=torch.zeros_like(grasp_query_pe),
|
| 211 |
attention_mask=None,
|
| 212 |
position_embeddings=torch.zeros_like(grasp_encoder_hidden_states),
|
| 213 |
query_position_embeddings=grasp_query_pe,
|
| 214 |
encoder_hidden_states=grasp_encoder_hidden_states,
|
| 215 |
+
encoder_attention_mask=downsampled_pixel_masks,
|
| 216 |
output_attentions=False,
|
| 217 |
output_hidden_states=False,
|
| 218 |
return_dict=True,
|
|
|
|
| 242 |
eos_coef=config.eos_coefficient,
|
| 243 |
losses=losses,
|
| 244 |
)
|
| 245 |
+
criterion.loss_labels, criterion.loss_boxes, criterion.forward = modify_grasp_loss_forward(criterion)
|
| 246 |
criterion.to(self.device)
|
| 247 |
# Third: compute the losses, based on outputs and labels
|
| 248 |
outputs_loss = {}
|
| 249 |
outputs_loss["logits"] = grasp_logits
|
| 250 |
outputs_loss["pred_boxes"] = pred_grasps
|
| 251 |
|
| 252 |
+
grasp_loss_dict = criterion(outputs_loss, grasp_labels, ignore_wh=batched_input[0].get("ignore_wh", False))
|
| 253 |
# Fourth: compute total loss, as a weighted sum of the various losses
|
| 254 |
weight_dict = {"loss_ce": 1, "loss_bbox": config.bbox_loss_coefficient}
|
| 255 |
weight_dict["loss_giou"] = config.giou_loss_coefficient
|
|
|
|
| 326 |
|
| 327 |
n_queries = iou_predictions.size(0)
|
| 328 |
|
| 329 |
+
batch_size = n_queries
|
| 330 |
+
|
| 331 |
# forward grasp decoder here
|
| 332 |
# 1. Get encoder hidden states
|
| 333 |
grasp_encoder_hidden_states = self.grasp_img_pos_embed(src.permute(0, 2, 3, 1))
|
|
|
|
| 335 |
grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device))
|
| 336 |
# repeat to batchsize
|
| 337 |
grasp_query_pe = grasp_query_pe.repeat(n_queries, 1, 1)
|
| 338 |
+
pixel_masks = torch.cat([batched_input[i]['pixel_mask'] for i in range(len(batched_input))], dim=0)
|
| 339 |
+
downsampled_pixel_masks = nn.functional.interpolate(pixel_masks.unsqueeze(1).float(), size=(64, 64),
|
| 340 |
+
mode='nearest').squeeze(1).bool()
|
| 341 |
+
downsampled_pixel_masks = downsampled_pixel_masks.view(batch_size, 64 * 64).contiguous()
|
| 342 |
+
grasp_encoder_hidden_states = grasp_encoder_hidden_states.view(batch_size, 64 * 64, 256).contiguous()
|
| 343 |
grasp_decoder_outputs = self.grasp_decoder(
|
| 344 |
inputs_embeds=torch.zeros_like(grasp_query_pe),
|
| 345 |
attention_mask=None,
|
| 346 |
position_embeddings=torch.zeros_like(grasp_encoder_hidden_states),
|
| 347 |
query_position_embeddings=grasp_query_pe,
|
| 348 |
encoder_hidden_states=grasp_encoder_hidden_states,
|
| 349 |
+
encoder_attention_mask=downsampled_pixel_masks,
|
| 350 |
output_attentions=False,
|
| 351 |
output_hidden_states=False,
|
| 352 |
return_dict=True,
|