File size: 6,155 Bytes
2c76547 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from models.core.extractor import ResidualBlock
autocast = torch.cuda.amp.autocast
class ste_fn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return (x > 0).float()
@staticmethod
def backward(ctx, grad):
return F.hardtanh(grad)
class STE(nn.Module):
def __init__(self):
super(STE, self).__init__()
def forward(self, x):
return ste_fn.apply(x)
class sci_encoder(nn.Module):
def __init__(
self,
sigma_range=[0, 1e-9],
n_frame=8,
in_channels=1,
n_taps=2,
resolution=[480, 640]):
super(sci_encoder, self).__init__()
assert n_taps in [1, 2], "[ERROR] n_taps should be either 1 or 2."
self.sigma_range = sigma_range
self.n_frame = n_frame
self.in_channels = in_channels
self.n_taps = n_taps
self.resolution = resolution
# -- Shutter code; Learnable parameters
self.ce_weight = nn.Parameter(torch.Tensor(n_frame, in_channels, *resolution))
# -- initialize
nn.init.uniform_(self.ce_weight, a=-1, b=1)
self.ste = STE()
def forward(self, frames):
# -- print ("[INFO] self.ce_weight.device: ", self.ce_weight.device)
ce_code = self.ste(self.ce_weight)
# -- print ("[INFO] ce_code.device: ", ce_code.device)
frames = frames[..., :self.resolution[0], :self.resolution[1]]
frames = frames.contiguous()
frames = torch.unsqueeze(frames, 2)
# -- print ("[INFO] ce_code.shape: ", ce_code.shape)
# -- print ("[INFO] frames.shape: ", frames.shape)
# -- repeat by the batch size
ce_code = ce_code.repeat(frames.shape[0], 1, 1, 1, 1)
# -- print ("[INFO] ce_code.shape: ", ce_code.shape)
# -- print ("[INFO] ce_code.squeeze(2).shape: ", ce_code.squeeze(2).shape)
ce_blur_img = torch.zeros(frames.shape[0], self.in_channels * self.n_taps, *self.resolution).to(frames.device) # -- (b, c, h, w)
# -- print ("[INFO] ce_blur_img.shape: ", ce_blur_img.shape)
ce_blur_img[:, 0, ...] = torch.sum( ce_code * frames, axis=1) / self.n_frame
ce_blur_img[:, 1, ...] = torch.sum((1. - ce_code) * frames, axis=1) / self.n_frame
# -- add noise
noise_level = np.random.uniform(*self.sigma_range)
ce_blur_img_noisy = ce_blur_img + torch.tensor(noise_level).to(frames.device) * torch.randn(ce_blur_img.shape).to(frames.device)
# -- concat snapshots and mask patterns
out = torch.zeros(frames.shape[0], self.n_taps + self.n_frame, *self.resolution).to(frames.device)
# -- print ("[INFO] out.shape: ", out.shape)
out[:, :self.n_taps, :, :] = ce_blur_img_noisy
out[:, self.n_taps:, :, :] = ce_code.squeeze(2)
return out
class sci_decoder(nn.Module):
def __init__(self,
n_frame=8,
n_taps=2,
output_dim=128,
norm_fn="batch",
dropout=.0):
super(sci_decoder, self).__init__()
self.norm_fn = norm_fn
if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=4, num_channels=4*n_frame)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(4*n_frame)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(4*n_frame, affine=True)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
# -- Input Convoultion
# -- Assuming n_frame=8; n_ich=10; n_och=32
self.conv1 = nn.Conv2d(n_taps+n_frame, 4*n_frame, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
# -- Residual Blocks
self.layer1 = self._make_layer( 4*n_frame, 4*n_frame, stride=1)
self.layer2 = self._make_layer( 4*n_frame, 16*n_frame, stride=2)
self.layer3 = self._make_layer(16*n_frame, 64*n_frame, stride=1)
# -- Output Convolution
self.conv2 = nn.Conv2d(64*n_frame, output_dim*n_frame, kernel_size=1)
if dropout > 0.:
self.dropout = nn.Dropout2d(p=dropout)
else:
self.dropout = None
# -- self.modules() is a PyTorch utility function that returns all submodules of this nn.Module recursively.
# -- This means it will looop through every layer: conv1, layer1, layer2, layer3, conv2 and so on.
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
# -- Private function to make residual blocks
def _make_layer(self, n_ich, n_och, stride=1):
layer1 = ResidualBlock(n_ich, n_och, self.norm_fn, stride=stride)
layer2 = ResidualBlock(n_och, n_och, self.norm_fn, stride=1)
layers = (layer1, layer2)
return nn.Sequential(*layers)
def forward(self, x):
# -- x = [L, R]
# -- L, R ~ (b, c, h, w); c=n_taps+n_frame
# -- if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
# -- print ("[INFO] x.shape: ", x.shape)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
# -- expand the temporal dimension
# -- (b, c, h, w) -> (b*t, c//t, h, w)
x = x.contiguous()
x = x.view(x.shape[0]*8, x.shape[1]//8, x.shape[-2], x.shape[-1])
if self.dropout is not None:
x = self.dropout(x)
# -- if input is list, split the first dimension
if is_list:
x = torch.split(x, x.shape[0] // 2, dim=0)
return x
|