| |
| |
| |
|
|
| import pdb |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from nets.ap_loss import APLoss |
|
|
|
|
| class PixelAPLoss(nn.Module): |
| """Computes the pixel-wise AP loss: |
| Given two images and ground-truth optical flow, computes the AP per pixel. |
| |
| feat1: (B, C, H, W) pixel-wise features extracted from img1 |
| feat2: (B, C, H, W) pixel-wise features extracted from img2 |
| aflow: (B, 2, H, W) absolute flow: aflow[...,y1,x1] = x2,y2 |
| """ |
|
|
| def __init__(self, sampler, nq=20): |
| nn.Module.__init__(self) |
| self.aploss = APLoss(nq, min=0, max=1, euc=False) |
| self.name = "pixAP" |
| self.sampler = sampler |
|
|
| def loss_from_ap(self, ap, rel): |
| return 1 - ap |
|
|
| def forward(self, descriptors, aflow, **kw): |
| |
| scores, gt, msk, qconf = self.sampler(descriptors, kw.get("reliability"), aflow) |
|
|
| |
| n = qconf.numel() |
| if n == 0: |
| return 0 |
| scores, gt = scores.view(n, -1), gt.view(n, -1) |
| ap = self.aploss(scores, gt).view(msk.shape) |
|
|
| pixel_loss = self.loss_from_ap(ap, qconf) |
|
|
| loss = pixel_loss[msk].mean() |
| return loss |
|
|
|
|
| class ReliabilityLoss(PixelAPLoss): |
| """same than PixelAPLoss, but also train a pixel-wise confidence |
| that this pixel is going to have a good AP. |
| """ |
|
|
| def __init__(self, sampler, base=0.5, **kw): |
| PixelAPLoss.__init__(self, sampler, **kw) |
| assert 0 <= base < 1 |
| self.base = base |
| self.name = "reliability" |
|
|
| def loss_from_ap(self, ap, rel): |
| return 1 - ap * rel - (1 - rel) * self.base |
|
|