Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch.nn as nn | |
| from mmpose.registry import MODELS | |
| class FeaLoss(nn.Module): | |
| """PyTorch version of feature-based distillation from DWPose Modified from | |
| the official implementation. | |
| <https://github.com/IDEA-Research/DWPose> | |
| Args: | |
| student_channels(int): Number of channels in the student's feature map. | |
| teacher_channels(int): Number of channels in the teacher's feature map. | |
| alpha_fea (float, optional): Weight of dis_loss. Defaults to 0.00007 | |
| """ | |
| def __init__( | |
| self, | |
| name, | |
| use_this, | |
| student_channels, | |
| teacher_channels, | |
| alpha_fea=0.00007, | |
| ): | |
| super(FeaLoss, self).__init__() | |
| self.alpha_fea = alpha_fea | |
| if teacher_channels != student_channels: | |
| self.align = nn.Conv2d( | |
| student_channels, | |
| teacher_channels, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0) | |
| else: | |
| self.align = None | |
| def forward(self, preds_S, preds_T): | |
| """Forward function. | |
| Args: | |
| preds_S(Tensor): Bs*C*H*W, student's feature map | |
| preds_T(Tensor): Bs*C*H*W, teacher's feature map | |
| """ | |
| if self.align is not None: | |
| outs = self.align(preds_S) | |
| else: | |
| outs = preds_S | |
| loss = self.get_dis_loss(outs, preds_T) | |
| return loss | |
| def get_dis_loss(self, preds_S, preds_T): | |
| loss_mse = nn.MSELoss(reduction='sum') | |
| N, C, H, W = preds_T.shape | |
| dis_loss = loss_mse(preds_S, preds_T) / N * self.alpha_fea | |
| return dis_loss | |