sam3-zerogpu / NeuFlow /upsample.py
magboola's picture
deploy sam3-zerogpu
e529923 verified
import torch
import torch.nn.functional as F
# from spatial_correlation_sampler import SpatialCorrelationSampler
class UpSample(torch.nn.Module):
def __init__(self, feature_dim, upsample_factor):
super(UpSample, self).__init__()
self.upsample_factor = upsample_factor
self.conv1 = torch.nn.Conv2d(2 + feature_dim, 256, 3, 1, 1)
self.conv2 = torch.nn.Conv2d(256, 512, 3, 1, 1)
self.conv3 = torch.nn.Conv2d(512, upsample_factor ** 2 * 9, 1, 1, 0)
self.relu = torch.nn.ReLU(inplace=True)
def forward(self, feature, flow):
concat = torch.cat((flow, feature), dim=1)
mask = self.conv3(self.relu(self.conv2(self.relu(self.conv1(concat)))))
b, _, h, w = flow.shape
mask = mask.view(b, 1, 9, self.upsample_factor, self.upsample_factor, h, w) # [B, 1, 9, K, K, H, W]
mask = torch.softmax(mask, dim=2)
# up_flow = F.unfold(self.upsample_factor * flow, [3, 3], padding=1)
up_flow = F.unfold(flow, [3, 3], padding=1)
up_flow = up_flow.view(b, 2, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W]
up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W]
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W]
up_flow = up_flow.reshape(b, 2, self.upsample_factor * h,
self.upsample_factor * w) # [B, 2, K*H, K*W]
return up_flow