Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import models | |
| from models import register | |
| class DIINN(nn.Module): | |
| def __init__(self, | |
| encoder_spec, | |
| mode=3, init_q=False): | |
| super().__init__() | |
| self.encoder = models.make(encoder_spec) | |
| self.decoder = ImplicitDecoder(mode=mode, init_q=init_q) | |
| def forward(self, x, size, bsize=None): | |
| x = self.encoder(x) | |
| x = self.decoder(x, size, bsize) | |
| return x | |
| class SineAct(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, x): | |
| return torch.sin(x) | |
| def patch_norm_2d(x, kernel_size=3): | |
| # B, C, H, W = x.shape | |
| # var, mean = torch.var_mean(F.unfold(x, kernel_size=kernel_size, padding=padding).view(B, C,kernel_size**2, H, W), dim=2, keepdim=False) | |
| # return (x - mean) / torch.sqrt(var + 1e-6) | |
| mean = F.avg_pool2d(x, kernel_size=kernel_size, padding=kernel_size // 2) | |
| mean_sq = F.avg_pool2d(x ** 2, kernel_size=kernel_size, padding=kernel_size // 2) | |
| var = mean_sq - mean ** 2 | |
| return (x - mean) / (var + 1e-6) | |
| class ImplicitDecoder(nn.Module): | |
| def __init__(self, in_channels=64, hidden_dims=[256, 256, 256, 256], mode=3, init_q=False): | |
| super().__init__() | |
| self.mode = mode | |
| self.init_q = init_q | |
| last_dim_K = in_channels * 9 | |
| if self.init_q: | |
| self.first_layer = nn.Sequential(nn.Conv2d(3, in_channels * 9, 1), | |
| SineAct()) | |
| last_dim_Q = in_channels * 9 | |
| else: | |
| last_dim_Q = 3 | |
| self.K = nn.ModuleList() | |
| self.Q = nn.ModuleList() | |
| if self.mode == 1: | |
| for hidden_dim in hidden_dims: | |
| self.K.append(nn.Sequential(nn.Conv2d(last_dim_K, hidden_dim, 1), | |
| nn.ReLU())) | |
| self.Q.append(nn.Sequential(nn.Conv2d(last_dim_Q, hidden_dim, 1), | |
| SineAct())) | |
| last_dim_K = hidden_dim | |
| last_dim_Q = hidden_dim | |
| elif self.mode == 2: | |
| for hidden_dim in hidden_dims: | |
| self.K.append(nn.Sequential(nn.Conv2d(last_dim_K, hidden_dim, 1), | |
| nn.ReLU())) | |
| self.Q.append(nn.Sequential(nn.Conv2d(last_dim_Q, hidden_dim, 1), | |
| SineAct())) | |
| last_dim_K = hidden_dim + in_channels * 9 | |
| last_dim_Q = hidden_dim | |
| elif self.mode == 3: | |
| for hidden_dim in hidden_dims: | |
| self.K.append(nn.Sequential(nn.Conv2d(last_dim_K, hidden_dim, 1), | |
| nn.ReLU())) | |
| self.Q.append(nn.Sequential(nn.Conv2d(last_dim_Q, hidden_dim, 1), | |
| SineAct())) | |
| last_dim_K = hidden_dim + in_channels * 9 | |
| last_dim_Q = hidden_dim | |
| elif self.mode == 4: | |
| for hidden_dim in hidden_dims: | |
| self.K.append(nn.Sequential(nn.Conv2d(last_dim_K, hidden_dim, 1), | |
| nn.ReLU())) | |
| self.Q.append(nn.Sequential(nn.Conv2d(last_dim_Q, hidden_dim, 1), | |
| SineAct())) | |
| last_dim_K = hidden_dim + in_channels * 9 | |
| last_dim_Q = hidden_dim | |
| if self.mode == 4: | |
| self.last_layer = nn.Conv2d(hidden_dims[-1], 3, 3, padding=1, padding_mode='reflect') | |
| else: | |
| self.last_layer = nn.Conv2d(hidden_dims[-1], 3, 1) | |
| def _make_pos_encoding(self, x, size): | |
| B, C, H, W = x.shape | |
| H_up, W_up = size | |
| h_idx = -1 + 1 / H + 2 / H * torch.arange(H, device=x.device).float() | |
| w_idx = -1 + 1 / W + 2 / W * torch.arange(W, device=x.device).float() | |
| in_grid = torch.stack(torch.meshgrid(h_idx, w_idx, indexing='ij'), dim=0) | |
| h_idx_up = -1 + 1 / H_up + 2 / H_up * torch.arange(H_up, device=x.device).float() | |
| w_idx_up = -1 + 1 / W_up + 2 / W_up * torch.arange(W_up, device=x.device).float() | |
| up_grid = torch.stack(torch.meshgrid(h_idx_up, w_idx_up, indexing='ij'), dim=0) | |
| rel_grid = (up_grid - F.interpolate(in_grid.unsqueeze(0), size=(H_up, W_up), | |
| mode='nearest-exact')) # important! mode='nearest' gives inconsistent results | |
| rel_grid[:, 0, :, :] *= H | |
| rel_grid[:, 1, :, :] *= W | |
| return rel_grid.contiguous().detach() | |
| def step(self, x, syn_inp): | |
| if self.init_q: | |
| syn_inp = self.first_layer(syn_inp) | |
| x = syn_inp * x | |
| if self.mode == 1: | |
| k = self.K[0](x) | |
| q = k * self.Q[0](syn_inp) | |
| for i in range(1, len(self.K)): | |
| k = self.K[i](k) | |
| q = k * self.Q[i](q) | |
| q = self.last_layer(q) | |
| return q | |
| elif self.mode == 2: | |
| k = self.K[0](x) | |
| q = k * self.Q[0](syn_inp) | |
| for i in range(1, len(self.K)): | |
| k = self.K[i](torch.cat([k, x], dim=1)) | |
| q = k * self.Q[i](q) | |
| q = self.last_layer(q) | |
| return q | |
| elif self.mode == 3: | |
| k = self.K[0](x) | |
| q = k * self.Q[0](syn_inp) | |
| # q = k + self.Q[0](syn_inp) | |
| for i in range(1, len(self.K)): | |
| k = self.K[i](torch.cat([q, x], dim=1)) | |
| q = k * self.Q[i](q) | |
| # q = k + self.Q[i](q) | |
| q = self.last_layer(q) | |
| return q | |
| elif self.mode == 4: | |
| k = self.K[0](x) | |
| q = k * self.Q[0](syn_inp) | |
| for i in range(1, len(self.K)): | |
| k = self.K[i](torch.cat([q, x], dim=1)) | |
| q = k * self.Q[i](q) | |
| q = self.last_layer(q) | |
| return q | |
| def batched_step(self, x, syn_inp, bsize): | |
| with torch.no_grad(): | |
| h, w = syn_inp.shape[-2:] | |
| ql = 0 | |
| preds = [] | |
| while ql < w: | |
| qr = min(ql + bsize // h, w) | |
| pred = self.step(x[:, :, :, ql: qr], syn_inp[:, :, :, ql: qr]) | |
| preds.append(pred) | |
| ql = qr | |
| pred = torch.cat(preds, dim=-1) | |
| return pred | |
| def forward(self, x, size, bsize=None): | |
| B, C, H_in, W_in = x.shape | |
| rel_coord = self._make_pos_encoding(x, size).expand(B, -1, *size) # 2 | |
| ratio = x.new_tensor([(H_in * W_in) / (size[0] * size[1])]).view(1, -1, 1, 1).expand(B, -1, *size) # 2 | |
| syn_inp = torch.cat([rel_coord, ratio], dim=1) | |
| x = F.interpolate(F.unfold(x, 3, padding=1).view(B, C * 9, H_in, W_in), size=syn_inp.shape[-2:], | |
| mode='nearest-exact') | |
| if bsize is None: | |
| pred = self.step(x, syn_inp) | |
| else: | |
| pred = self.batched_step(x, syn_inp, bsize) | |
| return pred | |