|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import torch.distributed as dist
|
|
|
|
|
|
from torch import Tensor
|
|
|
import torchvision
|
|
|
import torch.distributed as dist
|
|
|
from typing import List, Optional
|
|
|
|
|
|
|
|
|
def _max_by_axis(the_list):
|
|
|
|
|
|
maxes = the_list[0]
|
|
|
for sublist in the_list[1:]:
|
|
|
for index, item in enumerate(sublist):
|
|
|
maxes[index] = max(maxes[index], item)
|
|
|
return maxes
|
|
|
|
|
|
|
|
|
class NestedTensor(object):
|
|
|
def __init__(self, tensors, mask: Optional[Tensor]):
|
|
|
self.tensors = tensors
|
|
|
self.mask = mask
|
|
|
|
|
|
def to(self, device):
|
|
|
|
|
|
cast_tensor = self.tensors.to(device)
|
|
|
mask = self.mask
|
|
|
if mask is not None:
|
|
|
assert mask is not None
|
|
|
cast_mask = mask.to(device)
|
|
|
else:
|
|
|
cast_mask = None
|
|
|
return NestedTensor(cast_tensor, cast_mask)
|
|
|
|
|
|
def decompose(self):
|
|
|
return self.tensors, self.mask
|
|
|
|
|
|
def __repr__(self):
|
|
|
return str(self.tensors)
|
|
|
|
|
|
|
|
|
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
|
|
|
|
|
if tensor_list[0].ndim == 3:
|
|
|
if torchvision._is_tracing():
|
|
|
|
|
|
|
|
|
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
|
|
|
|
|
|
|
|
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
|
|
|
|
|
batch_shape = [len(tensor_list)] + max_size
|
|
|
b, c, h, w = batch_shape
|
|
|
dtype = tensor_list[0].dtype
|
|
|
device = tensor_list[0].device
|
|
|
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
|
|
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
|
|
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
|
|
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
|
|
m[: img.shape[1], : img.shape[2]] = False
|
|
|
else:
|
|
|
raise ValueError("not supported")
|
|
|
return NestedTensor(tensor, mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.unused
|
|
|
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
|
|
max_size = []
|
|
|
for i in range(tensor_list[0].dim()):
|
|
|
max_size_i = torch.max(
|
|
|
torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
|
|
|
).to(torch.int64)
|
|
|
max_size.append(max_size_i)
|
|
|
max_size = tuple(max_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
padded_imgs = []
|
|
|
padded_masks = []
|
|
|
for img in tensor_list:
|
|
|
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
|
|
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
|
|
padded_imgs.append(padded_img)
|
|
|
|
|
|
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
|
|
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
|
|
padded_masks.append(padded_mask.to(torch.bool))
|
|
|
|
|
|
tensor = torch.stack(padded_imgs)
|
|
|
mask = torch.stack(padded_masks)
|
|
|
|
|
|
return NestedTensor(tensor, mask=mask)
|
|
|
|
|
|
|
|
|
def is_dist_avail_and_initialized():
|
|
|
if not dist.is_available():
|
|
|
return False
|
|
|
if not dist.is_initialized():
|
|
|
return False
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_world_size() -> int:
|
|
|
if not dist.is_available():
|
|
|
return 1
|
|
|
if not dist.is_initialized():
|
|
|
return 1
|
|
|
return dist.get_world_size()
|
|
|
|
|
|
|
|
|
def dice_loss(inputs, targets, num_masks):
|
|
|
"""
|
|
|
Compute the DICE loss, similar to generalized IOU for masks
|
|
|
Args:
|
|
|
inputs: A float tensor of arbitrary shape.
|
|
|
The predictions for each example.
|
|
|
targets: A float tensor with the same shape as inputs. Stores the binary
|
|
|
classification label for each element in inputs
|
|
|
(0 for the negative class and 1 for the positive class).
|
|
|
"""
|
|
|
inputs = inputs.sigmoid()
|
|
|
inputs = inputs.flatten(1)
|
|
|
numerator = 2 * (inputs * targets).sum(-1)
|
|
|
denominator = inputs.sum(-1) + targets.sum(-1)
|
|
|
loss = 1 - (numerator + 1) / (denominator + 1)
|
|
|
return loss.sum() / num_masks
|
|
|
|
|
|
|
|
|
def sigmoid_focal_loss(inputs, targets, num_masks, alpha: float = 0.25, gamma: float = 2):
|
|
|
"""
|
|
|
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
|
|
Args:
|
|
|
inputs: A float tensor of arbitrary shape.
|
|
|
The predictions for each example.
|
|
|
targets: A float tensor with the same shape as inputs. Stores the binary
|
|
|
classification label for each element in inputs
|
|
|
(0 for the negative class and 1 for the positive class).
|
|
|
alpha: (optional) Weighting factor in range (0,1) to balance
|
|
|
positive vs negative examples. Default = -1 (no weighting).
|
|
|
gamma: Exponent of the modulating factor (1 - p_t) to
|
|
|
balance easy vs hard examples.
|
|
|
Returns:
|
|
|
Loss tensor
|
|
|
"""
|
|
|
prob = inputs.sigmoid()
|
|
|
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
|
|
p_t = prob * targets + (1 - prob) * (1 - targets)
|
|
|
loss = ce_loss * ((1 - p_t) ** gamma)
|
|
|
|
|
|
if alpha >= 0:
|
|
|
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
|
|
loss = alpha_t * loss
|
|
|
|
|
|
return loss.mean(1).sum() / num_masks
|
|
|
|
|
|
|
|
|
class SetCriterion(nn.Module):
|
|
|
"""This class computes the loss for DETR.
|
|
|
The process happens in two steps:
|
|
|
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
|
|
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
|
|
"""
|
|
|
|
|
|
def __init__(self, num_classes, weight_dict, losses, eos_coef=0.1):
|
|
|
"""Create the criterion.
|
|
|
Parameters:
|
|
|
num_classes: number of object categories, omitting the special no-object category
|
|
|
matcher: module able to compute a matching between targets and proposals
|
|
|
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
|
|
eos_coef: relative classification weight applied to the no-object category
|
|
|
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
|
|
"""
|
|
|
super().__init__()
|
|
|
self.num_classes = num_classes
|
|
|
self.weight_dict = weight_dict
|
|
|
self.eos_coef = eos_coef
|
|
|
self.losses = losses
|
|
|
empty_weight = torch.ones(self.num_classes + 1)
|
|
|
empty_weight[-1] = self.eos_coef
|
|
|
self.register_buffer("empty_weight", empty_weight)
|
|
|
self.empty_weight = self.empty_weight.to("cuda")
|
|
|
|
|
|
def loss_labels(self, outputs, targets, indices, num_masks):
|
|
|
"""Classification loss (NLL)
|
|
|
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
|
|
|
"""
|
|
|
assert "pred_logits" in outputs
|
|
|
src_logits = outputs["pred_logits"]
|
|
|
|
|
|
idx = self._get_src_permutation_idx(indices)
|
|
|
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
|
|
|
target_classes = torch.full(
|
|
|
src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device
|
|
|
)
|
|
|
target_classes[idx] = target_classes_o
|
|
|
|
|
|
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
|
|
|
losses = {"loss_ce": loss_ce}
|
|
|
return losses
|
|
|
|
|
|
def loss_masks(self, outputs, targets, indices, num_masks):
|
|
|
"""Compute the losses related to the masks: the focal loss and the dice loss.
|
|
|
targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
|
|
|
"""
|
|
|
assert "pred_masks" in outputs
|
|
|
|
|
|
src_idx = self._get_src_permutation_idx(indices)
|
|
|
tgt_idx = self._get_tgt_permutation_idx(indices)
|
|
|
src_masks = outputs["pred_masks"]
|
|
|
if src_masks.dim() != 4:
|
|
|
return {"no_loss": 0}
|
|
|
src_masks = src_masks[src_idx]
|
|
|
masks = [t["masks"] for t in targets]
|
|
|
|
|
|
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
|
|
|
target_masks = target_masks.to(src_masks)
|
|
|
target_masks = target_masks[tgt_idx]
|
|
|
|
|
|
|
|
|
src_masks = F.interpolate(
|
|
|
src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
|
|
|
)
|
|
|
src_masks = src_masks[:, 0].flatten(1)
|
|
|
|
|
|
target_masks = target_masks.flatten(1)
|
|
|
target_masks = target_masks.view(src_masks.shape)
|
|
|
losses = {
|
|
|
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_masks),
|
|
|
"loss_dice": dice_loss(src_masks, target_masks, num_masks),
|
|
|
}
|
|
|
return losses
|
|
|
|
|
|
def _get_src_permutation_idx(self, indices):
|
|
|
|
|
|
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
|
|
src_idx = torch.cat([src for (src, _) in indices])
|
|
|
return batch_idx, src_idx
|
|
|
|
|
|
def _get_tgt_permutation_idx(self, indices):
|
|
|
|
|
|
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
|
|
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
|
|
return batch_idx, tgt_idx
|
|
|
|
|
|
def get_loss(self, loss, outputs, targets, indices, num_masks):
|
|
|
loss_map = {"labels": self.loss_labels, "masks": self.loss_masks}
|
|
|
assert loss in loss_map, f"do you really want to compute {loss} loss?"
|
|
|
return loss_map[loss](outputs, targets, indices, num_masks)
|
|
|
|
|
|
def forward(self, outputs, targets):
|
|
|
"""This performs the loss computation.
|
|
|
Parameters:
|
|
|
outputs: dict of tensors, see the output specification of the model for the format
|
|
|
targets: list of dicts, such that len(targets) == batch_size.
|
|
|
The expected keys in each dict depends on the losses applied, see each loss' doc
|
|
|
"""
|
|
|
outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
|
|
|
|
|
|
|
|
|
|
|
|
labels = [x['labels'] for x in targets]
|
|
|
indices_new = []
|
|
|
for label in labels:
|
|
|
t = torch.arange(len(label))
|
|
|
indices_new.append([label, t])
|
|
|
indices = indices_new
|
|
|
|
|
|
num_masks = sum(len(t["labels"]) for t in targets)
|
|
|
num_masks = torch.as_tensor(
|
|
|
[num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
|
|
|
)
|
|
|
if is_dist_avail_and_initialized():
|
|
|
torch.distributed.all_reduce(num_masks)
|
|
|
num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
|
|
|
|
|
|
|
|
|
losses = {}
|
|
|
for loss in self.losses:
|
|
|
losses.update(self.get_loss(loss, outputs, targets, indices, num_masks))
|
|
|
|
|
|
|
|
|
if "aux_outputs" in outputs:
|
|
|
for i, aux_outputs in enumerate(outputs["aux_outputs"]):
|
|
|
|
|
|
for loss in self.losses:
|
|
|
l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks)
|
|
|
l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
|
|
|
losses.update(l_dict)
|
|
|
|
|
|
return losses
|
|
|
|
|
|
|
|
|
|
|
|
class ATMLoss(nn.Module):
|
|
|
"""ATMLoss.
|
|
|
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
ignore_index,
|
|
|
num_classes,
|
|
|
dec_layers,
|
|
|
mask_weight=20.0,
|
|
|
dice_weight=1.0,
|
|
|
cls_weight=1.0,
|
|
|
atm_loss_weight=1.0,
|
|
|
use_point=False):
|
|
|
super(ATMLoss, self).__init__()
|
|
|
self.ignore_index = ignore_index
|
|
|
weight_dict = {"loss_ce": cls_weight, "loss_mask": mask_weight, "loss_dice": dice_weight}
|
|
|
aux_weight_dict = {}
|
|
|
for i in range(dec_layers - 1):
|
|
|
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
|
|
weight_dict.update(aux_weight_dict)
|
|
|
if use_point:
|
|
|
self.criterion = SetCriterion_point(
|
|
|
num_classes,
|
|
|
weight_dict=weight_dict,
|
|
|
losses=["labels", "masks"],
|
|
|
)
|
|
|
else:
|
|
|
self.criterion = SetCriterion(
|
|
|
num_classes,
|
|
|
weight_dict=weight_dict,
|
|
|
losses=["labels", "masks"],
|
|
|
)
|
|
|
self.loss_weight = atm_loss_weight
|
|
|
|
|
|
def forward(self,
|
|
|
outputs,
|
|
|
label,
|
|
|
):
|
|
|
"""Forward function."""
|
|
|
|
|
|
targets = self.prepare_targets(label)
|
|
|
losses = self.criterion(outputs, targets)
|
|
|
|
|
|
totol_loss = torch.as_tensor(0, dtype=torch.float, device=label.device)
|
|
|
for k in list(losses.keys()):
|
|
|
if k in self.criterion.weight_dict:
|
|
|
losses[k] = losses[k] * self.criterion.weight_dict[k] * self.loss_weight
|
|
|
totol_loss += losses[k]
|
|
|
else:
|
|
|
|
|
|
losses.pop(k)
|
|
|
|
|
|
return totol_loss
|
|
|
|
|
|
def prepare_targets(self, targets):
|
|
|
new_targets = []
|
|
|
for targets_per_image in targets:
|
|
|
|
|
|
gt_cls = targets_per_image.unique()
|
|
|
gt_cls = gt_cls[gt_cls != self.ignore_index]
|
|
|
masks = []
|
|
|
for cls in gt_cls:
|
|
|
masks.append(targets_per_image == cls)
|
|
|
if len(gt_cls) == 0:
|
|
|
masks.append(targets_per_image == self.ignore_index)
|
|
|
|
|
|
masks = torch.stack(masks, dim=0)
|
|
|
new_targets.append(
|
|
|
{
|
|
|
"labels": gt_cls,
|
|
|
"masks": masks,
|
|
|
}
|
|
|
)
|
|
|
return new_targets |