Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from abc import ABCMeta | |
| from typing import Tuple | |
| import torch | |
| import torch.nn as nn | |
| from mmengine.config import Config | |
| from mmengine.logging import MessageHub | |
| from mmengine.model import BaseModel | |
| from mmengine.runner.checkpoint import load_checkpoint | |
| from torch import Tensor | |
| from mmpose.evaluation.functional import simcc_pck_accuracy | |
| from mmpose.models import build_pose_estimator | |
| from mmpose.registry import MODELS | |
| from mmpose.utils.tensor_utils import to_numpy | |
| from mmpose.utils.typing import (ForwardResults, OptConfigType, OptMultiConfig, | |
| OptSampleList, SampleList) | |
| class DWPoseDistiller(BaseModel, metaclass=ABCMeta): | |
| """Distiller introduced in `DWPose`_ by Yang et al (2023). This distiller | |
| is designed for distillation of RTMPose. | |
| It typically consists of teacher_model and student_model. Please use the | |
| script `tools/misc/pth_transfer.py` to transfer the distilled model to the | |
| original RTMPose model. | |
| Args: | |
| teacher_cfg (str): Config file of the teacher model. | |
| student_cfg (str): Config file of the student model. | |
| two_dis (bool): Whether this is the second stage of distillation. | |
| Defaults to False. | |
| distill_cfg (dict): Config for distillation. Defaults to None. | |
| teacher_pretrained (str): Path of the pretrained teacher model. | |
| Defaults to None. | |
| train_cfg (dict, optional): The runtime config for training process. | |
| Defaults to ``None`` | |
| data_preprocessor (dict, optional): The data preprocessing config to | |
| build the instance of :class:`BaseDataPreprocessor`. Defaults to | |
| ``None`` | |
| init_cfg (dict, optional): The config to control the initialization. | |
| Defaults to ``None`` | |
| .. _`DWPose`: https://arxiv.org/abs/2307.15880 | |
| """ | |
| def __init__(self, | |
| teacher_cfg, | |
| student_cfg, | |
| two_dis=False, | |
| distill_cfg=None, | |
| teacher_pretrained=None, | |
| train_cfg: OptConfigType = None, | |
| data_preprocessor: OptConfigType = None, | |
| init_cfg: OptMultiConfig = None): | |
| super().__init__( | |
| data_preprocessor=data_preprocessor, init_cfg=init_cfg) | |
| self.teacher = build_pose_estimator( | |
| (Config.fromfile(teacher_cfg)).model) | |
| self.teacher_pretrained = teacher_pretrained | |
| self.teacher.eval() | |
| for param in self.teacher.parameters(): | |
| param.requires_grad = False | |
| self.student = build_pose_estimator( | |
| (Config.fromfile(student_cfg)).model) | |
| self.distill_cfg = distill_cfg | |
| self.distill_losses = nn.ModuleDict() | |
| if self.distill_cfg is not None: | |
| for item_loc in distill_cfg: | |
| for item_loss in item_loc.methods: | |
| loss_name = item_loss.name | |
| use_this = item_loss.use_this | |
| if use_this: | |
| self.distill_losses[loss_name] = MODELS.build( | |
| item_loss) | |
| self.two_dis = two_dis | |
| self.train_cfg = train_cfg if train_cfg else self.student.train_cfg | |
| self.test_cfg = self.student.test_cfg | |
| self.metainfo = self.student.metainfo | |
| def init_weights(self): | |
| if self.teacher_pretrained is not None: | |
| load_checkpoint( | |
| self.teacher, self.teacher_pretrained, map_location='cpu') | |
| self.student.init_weights() | |
| def set_epoch(self): | |
| """Set epoch for distiller. | |
| Used for the decay of distillation loss. | |
| """ | |
| self.message_hub = MessageHub.get_current_instance() | |
| self.epoch = self.message_hub.get_info('epoch') | |
| self.max_epochs = self.message_hub.get_info('max_epochs') | |
| def forward(self, | |
| inputs: torch.Tensor, | |
| data_samples: OptSampleList, | |
| mode: str = 'tensor') -> ForwardResults: | |
| if mode == 'loss': | |
| return self.loss(inputs, data_samples) | |
| elif mode == 'predict': | |
| # use customed metainfo to override the default metainfo | |
| if self.metainfo is not None: | |
| for data_sample in data_samples: | |
| data_sample.set_metainfo(self.metainfo) | |
| return self.predict(inputs, data_samples) | |
| elif mode == 'tensor': | |
| return self._forward(inputs) | |
| else: | |
| raise RuntimeError(f'Invalid mode "{mode}". ' | |
| 'Only supports loss, predict and tensor mode.') | |
| def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: | |
| """Calculate losses from a batch of inputs and data samples. | |
| Args: | |
| inputs (Tensor): Inputs with shape (N, C, H, W). | |
| data_samples (List[:obj:`PoseDataSample`]): The batch | |
| data samples. | |
| Returns: | |
| dict: A dictionary of losses. | |
| """ | |
| self.set_epoch() | |
| losses = dict() | |
| with torch.no_grad(): | |
| fea_t = self.teacher.extract_feat(inputs) | |
| lt_x, lt_y = self.teacher.head(fea_t) | |
| pred_t = (lt_x, lt_y) | |
| if not self.two_dis: | |
| fea_s = self.student.extract_feat(inputs) | |
| ori_loss, pred, gt, target_weight = self.head_loss( | |
| fea_s, data_samples, train_cfg=self.train_cfg) | |
| losses.update(ori_loss) | |
| else: | |
| ori_loss, pred, gt, target_weight = self.head_loss( | |
| fea_t, data_samples, train_cfg=self.train_cfg) | |
| all_keys = self.distill_losses.keys() | |
| if 'loss_fea' in all_keys: | |
| loss_name = 'loss_fea' | |
| losses[loss_name] = self.distill_losses[loss_name](fea_s[-1], | |
| fea_t[-1]) | |
| if not self.two_dis: | |
| losses[loss_name] = ( | |
| 1 - self.epoch / self.max_epochs) * losses[loss_name] | |
| if 'loss_logit' in all_keys: | |
| loss_name = 'loss_logit' | |
| losses[loss_name] = self.distill_losses[loss_name]( | |
| pred, pred_t, self.student.head.loss_module.beta, | |
| target_weight) | |
| if not self.two_dis: | |
| losses[loss_name] = ( | |
| 1 - self.epoch / self.max_epochs) * losses[loss_name] | |
| return losses | |
| def predict(self, inputs, data_samples): | |
| """Predict results from a batch of inputs and data samples with post- | |
| processing. | |
| Args: | |
| inputs (Tensor): Inputs with shape (N, C, H, W) | |
| data_samples (List[:obj:`PoseDataSample`]): The batch | |
| data samples | |
| Returns: | |
| list[:obj:`PoseDataSample`]: The pose estimation results of the | |
| input images. The return value is `PoseDataSample` instances with | |
| ``pred_instances`` and ``pred_fields``(optional) field , and | |
| ``pred_instances`` usually contains the following keys: | |
| - keypoints (Tensor): predicted keypoint coordinates in shape | |
| (num_instances, K, D) where K is the keypoint number and D | |
| is the keypoint dimension | |
| - keypoint_scores (Tensor): predicted keypoint scores in shape | |
| (num_instances, K) | |
| """ | |
| if self.two_dis: | |
| assert self.student.with_head, ( | |
| 'The model must have head to perform prediction.') | |
| if self.test_cfg.get('flip_test', False): | |
| _feats = self.extract_feat(inputs) | |
| _feats_flip = self.extract_feat(inputs.flip(-1)) | |
| feats = [_feats, _feats_flip] | |
| else: | |
| feats = self.extract_feat(inputs) | |
| preds = self.student.head.predict( | |
| feats, data_samples, test_cfg=self.student.test_cfg) | |
| if isinstance(preds, tuple): | |
| batch_pred_instances, batch_pred_fields = preds | |
| else: | |
| batch_pred_instances = preds | |
| batch_pred_fields = None | |
| results = self.student.add_pred_to_datasample( | |
| batch_pred_instances, batch_pred_fields, data_samples) | |
| return results | |
| else: | |
| return self.student.predict(inputs, data_samples) | |
| def extract_feat(self, inputs: Tensor) -> Tuple[Tensor]: | |
| """Extract features. | |
| Args: | |
| inputs (Tensor): Image tensor with shape (N, C, H ,W). | |
| Returns: | |
| tuple[Tensor]: Multi-level features that may have various | |
| resolutions. | |
| """ | |
| x = self.teacher.extract_feat(inputs) | |
| return x | |
| def head_loss( | |
| self, | |
| feats: Tuple[Tensor], | |
| batch_data_samples: OptSampleList, | |
| train_cfg: OptConfigType = {}, | |
| ) -> dict: | |
| """Calculate losses from a batch of inputs and data samples.""" | |
| pred_x, pred_y = self.student.head.forward(feats) | |
| gt_x = torch.cat([ | |
| d.gt_instance_labels.keypoint_x_labels for d in batch_data_samples | |
| ], | |
| dim=0) | |
| gt_y = torch.cat([ | |
| d.gt_instance_labels.keypoint_y_labels for d in batch_data_samples | |
| ], | |
| dim=0) | |
| keypoint_weights = torch.cat( | |
| [ | |
| d.gt_instance_labels.keypoint_weights | |
| for d in batch_data_samples | |
| ], | |
| dim=0, | |
| ) | |
| pred_simcc = (pred_x, pred_y) | |
| gt_simcc = (gt_x, gt_y) | |
| # calculate losses | |
| losses = dict() | |
| loss = self.student.head.loss_module(pred_simcc, gt_simcc, | |
| keypoint_weights) | |
| losses.update(loss_kpt=loss) | |
| # calculate accuracy | |
| _, avg_acc, _ = simcc_pck_accuracy( | |
| output=to_numpy(pred_simcc), | |
| target=to_numpy(gt_simcc), | |
| simcc_split_ratio=self.student.head.simcc_split_ratio, | |
| mask=to_numpy(keypoint_weights) > 0, | |
| ) | |
| acc_pose = torch.tensor(avg_acc, device=gt_x.device) | |
| losses.update(acc_pose=acc_pose) | |
| return losses, pred_simcc, gt_simcc, keypoint_weights | |
| def _forward(self, inputs: Tensor): | |
| """Network forward process. Usually includes backbone, neck and head | |
| forward without any post-processing. | |
| Args: | |
| inputs (Tensor): Inputs with shape (N, C, H, W). | |
| Returns: | |
| Union[Tensor | Tuple[Tensor]]: forward output of the network. | |
| """ | |
| return self.student._forward(inputs) | |