Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmpose.registry import MODELS | |
| class KDLoss(nn.Module): | |
| """PyTorch version of logit-based distillation from DWPose Modified from | |
| the official implementation. | |
| <https://github.com/IDEA-Research/DWPose> | |
| Args: | |
| weight (float, optional): Weight of dis_loss. Defaults to 1.0 | |
| """ | |
| def __init__( | |
| self, | |
| name, | |
| use_this, | |
| weight=1.0, | |
| ): | |
| super(KDLoss, self).__init__() | |
| self.log_softmax = nn.LogSoftmax(dim=1) | |
| self.kl_loss = nn.KLDivLoss(reduction='none') | |
| self.weight = weight | |
| def forward(self, pred, pred_t, beta, target_weight): | |
| ls_x, ls_y = pred | |
| lt_x, lt_y = pred_t | |
| lt_x = lt_x.detach() | |
| lt_y = lt_y.detach() | |
| num_joints = ls_x.size(1) | |
| loss = 0 | |
| loss += (self.loss(ls_x, lt_x, beta, target_weight)) | |
| loss += (self.loss(ls_y, lt_y, beta, target_weight)) | |
| return loss / num_joints | |
| def loss(self, logit_s, logit_t, beta, weight): | |
| N = logit_s.shape[0] | |
| if len(logit_s.shape) == 3: | |
| K = logit_s.shape[1] | |
| logit_s = logit_s.reshape(N * K, -1) | |
| logit_t = logit_t.reshape(N * K, -1) | |
| # N*W(H) | |
| s_i = self.log_softmax(logit_s * beta) | |
| t_i = F.softmax(logit_t * beta, dim=1) | |
| # kd | |
| loss_all = torch.sum(self.kl_loss(s_i, t_i), dim=1) | |
| loss_all = loss_all.reshape(N, K).sum(dim=1).mean() | |
| loss_all = self.weight * loss_all | |
| return loss_all | |