|
|
import torch |
|
|
import torchvision |
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
from typing import Tuple |
|
|
|
|
|
|
|
|
class PatchRefiner_Updated(nn.Module): |
|
|
""" |
|
|
Patch-Refiner refines the selected patches. |
|
|
|
|
|
Args: |
|
|
mode: area selection mode. Options: |
|
|
"sampling" - Refine fixed amount of pixels ranked by the top most errors. |
|
|
"thresholding" - Refine varying amount of pixels that have greater error than the threshold. |
|
|
sample_pixels: number of pixels to refine. Only used when mode == "sampling". |
|
|
threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == "thresholding". |
|
|
kernel_size: The convolution kernel_size. Options: [1, 3] |
|
|
prevent_oversampling: True for regular cases, False for speedtest. |
|
|
|
|
|
Compatibility Args: |
|
|
patch_crop_method: the method for cropping patches. Options: |
|
|
"unfold" - Best performance for PyTorch and TorchScript. |
|
|
"roi_align" - Another way for croping patches. |
|
|
"gather" - Another way for croping patches. |
|
|
patch_replace_method: the method for replacing patches. Options: |
|
|
"scatter_nd" - Best performance for PyTorch and TorchScript. |
|
|
"scatter_element" - Another way for replacing patches. |
|
|
|
|
|
Input: |
|
|
src: (B, 3, H, W) full resolution source image. |
|
|
pha: (B, 1, Hc, Wc) coarse alpha prediction. |
|
|
fgr: (B, 3, Hc, Wc) coarse foreground residual prediction. |
|
|
err: (B, 1, Hc, Hc) coarse error prediction. |
|
|
hid: (B, 32, Hc, Hc) coarse hidden encoding. |
|
|
|
|
|
Output: |
|
|
pha: (B, 1, H, W) full resolution alpha prediction. |
|
|
fgr: (B, 3, H, W) full resolution foreground residual prediction. |
|
|
""" |
|
|
|
|
|
|
|
|
__constants__ = ['kernel_size', 'patch_crop_method', 'patch_replace_method'] |
|
|
|
|
|
def __init__(self, |
|
|
mode: str, |
|
|
sample_pixels: int, |
|
|
threshold: float, |
|
|
kernel_size: int = 3, |
|
|
prevent_oversampling: bool = True, |
|
|
patch_crop_method: str = 'unfold', |
|
|
patch_replace_method: str = 'scatter_nd'): |
|
|
|
|
|
super().__init__() |
|
|
assert mode in ['full', 'sampling', 'thresholding'] |
|
|
assert kernel_size in [1, 3] |
|
|
assert patch_crop_method in ['unfold', 'roi_align', 'gather'] |
|
|
assert patch_replace_method in ['scatter_nd', 'scatter_element'] |
|
|
|
|
|
self.mode = mode |
|
|
self.sample_pixels = sample_pixels |
|
|
self.threshold = threshold |
|
|
self.kernel_size = kernel_size |
|
|
self.prevent_oversampling = prevent_oversampling |
|
|
self.patch_crop_method = patch_crop_method |
|
|
self.patch_replace_method = patch_replace_method |
|
|
self.patch_size = 16 |
|
|
|
|
|
channels = [16, 16, 12, 8, 4] |
|
|
self.conv1 = nn.Conv2d(channels[0] + 3 + 4, channels[1], kernel_size, bias=False) |
|
|
self.bn1 = nn.BatchNorm2d(channels[1]) |
|
|
self.conv2 = nn.Conv2d(channels[1], channels[2], kernel_size, bias=False) |
|
|
self.bn2 = nn.BatchNorm2d(channels[2]) |
|
|
self.conv3 = nn.Conv2d(channels[2] + 3, channels[3], kernel_size, bias=False) |
|
|
self.bn3 = nn.BatchNorm2d(channels[3]) |
|
|
self.conv4 = nn.Conv2d(channels[3], channels[4], kernel_size, bias=True) |
|
|
self.relu = nn.ReLU(True) |
|
|
|
|
|
|
|
|
def forward(self, |
|
|
src: torch.Tensor, |
|
|
pha: torch.Tensor, |
|
|
fgr: torch.Tensor, |
|
|
hid: torch.Tensor): |
|
|
|
|
|
if src.ndim == 5: |
|
|
return self.forward_time_series(src, pha, fgr, hid) |
|
|
else: |
|
|
return self.forward_single_frame(src, pha, fgr, hid) |
|
|
|
|
|
|
|
|
|
|
|
def forward_time_series(self, fine_src, base_pha, base_fgr, base_hid): |
|
|
B, T = fine_src.shape[:2] |
|
|
pha, fgr = self.forward_single_frame( |
|
|
fine_src.flatten(0, 1), |
|
|
base_pha.flatten(0, 1), |
|
|
base_fgr.flatten(0, 1), |
|
|
base_hid.flatten(0, 1)) |
|
|
fgr = fgr.unflatten(0, (B, T)) |
|
|
pha = pha.unflatten(0, (B, T)) |
|
|
return pha, fgr |
|
|
|
|
|
def forward_single_frame(self, src, pha, fgr, hid): |
|
|
|
|
|
h_orig, w_orig = src.shape[2:] |
|
|
h = self.patch_size*(h_orig//self.patch_size) |
|
|
w = self.patch_size*(w_orig//self.patch_size) |
|
|
|
|
|
src = F.interpolate(src, (h, w), mode='bilinear', align_corners=False) |
|
|
pha = F.interpolate(pha, (h, w), mode='bilinear', align_corners=False) |
|
|
fgr = F.interpolate(fgr, (h, w), mode='bilinear', align_corners=False) |
|
|
|
|
|
H_full, W_full = src.shape[2:] |
|
|
H_half, W_half = H_full // 2, W_full // 2 |
|
|
H_quat, W_quat = H_full // 4, W_full // 4 |
|
|
H_c, W_c = H_full // self.patch_size, W_full // self.patch_size |
|
|
|
|
|
|
|
|
src_bgr = src |
|
|
|
|
|
err = self.compute_err(pha) |
|
|
|
|
|
|
|
|
err = F.interpolate(err, (H_c, W_c), mode='bilinear', align_corners=False) + F.avg_pool2d(err, kernel_size=self.patch_size) |
|
|
ref = self.select_refinement_regions(err) |
|
|
ref.requires_grad=True |
|
|
idx = torch.nonzero(ref.squeeze(1)) |
|
|
idx = idx[:, 0], idx[:, 1], idx[:, 2] |
|
|
|
|
|
if idx[0].size(0) > 0: |
|
|
|
|
|
x = F.interpolate(hid, (H_half, W_half), mode='bilinear', align_corners=False) |
|
|
x = self.crop_patch(x, idx, int(self.patch_size/2), 3 if self.kernel_size == 3 else 0) |
|
|
|
|
|
y = torch.cat([pha, fgr, src_bgr], dim=1) |
|
|
y = F.interpolate(y, (H_half, W_half), mode='bilinear', align_corners=False) |
|
|
y = self.crop_patch(y, idx, int(self.patch_size/2), 3 if self.kernel_size == 3 else 0) |
|
|
|
|
|
x = self.conv1(torch.cat([x, y], dim=1)) |
|
|
x = self.bn1(x) |
|
|
x = self.relu(x) |
|
|
x = self.conv2(x) |
|
|
x = self.bn2(x) |
|
|
x = self.relu(x) |
|
|
|
|
|
x = F.interpolate(x, self.patch_size+4 if self.kernel_size == 3 else 4, mode='nearest') |
|
|
y = self.crop_patch(src_bgr, idx, self.patch_size, 2 if self.kernel_size == 3 else 0) |
|
|
|
|
|
x = self.conv3(torch.cat([x, y], dim=1)) |
|
|
x = self.bn3(x) |
|
|
x = self.relu(x) |
|
|
x = self.conv4(x) |
|
|
|
|
|
out = torch.cat([pha, fgr], dim=1) |
|
|
out = F.interpolate(out, (H_full, W_full), mode='bilinear', align_corners=False) |
|
|
out = self.replace_patch_hr(out, x, idx) |
|
|
|
|
|
|
|
|
out = F.interpolate(out, (h_orig, w_orig), mode='bilinear', align_corners=False) |
|
|
|
|
|
|
|
|
pha = out[:, :1] |
|
|
fgr = out[:, 1:] |
|
|
else: |
|
|
pha = F.interpolate(pha, (h_orig, w_orig), mode='bilinear', align_corners=False) |
|
|
fgr = F.interpolate(fgr, (h_orig, w_orig), mode='bilinear', align_corners=False) |
|
|
|
|
|
return pha, fgr |
|
|
|
|
|
def select_refinement_regions(self, err: torch.Tensor): |
|
|
""" |
|
|
Select refinement regions. |
|
|
Input: |
|
|
err: error map (B, 1, H, W) |
|
|
Output: |
|
|
ref: refinement regions (B, 1, H, W). FloatTensor. 1 is selected, 0 is not. |
|
|
""" |
|
|
if self.mode == 'sampling': |
|
|
|
|
|
b,_, h, w = err.shape |
|
|
err = err.view(b, -1) |
|
|
if self.sample_pixels // self.patch_size > err.shape[-1]: |
|
|
idx = err.topk(err.shape[-1], dim=1, sorted=False).indices |
|
|
else: |
|
|
idx = err.topk(self.sample_pixels // self.patch_size, dim=1, sorted=False).indices |
|
|
ref = torch.zeros_like(err) |
|
|
ref.scatter_(1, idx, 1.) |
|
|
if self.prevent_oversampling: |
|
|
ref.mul_(err.gt(0).float()) |
|
|
ref = ref.view(b, 1, h, w) |
|
|
else: |
|
|
|
|
|
ref = err.gt(self.threshold).float() |
|
|
return ref |
|
|
|
|
|
def crop_patch(self, |
|
|
x: torch.Tensor, |
|
|
idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], |
|
|
size: int, |
|
|
padding: int): |
|
|
""" |
|
|
Crops selected patches from image given indices. |
|
|
|
|
|
Inputs: |
|
|
x: image (B, C, H, W). |
|
|
idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index. |
|
|
size: center size of the patch, also stride of the crop. |
|
|
padding: expansion size of the patch. |
|
|
Output: |
|
|
patch: (P, C, h, w), where h = w = size + 2 * padding. |
|
|
""" |
|
|
if padding != 0: |
|
|
x = F.pad(x, (padding,) * 4) |
|
|
|
|
|
if self.patch_crop_method == 'unfold': |
|
|
|
|
|
return x.permute(0, 2, 3, 1) \ |
|
|
.unfold(1, size + 2 * padding, size) \ |
|
|
.unfold(2, size + 2 * padding, size)[idx[0], idx[1], idx[2]] |
|
|
elif self.patch_crop_method == 'roi_align': |
|
|
|
|
|
idx = idx[0].type_as(x), idx[1].type_as(x), idx[2].type_as(x), idx[2].type_as(x) |
|
|
b = idx[0] |
|
|
x1 = idx[2] * size - 0.5 |
|
|
y1 = idx[1] * size - 0.5 |
|
|
x2 = idx[2] * size + size + 2 * padding - 0.5 |
|
|
y2 = idx[1] * size + size + 2 * padding - 0.5 |
|
|
boxes = torch.stack([b, x1, y1, x2, y2], dim=1) |
|
|
return torchvision.ops.roi_align(x, boxes, size + 2 * padding, sampling_ratio=1) |
|
|
else: |
|
|
|
|
|
idx_pix = self.compute_pixel_indices(x, idx, size, padding) |
|
|
pat = torch.gather(x.view(-1), 0, idx_pix.view(-1)) |
|
|
pat = pat.view(-1, x.size(1), size + 2 * padding, size + 2 * padding) |
|
|
return pat |
|
|
|
|
|
def replace_patch(self, |
|
|
x: torch.Tensor, |
|
|
y: torch.Tensor, |
|
|
idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): |
|
|
""" |
|
|
Replaces patches back into image given index. |
|
|
|
|
|
Inputs: |
|
|
x: image (B, C, H, W) |
|
|
y: patches (P, C, h, w) |
|
|
idx: selection indices Tuple[(P,), (P,), (P,)] where the 3 values are (B, H, W) index. |
|
|
|
|
|
Output: |
|
|
image: (B, C, H, W), where patches at idx locations are replaced with y. |
|
|
""" |
|
|
xB, xC, xH, xW = x.shape |
|
|
yB, yC, yH, yW = y.shape |
|
|
if self.patch_replace_method == 'scatter_nd': |
|
|
|
|
|
x = x.view(xB, xC, xH // yH, yH, xW // yW, yW).permute(0, 2, 4, 1, 3, 5) |
|
|
x[idx[0], idx[1], idx[2]] = y |
|
|
x = x.permute(0, 3, 1, 4, 2, 5).view(xB, xC, xH, xW) |
|
|
return x |
|
|
else: |
|
|
|
|
|
idx_pix = self.compute_pixel_indices(x, idx, size=4, padding=0) |
|
|
return x.view(-1).scatter_(0, idx_pix.view(-1), y.view(-1)).view(x.shape) |
|
|
|
|
|
def replace_patch_hr(self, |
|
|
x: torch.Tensor, |
|
|
y: torch.Tensor, |
|
|
idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): |
|
|
""" |
|
|
Replaces patches back into image given index. |
|
|
|
|
|
Inputs: |
|
|
x: image (B, C, H, W) |
|
|
y: patches (P, C, h, w) |
|
|
idx: selection indices Tuple[(P,), (P,), (P,)] where the 3 values are (B, H, W) index. |
|
|
|
|
|
Output: |
|
|
image: (B, C, H, W), where patches at idx locations are replaced with y. |
|
|
""" |
|
|
xB, xC, xH, xW = x.shape |
|
|
yB, yC, yH, yW = y.shape |
|
|
if self.patch_replace_method == 'scatter_nd': |
|
|
|
|
|
x = x.view(xB, xC, xH // yH, yH, xW // yW, yW).permute(0, 2, 4, 1, 3, 5) |
|
|
if x.dtype==torch.float16: |
|
|
x[idx[0], idx[1], idx[2]] = y |
|
|
elif x.dtype==torch.float32: |
|
|
x[idx[0], idx[1], idx[2]] = y.to(torch.float32) |
|
|
x = x.permute(0, 3, 1, 4, 2, 5).view(xB, xC, xH, xW) |
|
|
return x |
|
|
else: |
|
|
|
|
|
idx_pix = self.compute_pixel_indices(x, idx, size=4, padding=0) |
|
|
if x.dtype==torch.float16: |
|
|
return x.view(-1).scatter_(0, idx_pix.view(-1), y.view(-1)).view(x.shape) |
|
|
elif x.dtype==torch.float32: |
|
|
return x.view(-1).scatter_(0, idx_pix.view(-1), (y.to(torch.float32)).view(-1)).view(x.shape) |
|
|
|
|
|
|
|
|
def compute_pixel_indices(self, |
|
|
x: torch.Tensor, |
|
|
idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], |
|
|
size: int, |
|
|
padding: int): |
|
|
""" |
|
|
Compute selected pixel indices in the tensor. |
|
|
Used for crop_method == 'gather' and replace_method == 'scatter_element', which crop and replace pixel by pixel. |
|
|
Input: |
|
|
x: image: (B, C, H, W) |
|
|
idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index. |
|
|
size: center size of the patch, also stride of the crop. |
|
|
padding: expansion size of the patch. |
|
|
Output: |
|
|
idx: (P, C, O, O) long tensor where O is the output size: size + 2 * padding, P is number of patches. |
|
|
the element are indices pointing to the input x.view(-1). |
|
|
""" |
|
|
B, C, H, W = x.shape |
|
|
S, P = size, padding |
|
|
O = S + 2 * P |
|
|
b, y, x = idx |
|
|
n = b.size(0) |
|
|
c = torch.arange(C).cuda() |
|
|
o = torch.arange(O).cuda() |
|
|
idx_pat = (c * H * W).view(C, 1, 1).expand([C, O, O]) + (o * W).view(1, O, 1).expand([C, O, O]) + o.view(1, 1, O).expand([C, O, O]) |
|
|
idx_loc = b * W * H + y * W * S + x * S |
|
|
idx_pix = idx_loc.view(-1, 1, 1, 1).expand([n, C, O, O]) + idx_pat.view(1, C, O, O).expand([n, C, O, O]) |
|
|
return idx_pix |
|
|
|
|
|
def compute_err(self, |
|
|
x: torch.Tensor): |
|
|
|
|
|
eps = torch.tensor(1e-6) |
|
|
ent = -(x*torch.log10(x+eps))-((torch.tensor(1)-x)*torch.log10(torch.tensor(1)-x+eps)) |
|
|
ent = ent/torch.log10(torch.tensor(2)) |
|
|
|
|
|
ent = ent.clamp(0., 1.) |
|
|
|
|
|
return ent |