| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn |
| |
|
| | def refer_ce_loss( |
| | inputs: torch.Tensor, |
| | targets: torch.Tensor, |
| | weight: torch.Tensor): |
| |
|
| | loss = F.cross_entropy(inputs, targets, weight=weight) |
| |
|
| | return loss |
| |
|
| |
|
| | refer_ce_loss_jit = torch.jit.script( |
| | refer_ce_loss |
| | ) |
| |
|
| | class ReferringCriterion(nn.Module): |
| | def __init__(self, weight_dict, losses): |
| | super().__init__() |
| | self.weight_dict = weight_dict |
| | self.losses = losses |
| |
|
| | def get_loss(self, loss, outputs, targets): |
| | loss_map = { |
| | 'masks': self.loss_masks_refer, |
| | } |
| | assert loss in loss_map, f"do you really want to compute {loss} loss?" |
| | return loss_map[loss](outputs, targets) |
| |
|
| | def loss_masks_refer(self, outputs, targets): |
| | src_masks = outputs["pred_masks"] |
| | src_minimap = outputs["pred_logits"].permute(0,2,1) |
| | src_nt_label = outputs["nt_label"] |
| |
|
| | masks = [t["gt_mask_merged"] for t in targets] |
| | target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() |
| | target_masks = target_masks.to(src_masks) |
| |
|
| | target_nts = torch.stack([t["empty"] for t in targets]) |
| |
|
| | h, w = target_masks.shape[-2:] |
| | src_masks = F.interpolate(src_masks, (h, w), mode='bilinear', align_corners=False) |
| |
|
| | target_minimap = F.interpolate(target_masks, (10, 10), mode='bilinear', align_corners=False).flatten(start_dim=1) |
| |
|
| | weight = torch.FloatTensor([0.9, 1.1]).to(src_masks) |
| |
|
| | loss_mask = \ |
| | refer_ce_loss_jit(src_masks, target_masks.squeeze(1).long(), weight) + \ |
| | refer_ce_loss_jit(src_minimap, target_minimap.squeeze(1).long(), weight) * 0.1 + \ |
| | refer_ce_loss_jit(src_nt_label, target_nts, weight) * 0.1 |
| |
|
| | losses = { |
| | "loss_mask": loss_mask, |
| | } |
| |
|
| | del src_masks |
| | del target_masks |
| | return losses |
| |
|
| | def forward(self, outputs, targets): |
| | |
| | losses = {} |
| | losses.update(self.loss_masks_refer(outputs, targets)) |
| |
|
| | return losses |
| |
|
| |
|
| | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): |
| | |
| | if tensor_list[0].ndim == 3: |
| | if torchvision._is_tracing(): |
| | |
| | |
| | return _onnx_nested_tensor_from_tensor_list(tensor_list) |
| |
|
| | |
| | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) |
| | |
| | batch_shape = [len(tensor_list)] + max_size |
| | b, c, h, w = batch_shape |
| | dtype = tensor_list[0].dtype |
| | device = tensor_list[0].device |
| | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) |
| | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) |
| | for img, pad_img, m in zip(tensor_list, tensor, mask): |
| | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) |
| | m[: img.shape[1], : img.shape[2]] = False |
| | else: |
| | raise ValueError("not supported") |
| | return NestedTensor(tensor, mask) |