| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import geffnet |
|
|
|
|
| INPUT_CHANNELS_DICT = { |
| 0: [1280, 112, 40, 24, 16], |
| 1: [1280, 112, 40, 24, 16], |
| 2: [1408, 120, 48, 24, 16], |
| 3: [1536, 136, 48, 32, 24], |
| 4: [1792, 160, 56, 32, 24], |
| 5: [2048, 176, 64, 40, 24], |
| 6: [2304, 200, 72, 40, 32], |
| 7: [2560, 224, 80, 48, 32], |
| } |
|
|
|
|
| class Encoder(nn.Module): |
| def __init__(self, B=5, pretrained=True): |
| """e.g. B=5 will return EfficientNet-B5""" |
| super(Encoder, self).__init__() |
| basemodel = geffnet.create_model( |
| "tf_efficientnet_b%s_ap" % B, pretrained=pretrained |
| ) |
| |
| basemodel.global_pool = nn.Identity() |
| basemodel.classifier = nn.Identity() |
| self.original_model = basemodel |
|
|
| def forward(self, x): |
| features = [x] |
| for k, v in self.original_model._modules.items(): |
| if k == "blocks": |
| for ki, vi in v._modules.items(): |
| features.append(vi(features[-1])) |
| else: |
| features.append(v(features[-1])) |
| return features |
|
|
|
|
| class ConvGRU(nn.Module): |
| def __init__(self, hidden_dim, input_dim, ks=3): |
| super(ConvGRU, self).__init__() |
| p = (ks - 1) // 2 |
| self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, ks, padding=p) |
| self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, ks, padding=p) |
| self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, ks, padding=p) |
|
|
| def forward(self, h, x): |
| hx = torch.cat([h, x], dim=1) |
| z = torch.sigmoid(self.convz(hx)) |
| r = torch.sigmoid(self.convr(hx)) |
| q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) |
| h = (1 - z) * h + z * q |
| return h |
|
|
|
|
| class RayReLU(nn.Module): |
| def __init__(self, eps=1e-2): |
| super(RayReLU, self).__init__() |
| self.eps = eps |
|
|
| def forward(self, pred_norm, ray): |
| |
| cos = torch.cosine_similarity(pred_norm, ray, dim=1).unsqueeze( |
| 1 |
| ) |
|
|
| |
| norm_along_view = ray * cos |
|
|
| |
| norm_along_view_relu = ray * (torch.relu(cos - self.eps) + self.eps) |
|
|
| |
| diff = norm_along_view_relu - norm_along_view |
|
|
| |
| new_pred_norm = pred_norm + diff |
| new_pred_norm = F.normalize(new_pred_norm, dim=1) |
|
|
| return new_pred_norm |
|
|
|
|
| class UpSampleBN(nn.Module): |
| def __init__(self, skip_input, output_features, align_corners=True): |
| super(UpSampleBN, self).__init__() |
| self._net = nn.Sequential( |
| nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1), |
| nn.BatchNorm2d(output_features), |
| nn.LeakyReLU(), |
| nn.Conv2d( |
| output_features, output_features, kernel_size=3, stride=1, padding=1 |
| ), |
| nn.BatchNorm2d(output_features), |
| nn.LeakyReLU(), |
| ) |
| self.align_corners = align_corners |
|
|
| def forward(self, x, concat_with): |
| up_x = F.interpolate( |
| x, |
| size=[concat_with.size(2), concat_with.size(3)], |
| mode="bilinear", |
| align_corners=self.align_corners, |
| ) |
| f = torch.cat([up_x, concat_with], dim=1) |
| return self._net(f) |
|
|
|
|
| class Conv2d_WS(nn.Conv2d): |
| """weight standardization""" |
|
|
| def __init__( |
| self, |
| in_channels, |
| out_channels, |
| kernel_size, |
| stride=1, |
| padding=0, |
| dilation=1, |
| groups=1, |
| bias=True, |
| ): |
| super(Conv2d_WS, self).__init__( |
| in_channels, |
| out_channels, |
| kernel_size, |
| stride, |
| padding, |
| dilation, |
| groups, |
| bias, |
| ) |
|
|
| def forward(self, x): |
| weight = self.weight |
| weight_mean = ( |
| weight.mean(dim=1, keepdim=True) |
| .mean(dim=2, keepdim=True) |
| .mean(dim=3, keepdim=True) |
| ) |
| weight = weight - weight_mean |
| std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 |
| weight = weight / std.expand_as(weight) |
| return F.conv2d( |
| x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups |
| ) |
|
|
|
|
| class UpSampleGN(nn.Module): |
| """UpSample with GroupNorm""" |
|
|
| def __init__(self, skip_input, output_features, align_corners=True): |
| super(UpSampleGN, self).__init__() |
| self._net = nn.Sequential( |
| Conv2d_WS(skip_input, output_features, kernel_size=3, stride=1, padding=1), |
| nn.GroupNorm(8, output_features), |
| nn.LeakyReLU(), |
| Conv2d_WS( |
| output_features, output_features, kernel_size=3, stride=1, padding=1 |
| ), |
| nn.GroupNorm(8, output_features), |
| nn.LeakyReLU(), |
| ) |
| self.align_corners = align_corners |
|
|
| def forward(self, x, concat_with): |
| up_x = F.interpolate( |
| x, |
| size=[concat_with.size(2), concat_with.size(3)], |
| mode="bilinear", |
| align_corners=self.align_corners, |
| ) |
| f = torch.cat([up_x, concat_with], dim=1) |
| return self._net(f) |
|
|
|
|
| def upsample_via_bilinear(out, up_mask, downsample_ratio): |
| """bilinear upsampling (up_mask is a dummy variable)""" |
| return F.interpolate( |
| out, scale_factor=downsample_ratio, mode="bilinear", align_corners=True |
| ) |
|
|
|
|
| def upsample_via_mask(out, up_mask, downsample_ratio): |
| """convex upsampling""" |
| |
| |
| k = downsample_ratio |
|
|
| N, o_dim, H, W = out.shape |
| up_mask = up_mask.view(N, 1, 9, k, k, H, W) |
| up_mask = torch.softmax(up_mask, dim=2) |
|
|
| up_out = F.unfold(out, [3, 3], padding=1) |
| up_out = up_out.view(N, o_dim, 9, 1, 1, H, W) |
| up_out = torch.sum(up_mask * up_out, dim=2) |
|
|
| up_out = up_out.permute(0, 1, 4, 2, 5, 3) |
| return up_out.reshape(N, o_dim, k * H, k * W) |
|
|
|
|
| def convex_upsampling(out, up_mask, k): |
| |
| |
| B, C, H, W = out.shape |
| up_mask = up_mask.view(B, 1, 9, k, k, H, W) |
| up_mask = torch.softmax(up_mask, dim=2) |
|
|
| out = F.pad(out, pad=(1, 1, 1, 1), mode="replicate") |
| up_out = F.unfold(out, [3, 3], padding=0) |
| up_out = up_out.view(B, C, 9, 1, 1, H, W) |
|
|
| up_out = torch.sum(up_mask * up_out, dim=2) |
| up_out = up_out.permute(0, 1, 4, 2, 5, 3) |
| return up_out.reshape(B, C, k * H, k * W) |
|
|
|
|
| def get_unfold(pred_norm, ps, pad): |
| B, C, H, W = pred_norm.shape |
| pred_norm = F.pad( |
| pred_norm, pad=(pad, pad, pad, pad), mode="replicate" |
| ) |
| pred_norm_unfold = F.unfold(pred_norm, [ps, ps], padding=0) |
| pred_norm_unfold = pred_norm_unfold.view(B, C, ps * ps, H, W) |
| return pred_norm_unfold |
|
|
|
|
| def get_prediction_head(input_dim, hidden_dim, output_dim): |
| return nn.Sequential( |
| nn.Conv2d(input_dim, hidden_dim, 3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(hidden_dim, hidden_dim, 1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(hidden_dim, output_dim, 1), |
| ) |
|
|