| from __future__ import division |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.nn.init as init |
| import torch.utils.model_zoo as model_zoo |
| from torchvision import models |
| from torchvision import transforms |
| import cv2 |
| import matplotlib.pyplot as plt |
| from PIL import Image |
| import numpy as np |
| import math |
| import time |
| import tqdm |
| import os |
| import argparse |
| import copy |
| import sys |
| import networks as N |
| from model_module import * |
| sys.path.insert(0, '.') |
| |
| sys.path.insert(0, '../utils/') |
|
|
|
|
| class LiteISPNet(nn.Module): |
| def __init__(self,): |
| super(LiteISPNet, self).__init__() |
|
|
| ch_1 = 64 |
| ch_2 = 128 |
| ch_3 = 128 |
| n_blocks = 4 |
|
|
|
|
| self.head = N.seq( |
| N.conv(3, ch_1, mode='C') |
| ) |
|
|
| self.down1 = N.seq( |
| N.conv(ch_1, ch_1, mode='C'), |
| N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks), |
| N.conv(ch_1, ch_1, mode='C'), |
| N.DWTForward(ch_1) |
| ) |
|
|
| self.down2 = N.seq( |
| N.conv(ch_1*4, ch_1, mode='C'), |
| N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks), |
| N.DWTForward(ch_1) |
| ) |
|
|
| self.down3 = N.seq( |
| N.conv(ch_1*4, ch_2, mode='C'), |
| N.RCAGroup(in_channels=ch_2, out_channels=ch_2, nb=n_blocks), |
| N.DWTForward(ch_2) |
| ) |
|
|
| self.middle = N.seq( |
| N.conv(ch_2*4, ch_3, mode='C'), |
| N.RCAGroup(in_channels=ch_3, out_channels=ch_3, nb=n_blocks), |
| N.RCAGroup(in_channels=ch_3, out_channels=ch_3, nb=n_blocks), |
| N.conv(ch_3, ch_2*4, mode='C') |
| ) |
|
|
| self.up3 = N.seq( |
| N.DWTInverse(ch_2*4), |
| N.RCAGroup(in_channels=ch_2, out_channels=ch_2, nb=n_blocks), |
| N.conv(ch_2, ch_1*4, mode='C') |
| ) |
|
|
| self.up2 = N.seq( |
| N.DWTInverse(ch_1*4), |
| N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks), |
| N.conv(ch_1, ch_1*4, mode='C') |
| ) |
|
|
| self.up1 = N.seq( |
| N.DWTInverse(ch_1*4), |
| N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks), |
| N.conv(ch_1, ch_1, mode='C') |
| ) |
|
|
| self.tail = N.seq( |
| |
| |
| N.conv(ch_1, 3, mode='C') |
| ) |
|
|
| def forward(self, raw): |
| |
| input = torch.pow(raw, 1/2.2) |
|
|
| h = self.head(input) |
| h_coord = h |
|
|
| d1 = self.down1(h_coord) |
| d2 = self.down2(d1) |
| d3 = self.down3(d2) |
| m = self.middle(d3) + d3 |
| u3 = self.up3(m) + d2 |
| u2 = self.up2(u3) + d1 |
| u1 = self.up1(u2) + h |
| out = self.tail(u1) |
|
|
| return out |
|
|
|
|
| class LiteAWBISPNet(nn.Module): |
| def __init__(self,): |
| super(LiteAWBISPNet, self).__init__() |
|
|
| ch_1 = 64 |
| ch_2 = 128 |
| ch_3 = 128 |
| n_blocks = 4 |
|
|
|
|
| self.head = N.seq( |
| N.conv(3, ch_1, mode='C') |
| ) |
|
|
| self.down1 = N.seq( |
| N.conv(ch_1, ch_1, mode='C'), |
| N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks), |
| N.conv(ch_1, ch_1, mode='C'), |
| N.DWTForward(ch_1) |
| ) |
|
|
| self.down2 = N.seq( |
| N.conv(ch_1*4, ch_1, mode='C'), |
| N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks), |
| N.DWTForward(ch_1) |
| ) |
|
|
| self.down3 = N.seq( |
| N.conv(ch_1*4, ch_2, mode='C'), |
| N.RCAGroup(in_channels=ch_2, out_channels=ch_2, nb=n_blocks), |
| N.DWTForward(ch_2) |
| ) |
|
|
| self.middle = N.seq( |
| N.conv(ch_2*4, ch_3, mode='C'), |
| N.RCAGroup(in_channels=ch_3, out_channels=ch_3, nb=n_blocks), |
| N.RCAGroup(in_channels=ch_3, out_channels=ch_3, nb=n_blocks), |
| N.conv(ch_3, ch_2*4, mode='C') |
| ) |
|
|
| self.up3 = N.seq( |
| N.DWTInverse(ch_2*4), |
| N.RCAGroup(in_channels=ch_2, out_channels=ch_2, nb=n_blocks), |
| N.conv(ch_2, ch_1*4, mode='C') |
| ) |
|
|
| self.up2 = N.seq( |
| N.DWTInverse(ch_1*4), |
| N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks), |
| N.conv(ch_1, ch_1*4, mode='C') |
| ) |
|
|
| self.up1 = N.seq( |
| N.DWTInverse(ch_1*4), |
| N.RCAGroup(in_channels=ch_1, out_channels=ch_1, nb=n_blocks), |
| N.conv(ch_1, ch_1, mode='C') |
| ) |
|
|
| self.tail = N.seq( |
| |
| |
| N.conv(ch_1, 3, mode='C') |
| ) |
|
|
| def forward(self, raw): |
| |
|
|
| input = raw |
| h = self.head(input) |
| h_coord = h |
|
|
| d1 = self.down1(h_coord) |
| d2 = self.down2(d1) |
| d3 = self.down3(d2) |
| m = self.middle(d3) + d3 |
| u3 = self.up3(m) + d2 |
| u2 = self.up2(u3) + d1 |
| u1 = self.up1(u2) + h |
| out = self.tail(u1) |
|
|
| return out |
| |
| |
| |
| class A_Encoder(nn.Module): |
| def __init__(self): |
| super(A_Encoder, self).__init__() |
| self.conv12 = Conv2d(3, 64, kernel_size=5, stride=2, padding=2, activation=nn.ReLU()) |
| self.conv2 = Conv2d(64, 64, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
| self.conv23 = Conv2d(64, 128, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) |
| self.conv3 = Conv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
| self.conv34 = Conv2d(128, 256, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) |
| self.conv4a = Conv2d(256, 256, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
| self.conv4b = Conv2d(256, 256, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
| init_He(self) |
| self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1)) |
| self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1)) |
| |
| def forward(self, in_f): |
| f = (in_f - self.mean) / self.std |
| x = f |
| x = F.upsample(x, size=(224, 224), mode='bilinear', align_corners=False) |
| x = self.conv12(x) |
| x = self.conv2(x) |
| x = self.conv23(x) |
| x = self.conv3(x) |
| x = self.conv34(x) |
| x = self.conv4a(x) |
| x = self.conv4b(x) |
| return x |
|
|
| |
| class A_Regressor(nn.Module): |
| def __init__(self): |
| super(A_Regressor, self).__init__() |
| self.conv45 = Conv2d(512, 512, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) |
| self.conv5a = Conv2d(512, 512, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
| self.conv5b = Conv2d(512, 512, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
| self.conv56 = Conv2d(512, 512, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) |
| self.conv6a = Conv2d(512, 512, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
| self.conv6b = Conv2d(512, 512, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
| init_He(self) |
| |
| self.fc = nn.Linear(512, 6) |
| self.fc.weight.data.zero_() |
| self.fc.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float32)) |
|
|
| def forward(self, feat1, feat2): |
| x = torch.cat([feat1, feat2], dim=1) |
| x = self.conv45(x) |
| x = self.conv5a(x) |
| x = self.conv5b(x) |
| x = self.conv56(x) |
| x = self.conv5a(x) |
| x = self.conv5b(x) |
|
|
| x = F.avg_pool2d(x, x.shape[2]) |
| x = x.view(-1, x.shape[1]) |
|
|
| theta = self.fc(x) |
| theta = theta.view(-1, 2, 3) |
|
|
| return theta |
|
|
| |
| class Encoder(nn.Module): |
| def __init__(self): |
| super(Encoder, self).__init__() |
| self.conv12 = Conv2d(4, 64, kernel_size=5, stride=2, padding=2, activation=nn.ReLU()) |
| self.conv2 = Conv2d(64, 64, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
| self.conv23 = Conv2d(64, 128, kernel_size=3, stride=2, padding=1, activation=nn.ReLU()) |
| self.conv3 = Conv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
| self.value3 = Conv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=None) |
| init_He(self) |
| self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1)) |
| self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1)) |
| |
| def forward(self, in_f, in_v): |
| f = (in_f - self.mean) / self.std |
| x = torch.cat([f, in_v], dim=1) |
| x = self.conv12(x) |
| x = self.conv2(x) |
| x = self.conv23(x) |
| x = self.conv3(x) |
| v = self.value3(x) |
| return v |
|
|
| |
| class Decoder(nn.Module): |
| def __init__(self): |
| super(Decoder, self).__init__() |
| self.conv4 = Conv2d(257, 257, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
| self.conv5_1 = Conv2d(257, 257, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
| self.conv5_2 = Conv2d(257, 257, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
| |
| |
| self.convA4_1 = Conv2d(257, 257, kernel_size=3, stride=1, padding=2, D=2, activation=nn.ReLU()) |
| self.convA4_2 = Conv2d(257, 257, kernel_size=3, stride=1, padding=4, D=4, activation=nn.ReLU()) |
| self.convA4_3 = Conv2d(257, 257, kernel_size=3, stride=1, padding=8, D=8, activation=nn.ReLU()) |
| self.convA4_4 = Conv2d(257, 257, kernel_size=3, stride=1, padding=16, D=16,activation=nn.ReLU()) |
|
|
| self.conv3c = Conv2d(257, 257, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
| self.conv3b = Conv2d(257, 128, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
| self.conv3a = Conv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
| self.conv32 = Conv2d(128, 64, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
| self.conv2 = Conv2d(64, 64, kernel_size=3, stride=1, padding=1, activation=nn.ReLU()) |
| self.conv21 = Conv2d(64, 3, kernel_size=5, stride=1, padding=2, activation=None) |
|
|
| self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1)) |
| self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1)) |
| |
| def forward(self, x): |
| x = self.conv4(x) |
| x = self.conv5_1(x) |
| x = self.conv5_2(x) |
|
|
| x = self.convA4_1(x) |
| x = self.convA4_2(x) |
| x = self.convA4_3(x) |
| x = self.convA4_4(x) |
|
|
| x = self.conv3c(x) |
| x = self.conv3b(x) |
| x = self.conv3a(x) |
| x = F.upsample(x, scale_factor=2, mode='nearest') |
| x = self.conv32(x) |
| x = self.conv2(x) |
| x = F.upsample(x, scale_factor=2, mode='nearest') |
| x = self.conv21(x) |
|
|
| p = (x *self.std) + self.mean |
| return p |
|
|
|
|
| |
| class CM_Module(nn.Module): |
| def __init__(self): |
| super(CM_Module, self).__init__() |
| |
| def masked_softmax(self, vec, mask, dim): |
| masked_vec = vec * mask.float() |
| max_vec = torch.max(masked_vec, dim=dim, keepdim=True)[0] |
| exps = torch.exp(masked_vec-max_vec) |
| masked_exps = exps * mask.float() |
| masked_sums = masked_exps.sum(dim, keepdim=True) |
| zeros = (masked_sums <1e-4) |
| masked_sums += zeros.float() |
| return masked_exps/masked_sums |
|
|
| def forward(self, values, tvmap, rvmaps): |
|
|
| B, C, T, H, W = values.size() |
| |
| t_feat = values[:, :, 0] |
| |
| r_feats = values[:, :, 1:] |
| |
| B, Cv, T, H, W = r_feats.size() |
| |
| |
| |
| |
| |
| gs_,vmap_ = [], [] |
| tvmap_t = (F.upsample(tvmap, size=(H, W), mode='bilinear', align_corners=False)>0.5).float() |
| for r in range(T): |
| rvmap_t = (F.upsample(rvmaps[:,:,r], size=(H, W), mode='bilinear', align_corners=False)>0.5).float() |
| |
| vmap = tvmap_t * rvmap_t |
| gs = (vmap * t_feat * r_feats[:,:,r]).sum(-1).sum(-1).sum(-1) |
| |
| v_sum = vmap[:,0].sum(-1).sum(-1) |
| zeros = (v_sum <1e-4) |
| gs[zeros] = 0 |
| v_sum += zeros.float() |
| gs = gs / v_sum / C |
| gs = torch.ones(t_feat.shape).float().cuda() * gs.view(B,1,1,1) |
| gs_.append(gs) |
| vmap_.append(rvmap_t) |
|
|
| gss = torch.stack(gs_, dim=2) |
| vmaps = torch.stack(vmap_, dim=2) |
| |
| |
| c_match = self.masked_softmax(gss, vmaps, dim=2) |
| c_out = torch.sum(r_feats * c_match, dim=2) |
|
|
| |
| c_mask = (c_match * vmaps) |
| c_mask = torch.sum(c_mask,2) |
| c_mask = 1. - (torch.mean(c_mask, 1, keepdim=True)) |
|
|
| return torch.cat([t_feat, c_out, c_mask], dim=1), c_mask |
|
|
|
|
| class GCMModel(nn.Module): |
| def __init__(self): |
| super(GCMModel, self).__init__() |
| self.ch_1 = 16 |
| self.ch_2 = 32 |
| guide_input_channels = 3 |
| align_input_channels = 3 |
| self.gcm_coord = None |
|
|
| if not self.gcm_coord: |
| guide_input_channels = 3 |
| align_input_channels = 3 |
|
|
| self.guide_net = N.seq( |
| N.conv(guide_input_channels, self.ch_1, 7, stride=2, padding=0, mode='CR'), |
| N.conv(self.ch_1, self.ch_1, kernel_size=3, stride=1, padding=1, mode='CRC'), |
| nn.AdaptiveAvgPool2d(1), |
| N.conv(self.ch_1, self.ch_2, 1, stride=1, padding=0, mode='C') |
| ) |
|
|
| self.align_head = N.conv(align_input_channels, self.ch_2, 1, padding=0, mode='CR') |
|
|
| self.align_base = N.seq( |
| N.conv(self.ch_2, self.ch_2, kernel_size=1, stride=1, padding=0, mode='CRCRCRCRCR') |
| ) |
| self.align_tail = N.seq( |
| N.conv(self.ch_2, 3, 1, padding=0, mode='C') |
| ) |
|
|
| def forward(self, demosaic_raw): |
| demosaic_raw = torch.pow(demosaic_raw, 1 / 2.2) |
| guide_input = demosaic_raw |
| base_input =demosaic_raw |
| guide = self.guide_net(guide_input) |
| out = self.align_head(base_input) |
| out = guide * out + out |
| out = self.align_base(out) |
| out = self.align_tail(out)+demosaic_raw |
|
|
| return out |
|
|
| class Fusion(nn.Module): |
| def __init__(self): |
| super(Fusion, self).__init__() |
| self.ch_1 = 16 |
| self.ch_2 = 32 |
| guide_input_channels = 9 |
| align_input_channels = 9 |
| self.gcm_coord = None |
|
|
| if not self.gcm_coord: |
| guide_input_channels = 9 |
| align_input_channels = 9 |
|
|
| self.guide_net = N.seq( |
| N.conv(guide_input_channels, self.ch_1, 7, stride=2, padding=0, mode='CR'), |
| N.conv(self.ch_1, self.ch_1, kernel_size=3, stride=1, padding=1, mode='CRC'), |
| nn.AdaptiveAvgPool2d(1), |
| N.conv(self.ch_1, self.ch_2, 1, stride=1, padding=0, mode='C') |
| ) |
|
|
| self.align_head = N.conv(align_input_channels, self.ch_2, 1, padding=0, mode='CR') |
|
|
| self.align_base = N.seq( |
| N.conv(self.ch_2, self.ch_2, kernel_size=1, stride=1, padding=0, mode='CRCRCR') |
| ) |
| self.align_tail = N.seq( |
| N.conv(self.ch_2, 3, 1, padding=0, mode='C') |
| ) |
|
|
| def forward(self, demosaic_raw): |
| |
| guide_input = demosaic_raw |
| base_input =demosaic_raw |
| guide = self.guide_net(guide_input) |
| out = self.align_head(base_input) |
| out = guide * out + out |
| out = self.align_base(out) |
| out = self.align_tail(out) |
|
|
| return out |
|
|
|
|
|
|
|
|
| class CPNet(nn.Module): |
| def __init__(self, mode='Train'): |
| super(CPNet, self).__init__() |
| self.A_Encoder = A_Encoder() |
| self.A_Regressor = A_Regressor() |
| self.GCMModel = GCMModel() |
| self.Encoder = Encoder() |
| self.CM_Module = CM_Module() |
|
|
| self.Decoder = Decoder() |
| |
| self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1)) |
| self.register_buffer('mean3d', torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1,1)) |
|
|
| |
| def encoding(self, frames, holes): |
|
|
| batch_size, _, num_frames, height, width = frames.size() |
| |
| (frames, holes), pad = pad_divide_by([frames, holes], 8, (frames.size()[3], frames.size()[4])) |
| |
| feat_ = [] |
| for t in range(num_frames): |
| feat = self.A_Encoder(frames[:,:,t], holes[:,:,t]) |
| feat_.append(feat) |
| feats = torch.stack(feat_, dim=2) |
| return feats |
|
|
| def inpainting(self, rfeats, rframes, rholes, frame, hole, gt): |
|
|
| batch_size, _, height, width = frame.size() |
| num_r = rfeats.size()[2] |
|
|
| |
| (rframes, rholes, frame, hole, gt), pad = pad_divide_by([rframes, rholes, frame, hole, gt], 8, (height, width)) |
| |
| |
| tfeat = self.A_Encoder(frame, hole) |
| |
| |
| c_feat_ = [self.Encoder(frame, hole)] |
| L_align = torch.zeros_like(frame) |
| |
| |
| aligned_r_ = [] |
|
|
| |
| rvmap_ = [] |
| |
| for r in range(num_r): |
| theta_rt = self.A_Regressor(tfeat, rfeats[:,:,r]) |
| grid_rt = F.affine_grid(theta_rt, frame.size()) |
|
|
| |
| |
| aligned_r = F.grid_sample(rframes[:,:,r], grid_rt) |
| |
| |
| |
| aligned_v = F.grid_sample(1-rholes[:,:,r], grid_rt) |
| aligned_v = (aligned_v>0.5).float() |
|
|
| aligned_r_.append(aligned_r) |
|
|
| |
| trvmap = (1-hole) * aligned_v |
| |
| |
| c_feat_.append(self.Encoder(aligned_r, aligned_v)) |
| |
| rvmap_.append(aligned_v) |
|
|
| aligned_rs = torch.stack(aligned_r_, 2) |
| |
| c_feats =torch.stack(c_feat_, dim=2) |
| rvmaps = torch.stack(rvmap_, dim=2) |
|
|
| |
| p_in, c_mask = self.CM_Module(c_feats, 1-hole, rvmaps) |
| |
| pred = self.Decoder(p_in) |
| |
| _, _, _, H, W = aligned_rs.shape |
| c_mask = (F.upsample(c_mask, size=(H, W), mode='bilinear', align_corners=False)).detach() |
|
|
| comp = pred * (hole) + gt * (1.-hole) |
|
|
|
|
| if pad[2]+pad[3] > 0: |
| comp = comp[:,:,pad[2]:-pad[3],:] |
|
|
| if pad[0]+pad[1] > 0: |
| comp = comp[:,:,:,pad[0]:-pad[1]] |
| |
| comp = torch.clamp(comp, 0, 1) |
|
|
| return comp |
|
|
| def forward(self, Source, Target): |
|
|
| feat_target =self.A_Encoder(Target) |
| feat_source = self.A_Encoder(Source) |
|
|
| theta = self.A_Regressor(feat_target,feat_source) |
| grid_rt = F.affine_grid(theta, Target.size()) |
| aligned = F.grid_sample(Source, grid_rt) |
| mask = torch.ones_like(Source) |
| mask = F.grid_sample(mask,grid_rt) |
|
|
| return aligned,mask |
|
|
|
|
| class AC(nn.Module): |
| def __init__(self): |
| super(AC, self).__init__() |
| self.ch_1 = 32 |
| self.ch_2 = 64 |
| guide_input_channels = 8 |
| align_input_channels = 5 |
| self.gcm_coord = None |
|
|
| if not self.gcm_coord: |
| guide_input_channels = 6 |
| align_input_channels = 3 |
|
|
| self.guide_net = N.seq( |
| N.conv(guide_input_channels, self.ch_1, 7, stride=2, padding=0, mode='CR'), |
| N.conv(self.ch_1, self.ch_1, kernel_size=3, stride=1, padding=1, mode='CRC'), |
| nn.AdaptiveAvgPool2d(1), |
| N.conv(self.ch_1, self.ch_2, 1, stride=1, padding=0, mode='C') |
| ) |
|
|
| self.align_head = N.conv(align_input_channels, self.ch_2, 1, padding=0, mode='CR') |
|
|
| self.align_base = N.seq( |
| N.conv(self.ch_2, self.ch_2, kernel_size=1, stride=1, padding=0, mode='CRCRCR') |
| ) |
| self.align_tail = N.seq( |
| N.conv(self.ch_2, 3, 1, padding=0, mode='C') |
| ) |
|
|
| def forward(self, demosaic_raw, dslr, coord=None): |
| demosaic_raw = demosaic_raw+0.01*torch.ones_like(demosaic_raw ) |
| demosaic_raw = torch.pow(demosaic_raw, 1 / 2.2) |
| demosaic_raw = demosaic_raw/2 |
| if self.gcm_coord: |
| guide_input = torch.cat((demosaic_raw, dslr, coord), 1) |
| base_input = torch.cat((demosaic_raw, coord), 1) |
| else: |
| guide_input = torch.cat((demosaic_raw, dslr), 1) |
| base_input = demosaic_raw |
|
|
| guide = self.guide_net(guide_input) |
|
|
| out = self.align_head(base_input) |
| out = guide * out + out |
| out = self.align_base(out) |
| out = self.align_tail(out) +demosaic_raw |
|
|
| return out |