msdhon393 commited on
Commit
2598287
·
verified ·
1 Parent(s): cc6f2bd

Upload 9 files

Browse files
roop/processors/Enhance_CodeFormer.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Callable
2
+ import cv2
3
+ import numpy as np
4
+ import onnxruntime
5
+ import roop.globals
6
+
7
+ from roop.typing import Face, Frame, FaceSet
8
+ from roop.utilities import resolve_relative_path
9
+
10
+
11
+ # THREAD_LOCK = threading.Lock()
12
+
13
+
14
+ class Enhance_CodeFormer():
15
+ model_codeformer = None
16
+ devicename = None
17
+
18
+ processorname = 'codeformer'
19
+ type = 'enhance'
20
+
21
+
22
+ def Initialize(self, devicename:str):
23
+ if self.model_codeformer is None:
24
+ # replace Mac mps with cpu for the moment
25
+ devicename = devicename.replace('mps', 'cpu')
26
+ self.devicename = devicename
27
+ model_path = resolve_relative_path('../models/CodeFormer/CodeFormerv0.1.onnx')
28
+ self.model_codeformer = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
29
+ self.model_inputs = self.model_codeformer.get_inputs()
30
+ model_outputs = self.model_codeformer.get_outputs()
31
+ self.io_binding = self.model_codeformer.io_binding()
32
+ self.io_binding.bind_cpu_input(self.model_inputs[1].name, np.array([0.5]))
33
+ self.io_binding.bind_output(model_outputs[0].name, self.devicename)
34
+
35
+
36
+ def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame:
37
+ input_size = temp_frame.shape[1]
38
+ # preprocess
39
+ temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC)
40
+ temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)
41
+ temp_frame = temp_frame.astype('float32') / 255.0
42
+ temp_frame = (temp_frame - 0.5) / 0.5
43
+ temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2)
44
+
45
+ self.io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame.astype(np.float32))
46
+ self.model_codeformer.run_with_iobinding(self.io_binding)
47
+ ort_outs = self.io_binding.copy_outputs_to_cpu()
48
+ result = ort_outs[0][0]
49
+ del ort_outs
50
+
51
+ # post-process
52
+ result = result.transpose((1, 2, 0))
53
+
54
+ un_min = -1.0
55
+ un_max = 1.0
56
+ result = np.clip(result, un_min, un_max)
57
+ result = (result - un_min) / (un_max - un_min)
58
+
59
+ result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
60
+ result = (result * 255.0).round()
61
+ scale_factor = int(result.shape[1] / input_size)
62
+ return result.astype(np.uint8), scale_factor
63
+
64
+
65
+ def Release(self):
66
+ del self.model_codeformer
67
+ self.model_codeformer = None
68
+ del self.io_binding
69
+ self.io_binding = None
70
+
roop/processors/Enhance_DMDNet.py ADDED
@@ -0,0 +1,893 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Callable
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.nn.utils.spectral_norm as SpectralNorm
8
+ import threading
9
+ from torchvision.ops import roi_align
10
+
11
+ from math import sqrt
12
+
13
+ from torchvision.transforms.functional import normalize
14
+
15
+ from roop.typing import Face, Frame, FaceSet
16
+
17
+
18
+ THREAD_LOCK_DMDNET = threading.Lock()
19
+
20
+
21
+ class Enhance_DMDNet():
22
+
23
+ model_dmdnet = None
24
+ torchdevice = None
25
+
26
+ processorname = 'dmdnet'
27
+ type = 'enhance'
28
+
29
+
30
+ def Initialize(self, devicename):
31
+ if self.model_dmdnet is None:
32
+ self.model_dmdnet = self.create(devicename)
33
+
34
+
35
+ # temp_frame already cropped+aligned, bbox not
36
+ def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame:
37
+ input_size = temp_frame.shape[1]
38
+
39
+ result = self.enhance_face(source_faceset, temp_frame, target_face)
40
+ scale_factor = int(result.shape[1] / input_size)
41
+ return result.astype(np.uint8), scale_factor
42
+
43
+
44
+ def Release(self):
45
+ self.model_gfpgan = None
46
+
47
+
48
+ # https://stackoverflow.com/a/67174339
49
+ def landmarks106_to_68(self, pt106):
50
+ map106to68=[1,10,12,14,16,3,5,7,0,23,21,19,32,30,28,26,17,
51
+ 43,48,49,51,50,
52
+ 102,103,104,105,101,
53
+ 72,73,74,86,78,79,80,85,84,
54
+ 35,41,42,39,37,36,
55
+ 89,95,96,93,91,90,
56
+ 52,64,63,71,67,68,61,58,59,53,56,55,65,66,62,70,69,57,60,54
57
+ ]
58
+
59
+ pt68 = []
60
+ for i in range(68):
61
+ index = map106to68[i]
62
+ pt68.append(pt106[index])
63
+ return pt68
64
+
65
+
66
+
67
+
68
+ def check_bbox(self, imgs, boxes):
69
+ boxes = boxes.view(-1, 4, 4)
70
+ colors = [(0, 255, 0), (0, 255, 0), (255, 255, 0), (255, 0, 0)]
71
+ i = 0
72
+ for img, box in zip(imgs, boxes):
73
+ img = (img + 1)/2 * 255
74
+ img2 = img.permute(1, 2, 0).float().cpu().flip(2).numpy().copy()
75
+ for idx, point in enumerate(box):
76
+ cv2.rectangle(img2, (int(point[0]), int(point[1])), (int(point[2]), int(point[3])), color=colors[idx], thickness=2)
77
+ cv2.imwrite('dmdnet_{:02d}.png'.format(i), img2)
78
+ i += 1
79
+
80
+
81
+ def trans_points2d(self, pts, M):
82
+ new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
83
+ for i in range(pts.shape[0]):
84
+ pt = pts[i]
85
+ new_pt = np.array([pt[0], pt[1], 1.0], dtype=np.float32)
86
+ new_pt = np.dot(M, new_pt)
87
+ new_pts[i] = new_pt[0:2]
88
+
89
+ return new_pts
90
+
91
+
92
+ def enhance_face(self, ref_faceset: FaceSet, temp_frame, face: Face):
93
+ # preprocess
94
+ start_x, start_y, end_x, end_y = map(int, face['bbox'])
95
+ lm106 = face.landmark_2d_106
96
+ lq_landmarks = np.asarray(self.landmarks106_to_68(lm106))
97
+
98
+ if temp_frame.shape[0] != 512 or temp_frame.shape[1] != 512:
99
+ # scale to 512x512
100
+ scale_factor = 512 / temp_frame.shape[1]
101
+
102
+ M = face.matrix * scale_factor
103
+
104
+ lq_landmarks = self.trans_points2d(lq_landmarks, M)
105
+ temp_frame = cv2.resize(temp_frame, (512,512), interpolation = cv2.INTER_AREA)
106
+
107
+ if temp_frame.ndim == 2:
108
+ temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) # GGG
109
+ # else:
110
+ # temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) # RGB
111
+
112
+ lq = read_img_tensor(temp_frame)
113
+
114
+ LQLocs = get_component_location(lq_landmarks)
115
+ # self.check_bbox(lq, LQLocs.unsqueeze(0))
116
+
117
+ # specific, change 1000 to 1 to activate
118
+ if len(ref_faceset.faces) > 1:
119
+ SpecificImgs = []
120
+ SpecificLocs = []
121
+ for i,face in enumerate(ref_faceset.faces):
122
+ lm106 = face.landmark_2d_106
123
+ lq_landmarks = np.asarray(self.landmarks106_to_68(lm106))
124
+ ref_image = ref_faceset.ref_images[i]
125
+ if ref_image.shape[0] != 512 or ref_image.shape[1] != 512:
126
+ # scale to 512x512
127
+ scale_factor = 512 / ref_image.shape[1]
128
+
129
+ M = face.matrix * scale_factor
130
+
131
+ lq_landmarks = self.trans_points2d(lq_landmarks, M)
132
+ ref_image = cv2.resize(ref_image, (512,512), interpolation = cv2.INTER_AREA)
133
+
134
+ if ref_image.ndim == 2:
135
+ temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_GRAY2RGB) # GGG
136
+ # else:
137
+ # temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB) # RGB
138
+
139
+ ref_tensor = read_img_tensor(ref_image)
140
+ ref_locs = get_component_location(lq_landmarks)
141
+ # self.check_bbox(ref_tensor, ref_locs.unsqueeze(0))
142
+
143
+ SpecificImgs.append(ref_tensor)
144
+ SpecificLocs.append(ref_locs.unsqueeze(0))
145
+
146
+ SpecificImgs = torch.cat(SpecificImgs, dim=0)
147
+ SpecificLocs = torch.cat(SpecificLocs, dim=0)
148
+ # check_bbox(SpecificImgs, SpecificLocs)
149
+ SpMem256, SpMem128, SpMem64 = self.model_dmdnet.generate_specific_dictionary(sp_imgs = SpecificImgs.to(self.torchdevice), sp_locs = SpecificLocs)
150
+ SpMem256Para = {}
151
+ SpMem128Para = {}
152
+ SpMem64Para = {}
153
+ for k, v in SpMem256.items():
154
+ SpMem256Para[k] = v
155
+ for k, v in SpMem128.items():
156
+ SpMem128Para[k] = v
157
+ for k, v in SpMem64.items():
158
+ SpMem64Para[k] = v
159
+ else:
160
+ # generic
161
+ SpMem256Para, SpMem128Para, SpMem64Para = None, None, None
162
+
163
+ with torch.no_grad():
164
+ with THREAD_LOCK_DMDNET:
165
+ try:
166
+ GenericResult, SpecificResult = self.model_dmdnet(lq = lq.to(self.torchdevice), loc = LQLocs.unsqueeze(0), sp_256 = SpMem256Para, sp_128 = SpMem128Para, sp_64 = SpMem64Para)
167
+ except Exception as e:
168
+ print(f'Error {e} there may be something wrong with the detected component locations.')
169
+ return temp_frame
170
+
171
+ if SpecificResult is not None:
172
+ save_specific = SpecificResult * 0.5 + 0.5
173
+ save_specific = save_specific.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
174
+ save_specific = np.clip(save_specific.float().cpu().numpy(), 0, 1) * 255.0
175
+ temp_frame = save_specific.astype("uint8")
176
+ if False:
177
+ save_generic = GenericResult * 0.5 + 0.5
178
+ save_generic = save_generic.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
179
+ save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0
180
+ check_lq = lq * 0.5 + 0.5
181
+ check_lq = check_lq.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
182
+ check_lq = np.clip(check_lq.float().cpu().numpy(), 0, 1) * 255.0
183
+ cv2.imwrite('dmdnet_comparison.png', cv2.cvtColor(np.hstack((check_lq, save_generic, save_specific)),cv2.COLOR_RGB2BGR))
184
+ else:
185
+ save_generic = GenericResult * 0.5 + 0.5
186
+ save_generic = save_generic.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
187
+ save_generic = np.clip(save_generic.float().cpu().numpy(), 0, 1) * 255.0
188
+ temp_frame = save_generic.astype("uint8")
189
+ temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_RGB2BGR) # RGB
190
+ return temp_frame
191
+
192
+
193
+
194
+ def create(self, devicename):
195
+ self.torchdevice = torch.device(devicename)
196
+ model_dmdnet = DMDNet().to(self.torchdevice)
197
+ weights = torch.load('./models/DMDNet.pth')
198
+ model_dmdnet.load_state_dict(weights, strict=True)
199
+
200
+ model_dmdnet.eval()
201
+ num_params = 0
202
+ for param in model_dmdnet.parameters():
203
+ num_params += param.numel()
204
+ return model_dmdnet
205
+
206
+ # print('{:>8s} : {}'.format('Using device', device))
207
+ # print('{:>8s} : {:.2f}M'.format('Model params', num_params/1e6))
208
+
209
+
210
+
211
+ def read_img_tensor(Img=None): #rgb -1~1
212
+ Img = Img.transpose((2, 0, 1))/255.0
213
+ Img = torch.from_numpy(Img).float()
214
+ normalize(Img, [0.5,0.5,0.5], [0.5,0.5,0.5], inplace=True)
215
+ ImgTensor = Img.unsqueeze(0)
216
+ return ImgTensor
217
+
218
+
219
+ def get_component_location(Landmarks, re_read=False):
220
+ if re_read:
221
+ ReadLandmark = []
222
+ with open(Landmarks,'r') as f:
223
+ for line in f:
224
+ tmp = [float(i) for i in line.split(' ') if i != '\n']
225
+ ReadLandmark.append(tmp)
226
+ ReadLandmark = np.array(ReadLandmark) #
227
+ Landmarks = np.reshape(ReadLandmark, [-1, 2]) # 68*2
228
+ Map_LE_B = list(np.hstack((range(17,22), range(36,42))))
229
+ Map_RE_B = list(np.hstack((range(22,27), range(42,48))))
230
+ Map_LE = list(range(36,42))
231
+ Map_RE = list(range(42,48))
232
+ Map_NO = list(range(29,36))
233
+ Map_MO = list(range(48,68))
234
+
235
+ Landmarks[Landmarks>504]=504
236
+ Landmarks[Landmarks<8]=8
237
+
238
+ #left eye
239
+ Mean_LE = np.mean(Landmarks[Map_LE],0)
240
+ L_LE1 = Mean_LE[1] - np.min(Landmarks[Map_LE_B,1])
241
+ L_LE1 = L_LE1 * 1.3
242
+ L_LE2 = L_LE1 / 1.9
243
+ L_LE_xy = L_LE1 + L_LE2
244
+ L_LE_lt = [L_LE_xy/2, L_LE1]
245
+ L_LE_rb = [L_LE_xy/2, L_LE2]
246
+ Location_LE = np.hstack((Mean_LE - L_LE_lt + 1, Mean_LE + L_LE_rb)).astype(int)
247
+
248
+ #right eye
249
+ Mean_RE = np.mean(Landmarks[Map_RE],0)
250
+ L_RE1 = Mean_RE[1] - np.min(Landmarks[Map_RE_B,1])
251
+ L_RE1 = L_RE1 * 1.3
252
+ L_RE2 = L_RE1 / 1.9
253
+ L_RE_xy = L_RE1 + L_RE2
254
+ L_RE_lt = [L_RE_xy/2, L_RE1]
255
+ L_RE_rb = [L_RE_xy/2, L_RE2]
256
+ Location_RE = np.hstack((Mean_RE - L_RE_lt + 1, Mean_RE + L_RE_rb)).astype(int)
257
+
258
+ #nose
259
+ Mean_NO = np.mean(Landmarks[Map_NO],0)
260
+ L_NO1 =( np.max([Mean_NO[0] - Landmarks[31][0], Landmarks[35][0] - Mean_NO[0]])) * 1.25
261
+ L_NO2 = (Landmarks[33][1] - Mean_NO[1]) * 1.1
262
+ L_NO_xy = L_NO1 * 2
263
+ L_NO_lt = [L_NO_xy/2, L_NO_xy - L_NO2]
264
+ L_NO_rb = [L_NO_xy/2, L_NO2]
265
+ Location_NO = np.hstack((Mean_NO - L_NO_lt + 1, Mean_NO + L_NO_rb)).astype(int)
266
+
267
+ #mouth
268
+ Mean_MO = np.mean(Landmarks[Map_MO],0)
269
+ L_MO = np.max((np.max(np.max(Landmarks[Map_MO],0) - np.min(Landmarks[Map_MO],0))/2,16)) * 1.1
270
+ MO_O = Mean_MO - L_MO + 1
271
+ MO_T = Mean_MO + L_MO
272
+ MO_T[MO_T>510]=510
273
+ Location_MO = np.hstack((MO_O, MO_T)).astype(int)
274
+ return torch.cat([torch.FloatTensor(Location_LE).unsqueeze(0), torch.FloatTensor(Location_RE).unsqueeze(0), torch.FloatTensor(Location_NO).unsqueeze(0), torch.FloatTensor(Location_MO).unsqueeze(0)], dim=0)
275
+
276
+
277
+
278
+
279
+ def calc_mean_std_4D(feat, eps=1e-5):
280
+ # eps is a small value added to the variance to avoid divide-by-zero.
281
+ size = feat.size()
282
+ assert (len(size) == 4)
283
+ N, C = size[:2]
284
+ feat_var = feat.view(N, C, -1).var(dim=2) + eps
285
+ feat_std = feat_var.sqrt().view(N, C, 1, 1)
286
+ feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
287
+ return feat_mean, feat_std
288
+
289
+ def adaptive_instance_normalization_4D(content_feat, style_feat): # content_feat is ref feature, style is degradate feature
290
+ size = content_feat.size()
291
+ style_mean, style_std = calc_mean_std_4D(style_feat)
292
+ content_mean, content_std = calc_mean_std_4D(content_feat)
293
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
294
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
295
+
296
+
297
+ def convU(in_channels, out_channels,conv_layer, norm_layer, kernel_size=3, stride=1,dilation=1, bias=True):
298
+ return nn.Sequential(
299
+ SpectralNorm(conv_layer(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
300
+ nn.LeakyReLU(0.2),
301
+ SpectralNorm(conv_layer(out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias)),
302
+ )
303
+
304
+
305
+ class MSDilateBlock(nn.Module):
306
+ def __init__(self, in_channels,conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, kernel_size=3, dilation=[1,1,1,1], bias=True):
307
+ super(MSDilateBlock, self).__init__()
308
+ self.conv1 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[0], bias=bias)
309
+ self.conv2 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[1], bias=bias)
310
+ self.conv3 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[2], bias=bias)
311
+ self.conv4 = convU(in_channels, in_channels,conv_layer, norm_layer, kernel_size,dilation=dilation[3], bias=bias)
312
+ self.convi = SpectralNorm(conv_layer(in_channels*4, in_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size-1)//2, bias=bias))
313
+ def forward(self, x):
314
+ conv1 = self.conv1(x)
315
+ conv2 = self.conv2(x)
316
+ conv3 = self.conv3(x)
317
+ conv4 = self.conv4(x)
318
+ cat = torch.cat([conv1, conv2, conv3, conv4], 1)
319
+ out = self.convi(cat) + x
320
+ return out
321
+
322
+
323
+ class AdaptiveInstanceNorm(nn.Module):
324
+ def __init__(self, in_channel):
325
+ super().__init__()
326
+ self.norm = nn.InstanceNorm2d(in_channel)
327
+
328
+ def forward(self, input, style):
329
+ style_mean, style_std = calc_mean_std_4D(style)
330
+ out = self.norm(input)
331
+ size = input.size()
332
+ out = style_std.expand(size) * out + style_mean.expand(size)
333
+ return out
334
+
335
+ class NoiseInjection(nn.Module):
336
+ def __init__(self, channel):
337
+ super().__init__()
338
+ self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
339
+ def forward(self, image, noise):
340
+ if noise is None:
341
+ b, c, h, w = image.shape
342
+ noise = image.new_empty(b, 1, h, w).normal_()
343
+ return image + self.weight * noise
344
+
345
+ class StyledUpBlock(nn.Module):
346
+ def __init__(self, in_channel, out_channel, kernel_size=3, padding=1,upsample=False, noise_inject=False):
347
+ super().__init__()
348
+
349
+ self.noise_inject = noise_inject
350
+ if upsample:
351
+ self.conv1 = nn.Sequential(
352
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
353
+ SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
354
+ nn.LeakyReLU(0.2),
355
+ )
356
+ else:
357
+ self.conv1 = nn.Sequential(
358
+ SpectralNorm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
359
+ nn.LeakyReLU(0.2),
360
+ SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
361
+ )
362
+ self.convup = nn.Sequential(
363
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
364
+ SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
365
+ nn.LeakyReLU(0.2),
366
+ SpectralNorm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
367
+ )
368
+ if self.noise_inject:
369
+ self.noise1 = NoiseInjection(out_channel)
370
+
371
+ self.lrelu1 = nn.LeakyReLU(0.2)
372
+
373
+ self.ScaleModel1 = nn.Sequential(
374
+ SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
375
+ nn.LeakyReLU(0.2),
376
+ SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1))
377
+ )
378
+ self.ShiftModel1 = nn.Sequential(
379
+ SpectralNorm(nn.Conv2d(in_channel,out_channel,3, 1, 1)),
380
+ nn.LeakyReLU(0.2),
381
+ SpectralNorm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)),
382
+ )
383
+
384
+ def forward(self, input, style):
385
+ out = self.conv1(input)
386
+ out = self.lrelu1(out)
387
+ Shift1 = self.ShiftModel1(style)
388
+ Scale1 = self.ScaleModel1(style)
389
+ out = out * Scale1 + Shift1
390
+ if self.noise_inject:
391
+ out = self.noise1(out, noise=None)
392
+ outup = self.convup(out)
393
+ return outup
394
+
395
+
396
+ ####################################################################
397
+ ###############Face Dictionary Generator
398
+ ####################################################################
399
+ def AttentionBlock(in_channel):
400
+ return nn.Sequential(
401
+ SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
402
+ nn.LeakyReLU(0.2),
403
+ SpectralNorm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
404
+ )
405
+
406
+ class DilateResBlock(nn.Module):
407
+ def __init__(self, dim, dilation=[5,3] ):
408
+ super(DilateResBlock, self).__init__()
409
+ self.Res = nn.Sequential(
410
+ SpectralNorm(nn.Conv2d(dim, dim, 3, 1, ((3-1)//2)*dilation[0], dilation[0])),
411
+ nn.LeakyReLU(0.2),
412
+ SpectralNorm(nn.Conv2d(dim, dim, 3, 1, ((3-1)//2)*dilation[1], dilation[1])),
413
+ )
414
+ def forward(self, x):
415
+ out = x + self.Res(x)
416
+ return out
417
+
418
+
419
+ class KeyValue(nn.Module):
420
+ def __init__(self, indim, keydim, valdim):
421
+ super(KeyValue, self).__init__()
422
+ self.Key = nn.Sequential(
423
+ SpectralNorm(nn.Conv2d(indim, keydim, kernel_size=(3,3), padding=(1,1), stride=1)),
424
+ nn.LeakyReLU(0.2),
425
+ SpectralNorm(nn.Conv2d(keydim, keydim, kernel_size=(3,3), padding=(1,1), stride=1)),
426
+ )
427
+ self.Value = nn.Sequential(
428
+ SpectralNorm(nn.Conv2d(indim, valdim, kernel_size=(3,3), padding=(1,1), stride=1)),
429
+ nn.LeakyReLU(0.2),
430
+ SpectralNorm(nn.Conv2d(valdim, valdim, kernel_size=(3,3), padding=(1,1), stride=1)),
431
+ )
432
+ def forward(self, x):
433
+ return self.Key(x), self.Value(x)
434
+
435
+ class MaskAttention(nn.Module):
436
+ def __init__(self, indim):
437
+ super(MaskAttention, self).__init__()
438
+ self.conv1 = nn.Sequential(
439
+ SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
440
+ nn.LeakyReLU(0.2),
441
+ SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
442
+ )
443
+ self.conv2 = nn.Sequential(
444
+ SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
445
+ nn.LeakyReLU(0.2),
446
+ SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
447
+ )
448
+ self.conv3 = nn.Sequential(
449
+ SpectralNorm(nn.Conv2d(indim, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
450
+ nn.LeakyReLU(0.2),
451
+ SpectralNorm(nn.Conv2d(indim//3, indim//3, kernel_size=(3,3), padding=(1,1), stride=1)),
452
+ )
453
+ self.convCat = nn.Sequential(
454
+ SpectralNorm(nn.Conv2d(indim//3 * 3, indim, kernel_size=(3,3), padding=(1,1), stride=1)),
455
+ nn.LeakyReLU(0.2),
456
+ SpectralNorm(nn.Conv2d(indim, indim, kernel_size=(3,3), padding=(1,1), stride=1)),
457
+ )
458
+ def forward(self, x, y, z):
459
+ c1 = self.conv1(x)
460
+ c2 = self.conv2(y)
461
+ c3 = self.conv3(z)
462
+ return self.convCat(torch.cat([c1,c2,c3], dim=1))
463
+
464
+ class Query(nn.Module):
465
+ def __init__(self, indim, quedim):
466
+ super(Query, self).__init__()
467
+ self.Query = nn.Sequential(
468
+ SpectralNorm(nn.Conv2d(indim, quedim, kernel_size=(3,3), padding=(1,1), stride=1)),
469
+ nn.LeakyReLU(0.2),
470
+ SpectralNorm(nn.Conv2d(quedim, quedim, kernel_size=(3,3), padding=(1,1), stride=1)),
471
+ )
472
+ def forward(self, x):
473
+ return self.Query(x)
474
+
475
+ def roi_align_self(input, location, target_size):
476
+ test = (target_size.item(),target_size.item())
477
+ return torch.cat([F.interpolate(input[i:i+1,:,location[i,1]:location[i,3],location[i,0]:location[i,2]],test,mode='bilinear',align_corners=False) for i in range(input.size(0))],0)
478
+
479
+ class FeatureExtractor(nn.Module):
480
+ def __init__(self, ngf = 64, key_scale = 4):#
481
+ super().__init__()
482
+
483
+ self.key_scale = 4
484
+ self.part_sizes = np.array([80,80,50,110]) #
485
+ self.feature_sizes = np.array([256,128,64]) #
486
+
487
+ self.conv1 = nn.Sequential(
488
+ SpectralNorm(nn.Conv2d(3, ngf, 3, 2, 1)),
489
+ nn.LeakyReLU(0.2),
490
+ SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
491
+ )
492
+ self.conv2 = nn.Sequential(
493
+ SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
494
+ nn.LeakyReLU(0.2),
495
+ SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1))
496
+ )
497
+ self.res1 = DilateResBlock(ngf, [5,3])
498
+ self.res2 = DilateResBlock(ngf, [5,3])
499
+
500
+
501
+ self.conv3 = nn.Sequential(
502
+ SpectralNorm(nn.Conv2d(ngf, ngf*2, 3, 2, 1)),
503
+ nn.LeakyReLU(0.2),
504
+ SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1)),
505
+ )
506
+ self.conv4 = nn.Sequential(
507
+ SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1)),
508
+ nn.LeakyReLU(0.2),
509
+ SpectralNorm(nn.Conv2d(ngf*2, ngf*2, 3, 1, 1))
510
+ )
511
+ self.res3 = DilateResBlock(ngf*2, [3,1])
512
+ self.res4 = DilateResBlock(ngf*2, [3,1])
513
+
514
+ self.conv5 = nn.Sequential(
515
+ SpectralNorm(nn.Conv2d(ngf*2, ngf*4, 3, 2, 1)),
516
+ nn.LeakyReLU(0.2),
517
+ SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1)),
518
+ )
519
+ self.conv6 = nn.Sequential(
520
+ SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1)),
521
+ nn.LeakyReLU(0.2),
522
+ SpectralNorm(nn.Conv2d(ngf*4, ngf*4, 3, 1, 1))
523
+ )
524
+ self.res5 = DilateResBlock(ngf*4, [1,1])
525
+ self.res6 = DilateResBlock(ngf*4, [1,1])
526
+
527
+ self.LE_256_Q = Query(ngf, ngf // self.key_scale)
528
+ self.RE_256_Q = Query(ngf, ngf // self.key_scale)
529
+ self.MO_256_Q = Query(ngf, ngf // self.key_scale)
530
+ self.LE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
531
+ self.RE_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
532
+ self.MO_128_Q = Query(ngf * 2, ngf * 2 // self.key_scale)
533
+ self.LE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
534
+ self.RE_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
535
+ self.MO_64_Q = Query(ngf * 4, ngf * 4 // self.key_scale)
536
+
537
+
538
+ def forward(self, img, locs):
539
+ le_location = locs[:,0,:].int().cpu().numpy()
540
+ re_location = locs[:,1,:].int().cpu().numpy()
541
+ no_location = locs[:,2,:].int().cpu().numpy()
542
+ mo_location = locs[:,3,:].int().cpu().numpy()
543
+
544
+
545
+ f1_0 = self.conv1(img)
546
+ f1_1 = self.res1(f1_0)
547
+ f2_0 = self.conv2(f1_1)
548
+ f2_1 = self.res2(f2_0)
549
+
550
+ f3_0 = self.conv3(f2_1)
551
+ f3_1 = self.res3(f3_0)
552
+ f4_0 = self.conv4(f3_1)
553
+ f4_1 = self.res4(f4_0)
554
+
555
+ f5_0 = self.conv5(f4_1)
556
+ f5_1 = self.res5(f5_0)
557
+ f6_0 = self.conv6(f5_1)
558
+ f6_1 = self.res6(f6_0)
559
+
560
+
561
+ ####ROI Align
562
+ le_part_256 = roi_align_self(f2_1.clone(), le_location//2, self.part_sizes[0]//2)
563
+ re_part_256 = roi_align_self(f2_1.clone(), re_location//2, self.part_sizes[1]//2)
564
+ mo_part_256 = roi_align_self(f2_1.clone(), mo_location//2, self.part_sizes[3]//2)
565
+
566
+ le_part_128 = roi_align_self(f4_1.clone(), le_location//4, self.part_sizes[0]//4)
567
+ re_part_128 = roi_align_self(f4_1.clone(), re_location//4, self.part_sizes[1]//4)
568
+ mo_part_128 = roi_align_self(f4_1.clone(), mo_location//4, self.part_sizes[3]//4)
569
+
570
+ le_part_64 = roi_align_self(f6_1.clone(), le_location//8, self.part_sizes[0]//8)
571
+ re_part_64 = roi_align_self(f6_1.clone(), re_location//8, self.part_sizes[1]//8)
572
+ mo_part_64 = roi_align_self(f6_1.clone(), mo_location//8, self.part_sizes[3]//8)
573
+
574
+
575
+ le_256_q = self.LE_256_Q(le_part_256)
576
+ re_256_q = self.RE_256_Q(re_part_256)
577
+ mo_256_q = self.MO_256_Q(mo_part_256)
578
+
579
+ le_128_q = self.LE_128_Q(le_part_128)
580
+ re_128_q = self.RE_128_Q(re_part_128)
581
+ mo_128_q = self.MO_128_Q(mo_part_128)
582
+
583
+ le_64_q = self.LE_64_Q(le_part_64)
584
+ re_64_q = self.RE_64_Q(re_part_64)
585
+ mo_64_q = self.MO_64_Q(mo_part_64)
586
+
587
+ return {'f256': f2_1, 'f128': f4_1, 'f64': f6_1,\
588
+ 'le256': le_part_256, 're256': re_part_256, 'mo256': mo_part_256, \
589
+ 'le128': le_part_128, 're128': re_part_128, 'mo128': mo_part_128, \
590
+ 'le64': le_part_64, 're64': re_part_64, 'mo64': mo_part_64, \
591
+ 'le_256_q': le_256_q, 're_256_q': re_256_q, 'mo_256_q': mo_256_q,\
592
+ 'le_128_q': le_128_q, 're_128_q': re_128_q, 'mo_128_q': mo_128_q,\
593
+ 'le_64_q': le_64_q, 're_64_q': re_64_q, 'mo_64_q': mo_64_q}
594
+
595
+
596
+ class DMDNet(nn.Module):
597
+ def __init__(self, ngf = 64, banks_num = 128):
598
+ super().__init__()
599
+ self.part_sizes = np.array([80,80,50,110]) # size for 512
600
+ self.feature_sizes = np.array([256,128,64]) # size for 512
601
+
602
+ self.banks_num = banks_num
603
+ self.key_scale = 4
604
+
605
+ self.E_lq = FeatureExtractor(key_scale = self.key_scale)
606
+ self.E_hq = FeatureExtractor(key_scale = self.key_scale)
607
+
608
+ self.LE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
609
+ self.RE_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
610
+ self.MO_256_KV = KeyValue(ngf, ngf // self.key_scale, ngf)
611
+
612
+ self.LE_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2)
613
+ self.RE_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2)
614
+ self.MO_128_KV = KeyValue(ngf * 2 , ngf * 2 // self.key_scale, ngf * 2)
615
+
616
+ self.LE_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4)
617
+ self.RE_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4)
618
+ self.MO_64_KV = KeyValue(ngf * 4 , ngf * 4 // self.key_scale, ngf * 4)
619
+
620
+
621
+ self.LE_256_Attention = AttentionBlock(64)
622
+ self.RE_256_Attention = AttentionBlock(64)
623
+ self.MO_256_Attention = AttentionBlock(64)
624
+
625
+ self.LE_128_Attention = AttentionBlock(128)
626
+ self.RE_128_Attention = AttentionBlock(128)
627
+ self.MO_128_Attention = AttentionBlock(128)
628
+
629
+ self.LE_64_Attention = AttentionBlock(256)
630
+ self.RE_64_Attention = AttentionBlock(256)
631
+ self.MO_64_Attention = AttentionBlock(256)
632
+
633
+ self.LE_256_Mask = MaskAttention(64)
634
+ self.RE_256_Mask = MaskAttention(64)
635
+ self.MO_256_Mask = MaskAttention(64)
636
+
637
+ self.LE_128_Mask = MaskAttention(128)
638
+ self.RE_128_Mask = MaskAttention(128)
639
+ self.MO_128_Mask = MaskAttention(128)
640
+
641
+ self.LE_64_Mask = MaskAttention(256)
642
+ self.RE_64_Mask = MaskAttention(256)
643
+ self.MO_64_Mask = MaskAttention(256)
644
+
645
+ self.MSDilate = MSDilateBlock(ngf*4, dilation = [4,3,2,1])
646
+
647
+ self.up1 = StyledUpBlock(ngf*4, ngf*2, noise_inject=False) #
648
+ self.up2 = StyledUpBlock(ngf*2, ngf, noise_inject=False) #
649
+ self.up3 = StyledUpBlock(ngf, ngf, noise_inject=False) #
650
+ self.up4 = nn.Sequential(
651
+ SpectralNorm(nn.Conv2d(ngf, ngf, 3, 1, 1)),
652
+ nn.LeakyReLU(0.2),
653
+ UpResBlock(ngf),
654
+ UpResBlock(ngf),
655
+ SpectralNorm(nn.Conv2d(ngf, 3, kernel_size=3, stride=1, padding=1)),
656
+ nn.Tanh()
657
+ )
658
+
659
+ # define generic memory, revise register_buffer to register_parameter for backward update
660
+ self.register_buffer('le_256_mem_key', torch.randn(128,16,40,40))
661
+ self.register_buffer('re_256_mem_key', torch.randn(128,16,40,40))
662
+ self.register_buffer('mo_256_mem_key', torch.randn(128,16,55,55))
663
+ self.register_buffer('le_256_mem_value', torch.randn(128,64,40,40))
664
+ self.register_buffer('re_256_mem_value', torch.randn(128,64,40,40))
665
+ self.register_buffer('mo_256_mem_value', torch.randn(128,64,55,55))
666
+
667
+
668
+ self.register_buffer('le_128_mem_key', torch.randn(128,32,20,20))
669
+ self.register_buffer('re_128_mem_key', torch.randn(128,32,20,20))
670
+ self.register_buffer('mo_128_mem_key', torch.randn(128,32,27,27))
671
+ self.register_buffer('le_128_mem_value', torch.randn(128,128,20,20))
672
+ self.register_buffer('re_128_mem_value', torch.randn(128,128,20,20))
673
+ self.register_buffer('mo_128_mem_value', torch.randn(128,128,27,27))
674
+
675
+ self.register_buffer('le_64_mem_key', torch.randn(128,64,10,10))
676
+ self.register_buffer('re_64_mem_key', torch.randn(128,64,10,10))
677
+ self.register_buffer('mo_64_mem_key', torch.randn(128,64,13,13))
678
+ self.register_buffer('le_64_mem_value', torch.randn(128,256,10,10))
679
+ self.register_buffer('re_64_mem_value', torch.randn(128,256,10,10))
680
+ self.register_buffer('mo_64_mem_value', torch.randn(128,256,13,13))
681
+
682
+
683
+ def readMem(self, k, v, q):
684
+ sim = F.conv2d(q, k)
685
+ score = F.softmax(sim/sqrt(sim.size(1)), dim=1) #B * S * 1 * 1 6*128
686
+ sb,sn,sw,sh = score.size()
687
+ s_m = score.view(sb, -1).unsqueeze(1)#2*1*M
688
+ vb,vn,vw,vh = v.size()
689
+ v_in = v.view(vb, -1).repeat(sb,1,1)#2*M*(c*w*h)
690
+ mem_out = torch.bmm(s_m, v_in).squeeze(1).view(sb, vn, vw,vh)
691
+ max_inds = torch.argmax(score, dim=1).squeeze()
692
+ return mem_out, max_inds
693
+
694
+
695
+ def memorize(self, img, locs):
696
+ fs = self.E_hq(img, locs)
697
+ LE256_key, LE256_value = self.LE_256_KV(fs['le256'])
698
+ RE256_key, RE256_value = self.RE_256_KV(fs['re256'])
699
+ MO256_key, MO256_value = self.MO_256_KV(fs['mo256'])
700
+
701
+ LE128_key, LE128_value = self.LE_128_KV(fs['le128'])
702
+ RE128_key, RE128_value = self.RE_128_KV(fs['re128'])
703
+ MO128_key, MO128_value = self.MO_128_KV(fs['mo128'])
704
+
705
+ LE64_key, LE64_value = self.LE_64_KV(fs['le64'])
706
+ RE64_key, RE64_value = self.RE_64_KV(fs['re64'])
707
+ MO64_key, MO64_value = self.MO_64_KV(fs['mo64'])
708
+
709
+ Mem256 = {'LE256Key': LE256_key, 'LE256Value': LE256_value, 'RE256Key': RE256_key, 'RE256Value': RE256_value,'MO256Key': MO256_key, 'MO256Value': MO256_value}
710
+ Mem128 = {'LE128Key': LE128_key, 'LE128Value': LE128_value, 'RE128Key': RE128_key, 'RE128Value': RE128_value,'MO128Key': MO128_key, 'MO128Value': MO128_value}
711
+ Mem64 = {'LE64Key': LE64_key, 'LE64Value': LE64_value, 'RE64Key': RE64_key, 'RE64Value': RE64_value,'MO64Key': MO64_key, 'MO64Value': MO64_value}
712
+
713
+ FS256 = {'LE256F':fs['le256'], 'RE256F':fs['re256'], 'MO256F':fs['mo256']}
714
+ FS128 = {'LE128F':fs['le128'], 'RE128F':fs['re128'], 'MO128F':fs['mo128']}
715
+ FS64 = {'LE64F':fs['le64'], 'RE64F':fs['re64'], 'MO64F':fs['mo64']}
716
+
717
+ return Mem256, Mem128, Mem64
718
+
719
+ def enhancer(self, fs_in, sp_256=None, sp_128=None, sp_64=None):
720
+ le_256_q = fs_in['le_256_q']
721
+ re_256_q = fs_in['re_256_q']
722
+ mo_256_q = fs_in['mo_256_q']
723
+
724
+ le_128_q = fs_in['le_128_q']
725
+ re_128_q = fs_in['re_128_q']
726
+ mo_128_q = fs_in['mo_128_q']
727
+
728
+ le_64_q = fs_in['le_64_q']
729
+ re_64_q = fs_in['re_64_q']
730
+ mo_64_q = fs_in['mo_64_q']
731
+
732
+
733
+ ####for 256
734
+ le_256_mem_g, le_256_inds = self.readMem(self.le_256_mem_key, self.le_256_mem_value, le_256_q)
735
+ re_256_mem_g, re_256_inds = self.readMem(self.re_256_mem_key, self.re_256_mem_value, re_256_q)
736
+ mo_256_mem_g, mo_256_inds = self.readMem(self.mo_256_mem_key, self.mo_256_mem_value, mo_256_q)
737
+
738
+ le_128_mem_g, le_128_inds = self.readMem(self.le_128_mem_key, self.le_128_mem_value, le_128_q)
739
+ re_128_mem_g, re_128_inds = self.readMem(self.re_128_mem_key, self.re_128_mem_value, re_128_q)
740
+ mo_128_mem_g, mo_128_inds = self.readMem(self.mo_128_mem_key, self.mo_128_mem_value, mo_128_q)
741
+
742
+ le_64_mem_g, le_64_inds = self.readMem(self.le_64_mem_key, self.le_64_mem_value, le_64_q)
743
+ re_64_mem_g, re_64_inds = self.readMem(self.re_64_mem_key, self.re_64_mem_value, re_64_q)
744
+ mo_64_mem_g, mo_64_inds = self.readMem(self.mo_64_mem_key, self.mo_64_mem_value, mo_64_q)
745
+
746
+ if sp_256 is not None and sp_128 is not None and sp_64 is not None:
747
+ le_256_mem_s, _ = self.readMem(sp_256['LE256Key'], sp_256['LE256Value'], le_256_q)
748
+ re_256_mem_s, _ = self.readMem(sp_256['RE256Key'], sp_256['RE256Value'], re_256_q)
749
+ mo_256_mem_s, _ = self.readMem(sp_256['MO256Key'], sp_256['MO256Value'], mo_256_q)
750
+ le_256_mask = self.LE_256_Mask(fs_in['le256'],le_256_mem_s,le_256_mem_g)
751
+ le_256_mem = le_256_mask*le_256_mem_s + (1-le_256_mask)*le_256_mem_g
752
+ re_256_mask = self.RE_256_Mask(fs_in['re256'],re_256_mem_s,re_256_mem_g)
753
+ re_256_mem = re_256_mask*re_256_mem_s + (1-re_256_mask)*re_256_mem_g
754
+ mo_256_mask = self.MO_256_Mask(fs_in['mo256'],mo_256_mem_s,mo_256_mem_g)
755
+ mo_256_mem = mo_256_mask*mo_256_mem_s + (1-mo_256_mask)*mo_256_mem_g
756
+
757
+ le_128_mem_s, _ = self.readMem(sp_128['LE128Key'], sp_128['LE128Value'], le_128_q)
758
+ re_128_mem_s, _ = self.readMem(sp_128['RE128Key'], sp_128['RE128Value'], re_128_q)
759
+ mo_128_mem_s, _ = self.readMem(sp_128['MO128Key'], sp_128['MO128Value'], mo_128_q)
760
+ le_128_mask = self.LE_128_Mask(fs_in['le128'],le_128_mem_s,le_128_mem_g)
761
+ le_128_mem = le_128_mask*le_128_mem_s + (1-le_128_mask)*le_128_mem_g
762
+ re_128_mask = self.RE_128_Mask(fs_in['re128'],re_128_mem_s,re_128_mem_g)
763
+ re_128_mem = re_128_mask*re_128_mem_s + (1-re_128_mask)*re_128_mem_g
764
+ mo_128_mask = self.MO_128_Mask(fs_in['mo128'],mo_128_mem_s,mo_128_mem_g)
765
+ mo_128_mem = mo_128_mask*mo_128_mem_s + (1-mo_128_mask)*mo_128_mem_g
766
+
767
+ le_64_mem_s, _ = self.readMem(sp_64['LE64Key'], sp_64['LE64Value'], le_64_q)
768
+ re_64_mem_s, _ = self.readMem(sp_64['RE64Key'], sp_64['RE64Value'], re_64_q)
769
+ mo_64_mem_s, _ = self.readMem(sp_64['MO64Key'], sp_64['MO64Value'], mo_64_q)
770
+ le_64_mask = self.LE_64_Mask(fs_in['le64'],le_64_mem_s,le_64_mem_g)
771
+ le_64_mem = le_64_mask*le_64_mem_s + (1-le_64_mask)*le_64_mem_g
772
+ re_64_mask = self.RE_64_Mask(fs_in['re64'],re_64_mem_s,re_64_mem_g)
773
+ re_64_mem = re_64_mask*re_64_mem_s + (1-re_64_mask)*re_64_mem_g
774
+ mo_64_mask = self.MO_64_Mask(fs_in['mo64'],mo_64_mem_s,mo_64_mem_g)
775
+ mo_64_mem = mo_64_mask*mo_64_mem_s + (1-mo_64_mask)*mo_64_mem_g
776
+ else:
777
+ le_256_mem = le_256_mem_g
778
+ re_256_mem = re_256_mem_g
779
+ mo_256_mem = mo_256_mem_g
780
+ le_128_mem = le_128_mem_g
781
+ re_128_mem = re_128_mem_g
782
+ mo_128_mem = mo_128_mem_g
783
+ le_64_mem = le_64_mem_g
784
+ re_64_mem = re_64_mem_g
785
+ mo_64_mem = mo_64_mem_g
786
+
787
+ le_256_mem_norm = adaptive_instance_normalization_4D(le_256_mem, fs_in['le256'])
788
+ re_256_mem_norm = adaptive_instance_normalization_4D(re_256_mem, fs_in['re256'])
789
+ mo_256_mem_norm = adaptive_instance_normalization_4D(mo_256_mem, fs_in['mo256'])
790
+
791
+ ####for 128
792
+ le_128_mem_norm = adaptive_instance_normalization_4D(le_128_mem, fs_in['le128'])
793
+ re_128_mem_norm = adaptive_instance_normalization_4D(re_128_mem, fs_in['re128'])
794
+ mo_128_mem_norm = adaptive_instance_normalization_4D(mo_128_mem, fs_in['mo128'])
795
+
796
+ ####for 64
797
+ le_64_mem_norm = adaptive_instance_normalization_4D(le_64_mem, fs_in['le64'])
798
+ re_64_mem_norm = adaptive_instance_normalization_4D(re_64_mem, fs_in['re64'])
799
+ mo_64_mem_norm = adaptive_instance_normalization_4D(mo_64_mem, fs_in['mo64'])
800
+
801
+
802
+ EnMem256 = {'LE256Norm': le_256_mem_norm, 'RE256Norm': re_256_mem_norm, 'MO256Norm': mo_256_mem_norm}
803
+ EnMem128 = {'LE128Norm': le_128_mem_norm, 'RE128Norm': re_128_mem_norm, 'MO128Norm': mo_128_mem_norm}
804
+ EnMem64 = {'LE64Norm': le_64_mem_norm, 'RE64Norm': re_64_mem_norm, 'MO64Norm': mo_64_mem_norm}
805
+ Ind256 = {'LE': le_256_inds, 'RE': re_256_inds, 'MO': mo_256_inds}
806
+ Ind128 = {'LE': le_128_inds, 'RE': re_128_inds, 'MO': mo_128_inds}
807
+ Ind64 = {'LE': le_64_inds, 'RE': re_64_inds, 'MO': mo_64_inds}
808
+ return EnMem256, EnMem128, EnMem64, Ind256, Ind128, Ind64
809
+
810
+ def reconstruct(self, fs_in, locs, memstar):
811
+ le_256_mem_norm, re_256_mem_norm, mo_256_mem_norm = memstar[0]['LE256Norm'], memstar[0]['RE256Norm'], memstar[0]['MO256Norm']
812
+ le_128_mem_norm, re_128_mem_norm, mo_128_mem_norm = memstar[1]['LE128Norm'], memstar[1]['RE128Norm'], memstar[1]['MO128Norm']
813
+ le_64_mem_norm, re_64_mem_norm, mo_64_mem_norm = memstar[2]['LE64Norm'], memstar[2]['RE64Norm'], memstar[2]['MO64Norm']
814
+
815
+ le_256_final = self.LE_256_Attention(le_256_mem_norm - fs_in['le256']) * le_256_mem_norm + fs_in['le256']
816
+ re_256_final = self.RE_256_Attention(re_256_mem_norm - fs_in['re256']) * re_256_mem_norm + fs_in['re256']
817
+ mo_256_final = self.MO_256_Attention(mo_256_mem_norm - fs_in['mo256']) * mo_256_mem_norm + fs_in['mo256']
818
+
819
+ le_128_final = self.LE_128_Attention(le_128_mem_norm - fs_in['le128']) * le_128_mem_norm + fs_in['le128']
820
+ re_128_final = self.RE_128_Attention(re_128_mem_norm - fs_in['re128']) * re_128_mem_norm + fs_in['re128']
821
+ mo_128_final = self.MO_128_Attention(mo_128_mem_norm - fs_in['mo128']) * mo_128_mem_norm + fs_in['mo128']
822
+
823
+ le_64_final = self.LE_64_Attention(le_64_mem_norm - fs_in['le64']) * le_64_mem_norm + fs_in['le64']
824
+ re_64_final = self.RE_64_Attention(re_64_mem_norm - fs_in['re64']) * re_64_mem_norm + fs_in['re64']
825
+ mo_64_final = self.MO_64_Attention(mo_64_mem_norm - fs_in['mo64']) * mo_64_mem_norm + fs_in['mo64']
826
+
827
+
828
+ le_location = locs[:,0,:]
829
+ re_location = locs[:,1,:]
830
+ mo_location = locs[:,3,:]
831
+
832
+ # Somehow with latest Torch it doesn't like numpy wrappers anymore
833
+
834
+ # le_location = le_location.cpu().int().numpy()
835
+ # re_location = re_location.cpu().int().numpy()
836
+ # mo_location = mo_location.cpu().int().numpy()
837
+ le_location = le_location.cpu().int()
838
+ re_location = re_location.cpu().int()
839
+ mo_location = mo_location.cpu().int()
840
+
841
+ up_in_256 = fs_in['f256'].clone()# * 0
842
+ up_in_128 = fs_in['f128'].clone()# * 0
843
+ up_in_64 = fs_in['f64'].clone()# * 0
844
+
845
+ for i in range(fs_in['f256'].size(0)):
846
+ up_in_256[i:i+1,:,le_location[i,1]//2:le_location[i,3]//2,le_location[i,0]//2:le_location[i,2]//2] = F.interpolate(le_256_final[i:i+1,:,:,:].clone(), (le_location[i,3]//2-le_location[i,1]//2,le_location[i,2]//2-le_location[i,0]//2),mode='bilinear',align_corners=False)
847
+ up_in_256[i:i+1,:,re_location[i,1]//2:re_location[i,3]//2,re_location[i,0]//2:re_location[i,2]//2] = F.interpolate(re_256_final[i:i+1,:,:,:].clone(), (re_location[i,3]//2-re_location[i,1]//2,re_location[i,2]//2-re_location[i,0]//2),mode='bilinear',align_corners=False)
848
+ up_in_256[i:i+1,:,mo_location[i,1]//2:mo_location[i,3]//2,mo_location[i,0]//2:mo_location[i,2]//2] = F.interpolate(mo_256_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//2-mo_location[i,1]//2,mo_location[i,2]//2-mo_location[i,0]//2),mode='bilinear',align_corners=False)
849
+
850
+ up_in_128[i:i+1,:,le_location[i,1]//4:le_location[i,3]//4,le_location[i,0]//4:le_location[i,2]//4] = F.interpolate(le_128_final[i:i+1,:,:,:].clone(), (le_location[i,3]//4-le_location[i,1]//4,le_location[i,2]//4-le_location[i,0]//4),mode='bilinear',align_corners=False)
851
+ up_in_128[i:i+1,:,re_location[i,1]//4:re_location[i,3]//4,re_location[i,0]//4:re_location[i,2]//4] = F.interpolate(re_128_final[i:i+1,:,:,:].clone(), (re_location[i,3]//4-re_location[i,1]//4,re_location[i,2]//4-re_location[i,0]//4),mode='bilinear',align_corners=False)
852
+ up_in_128[i:i+1,:,mo_location[i,1]//4:mo_location[i,3]//4,mo_location[i,0]//4:mo_location[i,2]//4] = F.interpolate(mo_128_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//4-mo_location[i,1]//4,mo_location[i,2]//4-mo_location[i,0]//4),mode='bilinear',align_corners=False)
853
+
854
+ up_in_64[i:i+1,:,le_location[i,1]//8:le_location[i,3]//8,le_location[i,0]//8:le_location[i,2]//8] = F.interpolate(le_64_final[i:i+1,:,:,:].clone(), (le_location[i,3]//8-le_location[i,1]//8,le_location[i,2]//8-le_location[i,0]//8),mode='bilinear',align_corners=False)
855
+ up_in_64[i:i+1,:,re_location[i,1]//8:re_location[i,3]//8,re_location[i,0]//8:re_location[i,2]//8] = F.interpolate(re_64_final[i:i+1,:,:,:].clone(), (re_location[i,3]//8-re_location[i,1]//8,re_location[i,2]//8-re_location[i,0]//8),mode='bilinear',align_corners=False)
856
+ up_in_64[i:i+1,:,mo_location[i,1]//8:mo_location[i,3]//8,mo_location[i,0]//8:mo_location[i,2]//8] = F.interpolate(mo_64_final[i:i+1,:,:,:].clone(), (mo_location[i,3]//8-mo_location[i,1]//8,mo_location[i,2]//8-mo_location[i,0]//8),mode='bilinear',align_corners=False)
857
+
858
+ ms_in_64 = self.MSDilate(fs_in['f64'].clone())
859
+ fea_up1 = self.up1(ms_in_64, up_in_64)
860
+ fea_up2 = self.up2(fea_up1, up_in_128) #
861
+ fea_up3 = self.up3(fea_up2, up_in_256) #
862
+ output = self.up4(fea_up3) #
863
+ return output
864
+
865
+ def generate_specific_dictionary(self, sp_imgs=None, sp_locs=None):
866
+ return self.memorize(sp_imgs, sp_locs)
867
+
868
+ def forward(self, lq=None, loc=None, sp_256 = None, sp_128 = None, sp_64 = None):
869
+ try:
870
+ fs_in = self.E_lq(lq, loc) # low quality images
871
+ except Exception as e:
872
+ print(e)
873
+
874
+ GeMemNorm256, GeMemNorm128, GeMemNorm64, Ind256, Ind128, Ind64 = self.enhancer(fs_in)
875
+ GeOut = self.reconstruct(fs_in, loc, memstar = [GeMemNorm256, GeMemNorm128, GeMemNorm64])
876
+ if sp_256 is not None and sp_128 is not None and sp_64 is not None:
877
+ GSMemNorm256, GSMemNorm128, GSMemNorm64, _, _, _ = self.enhancer(fs_in, sp_256, sp_128, sp_64)
878
+ GSOut = self.reconstruct(fs_in, loc, memstar = [GSMemNorm256, GSMemNorm128, GSMemNorm64])
879
+ else:
880
+ GSOut = None
881
+ return GeOut, GSOut
882
+
883
+ class UpResBlock(nn.Module):
884
+ def __init__(self, dim, conv_layer = nn.Conv2d, norm_layer = nn.BatchNorm2d):
885
+ super(UpResBlock, self).__init__()
886
+ self.Model = nn.Sequential(
887
+ SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
888
+ nn.LeakyReLU(0.2),
889
+ SpectralNorm(conv_layer(dim, dim, 3, 1, 1)),
890
+ )
891
+ def forward(self, x):
892
+ out = x + self.Model(x)
893
+ return out
roop/processors/Enhance_GFPGAN.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Callable
2
+ import cv2
3
+ import numpy as np
4
+ import onnxruntime
5
+ import roop.globals
6
+
7
+ from roop.typing import Face, Frame, FaceSet
8
+ from roop.utilities import resolve_relative_path
9
+
10
+
11
+ # THREAD_LOCK = threading.Lock()
12
+
13
+
14
+ class Enhance_GFPGAN():
15
+
16
+ model_gfpgan = None
17
+ name = None
18
+ devicename = None
19
+
20
+ processorname = 'gfpgan'
21
+ type = 'enhance'
22
+
23
+
24
+ def Initialize(self, devicename):
25
+ if self.model_gfpgan is None:
26
+ model_path = resolve_relative_path('../models/GFPGANv1.4.onnx')
27
+ self.model_gfpgan = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
28
+ # replace Mac mps with cpu for the moment
29
+ devicename = devicename.replace('mps', 'cpu')
30
+ self.devicename = devicename
31
+
32
+ self.name = self.model_gfpgan.get_inputs()[0].name
33
+
34
+ def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame:
35
+ # preprocess
36
+ input_size = temp_frame.shape[1]
37
+ temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC)
38
+
39
+ temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)
40
+ temp_frame = temp_frame.astype('float32') / 255.0
41
+ temp_frame = (temp_frame - 0.5) / 0.5
42
+ temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2)
43
+
44
+ io_binding = self.model_gfpgan.io_binding()
45
+ io_binding.bind_cpu_input("input", temp_frame)
46
+ io_binding.bind_output("1288", self.devicename)
47
+ self.model_gfpgan.run_with_iobinding(io_binding)
48
+ ort_outs = io_binding.copy_outputs_to_cpu()
49
+ result = ort_outs[0][0]
50
+
51
+ # post-process
52
+ result = np.clip(result, -1, 1)
53
+ result = (result + 1) / 2
54
+ result = result.transpose(1, 2, 0) * 255.0
55
+ result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
56
+ scale_factor = int(result.shape[1] / input_size)
57
+ return result.astype(np.uint8), scale_factor
58
+
59
+
60
+ def Release(self):
61
+ self.model_gfpgan = None
62
+
63
+
64
+
65
+
66
+
67
+
68
+
69
+
70
+
71
+
72
+
roop/processors/Enhance_GPEN.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Callable
2
+ import cv2
3
+ import numpy as np
4
+ import onnxruntime
5
+ import roop.globals
6
+
7
+ from roop.typing import Face, Frame, FaceSet
8
+ from roop.utilities import resolve_relative_path
9
+
10
+
11
+ class Enhance_GPEN():
12
+
13
+ model_gpen = None
14
+ name = None
15
+ devicename = None
16
+
17
+ processorname = 'gpen'
18
+ type = 'enhance'
19
+
20
+
21
+ def Initialize(self, devicename):
22
+ if self.model_gpen is None:
23
+ model_path = resolve_relative_path('../models/GPEN-BFR-512.onnx')
24
+ self.model_gpen = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
25
+ # replace Mac mps with cpu for the moment
26
+ devicename = devicename.replace('mps', 'cpu')
27
+ self.devicename = devicename
28
+
29
+ self.name = self.model_gpen.get_inputs()[0].name
30
+
31
+ def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame:
32
+ # preprocess
33
+ input_size = temp_frame.shape[1]
34
+ temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC)
35
+
36
+ temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)
37
+ temp_frame = temp_frame.astype('float32') / 255.0
38
+ temp_frame = (temp_frame - 0.5) / 0.5
39
+ temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2)
40
+
41
+ io_binding = self.model_gpen.io_binding()
42
+ io_binding.bind_cpu_input("input", temp_frame)
43
+ io_binding.bind_output("output", self.devicename)
44
+ self.model_gpen.run_with_iobinding(io_binding)
45
+ ort_outs = io_binding.copy_outputs_to_cpu()
46
+ result = ort_outs[0][0]
47
+
48
+ # post-process
49
+ result = np.clip(result, -1, 1)
50
+ result = (result + 1) / 2
51
+ result = result.transpose(1, 2, 0) * 255.0
52
+ result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
53
+ scale_factor = int(result.shape[1] / input_size)
54
+ return result.astype(np.uint8), scale_factor
55
+
56
+
57
+ def Release(self):
58
+ self.model_gpen = None
roop/processors/Enhance_RestoreFormerPPlus.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Callable
2
+ import cv2
3
+ import numpy as np
4
+ import onnxruntime
5
+ import roop.globals
6
+
7
+ from roop.typing import Face, Frame, FaceSet
8
+ from roop.utilities import resolve_relative_path
9
+
10
+ class Enhance_RestoreFormerPPlus():
11
+ model_restoreformerpplus = None
12
+ devicename = None
13
+ name = None
14
+
15
+ processorname = 'restoreformer++'
16
+ type = 'enhance'
17
+
18
+
19
+ def Initialize(self, devicename:str):
20
+ if self.model_restoreformerpplus is None:
21
+ # replace Mac mps with cpu for the moment
22
+ devicename = devicename.replace('mps', 'cpu')
23
+ self.devicename = devicename
24
+ model_path = resolve_relative_path('../models/restoreformer_plus_plus.onnx')
25
+ self.model_restoreformerpplus = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
26
+ self.model_inputs = self.model_restoreformerpplus.get_inputs()
27
+ model_outputs = self.model_restoreformerpplus.get_outputs()
28
+ self.io_binding = self.model_restoreformerpplus.io_binding()
29
+ self.io_binding.bind_output(model_outputs[0].name, self.devicename)
30
+
31
+ def Run(self, source_faceset: FaceSet, target_face: Face, temp_frame: Frame) -> Frame:
32
+ # preprocess
33
+ input_size = temp_frame.shape[1]
34
+ temp_frame = cv2.resize(temp_frame, (512, 512), cv2.INTER_CUBIC)
35
+ temp_frame = cv2.cvtColor(temp_frame, cv2.COLOR_BGR2RGB)
36
+ temp_frame = temp_frame.astype('float32') / 255.0
37
+ temp_frame = (temp_frame - 0.5) / 0.5
38
+ temp_frame = np.expand_dims(temp_frame, axis=0).transpose(0, 3, 1, 2)
39
+
40
+ self.io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame) # .astype(np.float32)
41
+ self.model_restoreformerpplus.run_with_iobinding(self.io_binding)
42
+ ort_outs = self.io_binding.copy_outputs_to_cpu()
43
+ result = ort_outs[0][0]
44
+ del ort_outs
45
+
46
+ result = np.clip(result, -1, 1)
47
+ result = (result + 1) / 2
48
+ result = result.transpose(1, 2, 0) * 255.0
49
+ result = cv2.cvtColor(result, cv2.COLOR_RGB2BGR)
50
+ scale_factor = int(result.shape[1] / input_size)
51
+ return result.astype(np.uint8), scale_factor
52
+
53
+
54
+ def Release(self):
55
+ del self.model_restoreformerpplus
56
+ self.model_restoreformerpplus = None
57
+ del self.io_binding
58
+ self.io_binding = None
59
+
roop/processors/FaceSwapInsightFace.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import roop.globals
2
+ import cv2
3
+ import numpy as np
4
+ import onnx
5
+ import onnxruntime
6
+
7
+ from roop.typing import Face, Frame
8
+ from roop.utilities import resolve_relative_path
9
+
10
+
11
+
12
+ class FaceSwapInsightFace():
13
+ model_swap_insightface = None
14
+
15
+
16
+ processorname = 'faceswap'
17
+ type = 'swap'
18
+
19
+
20
+ def Initialize(self, devicename):
21
+ if self.model_swap_insightface is None:
22
+ model_path = resolve_relative_path('../models/inswapper_128.onnx')
23
+ graph = onnx.load(model_path).graph
24
+ self.emap = onnx.numpy_helper.to_array(graph.initializer[-1])
25
+ devicename = devicename.replace('mps', 'cpu')
26
+ self.devicename = devicename
27
+ self.input_mean = 0.0
28
+ self.input_std = 255.0
29
+ #cuda_options = {"arena_extend_strategy": "kSameAsRequested", 'cudnn_conv_algo_search': 'DEFAULT'}
30
+ sess_options = onnxruntime.SessionOptions()
31
+ sess_options.enable_cpu_mem_arena = False
32
+ self.model_swap_insightface = onnxruntime.InferenceSession(model_path, sess_options, providers=roop.globals.execution_providers)
33
+ # replace Mac mps with cpu for the moment
34
+
35
+
36
+
37
+ def Run(self, source_face: Face, target_face: Face, temp_frame: Frame) -> Frame:
38
+ blob = cv2.dnn.blobFromImage(temp_frame, 1.0 / self.input_std, (128, 128),
39
+ (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
40
+ latent = source_face.normed_embedding.reshape((1,-1))
41
+ latent = np.dot(latent, self.emap)
42
+ latent /= np.linalg.norm(latent)
43
+ io_binding = self.model_swap_insightface.io_binding()
44
+ io_binding.bind_cpu_input("target", blob)
45
+ io_binding.bind_cpu_input("source", latent)
46
+ io_binding.bind_output("output", self.devicename)
47
+ self.model_swap_insightface.run_with_iobinding(io_binding)
48
+ ort_outs = io_binding.copy_outputs_to_cpu()[0]
49
+ img_fake = ort_outs.transpose((0,2,3,1))[0]
50
+ return np.clip(255 * img_fake, 0, 255).astype(np.uint8)[:,:,::-1]
51
+
52
+
53
+ img_fake, M = self.model_swap_insightface.get(temp_frame, target_face, source_face, paste_back=False)
54
+ # target_face.matrix = M
55
+ # return img_fake
56
+
57
+
58
+ def Release(self):
59
+ del self.model_swap_insightface
60
+ self.model_swap_insightface = None
61
+
62
+
63
+
64
+
65
+
66
+
roop/processors/Mask_Clip2Seg.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ import threading
5
+ from torchvision import transforms
6
+ from clip.clipseg import CLIPDensePredT
7
+ import numpy as np
8
+
9
+ from roop.typing import Frame
10
+
11
+ THREAD_LOCK_CLIP = threading.Lock()
12
+
13
+
14
+ class Mask_Clip2Seg():
15
+
16
+ model_clip = None
17
+
18
+ processorname = 'clip2seg'
19
+ type = 'mask'
20
+
21
+
22
+ def Initialize(self, devicename):
23
+ if self.model_clip is None:
24
+ self.model_clip = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
25
+ self.model_clip.eval();
26
+ self.model_clip.load_state_dict(torch.load('models/CLIP/rd64-uni-refined.pth', map_location=torch.device('cpu')), strict=False)
27
+
28
+ device = torch.device(devicename)
29
+ self.model_clip.to(device)
30
+
31
+
32
+ def Run(self, img1, keywords:str) -> Frame:
33
+ if keywords is None or len(keywords) < 1 or img1 is None:
34
+ return img1
35
+
36
+ source_image_small = cv2.resize(img1, (256,256))
37
+
38
+ img_mask = np.full((source_image_small.shape[0],source_image_small.shape[1]), 0, dtype=np.float32)
39
+ mask_border = 1
40
+ l = 0
41
+ t = 0
42
+ r = 1
43
+ b = 1
44
+
45
+ mask_blur = 5
46
+ clip_blur = 5
47
+
48
+ img_mask = cv2.rectangle(img_mask, (mask_border+int(l), mask_border+int(t)),
49
+ (256 - mask_border-int(r), 256-mask_border-int(b)), (255, 255, 255), -1)
50
+ img_mask = cv2.GaussianBlur(img_mask, (mask_blur*2+1,mask_blur*2+1), 0)
51
+ img_mask /= 255
52
+
53
+
54
+ input_image = source_image_small
55
+
56
+ transform = transforms.Compose([
57
+ transforms.ToTensor(),
58
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
59
+ transforms.Resize((256, 256)),
60
+ ])
61
+ img = transform(input_image).unsqueeze(0)
62
+
63
+ thresh = 0.5
64
+ prompts = keywords.split(',')
65
+ with THREAD_LOCK_CLIP:
66
+ with torch.no_grad():
67
+ preds = self.model_clip(img.repeat(len(prompts),1,1,1), prompts)[0]
68
+ clip_mask = torch.sigmoid(preds[0][0])
69
+ for i in range(len(prompts)-1):
70
+ clip_mask += torch.sigmoid(preds[i+1][0])
71
+
72
+ clip_mask = clip_mask.data.cpu().numpy()
73
+ np.clip(clip_mask, 0, 1)
74
+
75
+ clip_mask[clip_mask>thresh] = 1.0
76
+ clip_mask[clip_mask<=thresh] = 0.0
77
+ kernel = np.ones((5, 5), np.float32)
78
+ clip_mask = cv2.dilate(clip_mask, kernel, iterations=1)
79
+ clip_mask = cv2.GaussianBlur(clip_mask, (clip_blur*2+1,clip_blur*2+1), 0)
80
+
81
+ img_mask *= clip_mask
82
+ img_mask[img_mask<0.0] = 0.0
83
+ return img_mask
84
+
85
+
86
+
87
+ def Release(self):
88
+ self.model_clip = None
89
+
roop/processors/Mask_XSeg.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import onnxruntime
4
+ import threading
5
+ import roop.globals
6
+
7
+ from roop.typing import Frame
8
+ from roop.utilities import resolve_relative_path
9
+
10
+ THREAD_LOCK_CLIP = threading.Lock()
11
+
12
+
13
+ class Mask_XSeg():
14
+
15
+ model_xseg = None
16
+
17
+ processorname = 'xseg'
18
+ type = 'mask'
19
+
20
+
21
+ def Initialize(self, devicename):
22
+ if self.model_xseg is None:
23
+ model_path = resolve_relative_path('../models/xseg.onnx')
24
+ onnxruntime.set_default_logger_severity(3)
25
+ self.model_xseg = onnxruntime.InferenceSession(model_path, None, providers=roop.globals.execution_providers)
26
+ self.model_inputs = self.model_xseg.get_inputs()
27
+ self.model_outputs = self.model_xseg.get_outputs()
28
+
29
+ # replace Mac mps with cpu for the moment
30
+ devicename = devicename.replace('mps', 'cpu')
31
+ self.devicename = devicename
32
+
33
+
34
+ def Run(self, img1, keywords:str) -> Frame:
35
+ temp_frame = cv2.resize(img1, (256, 256), cv2.INTER_CUBIC)
36
+ temp_frame = temp_frame.astype('float32') / 255.0
37
+ temp_frame = temp_frame[None, ...]
38
+ io_binding = self.model_xseg.io_binding()
39
+ io_binding.bind_cpu_input(self.model_inputs[0].name, temp_frame)
40
+ io_binding.bind_output(self.model_outputs[0].name, self.devicename)
41
+ self.model_xseg.run_with_iobinding(io_binding)
42
+ ort_outs = io_binding.copy_outputs_to_cpu()
43
+ result = ort_outs[0][0]
44
+ result = np.clip(result, 0, 1.0)
45
+ result[result < 0.1] = 0
46
+ # invert values to mask areas to keep
47
+ result = 1.0 - result
48
+ return result
49
+
50
+
51
+ def Release(self):
52
+ self.model_xseg = None
53
+
54
+
roop/processors/__init__.py ADDED
File without changes