|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
|
|
|
from models.core.extractor import ResidualBlock |
|
|
|
|
|
autocast = torch.cuda.amp.autocast |
|
|
|
|
|
class ste_fn(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward(ctx, x): |
|
|
return (x > 0).float() |
|
|
@staticmethod |
|
|
def backward(ctx, grad): |
|
|
return F.hardtanh(grad) |
|
|
|
|
|
class STE(nn.Module): |
|
|
def __init__(self): |
|
|
super(STE, self).__init__() |
|
|
def forward(self, x): |
|
|
return ste_fn.apply(x) |
|
|
|
|
|
class sci_encoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
sigma_range=[0, 1e-9], |
|
|
n_frame=8, |
|
|
in_channels=1, |
|
|
n_taps=2, |
|
|
resolution=[480, 640]): |
|
|
|
|
|
super(sci_encoder, self).__init__() |
|
|
|
|
|
assert n_taps in [1, 2], "[ERROR] n_taps should be either 1 or 2." |
|
|
|
|
|
self.sigma_range = sigma_range |
|
|
self.n_frame = n_frame |
|
|
self.in_channels = in_channels |
|
|
self.n_taps = n_taps |
|
|
self.resolution = resolution |
|
|
|
|
|
|
|
|
self.ce_weight = nn.Parameter(torch.Tensor(n_frame, in_channels, *resolution)) |
|
|
|
|
|
|
|
|
nn.init.uniform_(self.ce_weight, a=-1, b=1) |
|
|
|
|
|
self.ste = STE() |
|
|
|
|
|
def forward(self, frames): |
|
|
|
|
|
|
|
|
ce_code = self.ste(self.ce_weight) |
|
|
|
|
|
|
|
|
frames = frames[..., :self.resolution[0], :self.resolution[1]] |
|
|
frames = frames.contiguous() |
|
|
frames = torch.unsqueeze(frames, 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ce_code = ce_code.repeat(frames.shape[0], 1, 1, 1, 1) |
|
|
|
|
|
|
|
|
|
|
|
ce_blur_img = torch.zeros(frames.shape[0], self.in_channels * self.n_taps, *self.resolution).to(frames.device) |
|
|
|
|
|
|
|
|
ce_blur_img[:, 0, ...] = torch.sum( ce_code * frames, axis=1) / self.n_frame |
|
|
ce_blur_img[:, 1, ...] = torch.sum((1. - ce_code) * frames, axis=1) / self.n_frame |
|
|
|
|
|
|
|
|
noise_level = np.random.uniform(*self.sigma_range) |
|
|
ce_blur_img_noisy = ce_blur_img + torch.tensor(noise_level).to(frames.device) * torch.randn(ce_blur_img.shape).to(frames.device) |
|
|
|
|
|
|
|
|
out = torch.zeros(frames.shape[0], self.n_taps + self.n_frame, *self.resolution).to(frames.device) |
|
|
|
|
|
|
|
|
out[:, :self.n_taps, :, :] = ce_blur_img_noisy |
|
|
out[:, self.n_taps:, :, :] = ce_code.squeeze(2) |
|
|
|
|
|
return out |
|
|
|
|
|
class sci_decoder(nn.Module): |
|
|
def __init__(self, |
|
|
n_frame=8, |
|
|
n_taps=2, |
|
|
output_dim=128, |
|
|
norm_fn="batch", |
|
|
dropout=.0): |
|
|
|
|
|
super(sci_decoder, self).__init__() |
|
|
|
|
|
self.norm_fn = norm_fn |
|
|
if norm_fn == "group": |
|
|
self.norm1 = nn.GroupNorm(num_groups=4, num_channels=4*n_frame) |
|
|
elif norm_fn == "batch": |
|
|
self.norm1 = nn.BatchNorm2d(4*n_frame) |
|
|
elif norm_fn == "instance": |
|
|
self.norm1 = nn.InstanceNorm2d(4*n_frame, affine=True) |
|
|
elif norm_fn == "none": |
|
|
self.norm1 = nn.Sequential() |
|
|
|
|
|
|
|
|
|
|
|
self.conv1 = nn.Conv2d(n_taps+n_frame, 4*n_frame, kernel_size=7, stride=2, padding=3) |
|
|
self.relu1 = nn.ReLU(inplace=True) |
|
|
|
|
|
|
|
|
self.layer1 = self._make_layer( 4*n_frame, 4*n_frame, stride=1) |
|
|
self.layer2 = self._make_layer( 4*n_frame, 16*n_frame, stride=2) |
|
|
self.layer3 = self._make_layer(16*n_frame, 64*n_frame, stride=1) |
|
|
|
|
|
|
|
|
self.conv2 = nn.Conv2d(64*n_frame, output_dim*n_frame, kernel_size=1) |
|
|
|
|
|
if dropout > 0.: |
|
|
self.dropout = nn.Dropout2d(p=dropout) |
|
|
else: |
|
|
self.dropout = None |
|
|
|
|
|
|
|
|
|
|
|
for m in self.modules(): |
|
|
if isinstance(m, nn.Conv2d): |
|
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") |
|
|
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): |
|
|
if m.weight is not None: |
|
|
nn.init.constant_(m.weight, 1) |
|
|
if m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
|
|
|
def _make_layer(self, n_ich, n_och, stride=1): |
|
|
layer1 = ResidualBlock(n_ich, n_och, self.norm_fn, stride=stride) |
|
|
layer2 = ResidualBlock(n_och, n_och, self.norm_fn, stride=1) |
|
|
layers = (layer1, layer2) |
|
|
|
|
|
return nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_list = isinstance(x, tuple) or isinstance(x, list) |
|
|
if is_list: |
|
|
batch_dim = x[0].shape[0] |
|
|
x = torch.cat(x, dim=0) |
|
|
|
|
|
|
|
|
|
|
|
x = self.conv1(x) |
|
|
x = self.norm1(x) |
|
|
x = self.relu1(x) |
|
|
|
|
|
x = self.layer1(x) |
|
|
x = self.layer2(x) |
|
|
x = self.layer3(x) |
|
|
|
|
|
x = self.conv2(x) |
|
|
|
|
|
|
|
|
|
|
|
x = x.contiguous() |
|
|
x = x.view(x.shape[0]*8, x.shape[1]//8, x.shape[-2], x.shape[-1]) |
|
|
|
|
|
if self.dropout is not None: |
|
|
x = self.dropout(x) |
|
|
|
|
|
|
|
|
if is_list: |
|
|
x = torch.split(x, x.shape[0] // 2, dim=0) |
|
|
|
|
|
return x |
|
|
|
|
|
|