| |
| from typing import Optional |
|
|
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch import Tensor |
|
|
| from mmdet.registry import MODELS |
| from .utils import weighted_loss |
|
|
|
|
| @weighted_loss |
| def knowledge_distillation_kl_div_loss(pred: Tensor, |
| soft_label: Tensor, |
| T: int, |
| detach_target: bool = True) -> Tensor: |
| r"""Loss function for knowledge distilling using KL divergence. |
| |
| Args: |
| pred (Tensor): Predicted logits with shape (N, n + 1). |
| soft_label (Tensor): Target logits with shape (N, N + 1). |
| T (int): Temperature for distillation. |
| detach_target (bool): Remove soft_label from automatic differentiation |
| |
| Returns: |
| Tensor: Loss tensor with shape (N,). |
| """ |
| assert pred.size() == soft_label.size() |
| target = F.softmax(soft_label / T, dim=1) |
| if detach_target: |
| target = target.detach() |
|
|
| kd_loss = F.kl_div( |
| F.log_softmax(pred / T, dim=1), target, reduction='none').mean(1) * ( |
| T * T) |
|
|
| return kd_loss |
|
|
|
|
| @MODELS.register_module() |
| class KnowledgeDistillationKLDivLoss(nn.Module): |
| """Loss function for knowledge distilling using KL divergence. |
| |
| Args: |
| reduction (str): Options are `'none'`, `'mean'` and `'sum'`. |
| loss_weight (float): Loss weight of current loss. |
| T (int): Temperature for distillation. |
| """ |
|
|
| def __init__(self, |
| reduction: str = 'mean', |
| loss_weight: float = 1.0, |
| T: int = 10) -> None: |
| super().__init__() |
| assert T >= 1 |
| self.reduction = reduction |
| self.loss_weight = loss_weight |
| self.T = T |
|
|
| def forward(self, |
| pred: Tensor, |
| soft_label: Tensor, |
| weight: Optional[Tensor] = None, |
| avg_factor: Optional[int] = None, |
| reduction_override: Optional[str] = None) -> Tensor: |
| """Forward function. |
| |
| Args: |
| pred (Tensor): Predicted logits with shape (N, n + 1). |
| soft_label (Tensor): Target logits with shape (N, N + 1). |
| weight (Tensor, optional): The weight of loss for each |
| prediction. Defaults to None. |
| avg_factor (int, optional): Average factor that is used to average |
| the loss. Defaults to None. |
| reduction_override (str, optional): The reduction method used to |
| override the original reduction method of the loss. |
| Defaults to None. |
| |
| Returns: |
| Tensor: Loss tensor. |
| """ |
| assert reduction_override in (None, 'none', 'mean', 'sum') |
|
|
| reduction = ( |
| reduction_override if reduction_override else self.reduction) |
|
|
| loss_kd = self.loss_weight * knowledge_distillation_kl_div_loss( |
| pred, |
| soft_label, |
| weight, |
| reduction=reduction, |
| avg_factor=avg_factor, |
| T=self.T) |
|
|
| return loss_kd |
|
|