| from torch import nn | |
| import torch | |
| from torchvision import models | |
| class KPDetector(nn.Module): | |
| """ | |
| Predict K*5 keypoints. | |
| """ | |
| def __init__(self, num_tps, **kwargs): | |
| super(KPDetector, self).__init__() | |
| self.num_tps = num_tps | |
| self.fg_encoder = models.resnet18(pretrained=False) | |
| num_features = self.fg_encoder.fc.in_features | |
| self.fg_encoder.fc = nn.Linear(num_features, num_tps*5*2) | |
| def forward(self, image): | |
| fg_kp = self.fg_encoder(image) | |
| bs, _, = fg_kp.shape | |
| fg_kp = torch.sigmoid(fg_kp) | |
| fg_kp = fg_kp * 2 - 1 | |
| out = {'fg_kp': fg_kp.view(bs, self.num_tps*5, -1)} | |
| return out | |