| |
| import functools |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from mmengine.fileio import load |
|
|
|
|
| def get_class_weight(class_weight): |
| """Get class weight for loss function. |
| |
| Args: |
| class_weight (list[float] | str | None): If class_weight is a str, |
| take it as a file name and read from it. |
| """ |
| if isinstance(class_weight, str): |
| |
| if class_weight.endswith('.npy'): |
| class_weight = np.load(class_weight) |
| else: |
| |
| class_weight = load(class_weight) |
|
|
| return class_weight |
|
|
|
|
| def reduce_loss(loss, reduction) -> torch.Tensor: |
| """Reduce loss as specified. |
| |
| Args: |
| loss (Tensor): Elementwise loss tensor. |
| reduction (str): Options are "none", "mean" and "sum". |
| |
| Return: |
| Tensor: Reduced loss tensor. |
| """ |
| reduction_enum = F._Reduction.get_enum(reduction) |
| |
| if reduction_enum == 0: |
| return loss |
| elif reduction_enum == 1: |
| return loss.mean() |
| elif reduction_enum == 2: |
| return loss.sum() |
|
|
|
|
| def weight_reduce_loss(loss, |
| weight=None, |
| reduction='mean', |
| avg_factor=None) -> torch.Tensor: |
| """Apply element-wise weight and reduce loss. |
| |
| Args: |
| loss (Tensor): Element-wise loss. |
| weight (Tensor): Element-wise weights. |
| reduction (str): Same as built-in losses of PyTorch. |
| avg_factor (float): Average factor when computing the mean of losses. |
| |
| Returns: |
| Tensor: Processed loss values. |
| """ |
| |
| if weight is not None: |
| assert weight.dim() == loss.dim() |
| if weight.dim() > 1: |
| assert weight.size(1) == 1 or weight.size(1) == loss.size(1) |
| loss = loss * weight |
|
|
| |
| if avg_factor is None: |
| loss = reduce_loss(loss, reduction) |
| else: |
| |
| if reduction == 'mean': |
| |
| |
| eps = torch.finfo(torch.float32).eps |
| loss = loss.sum() / (avg_factor + eps) |
| |
| elif reduction != 'none': |
| raise ValueError('avg_factor can not be used with reduction="sum"') |
| return loss |
|
|
|
|
| def weighted_loss(loss_func): |
| """Create a weighted version of a given loss function. |
| |
| To use this decorator, the loss function must have the signature like |
| `loss_func(pred, target, **kwargs)`. The function only needs to compute |
| element-wise loss without any reduction. This decorator will add weight |
| and reduction arguments to the function. The decorated function will have |
| the signature like `loss_func(pred, target, weight=None, reduction='mean', |
| avg_factor=None, **kwargs)`. |
| |
| :Example: |
| |
| >>> import torch |
| >>> @weighted_loss |
| >>> def l1_loss(pred, target): |
| >>> return (pred - target).abs() |
| |
| >>> pred = torch.Tensor([0, 2, 3]) |
| >>> target = torch.Tensor([1, 1, 1]) |
| >>> weight = torch.Tensor([1, 0, 1]) |
| |
| >>> l1_loss(pred, target) |
| tensor(1.3333) |
| >>> l1_loss(pred, target, weight) |
| tensor(1.) |
| >>> l1_loss(pred, target, reduction='none') |
| tensor([1., 1., 2.]) |
| >>> l1_loss(pred, target, weight, avg_factor=2) |
| tensor(1.5000) |
| """ |
|
|
| @functools.wraps(loss_func) |
| def wrapper(pred, |
| target, |
| weight=None, |
| reduction='mean', |
| avg_factor=None, |
| **kwargs): |
| |
| loss = loss_func(pred, target, **kwargs) |
| loss = weight_reduce_loss(loss, weight, reduction, avg_factor) |
| return loss |
|
|
| return wrapper |
|
|