Spaces:
Running on Zero
Running on Zero
| import torch | |
| import torch.nn.functional as F | |
| # from spatial_correlation_sampler import SpatialCorrelationSampler | |
| class UpSample(torch.nn.Module): | |
| def __init__(self, feature_dim, upsample_factor): | |
| super(UpSample, self).__init__() | |
| self.upsample_factor = upsample_factor | |
| self.conv1 = torch.nn.Conv2d(2 + feature_dim, 256, 3, 1, 1) | |
| self.conv2 = torch.nn.Conv2d(256, 512, 3, 1, 1) | |
| self.conv3 = torch.nn.Conv2d(512, upsample_factor ** 2 * 9, 1, 1, 0) | |
| self.relu = torch.nn.ReLU(inplace=True) | |
| def forward(self, feature, flow): | |
| concat = torch.cat((flow, feature), dim=1) | |
| mask = self.conv3(self.relu(self.conv2(self.relu(self.conv1(concat))))) | |
| b, _, h, w = flow.shape | |
| mask = mask.view(b, 1, 9, self.upsample_factor, self.upsample_factor, h, w) # [B, 1, 9, K, K, H, W] | |
| mask = torch.softmax(mask, dim=2) | |
| # up_flow = F.unfold(self.upsample_factor * flow, [3, 3], padding=1) | |
| up_flow = F.unfold(flow, [3, 3], padding=1) | |
| up_flow = up_flow.view(b, 2, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W] | |
| up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W] | |
| up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W] | |
| up_flow = up_flow.reshape(b, 2, self.upsample_factor * h, | |
| self.upsample_factor * w) # [B, 2, K*H, K*W] | |
| return up_flow |