Spaces:
Build error
Build error
| r""" Convolutional Hough Matching Networks """ | |
| import torch.nn as nn | |
| import torch | |
| from . import chmlearner as chmlearner | |
| from .base import backbone | |
| class CHMNet(nn.Module): | |
| def __init__(self, ktype): | |
| super(CHMNet, self).__init__() | |
| self.backbone = backbone.resnet101(pretrained=True) | |
| self.learner = chmlearner.CHMLearner(ktype, feat_dim=1024) | |
| def forward(self, src_img, trg_img): | |
| src_feat, trg_feat = self.extract_features(src_img, trg_img) | |
| correlation = self.learner(src_feat, trg_feat) | |
| return correlation | |
| def extract_features(self, src_img, trg_img): | |
| feat = self.backbone.conv1.forward(torch.cat([src_img, trg_img], dim=1)) | |
| feat = self.backbone.bn1.forward(feat) | |
| feat = self.backbone.relu.forward(feat) | |
| feat = self.backbone.maxpool.forward(feat) | |
| for idx in range(1, 5): | |
| feat = self.backbone.__getattr__('layer%d' % idx)(feat) | |
| if idx == 3: | |
| src_feat = feat.narrow(1, 0, feat.size(1) // 2).clone() | |
| trg_feat = feat.narrow(1, feat.size(1) // 2, feat.size(1) // 2).clone() | |
| return src_feat, trg_feat | |
| def training_objective(cls, prd_kps, trg_kps, npts): | |
| l2dist = (prd_kps - trg_kps).pow(2).sum(dim=1) | |
| loss = [] | |
| for dist, npt in zip(l2dist, npts): | |
| loss.append(dist[:npt].mean()) | |
| return torch.stack(loss).mean() | |