| | |
| |
|
| | """ |
| | Motion extractor(M), which directly predicts the canonical keypoints, head pose and expression deformation of the input image |
| | """ |
| |
|
| | from torch import nn |
| | import torch |
| |
|
| | from .convnextv2 import convnextv2_tiny |
| | from .util import filter_state_dict |
| |
|
| | model_dict = { |
| | 'convnextv2_tiny': convnextv2_tiny, |
| | } |
| |
|
| |
|
| | class MotionExtractor(nn.Module): |
| | def __init__(self, **kwargs): |
| | super(MotionExtractor, self).__init__() |
| |
|
| | |
| | backbone = kwargs.get('backbone', 'convnextv2_tiny') |
| | self.detector = model_dict.get(backbone)(**kwargs) |
| |
|
| | def load_pretrained(self, init_path: str): |
| | if init_path not in (None, ''): |
| | state_dict = torch.load(init_path, map_location=lambda storage, loc: storage)['model'] |
| | state_dict = filter_state_dict(state_dict, remove_name='head') |
| | ret = self.detector.load_state_dict(state_dict, strict=False) |
| | print(f'Load pretrained model from {init_path}, ret: {ret}') |
| |
|
| | def forward(self, x): |
| | out = self.detector(x) |
| | return out |
| |
|