| | from typing import Any, List, Callable |
| | import cv2 |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.nn.utils.spectral_norm as SpectralNorm |
| | import threading |
| | from torchvision.ops import roi_align |
| |
|
| | from math import sqrt |
| |
|
| | from torchvision.transforms.functional import normalize |
| |
|
| | from roop.typing import Face, Frame, FaceSet |
| |
|
| |
|
| | THREAD_LOCK_DMDNET = threading.Lock() |
| |
|
| |
|
| | class Enhance_DMDNet(): |
| |
|
| | model_dmdnet = None |
| | torchdevice = None |
| |
|
| | processorname = 'dmdnet' |
| | type = 'enhance' |
| |
|
| |
|
| | def Initialize(self, devicename): |
| | if self.model_dmdnet is None: |
| | self.model_dmdnet = self.create(devicename) |
| | |
| |
|
| | |
| | def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame: |
| | input_size = temp_frame.shape[1] |
| |
|
| | result = self.enhance_face(source_faceset, temp_frame, target_face) |
| | scale_factor = int(result.shape[1] / input_size) |
| | return result.astype(np.uint8), scale_factor |
| |
|
| |
|
| | def Release(self): |
| | self.model_gfpgan = None |
| |
|
| |
|
| | |
| | def landmarks106_to_68(self, pt106): |
| | map106to68=[1,10,12,14,16,3,5,7,0,23,21,19,32,30,28,26,17, |
| | 43,48,49,51,50, |
| | 102,103,104,105,101, |
| | 72,73,74,86,78,79,80,85,84, |
| | 35,41,42,39,37,36, |
| | 89,95,96,93,91,90, |
| | 52,64,63,71,67,68,61,58,59,53,56,55,65,66,62,70,69,57,60,54 |
| | ] |
| | |
| | pt68 = [] |
| | for i in range(68): |
| | index = map106to68[i] |
| | pt68.append(pt106[index]) |
| | return pt68 |
| |
|
| | |
| |
|
| |
|
| | def check_bbox(self, imgs, boxes): |
| | boxes = boxes.view(-1, 4, 4) |
| | colors = [(0, 255, 0), (0, 255, 0), (255, 255, 0), (255, 0, 0)] |
| | i = 0 |
| | for img, box in zip(imgs, boxes): |
| | img = (img + 1)/2 * 255 |
| | img2 = img.permute(1, 2, 0).float().cpu().flip(2).numpy().copy() |
| | for idx, point in enumerate(box): |
| | cv2.rectangle(img2, (int(point[0]), int(point[1])), (int(point[2]), int(point[3])), color=colors[idx], thickness=2) |
| | cv2.imwrite('dmdnet_{:02d}.png'.format(i), img2) |
| | i += 1 |
| |
|
| |
|
| | def trans_points2d(self, pts, M): |
| | new_pts = np.zeros(shape=pts.shape, dtype=np.float32) |
| | for i in range(pts.shape[0]): |
| | pt = pts[i] |
| | new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32) |
| | new_pt = np.dot(M, new_pt) |
| | new_pts[i] = new_pt[0:2] |
| |
|
| | return new_pts |
| |
|
| |
|
| | def enhance_face(self, ref_faceset: FaceSet, temp_frame, face: Face): |
| | |
| | start_x, start_y, end_x, end_y = map(int, face['bbox']) |
| | lm106 = face.landmark_2d_106 |
| | lq_landmarks = np.asarray(self.landmarks106_to_68(lm106)) |
| |
|
| | if temp_frame.shape[0] != 512 or temp_frame.shape[1] != 512: |
| | |
| | scale_factor = 512 / temp_frame.shape[1] |
| |
|
| | M = face.matrix * scale_factor |
| |
|
| | lq_landmarks = self.trans_points2d(lq_landmarks, M) |
| | temp_frame = cv2.resize(temp_frame, (512,512), interpolation = cv2.INTER_AREA) |
| |
|
| | if temp_frame.ndim == 2: |
| | temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) |
| | |
| | |
| |
|
| | lq = read_img_tensor(temp_frame) |
| |
|
| | LQLocs = get_component_location(lq_landmarks) |
| | |
| |
|
| | |
| | if len(ref_faceset.faces) > 1: |
| | SpecificImgs = [] |
| | SpecificLocs = [] |
| | for i,face in enumerate(ref_faceset.faces): |
| | lm106 = face.landmark_2d_106 |
| | lq_landmarks = np.asarray(self.landmarks106_to_68(lm106)) |
| | ref_image = ref_faceset.ref_images[i] |
| | if ref_image.shape[0] != 512 or ref_image.shape[1] != 512: |
| | |
| | scale_factor = 512 / ref_image.shape[1] |
| |
|
| | M = face.matrix * scale_factor |
| |
|
| | lq_landmarks = self.trans_points2d(lq_landmarks, M) |
| | ref_image = cv2.resize(ref_image, (512,512), interpolation = cv2.INTER_AREA) |
| |
|
| | if ref_image.ndim == 2: |
| | temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) |
| | |
| | |
| |
|
| | ref_tensor = read_img_tensor(ref_image) |
| | ref_locs = get_component_location(lq_landmarks) |
| | |
| |
|
| | SpecificImgs.append(ref_tensor) |
| | SpecificLocs.append(ref_locs.unsqueeze(0)) |
| |
|
| | SpecificImgs = torch.cat(SpecificImgs, dim=0) |
| | SpecificLocs = torch.cat(SpecificLocs, dim=0) |
| | |
| | SpMem256, SpMem128, SpMem64 = self.model_dmdnet.generate_specific_dictionary(sp_imgs = SpecificImgs.to(self.torchdevice), sp_locs = SpecificLocs) |
| | SpMem256Para = {} |
| | SpMem128Para = {} |
| | SpMem64Para = {} |
| | for k, v in SpMem256.items(): |
| | SpMem256Para[k] = v |
| | for k, v in SpMem128.items(): |
| | SpMem128Para[k] = v |
| | for k, v in SpMem64.items(): |
| | SpMem64Para[k] = v |
| | else: |
| | |
| | SpMem256Para, SpMem128Para, SpMem64Para = None, None, None |
| |
|
| | with torch.no_grad(): |
| | with THREAD_LOCK_DMDNET: |
| | try: |
| | GenericResult, SpecificResult = self.model_dmdnet(lq = lq.to(self.torchdevice), loc = LQLocs.unsqueeze(0), sp_256 = SpMem256Para, sp_128 = SpMem128Para, sp_64 = SpMem64Para) |
| | except Exception as e: |
| | print(f'Error {e} there may be something wrong with the detected component locations.') |
| | return temp_frame |
| | |
| | if SpecificResult is not None: |
| | save_specific = SpecificResult * 0.5 + 0.5 |
| | save_specific = save_specific.squeeze(0).permute(1, 2, 0).flip(2) |
| | save_specific = np.clip(save_specific.float().cpu().numpy(), 0, 1) * 255.0 |
| | temp_frame = save_specific.astype("uint8") |
| | if False: |
| | save_generic = GenericResult * 0.5 + 0.5 |
| | save_generic = save_generic.squeeze(0).permute(1, 2, 0).flip(2) |
| | save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0 |
| | check_lq = lq * 0.5 + 0.5 |
| | check_lq = check_lq.squeeze(0).permute(1, 2, 0).flip(2) |
| | check_lq = np.clip(check_lq.float().cpu().numpy(), 0, 1) * 255.0 |
| | cv2.imwrite('dmdnet_comparison.png', cv2.cvtColor(np.hstack((check_lq, save_generic, save_specific)),cv2.COLOR_RGB2BGR)) |
| | else: |
| | save_generic = GenericResult * 0.5 + 0.5 |
| | save_generic = save_generic.squeeze(0).permute(1, 2, 0).flip(2) |
| | save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0 |
| | temp_frame = save_generic.astype("uint8") |
| | temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_RGB2BGR) |
| | return temp_frame |
| |
|
| | |
| |
|
| | def create(self, devicename): |
| | self.torchdevice = torch.device(devicename) |
| | model_dmdnet = DMDNet().to(self.torchdevice) |
| | weights = torch.load('./models/DMDNet.pth') |
| | model_dmdnet.load_state_dict(weights, strict=True) |
| |
|
| | model_dmdnet.eval() |
| | num_params = 0 |
| | for param in model_dmdnet.parameters(): |
| | num_params += param.numel() |
| | return model_dmdnet |
| |
|
| | |
| | |
| |
|
| |
|
| |
|
| | def read_img_tensor(Img=None): |
| | Img = Img.transpose((2, 0, 1))/255.0 |
| | Img = torch.from_numpy(Img).float() |
| | normalize(Img, [0.5,0.5,0.5], [0.5,0.5,0.5], inplace=True) |
| | ImgTensor = Img.unsqueeze(0) |
| | return ImgTensor |
| |
|
| |
|
| | def get_component_location(Landmarks, re_read=False): |
| | if re_read: |
| | ReadLandmark = [] |
| | with open(Landmarks,'r') as f: |
| | for line in f: |
| | tmp = [float(i) for i in line.split(' ') if i != '\n'] |
| | ReadLandmark.append(tmp) |
| | ReadLandmark = np.array(ReadLandmark) |
| | Landmarks = np.reshape(ReadLandmark, [-1, 2]) |
| | Map_LE_B = list(np.hstack((range(17,22), range(36,42)))) |
| | Map_RE_B = list(np.hstack((range(22,27), range(42,48)))) |
| | Map_LE = list(range(36,42)) |
| | Map_RE = list(range(42,48)) |
| | Map_NO = list(range(29,36)) |
| | Map_MO = list(range(48,68)) |
| |
|
| | Landmarks[Landmarks>504]=504 |
| | Landmarks[Landmarks<8]=8 |
| | |
| | |
| | Mean_LE = np.mean(Landmarks[Map_LE],0) |
| | L_LE1 = Mean_LE[1] - np.min(Landmarks[Map_LE_B,1]) |
| | L_LE1 = L_LE1 * 1.3 |
| | L_LE2 = L_LE1 / 1.9 |
| | L_LE_xy = L_LE1 + L_LE2 |
| | L_LE_lt = [L_LE_xy/2, L_LE1] |
| | L_LE_rb = [L_LE_xy/2, L_LE2] |
| | Location_LE = np.hstack((Mean_LE - L_LE_lt + 1, Mean_LE + L_LE_rb)).astype(int) |
| |
|
| | |
| | Mean_RE = np.mean(Landmarks[Map_RE],0) |
| | L_RE1 = Mean_RE[1] - np.min(Landmarks[Map_RE_B,1]) |
| | L_RE1 = L_RE1 * 1.3 |
| | L_RE2 = L_RE1 / 1.9 |
| | L_RE_xy = L_RE1 + L_RE2 |
| | L_RE_lt = [L_RE_xy/2, L_RE1] |
| | L_RE_rb = [L_RE_xy/2, L_RE2] |
| | Location_RE = np.hstack((Mean_RE - L_RE_lt + 1, Mean_RE + L_RE_rb)).astype(int) |
| |
|
| | |
| | Mean_NO = np.mean(Landmarks[Map_NO],0) |
| | L_NO1 =( np.max([Mean_NO[0] - Landmarks[31][0], Landmarks[35][0] - Mean_NO[0]])) * 1.25 |
| | L_NO2 = (Landmarks[33][1] - Mean_NO[1]) * 1.1 |
| | L_NO_xy = L_NO1 * 2 |
| | L_NO_lt = [L_NO_xy/2, L_NO_xy - L_NO2] |
| | L_NO_rb = [L_NO_xy/2, L_NO2] |
| | Location_NO = np.hstack((Mean_NO - L_NO_lt + 1, Mean_NO + L_NO_rb)).astype(int) |
| | |
| | |
| | Mean_MO = np.mean(Landmarks[Map_MO],0) |
| | L_MO = np.max((np.max(np.max(Landmarks[Map_MO],0) - np.min(Landmarks[Map_MO],0))/2,16)) * 1.1 |
| | MO_O = Mean_MO - L_MO + 1 |
| | MO_T = Mean_MO + L_MO |
| | MO_T[MO_T>510]=510 |
| | Location_MO = np.hstack((MO_O, MO_T)).astype(int) |
| | return torch.cat([torch.FloatTensor(Location_LE).unsqueeze(0), torch.FloatTensor(Location_RE).unsqueeze(0), torch.FloatTensor(Location_NO).unsqueeze(0), torch.FloatTensor(Location_MO).unsqueeze(0)], dim=0) |
| |
|
| |
|
| |
|
| |
|
| | def calc_mean_std_4D(feat, eps=1e-5): |
| | |
| | size = feat.size() |
| | assert (len(size) == 4) |
| | N, C = size[:2] |
| | feat_var = feat.view(N, C, -1).var(dim=2) + eps |
| | feat_std = feat_var.sqrt().view(N, C, 1, 1) |
| | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) |
| | return feat_mean, feat_std |
| |
|
| | def adaptive_instance_normalization_4D(content_feat, style_feat): |
| | size = content_feat.size() |
| | style_mean, style_std = calc_mean_std_4D(style_feat) |
| | content_mean, content_std = calc_mean_std_4D(content_feat) |
| | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) |
| | return normalized_feat * style_std.expand(size) + style_mean.expand(size) |
| |
|
| |
|
| | def convU(in_channels, out_channels,conv_layer, norm_layer, kernel_size=3, stride=1,dilation=1, bias=True): |
| | return nn.Sequential( |
| | SpectralNorm(conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)), |
| | nn.LeakyReLU(0.2), |
| | SpectralNorm(conv_layer(out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)), |
| | ) |
| | |
| |
|
| | class MSDilateBlock(nn.Module): |
| | def __init__(self, in_channels,conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, kernel_size=3, dilation=[1,1,1,1], bias=True): |
| | super(MSDilateBlock, self).__init__() |
| | self.conv1 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[0], bias=bias) |
| | self.conv2 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[1], bias=bias) |
| | self.conv3 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[2], bias=bias) |
| | self.conv4 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[3], bias=bias) |
| | self.convi = SpectralNorm(conv_layer(in_channels*4, in_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size-1)//2, bias=bias)) |
| | def forward(self, x): |
| | conv1 = self.conv1(x) |
| | conv2 = self.conv2(x) |
| | conv3 = self.conv3(x) |
| | conv4 = self.conv4(x) |
| | cat = torch.cat([conv1, conv2, conv3, conv4], 1) |
| | out = self.convi(cat) + x |
| | return out |
| |
|
| |
|
| | class AdaptiveInstanceNorm(nn.Module): |
| | def __init__(self, in_channel): |
| | super().__init__() |
| | self.norm = nn.InstanceNorm2d(in_channel) |
| |
|
| | def forward(self, input, style): |
| | style_mean, style_std = calc_mean_std_4D(style) |
| | out = self.norm(input) |
| | size = input.size() |
| | out = style_std.expand(size) * out + style_mean.expand(size) |
| | return out |
| |
|
| | class NoiseInjection(nn.Module): |
| | def __init__(self, channel): |
| | super().__init__() |
| | self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1)) |
| | def forward(self, image, noise): |
| | if noise is None: |
| | b, c, h, w = image.shape |
| | noise = image.new_empty(b, 1, h, w).normal_() |
| | return image + self.weight * noise |
| |
|
| | class StyledUpBlock(nn.Module): |
| | def __init__(self, in_channel, out_channel, kernel_size=3, padding=1,upsample=False, noise_inject=False): |
| | super().__init__() |
| |
|
| | self.noise_inject = noise_inject |
| | if upsample: |
| | self.conv1 = nn.Sequential( |
| | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), |
| | SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)), |
| | nn.LeakyReLU(0.2), |
| | ) |
| | else: |
| | self.conv1 = nn.Sequential( |
| | SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)), |
| | nn.LeakyReLU(0.2), |
| | SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)), |
| | ) |
| | self.convup = nn.Sequential( |
| | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), |
| | SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)), |
| | nn.LeakyReLU(0.2), |
| | SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)), |
| | ) |
| | if self.noise_inject: |
| | self.noise1 = NoiseInjection(out_channel) |
| |
|
| | self.lrelu1 = nn.LeakyReLU(0.2) |
| |
|
| | self.ScaleModel1 = nn.Sequential( |
| | SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)), |
| | nn.LeakyReLU(0.2), |
| | SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)) |
| | ) |
| | self.ShiftModel1 = nn.Sequential( |
| | SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)), |
| | nn.LeakyReLU(0.2), |
| | SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), |
| | ) |
| | |
| | def forward(self, input, style): |
| | out = self.conv1(input) |
| | out = self.lrelu1(out) |
| | Shift1 = self.ShiftModel1(style) |
| | Scale1 = self.ScaleModel1(style) |
| | out = out * Scale1 + Shift1 |
| | if self.noise_inject: |
| | out = self.noise1(out, noise=None) |
| | outup = self.convup(out) |
| | return outup |
| |
|
| |
|
| | |
| | |
| | |
| | def AttentionBlock(in_channel): |
| | return nn.Sequential( |
| | SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), |
| | nn.LeakyReLU(0.2), |
| | SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), |
| | ) |
| |
|
| | class DilateResBlock(nn.Module): |
| | def __init__(self, dim, dilation=[5,3] ): |
| | super(DilateResBlock, self).__init__() |
| | self.Res = nn.Sequential( |
| | SpectralNorm(nn.Conv2d(dim, dim, 3, 1, ((3-1)//2)*dilation[0], dilation[0])), |
| | nn.LeakyReLU(0.2), |
| | SpectralNorm(nn.Conv2d(dim, dim, 3, 1, ((3-1)//2)*dilation[1], dilation[1])), |
| | ) |
| | def forward(self, x): |
| | out = x + self.Res(x) |
| | return out |
| |
|
| |
|
| | class KeyValue(nn.Module): |
| | def __init__(self, indim, keydim, valdim): |
| | super(KeyValue, self).__init__() |
| | self.Key = nn.Sequential( |
| | SpectralNorm(nn.Conv2d(indim, keydim, kernel_size=(3,3), padding=(1,1), stride=1)), |
| | nn.LeakyReLU(0.2), |
| | SpectralNorm(nn.Conv2d(keydim, keydim, kernel_size=(3,3), padding=(1,1), stride=1)), |
| | ) |
| | self.Value = nn.Sequential( |
| | SpectralNorm(nn.Conv2d(indim, valdim, kernel_size=(3,3), padding=(1,1), stride=1)), |
| | nn.LeakyReLU(0.2), |
| | SpectralNorm(nn.Conv2d(valdim, valdim, kernel_size=(3,3), padding=(1,1), stride=1)), |
| | ) |
| | def forward(self, x): |
| | return self.Key(x), self.Value(x) |
| |
|
| | class MaskAttention(nn.Module): |
| | def __init__(self, indim): |
| | super(MaskAttention, self).__init__() |
| | self.conv1 = nn.Sequential( |
| | SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)), |
| | nn.LeakyReLU(0.2), |
| | SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)), |
| | ) |
| | self.conv2 = nn.Sequential( |
| | SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)), |
| | nn.LeakyReLU(0.2), |
| | SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)), |
| | ) |
| | self.conv3 = nn.Sequential( |
| | SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)), |
| | nn.LeakyReLU(0.2), |
| | SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)), |
| | ) |
| | self.convCat = nn.Sequential( |
| | SpectralNorm(nn.Conv2d(indim//3 * 3, indim, kernel_size=(3,3), padding=(1,1), stride=1)), |
| | nn.LeakyReLU(0.2), |
| | SpectralNorm(nn.Conv2d(indim, indim, kernel_size=(3,3), padding=(1,1), stride=1)), |
| | ) |
| | def forward(self, x, y, z): |
| | c1 = self.conv1(x) |
| | c2 = self.conv2(y) |
| | c3 = self.conv3(z) |
| | return self.convCat(torch.cat([c1,c2,c3], dim=1)) |
| |
|
| | class Query(nn.Module): |
| | def __init__(self, indim, quedim): |
| | super(Query, self).__init__() |
| | self.Query = nn.Sequential( |
| | SpectralNorm(nn.Conv2d(indim, quedim, kernel_size=(3,3), padding=(1,1), stride=1)), |
| | nn.LeakyReLU(0.2), |
| | SpectralNorm(nn.Conv2d(quedim, quedim, kernel_size=(3,3), padding=(1,1), stride=1)), |
| | ) |
| | def forward(self, x): |
| | return self.Query(x) |
| |
|
| | def roi_align_self(input, location, target_size): |
| | test = (target_size.item(),target_size.item()) |
| | return torch.cat([F.interpolate(input[i:i+1,:,location[i,1]:location[i,3],location[i,0]:location[i,2]],test,mode='bilinear',align_corners=False) for i in range(input.size(0))],0) |
| |
|
| | class FeatureExtractor(nn.Module): |
| | def __init__(self, ngf = 64, key_scale = 4): |
| | super().__init__() |
| |
|
| | self.key_scale = 4 |
| | self.part_sizes = np.array([80,80,50,110]) |
| | self.feature_sizes = np.array([256,128,64]) |
| |
|
| | self.conv1 = nn.Sequential( |
| | SpectralNorm(nn.Conv2d(3, ngf, 3, 2, 1)), |
| | nn.LeakyReLU(0.2), |
| | SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)), |
| | ) |
| | self.conv2 = nn.Sequential( |
| | SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)), |
| | nn.LeakyReLU(0.2), |
| | SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)) |
| | ) |
| | self.res1 = DilateResBlock(ngf, [5,3]) |
| | self.res2 = DilateResBlock(ngf, [5,3]) |
| |
|
| | |
| | self.conv3 = nn.Sequential( |
| | SpectralNorm(nn.Conv2d(ngf, ngf*2, 3, 2, 1)), |
| | nn.LeakyReLU(0.2), |
| | SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1)), |
| | ) |
| | self.conv4 = nn.Sequential( |
| | SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1)), |
| | nn.LeakyReLU(0.2), |
| | SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1)) |
| | ) |
| | self.res3 = DilateResBlock(ngf*2, [3,1]) |
| | self.res4 = DilateResBlock(ngf*2, [3,1]) |
| |
|
| | self.conv5 = nn.Sequential( |
| | SpectralNorm(nn.Conv2d(ngf*2, ngf*4, 3, 2, 1)), |
| | nn.LeakyReLU(0.2), |
| | SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1)), |
| | ) |
| | self.conv6 = nn.Sequential( |
| | SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1)), |
| | nn.LeakyReLU(0.2), |
| | SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1)) |
| | ) |
| | self.res5 = DilateResBlock(ngf*4, [1,1]) |
| | self.res6 = DilateResBlock(ngf*4, [1,1]) |
| |
|
| | self.LE_256_Q = Query(ngf, ngf // self.key_scale) |
| | self.RE_256_Q = Query(ngf, ngf // self.key_scale) |
| | self.MO_256_Q = Query(ngf, ngf // self.key_scale) |
| | self.LE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale) |
| | self.RE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale) |
| | self.MO_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale) |
| | self.LE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale) |
| | self.RE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale) |
| | self.MO_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale) |
| |
|
| |
|
| | def forward(self, img, locs): |
| | le_location = locs[:,0,:].int().cpu().numpy() |
| | re_location = locs[:,1,:].int().cpu().numpy() |
| | no_location = locs[:,2,:].int().cpu().numpy() |
| | mo_location = locs[:,3,:].int().cpu().numpy() |
| | |
| |
|
| | f1_0 = self.conv1(img) |
| | f1_1 = self.res1(f1_0) |
| | f2_0 = self.conv2(f1_1) |
| | f2_1 = self.res2(f2_0) |
| |
|
| | f3_0 = self.conv3(f2_1) |
| | f3_1 = self.res3(f3_0) |
| | f4_0 = self.conv4(f3_1) |
| | f4_1 = self.res4(f4_0) |
| |
|
| | f5_0 = self.conv5(f4_1) |
| | f5_1 = self.res5(f5_0) |
| | f6_0 = self.conv6(f5_1) |
| | f6_1 = self.res6(f6_0) |
| |
|
| |
|
| | |
| | le_part_256 = roi_align_self(f2_1.clone(), le_location//2, self.part_sizes[0]//2) |
| | re_part_256 = roi_align_self(f2_1.clone(), re_location//2, self.part_sizes[1]//2) |
| | mo_part_256 = roi_align_self(f2_1.clone(), mo_location//2, self.part_sizes[3]//2) |
| |
|
| | le_part_128 = roi_align_self(f4_1.clone(), le_location//4, self.part_sizes[0]//4) |
| | re_part_128 = roi_align_self(f4_1.clone(), re_location//4, self.part_sizes[1]//4) |
| | mo_part_128 = roi_align_self(f4_1.clone(), mo_location//4, self.part_sizes[3]//4) |
| |
|
| | le_part_64 = roi_align_self(f6_1.clone(), le_location//8, self.part_sizes[0]//8) |
| | re_part_64 = roi_align_self(f6_1.clone(), re_location//8, self.part_sizes[1]//8) |
| | mo_part_64 = roi_align_self(f6_1.clone(), mo_location//8, self.part_sizes[3]//8) |
| |
|
| |
|
| | le_256_q = self.LE_256_Q(le_part_256) |
| | re_256_q = self.RE_256_Q(re_part_256) |
| | mo_256_q = self.MO_256_Q(mo_part_256) |
| |
|
| | le_128_q = self.LE_128_Q(le_part_128) |
| | re_128_q = self.RE_128_Q(re_part_128) |
| | mo_128_q = self.MO_128_Q(mo_part_128) |
| |
|
| | le_64_q = self.LE_64_Q(le_part_64) |
| | re_64_q = self.RE_64_Q(re_part_64) |
| | mo_64_q = self.MO_64_Q(mo_part_64) |
| |
|
| | return {'f256': f2_1, 'f128': f4_1, 'f64': f6_1,\ |
| | 'le256': le_part_256, 're256': re_part_256, 'mo256': mo_part_256, \ |
| | 'le128': le_part_128, 're128': re_part_128, 'mo128': mo_part_128, \ |
| | 'le64': le_part_64, 're64': re_part_64, 'mo64': mo_part_64, \ |
| | 'le_256_q': le_256_q, 're_256_q': re_256_q, 'mo_256_q': mo_256_q,\ |
| | 'le_128_q': le_128_q, 're_128_q': re_128_q, 'mo_128_q': mo_128_q,\ |
| | 'le_64_q': le_64_q, 're_64_q': re_64_q, 'mo_64_q': mo_64_q} |
| |
|
| |
|
| | class DMDNet(nn.Module): |
| | def __init__(self, ngf = 64, banks_num = 128): |
| | super().__init__() |
| | self.part_sizes = np.array([80,80,50,110]) |
| | self.feature_sizes = np.array([256,128,64]) |
| |
|
| | self.banks_num = banks_num |
| | self.key_scale = 4 |
| |
|
| | self.E_lq = FeatureExtractor(key_scale = self.key_scale) |
| | self.E_hq = FeatureExtractor(key_scale = self.key_scale) |
| |
|
| | self.LE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf) |
| | self.RE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf) |
| | self.MO_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf) |
| |
|
| | self.LE_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2) |
| | self.RE_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2) |
| | self.MO_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2) |
| |
|
| | self.LE_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4) |
| | self.RE_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4) |
| | self.MO_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4) |
| |
|
| |
|
| | self.LE_256_Attention = AttentionBlock(64) |
| | self.RE_256_Attention = AttentionBlock(64) |
| | self.MO_256_Attention = AttentionBlock(64) |
| |
|
| | self.LE_128_Attention = AttentionBlock(128) |
| | self.RE_128_Attention = AttentionBlock(128) |
| | self.MO_128_Attention = AttentionBlock(128) |
| |
|
| | self.LE_64_Attention = AttentionBlock(256) |
| | self.RE_64_Attention = AttentionBlock(256) |
| | self.MO_64_Attention = AttentionBlock(256) |
| |
|
| | self.LE_256_Mask = MaskAttention(64) |
| | self.RE_256_Mask = MaskAttention(64) |
| | self.MO_256_Mask = MaskAttention(64) |
| |
|
| | self.LE_128_Mask = MaskAttention(128) |
| | self.RE_128_Mask = MaskAttention(128) |
| | self.MO_128_Mask = MaskAttention(128) |
| |
|
| | self.LE_64_Mask = MaskAttention(256) |
| | self.RE_64_Mask = MaskAttention(256) |
| | self.MO_64_Mask = MaskAttention(256) |
| |
|
| | self.MSDilate = MSDilateBlock(ngf*4, dilation = [4,3,2,1]) |
| |
|
| | self.up1 = StyledUpBlock(ngf*4, ngf*2, noise_inject=False) |
| | self.up2 = StyledUpBlock(ngf*2, ngf, noise_inject=False) |
| | self.up3 = StyledUpBlock(ngf, ngf, noise_inject=False) |
| | self.up4 = nn.Sequential( |
| | SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)), |
| | nn.LeakyReLU(0.2), |
| | UpResBlock(ngf), |
| | UpResBlock(ngf), |
| | SpectralNorm(nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)), |
| | nn.Tanh() |
| | ) |
| | |
| | |
| | self.register_buffer('le_256_mem_key', torch.randn(128,16,40,40)) |
| | self.register_buffer('re_256_mem_key', torch.randn(128,16,40,40)) |
| | self.register_buffer('mo_256_mem_key', torch.randn(128,16,55,55)) |
| | self.register_buffer('le_256_mem_value', torch.randn(128,64,40,40)) |
| | self.register_buffer('re_256_mem_value', torch.randn(128,64,40,40)) |
| | self.register_buffer('mo_256_mem_value', torch.randn(128,64,55,55)) |
| | |
| |
|
| | self.register_buffer('le_128_mem_key', torch.randn(128,32,20,20)) |
| | self.register_buffer('re_128_mem_key', torch.randn(128,32,20,20)) |
| | self.register_buffer('mo_128_mem_key', torch.randn(128,32,27,27)) |
| | self.register_buffer('le_128_mem_value', torch.randn(128,128,20,20)) |
| | self.register_buffer('re_128_mem_value', torch.randn(128,128,20,20)) |
| | self.register_buffer('mo_128_mem_value', torch.randn(128,128,27,27)) |
| |
|
| | self.register_buffer('le_64_mem_key', torch.randn(128,64,10,10)) |
| | self.register_buffer('re_64_mem_key', torch.randn(128,64,10,10)) |
| | self.register_buffer('mo_64_mem_key', torch.randn(128,64,13,13)) |
| | self.register_buffer('le_64_mem_value', torch.randn(128,256,10,10)) |
| | self.register_buffer('re_64_mem_value', torch.randn(128,256,10,10)) |
| | self.register_buffer('mo_64_mem_value', torch.randn(128,256,13,13)) |
| |
|
| | |
| | def readMem(self, k, v, q): |
| | sim = F.conv2d(q, k) |
| | score = F.softmax(sim/sqrt(sim.size(1)), dim=1) |
| | sb,sn,sw,sh = score.size() |
| | s_m = score.view(sb, -1).unsqueeze(1) |
| | vb,vn,vw,vh = v.size() |
| | v_in = v.view(vb, -1).repeat(sb,1,1) |
| | mem_out = torch.bmm(s_m, v_in).squeeze(1).view(sb, vn, vw,vh) |
| | max_inds = torch.argmax(score, dim=1).squeeze() |
| | return mem_out, max_inds |
| | |
| |
|
| | def memorize(self, img, locs): |
| | fs = self.E_hq(img, locs) |
| | LE256_key, LE256_value = self.LE_256_KV(fs['le256']) |
| | RE256_key, RE256_value = self.RE_256_KV(fs['re256']) |
| | MO256_key, MO256_value = self.MO_256_KV(fs['mo256']) |
| |
|
| | LE128_key, LE128_value = self.LE_128_KV(fs['le128']) |
| | RE128_key, RE128_value = self.RE_128_KV(fs['re128']) |
| | MO128_key, MO128_value = self.MO_128_KV(fs['mo128']) |
| |
|
| | LE64_key, LE64_value = self.LE_64_KV(fs['le64']) |
| | RE64_key, RE64_value = self.RE_64_KV(fs['re64']) |
| | MO64_key, MO64_value = self.MO_64_KV(fs['mo64']) |
| |
|
| | Mem256 = {'LE256Key': LE256_key, 'LE256Value': LE256_value, 'RE256Key': RE256_key, 'RE256Value': RE256_value,'MO256Key': MO256_key, 'MO256Value': MO256_value} |
| | Mem128 = {'LE128Key': LE128_key, 'LE128Value': LE128_value, 'RE128Key': RE128_key, 'RE128Value': RE128_value,'MO128Key': MO128_key, 'MO128Value': MO128_value} |
| | Mem64 = {'LE64Key': LE64_key, 'LE64Value': LE64_value, 'RE64Key': RE64_key, 'RE64Value': RE64_value,'MO64Key': MO64_key, 'MO64Value': MO64_value} |
| | |
| | FS256 = {'LE256F':fs['le256'], 'RE256F':fs['re256'], 'MO256F':fs['mo256']} |
| | FS128 = {'LE128F':fs['le128'], 'RE128F':fs['re128'], 'MO128F':fs['mo128']} |
| | FS64 = {'LE64F':fs['le64'], 'RE64F':fs['re64'], 'MO64F':fs['mo64']} |
| | |
| | return Mem256, Mem128, Mem64 |
| |
|
| | def enhancer(self, fs_in, sp_256=None, sp_128=None, sp_64=None): |
| | le_256_q = fs_in['le_256_q'] |
| | re_256_q = fs_in['re_256_q'] |
| | mo_256_q = fs_in['mo_256_q'] |
| |
|
| | le_128_q = fs_in['le_128_q'] |
| | re_128_q = fs_in['re_128_q'] |
| | mo_128_q = fs_in['mo_128_q'] |
| |
|
| | le_64_q = fs_in['le_64_q'] |
| | re_64_q = fs_in['re_64_q'] |
| | mo_64_q = fs_in['mo_64_q'] |
| |
|
| | |
| | |
| | le_256_mem_g, le_256_inds = self.readMem(self.le_256_mem_key, self.le_256_mem_value, le_256_q) |
| | re_256_mem_g, re_256_inds = self.readMem(self.re_256_mem_key, self.re_256_mem_value, re_256_q) |
| | mo_256_mem_g, mo_256_inds = self.readMem(self.mo_256_mem_key, self.mo_256_mem_value, mo_256_q) |
| |
|
| | le_128_mem_g, le_128_inds = self.readMem(self.le_128_mem_key, self.le_128_mem_value, le_128_q) |
| | re_128_mem_g, re_128_inds = self.readMem(self.re_128_mem_key, self.re_128_mem_value, re_128_q) |
| | mo_128_mem_g, mo_128_inds = self.readMem(self.mo_128_mem_key, self.mo_128_mem_value, mo_128_q) |
| |
|
| | le_64_mem_g, le_64_inds = self.readMem(self.le_64_mem_key, self.le_64_mem_value, le_64_q) |
| | re_64_mem_g, re_64_inds = self.readMem(self.re_64_mem_key, self.re_64_mem_value, re_64_q) |
| | mo_64_mem_g, mo_64_inds = self.readMem(self.mo_64_mem_key, self.mo_64_mem_value, mo_64_q) |
| |
|
| | if sp_256 is not None and sp_128 is not None and sp_64 is not None: |
| | le_256_mem_s, _ = self.readMem(sp_256['LE256Key'], sp_256['LE256Value'], le_256_q) |
| | re_256_mem_s, _ = self.readMem(sp_256['RE256Key'], sp_256['RE256Value'], re_256_q) |
| | mo_256_mem_s, _ = self.readMem(sp_256['MO256Key'], sp_256['MO256Value'], mo_256_q) |
| | le_256_mask = self.LE_256_Mask(fs_in['le256'],le_256_mem_s,le_256_mem_g) |
| | le_256_mem = le_256_mask*le_256_mem_s + (1-le_256_mask)*le_256_mem_g |
| | re_256_mask = self.RE_256_Mask(fs_in['re256'],re_256_mem_s,re_256_mem_g) |
| | re_256_mem = re_256_mask*re_256_mem_s + (1-re_256_mask)*re_256_mem_g |
| | mo_256_mask = self.MO_256_Mask(fs_in['mo256'],mo_256_mem_s,mo_256_mem_g) |
| | mo_256_mem = mo_256_mask*mo_256_mem_s + (1-mo_256_mask)*mo_256_mem_g |
| |
|
| | le_128_mem_s, _ = self.readMem(sp_128['LE128Key'], sp_128['LE128Value'], le_128_q) |
| | re_128_mem_s, _ = self.readMem(sp_128['RE128Key'], sp_128['RE128Value'], re_128_q) |
| | mo_128_mem_s, _ = self.readMem(sp_128['MO128Key'], sp_128['MO128Value'], mo_128_q) |
| | le_128_mask = self.LE_128_Mask(fs_in['le128'],le_128_mem_s,le_128_mem_g) |
| | le_128_mem = le_128_mask*le_128_mem_s + (1-le_128_mask)*le_128_mem_g |
| | re_128_mask = self.RE_128_Mask(fs_in['re128'],re_128_mem_s,re_128_mem_g) |
| | re_128_mem = re_128_mask*re_128_mem_s + (1-re_128_mask)*re_128_mem_g |
| | mo_128_mask = self.MO_128_Mask(fs_in['mo128'],mo_128_mem_s,mo_128_mem_g) |
| | mo_128_mem = mo_128_mask*mo_128_mem_s + (1-mo_128_mask)*mo_128_mem_g |
| |
|
| | le_64_mem_s, _ = self.readMem(sp_64['LE64Key'], sp_64['LE64Value'], le_64_q) |
| | re_64_mem_s, _ = self.readMem(sp_64['RE64Key'], sp_64['RE64Value'], re_64_q) |
| | mo_64_mem_s, _ = self.readMem(sp_64['MO64Key'], sp_64['MO64Value'], mo_64_q) |
| | le_64_mask = self.LE_64_Mask(fs_in['le64'],le_64_mem_s,le_64_mem_g) |
| | le_64_mem = le_64_mask*le_64_mem_s + (1-le_64_mask)*le_64_mem_g |
| | re_64_mask = self.RE_64_Mask(fs_in['re64'],re_64_mem_s,re_64_mem_g) |
| | re_64_mem = re_64_mask*re_64_mem_s + (1-re_64_mask)*re_64_mem_g |
| | mo_64_mask = self.MO_64_Mask(fs_in['mo64'],mo_64_mem_s,mo_64_mem_g) |
| | mo_64_mem = mo_64_mask*mo_64_mem_s + (1-mo_64_mask)*mo_64_mem_g |
| | else: |
| | le_256_mem = le_256_mem_g |
| | re_256_mem = re_256_mem_g |
| | mo_256_mem = mo_256_mem_g |
| | le_128_mem = le_128_mem_g |
| | re_128_mem = re_128_mem_g |
| | mo_128_mem = mo_128_mem_g |
| | le_64_mem = le_64_mem_g |
| | re_64_mem = re_64_mem_g |
| | mo_64_mem = mo_64_mem_g |
| |
|
| | le_256_mem_norm = adaptive_instance_normalization_4D(le_256_mem, fs_in['le256']) |
| | re_256_mem_norm = adaptive_instance_normalization_4D(re_256_mem, fs_in['re256']) |
| | mo_256_mem_norm = adaptive_instance_normalization_4D(mo_256_mem, fs_in['mo256']) |
| | |
| | |
| | le_128_mem_norm = adaptive_instance_normalization_4D(le_128_mem, fs_in['le128']) |
| | re_128_mem_norm = adaptive_instance_normalization_4D(re_128_mem, fs_in['re128']) |
| | mo_128_mem_norm = adaptive_instance_normalization_4D(mo_128_mem, fs_in['mo128']) |
| | |
| | |
| | le_64_mem_norm = adaptive_instance_normalization_4D(le_64_mem, fs_in['le64']) |
| | re_64_mem_norm = adaptive_instance_normalization_4D(re_64_mem, fs_in['re64']) |
| | mo_64_mem_norm = adaptive_instance_normalization_4D(mo_64_mem, fs_in['mo64']) |
| | |
| |
|
| | EnMem256 = {'LE256Norm': le_256_mem_norm, 'RE256Norm': re_256_mem_norm, 'MO256Norm': mo_256_mem_norm} |
| | EnMem128 = {'LE128Norm': le_128_mem_norm, 'RE128Norm': re_128_mem_norm, 'MO128Norm': mo_128_mem_norm} |
| | EnMem64 = {'LE64Norm': le_64_mem_norm, 'RE64Norm': re_64_mem_norm, 'MO64Norm': mo_64_mem_norm} |
| | Ind256 = {'LE': le_256_inds, 'RE': re_256_inds, 'MO': mo_256_inds} |
| | Ind128 = {'LE': le_128_inds, 'RE': re_128_inds, 'MO': mo_128_inds} |
| | Ind64 = {'LE': le_64_inds, 'RE': re_64_inds, 'MO': mo_64_inds} |
| | return EnMem256, EnMem128, EnMem64, Ind256, Ind128, Ind64 |
| |
|
| | def reconstruct(self, fs_in, locs, memstar): |
| | le_256_mem_norm, re_256_mem_norm, mo_256_mem_norm = memstar[0]['LE256Norm'], memstar[0]['RE256Norm'], memstar[0]['MO256Norm'] |
| | le_128_mem_norm, re_128_mem_norm, mo_128_mem_norm = memstar[1]['LE128Norm'], memstar[1]['RE128Norm'], memstar[1]['MO128Norm'] |
| | le_64_mem_norm, re_64_mem_norm, mo_64_mem_norm = memstar[2]['LE64Norm'], memstar[2]['RE64Norm'], memstar[2]['MO64Norm'] |
| |
|
| | le_256_final = self.LE_256_Attention(le_256_mem_norm - fs_in['le256']) * le_256_mem_norm + fs_in['le256'] |
| | re_256_final = self.RE_256_Attention(re_256_mem_norm - fs_in['re256']) * re_256_mem_norm + fs_in['re256'] |
| | mo_256_final = self.MO_256_Attention(mo_256_mem_norm - fs_in['mo256']) * mo_256_mem_norm + fs_in['mo256'] |
| | |
| | le_128_final = self.LE_128_Attention(le_128_mem_norm - fs_in['le128']) * le_128_mem_norm + fs_in['le128'] |
| | re_128_final = self.RE_128_Attention(re_128_mem_norm - fs_in['re128']) * re_128_mem_norm + fs_in['re128'] |
| | mo_128_final = self.MO_128_Attention(mo_128_mem_norm - fs_in['mo128']) * mo_128_mem_norm + fs_in['mo128'] |
| | |
| | le_64_final = self.LE_64_Attention(le_64_mem_norm - fs_in['le64']) * le_64_mem_norm + fs_in['le64'] |
| | re_64_final = self.RE_64_Attention(re_64_mem_norm - fs_in['re64']) * re_64_mem_norm + fs_in['re64'] |
| | mo_64_final = self.MO_64_Attention(mo_64_mem_norm - fs_in['mo64']) * mo_64_mem_norm + fs_in['mo64'] |
| |
|
| |
|
| | le_location = locs[:,0,:] |
| | re_location = locs[:,1,:] |
| | mo_location = locs[:,3,:] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | le_location = le_location.cpu().int() |
| | re_location = re_location.cpu().int() |
| | mo_location = mo_location.cpu().int() |
| |
|
| | up_in_256 = fs_in['f256'].clone() |
| | up_in_128 = fs_in['f128'].clone() |
| | up_in_64 = fs_in['f64'].clone() |
| |
|
| | for i in range(fs_in['f256'].size(0)): |
| | up_in_256[i:i+1,:,le_location[i,1]//2:le_location[i,3]//2,le_location[i,0]//2:le_location[i,2]//2] = F.interpolate(le_256_final[i:i+1,:,:,:].clone(), (le_location[i,3]//2-le_location[i,1]//2,le_location[i,2]//2-le_location[i,0]//2),mode='bilinear',align_corners=False) |
| | up_in_256[i:i+1,:,re_location[i,1]//2:re_location[i,3]//2,re_location[i,0]//2:re_location[i,2]//2] = F.interpolate(re_256_final[i:i+1,:,:,:].clone(), (re_location[i,3]//2-re_location[i,1]//2,re_location[i,2]//2-re_location[i,0]//2),mode='bilinear',align_corners=False) |
| | up_in_256[i:i+1,:,mo_location[i,1]//2:mo_location[i,3]//2,mo_location[i,0]//2:mo_location[i,2]//2] = F.interpolate(mo_256_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//2-mo_location[i,1]//2,mo_location[i,2]//2-mo_location[i,0]//2),mode='bilinear',align_corners=False) |
| | |
| | up_in_128[i:i+1,:,le_location[i,1]//4:le_location[i,3]//4,le_location[i,0]//4:le_location[i,2]//4] = F.interpolate(le_128_final[i:i+1,:,:,:].clone(), (le_location[i,3]//4-le_location[i,1]//4,le_location[i,2]//4-le_location[i,0]//4),mode='bilinear',align_corners=False) |
| | up_in_128[i:i+1,:,re_location[i,1]//4:re_location[i,3]//4,re_location[i,0]//4:re_location[i,2]//4] = F.interpolate(re_128_final[i:i+1,:,:,:].clone(), (re_location[i,3]//4-re_location[i,1]//4,re_location[i,2]//4-re_location[i,0]//4),mode='bilinear',align_corners=False) |
| | up_in_128[i:i+1,:,mo_location[i,1]//4:mo_location[i,3]//4,mo_location[i,0]//4:mo_location[i,2]//4] = F.interpolate(mo_128_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//4-mo_location[i,1]//4,mo_location[i,2]//4-mo_location[i,0]//4),mode='bilinear',align_corners=False) |
| |
|
| | up_in_64[i:i+1,:,le_location[i,1]//8:le_location[i,3]//8,le_location[i,0]//8:le_location[i,2]//8] = F.interpolate(le_64_final[i:i+1,:,:,:].clone(), (le_location[i,3]//8-le_location[i,1]//8,le_location[i,2]//8-le_location[i,0]//8),mode='bilinear',align_corners=False) |
| | up_in_64[i:i+1,:,re_location[i,1]//8:re_location[i,3]//8,re_location[i,0]//8:re_location[i,2]//8] = F.interpolate(re_64_final[i:i+1,:,:,:].clone(), (re_location[i,3]//8-re_location[i,1]//8,re_location[i,2]//8-re_location[i,0]//8),mode='bilinear',align_corners=False) |
| | up_in_64[i:i+1,:,mo_location[i,1]//8:mo_location[i,3]//8,mo_location[i,0]//8:mo_location[i,2]//8] = F.interpolate(mo_64_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//8-mo_location[i,1]//8,mo_location[i,2]//8-mo_location[i,0]//8),mode='bilinear',align_corners=False) |
| | |
| | ms_in_64 = self.MSDilate(fs_in['f64'].clone()) |
| | fea_up1 = self.up1(ms_in_64, up_in_64) |
| | fea_up2 = self.up2(fea_up1, up_in_128) |
| | fea_up3 = self.up3(fea_up2, up_in_256) |
| | output = self.up4(fea_up3) |
| | return output |
| |
|
| | def generate_specific_dictionary(self, sp_imgs=None, sp_locs=None): |
| | return self.memorize(sp_imgs, sp_locs) |
| |
|
| | def forward(self, lq=None, loc=None, sp_256 = None, sp_128 = None, sp_64 = None): |
| | try: |
| | fs_in = self.E_lq(lq, loc) |
| | except Exception as e: |
| | print(e) |
| |
|
| | GeMemNorm256, GeMemNorm128, GeMemNorm64, Ind256, Ind128, Ind64 = self.enhancer(fs_in) |
| | GeOut = self.reconstruct(fs_in, loc, memstar = [GeMemNorm256, GeMemNorm128, GeMemNorm64]) |
| | if sp_256 is not None and sp_128 is not None and sp_64 is not None: |
| | GSMemNorm256, GSMemNorm128, GSMemNorm64, _, _, _ = self.enhancer(fs_in, sp_256, sp_128, sp_64) |
| | GSOut = self.reconstruct(fs_in, loc, memstar = [GSMemNorm256, GSMemNorm128, GSMemNorm64]) |
| | else: |
| | GSOut = None |
| | return GeOut, GSOut |
| |
|
| | class UpResBlock(nn.Module): |
| | def __init__(self, dim, conv_layer = nn.Conv2d, norm_layer = nn.BatchNorm2d): |
| | super(UpResBlock, self).__init__() |
| | self.Model = nn.Sequential( |
| | SpectralNorm(conv_layer(dim, dim, 3, 1, 1)), |
| | nn.LeakyReLU(0.2), |
| | SpectralNorm(conv_layer(dim, dim, 3, 1, 1)), |
| | ) |
| | def forward(self, x): |
| | out = x + self.Model(x) |
| | return out |
| |
|