|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
import math
|
|
|
from utils import segutils
|
|
|
|
|
|
|
|
|
def buildHyperCol(feat_pyram):
|
|
|
|
|
|
|
|
|
target_size = feat_pyram[0].shape[-2:]
|
|
|
upsampled = []
|
|
|
for layer in feat_pyram:
|
|
|
|
|
|
upsampled.append(F.interpolate(layer, size=target_size, mode='bilinear', align_corners=False))
|
|
|
return torch.cat(upsampled, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def paste_supports_together(supports):
|
|
|
return torch.cat(supports.unbind(dim=1), dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def buildDenseAffinityMat(qfeat_volume, sfeat_volume, softmax_arg2=True):
|
|
|
qfeat_volume, sfeat_volume = qfeat_volume.permute(0, 2, 3, 1), sfeat_volume.permute(0, 2, 3, 1)
|
|
|
bsz, H, Wq, C = qfeat_volume.shape
|
|
|
Ws = sfeat_volume.shape[2]
|
|
|
|
|
|
dense_affinity_mat = torch.matmul(qfeat_volume.view(bsz, H * Wq, C),
|
|
|
sfeat_volume.view(bsz, H * Ws, C).transpose(1, 2))
|
|
|
if softmax_arg2 is False: return dense_affinity_mat
|
|
|
dense_affinity_mat_softmax = (dense_affinity_mat / math.sqrt(C)).softmax(
|
|
|
dim=-1)
|
|
|
return dense_affinity_mat_softmax
|
|
|
|
|
|
|
|
|
|
|
|
def filterDenseAffinityMap(dense_affinity_mat, downsampled_smask):
|
|
|
|
|
|
|
|
|
bsz, HWq, HWs = dense_affinity_mat.shape
|
|
|
|
|
|
|
|
|
q_coarse = torch.matmul(dense_affinity_mat, downsampled_smask.view(bsz, HWs, 1))
|
|
|
return q_coarse.view(bsz, HWq)
|
|
|
|
|
|
|
|
|
def upsample(volume, h, w):
|
|
|
return F.interpolate(volume, size=(h, w), mode='bilinear', align_corners=False)
|
|
|
|
|
|
class DAMatComparison:
|
|
|
|
|
|
def algo_mean(self, q_pred_coarses_t, s_mask=None):
|
|
|
return q_pred_coarses_t.mean(1)
|
|
|
|
|
|
def calc_q_pred_coarses(self, q_feat_t, s_feat_t, s_mask, l0=3):
|
|
|
q_pred_coarses = []
|
|
|
h0, w0 = q_feat_t[l0].shape[-2:]
|
|
|
for (qft, sft) in zip(q_feat_t[l0:], s_feat_t[l0:]):
|
|
|
qft, sft = qft.detach(), sft.detach()
|
|
|
bsz, c, hq, wq = qft.shape
|
|
|
hs, ws = sft.shape[-2:]
|
|
|
|
|
|
sft_row = torch.cat(sft.unbind(1), -1)
|
|
|
smasks_downsampled = [segutils.downsample_mask(m, hs, ws) for m in s_mask.unbind(1)]
|
|
|
smask_row = torch.cat(smasks_downsampled, -1)
|
|
|
|
|
|
damat = buildDenseAffinityMat(qft, sft_row)
|
|
|
filtered = filterDenseAffinityMap(damat, smask_row)
|
|
|
q_pred_coarse = upsample(filtered.view(bsz, 1, hq, wq), h0, w0).squeeze(1)
|
|
|
q_pred_coarses.append(q_pred_coarse)
|
|
|
return torch.stack(q_pred_coarses, dim=1)
|
|
|
|
|
|
def forward(self, q_feat_t, s_feat_t, s_mask, upsample=True, debug=False):
|
|
|
q_pred_coarses_t = self.calc_q_pred_coarses(q_feat_t, s_feat_t, s_mask)
|
|
|
|
|
|
if debug: display(segutils.pilImageRow(*q_pred_coarses_t.unbind(1), q_pred_coarses_t.mean(1)))
|
|
|
|
|
|
|
|
|
postprocessing_algorithm = self.algo_mean
|
|
|
|
|
|
logit_mask = postprocessing_algorithm(q_pred_coarses_t, s_mask)
|
|
|
if upsample:
|
|
|
logit_mask = segutils.downsample_mask(logit_mask, *s_mask.shape[-2:])
|
|
|
|
|
|
return logit_mask |