|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
|
|
|
from models.network import extractors |
|
|
from models.sync_batchnorm import SynchronizedBatchNorm2d |
|
|
from models.network.aspp import ASPP_no4level |
|
|
|
|
|
def make_coord(shape, ranges=None, flatten=True): |
|
|
""" Make coordinates at grid centers. |
|
|
""" |
|
|
coord_seqs = [] |
|
|
for i, n in enumerate(shape): |
|
|
if ranges is None: |
|
|
v0, v1 = -1, 1 |
|
|
else: |
|
|
v0, v1 = ranges[i] |
|
|
r = (v1 - v0) / (2 * n) |
|
|
seq = v0 + r + (2 * r) * torch.arange(n).float() |
|
|
coord_seqs.append(seq) |
|
|
ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1) |
|
|
if flatten: |
|
|
ret = ret.view(-1, ret.shape[-1]) |
|
|
return ret |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
|
|
|
def __init__(self, in_dim, out_dim, hidden_list): |
|
|
super().__init__() |
|
|
layers = [] |
|
|
lastv = in_dim |
|
|
for hidden in hidden_list: |
|
|
layers.append(nn.Linear(lastv, hidden)) |
|
|
layers.append(nn.ReLU()) |
|
|
lastv = hidden |
|
|
layers.append(nn.Linear(lastv, out_dim)) |
|
|
self.layers = nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, x): |
|
|
shape = x.shape[:-1] |
|
|
x = self.layers(x.view(-1, x.shape[-1])) |
|
|
return x.view(*shape, -1) |
|
|
|
|
|
|
|
|
class CRMNet(nn.Module): |
|
|
def __init__(self, backend='resnet34', pretrained=True): |
|
|
super().__init__() |
|
|
self.feats = getattr(extractors, backend)(pretrained) |
|
|
self.aspp_ = ASPP_no4level(backbone=backend, output_stride=8, BatchNorm=SynchronizedBatchNorm2d) |
|
|
self.imnet = MLP(in_dim=256+6, out_dim=1, hidden_list=[32, 32, 32, 32]) |
|
|
|
|
|
def forward(self, x, seg, coord, cell, inter_s8=None, inter_s4=None): |
|
|
|
|
|
|
|
|
p = torch.cat((x, seg), 1) |
|
|
|
|
|
|
|
|
x1_feat, x2_feat, x3_feat = self.feats(p) |
|
|
feat = self.aspp_(x1_feat, x2_feat, x3_feat) |
|
|
|
|
|
vx_lst = [-1, 1] |
|
|
vy_lst = [-1, 1] |
|
|
eps_shift = 1e-6 |
|
|
|
|
|
rx = 2 / feat.shape[-2] / 2 |
|
|
ry = 2 / feat.shape[-1] / 2 |
|
|
|
|
|
feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda().permute(2, 0, 1).unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:]) |
|
|
|
|
|
preds = [] |
|
|
areas = [] |
|
|
for vx in vx_lst: |
|
|
for vy in vy_lst: |
|
|
coord_ = coord.clone() |
|
|
coord_[:, :, 0] += vx * rx + eps_shift |
|
|
coord_[:, :, 1] += vy * ry + eps_shift |
|
|
coord_.clamp_(-1 + 1e-6, 1 - 1e-6) |
|
|
q_feat = F.grid_sample( |
|
|
feat, coord_.flip(-1).unsqueeze(1), |
|
|
mode='nearest', align_corners=False)[:, :, 0, :] \ |
|
|
.permute(0, 2, 1) |
|
|
q_coord = F.grid_sample( |
|
|
feat_coord, coord_.flip(-1).unsqueeze(1), |
|
|
mode='nearest', align_corners=False)[:, :, 0, :] \ |
|
|
.permute(0, 2, 1) |
|
|
rel_coord = coord - q_coord |
|
|
rel_coord[:, :, 0] *= feat.shape[-2] |
|
|
rel_coord[:, :, 1] *= feat.shape[-1] |
|
|
inp = torch.cat([q_feat, rel_coord, coord], dim=-1) |
|
|
|
|
|
|
|
|
rel_cell = cell.clone() |
|
|
rel_cell[:, :, 0] *= feat.shape[-2] |
|
|
rel_cell[:, :, 1] *= feat.shape[-1] |
|
|
inp = torch.cat([inp, rel_cell], dim=-1) |
|
|
|
|
|
bs, q = coord.shape[:2] |
|
|
pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1) |
|
|
preds.append(pred) |
|
|
|
|
|
area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1]) |
|
|
areas.append(area + 1e-9) |
|
|
|
|
|
tot_area = torch.stack(areas).sum(dim=0) |
|
|
|
|
|
t = areas[0]; areas[0] = areas[3]; areas[3] = t |
|
|
t = areas[1]; areas[1] = areas[2]; areas[2] = t |
|
|
ret = 0 |
|
|
|
|
|
for pred, area in zip(preds, areas): |
|
|
ret = ret + pred * (area / tot_area).unsqueeze(-1) |
|
|
|
|
|
pred_224 = torch.sigmoid(ret) |
|
|
|
|
|
images = {} |
|
|
images['out_224'] = ret |
|
|
images['pred_224'] = pred_224 |
|
|
|
|
|
return images |
|
|
|