YuqianFu's picture
Upload folder using huggingface_hub
fe6c2e4 verified
import torch
import torch.nn.functional as F
from torch import nn
def refer_ce_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
weight: torch.Tensor):
loss = F.cross_entropy(inputs, targets, weight=weight)
return loss
refer_ce_loss_jit = torch.jit.script(
refer_ce_loss
) # type: torch.jit.ScriptModule
class ReferringCriterion(nn.Module):
def __init__(self, weight_dict, losses):
super().__init__()
self.weight_dict = weight_dict
self.losses = losses
def get_loss(self, loss, outputs, targets):
loss_map = {
'masks': self.loss_masks_refer,
}
assert loss in loss_map, f"do you really want to compute {loss} loss?"
return loss_map[loss](outputs, targets)
def loss_masks_refer(self, outputs, targets):
src_masks = outputs["pred_masks"]
src_minimap = outputs["pred_logits"].permute(0,2,1)
src_nt_label = outputs["nt_label"]
masks = [t["gt_mask_merged"] for t in targets]
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
target_masks = target_masks.to(src_masks)
target_nts = torch.stack([t["empty"] for t in targets])
h, w = target_masks.shape[-2:]
src_masks = F.interpolate(src_masks, (h, w), mode='bilinear', align_corners=False)
target_minimap = F.interpolate(target_masks, (10, 10), mode='bilinear', align_corners=False).flatten(start_dim=1)
weight = torch.FloatTensor([0.9, 1.1]).to(src_masks)
loss_mask = \
refer_ce_loss_jit(src_masks, target_masks.squeeze(1).long(), weight) + \
refer_ce_loss_jit(src_minimap, target_minimap.squeeze(1).long(), weight) * 0.1 + \
refer_ce_loss_jit(src_nt_label, target_nts, weight) * 0.1
losses = {
"loss_mask": loss_mask,
}
del src_masks
del target_masks
return losses
def forward(self, outputs, targets):
# Compute all the requested losses
losses = {}
losses.update(self.loss_masks_refer(outputs, targets))
return losses
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
# TODO make this more general
if tensor_list[0].ndim == 3:
if torchvision._is_tracing():
# nested_tensor_from_tensor_list() does not export well to ONNX
# call _onnx_nested_tensor_from_tensor_list() instead
return _onnx_nested_tensor_from_tensor_list(tensor_list)
# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], : img.shape[2]] = False
else:
raise ValueError("not supported")
return NestedTensor(tensor, mask)