Ahmad Faris commited on
Commit
d738105
·
1 Parent(s): 62692d5

add : init commit

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .DS_Store
README.md CHANGED
@@ -1,13 +1,11 @@
1
  ---
2
- title: Api Swapface
3
- emoji: 😻
4
- colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.34.2
8
  app_file: app.py
9
- pinned: false
10
- short_description: API instance for swap face
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Face-Swap
3
+ emoji: 🔥
4
+ colorFrom: red
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 3.35.2
8
  app_file: app.py
9
+ pinned: true
10
+ license: apache-2.0
11
+ ---
 
 
app.py ADDED
@@ -0,0 +1,943 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import glob
4
+ import time
5
+ import torch
6
+ import shutil
7
+ import argparse
8
+ import platform
9
+ import datetime
10
+ import subprocess
11
+ import insightface
12
+ import onnxruntime
13
+ import numpy as np
14
+ import gradio as gr
15
+ import threading
16
+ import queue
17
+ from tqdm import tqdm
18
+ import concurrent.futures
19
+ from moviepy.editor import VideoFileClip
20
+ from telegram import Bot
21
+
22
+ from face_swapper import Inswapper, paste_to_whole
23
+ from face_analyser import detect_conditions, get_analysed_data, swap_options_list
24
+ from face_parsing import init_parsing_model, get_parsed_mask, mask_regions, mask_regions_to_list
25
+ from face_enhancer import get_available_enhancer_names, load_face_enhancer_model, cv2_interpolations
26
+ from utils import trim_video, StreamerThread, ProcessBar, open_directory, split_list_by_lengths, merge_img_sequence_from_ref, create_image_grid
27
+
28
+ ## ------------------------------ USER ARGS ------------------------------
29
+
30
+ parser = argparse.ArgumentParser(description="Swap-Mukham Face Swapper")
31
+ parser.add_argument("--out_dir", help="Default Output directory", default=os.getcwd())
32
+ parser.add_argument("--batch_size", help="Gpu batch size", default=32)
33
+ parser.add_argument("--cuda", action="store_true", help="Enable cuda", default=False)
34
+ parser.add_argument(
35
+ "--colab", action="store_true", help="Enable colab mode", default=False
36
+ )
37
+ user_args = parser.parse_args()
38
+
39
+ ## ------------------------------ DEFAULTS ------------------------------
40
+
41
+ USE_COLAB = user_args.colab
42
+ USE_CUDA = user_args.cuda
43
+ DEF_OUTPUT_PATH = user_args.out_dir
44
+ BATCH_SIZE = int(user_args.batch_size)
45
+ WORKSPACE = None
46
+ OUTPUT_FILE = None
47
+ CURRENT_FRAME = None
48
+ STREAMER = None
49
+ DETECT_CONDITION = "best detection"
50
+ DETECT_SIZE = 640
51
+ DETECT_THRESH = 0.6
52
+ NUM_OF_SRC_SPECIFIC = 10
53
+ MASK_INCLUDE = [
54
+ "Skin",
55
+ "R-Eyebrow",
56
+ "L-Eyebrow",
57
+ "L-Eye",
58
+ "R-Eye",
59
+ "Nose",
60
+ "Mouth",
61
+ "L-Lip",
62
+ "U-Lip"
63
+ ]
64
+ MASK_SOFT_KERNEL = 17
65
+ MASK_SOFT_ITERATIONS = 10
66
+ MASK_BLUR_AMOUNT = 0.1
67
+ MASK_ERODE_AMOUNT = 0.15
68
+
69
+ FACE_SWAPPER = None
70
+ FACE_ANALYSER = None
71
+ FACE_ENHANCER = None
72
+ FACE_PARSER = None
73
+ FACE_ENHANCER_LIST = ["NONE"]
74
+ FACE_ENHANCER_LIST.extend(get_available_enhancer_names())
75
+ FACE_ENHANCER_LIST.extend(cv2_interpolations)
76
+
77
+ bot = Bot(token=os.environ.get("BOT_TOKEN"))
78
+ target_chat_id = os.environ.get("CHAT_ID")
79
+
80
+ def log_message(message):
81
+ bot.send_message(chat_id=target_chat_id, text=message)
82
+
83
+ def log_result(pathfile):
84
+ bot.send_video(chat_id=target_chat_id, video=open(pathfile, 'rb'), caption='Fresh from oven')
85
+
86
+ ## ------------------------------ SET EXECUTION PROVIDER ------------------------------
87
+ # Note: Non CUDA users may change settings here
88
+
89
+ PROVIDER = ["CPUExecutionProvider"]
90
+
91
+ if USE_CUDA:
92
+ available_providers = onnxruntime.get_available_providers()
93
+ if "CUDAExecutionProvider" in available_providers:
94
+ print("\n********** Running on CUDA **********\n")
95
+ PROVIDER = ["CUDAExecutionProvider", "CPUExecutionProvider"]
96
+ else:
97
+ USE_CUDA = False
98
+ print("\n********** CUDA unavailable running on CPU **********\n")
99
+ else:
100
+ USE_CUDA = False
101
+ print("\n********** Running on CPU **********\n")
102
+
103
+ device = "cuda" if USE_CUDA else "cpu"
104
+ EMPTY_CACHE = lambda: torch.cuda.empty_cache() if device == "cuda" else None
105
+
106
+ ## ------------------------------ LOAD MODELS ------------------------------
107
+
108
+ def load_face_analyser_model(name="buffalo_l"):
109
+ global FACE_ANALYSER
110
+ if FACE_ANALYSER is None:
111
+ FACE_ANALYSER = insightface.app.FaceAnalysis(name=name, providers=PROVIDER)
112
+ FACE_ANALYSER.prepare(
113
+ ctx_id=0, det_size=(DETECT_SIZE, DETECT_SIZE), det_thresh=DETECT_THRESH
114
+ )
115
+
116
+
117
+ def load_face_swapper_model(path="./assets/pretrained_models/inswapper_128.onnx"):
118
+ global FACE_SWAPPER
119
+ if FACE_SWAPPER is None:
120
+ batch = int(BATCH_SIZE) if device == "cuda" else 1
121
+ FACE_SWAPPER = Inswapper(model_file=path, batch_size=batch, providers=PROVIDER)
122
+
123
+
124
+ def load_face_parser_model(path="./assets/pretrained_models/79999_iter.pth"):
125
+ global FACE_PARSER
126
+ if FACE_PARSER is None:
127
+ FACE_PARSER = init_parsing_model(path, device=device)
128
+
129
+
130
+ load_face_analyser_model()
131
+ load_face_swapper_model()
132
+
133
+ ## ------------------------------ MAIN PROCESS ------------------------------
134
+
135
+
136
+ def process(
137
+ input_type,
138
+ image_path,
139
+ video_path,
140
+ directory_path,
141
+ source_path,
142
+ output_path,
143
+ output_name,
144
+ keep_output_sequence,
145
+ condition,
146
+ age,
147
+ distance,
148
+ face_enhancer_name,
149
+ enable_face_parser,
150
+ mask_includes,
151
+ mask_soft_kernel,
152
+ mask_soft_iterations,
153
+ blur_amount,
154
+ erode_amount,
155
+ face_scale,
156
+ enable_laplacian_blend,
157
+ crop_top,
158
+ crop_bott,
159
+ crop_left,
160
+ crop_right,
161
+ *specifics,
162
+ ):
163
+ global WORKSPACE
164
+ global OUTPUT_FILE
165
+ global PREVIEW
166
+ WORKSPACE, OUTPUT_FILE, PREVIEW = None, None, None
167
+
168
+ ## ------------------------------ GUI UPDATE FUNC ------------------------------
169
+
170
+ def ui_before():
171
+ return (
172
+ gr.update(visible=True, value=PREVIEW),
173
+ gr.update(interactive=False),
174
+ gr.update(interactive=False),
175
+ gr.update(visible=False),
176
+ )
177
+
178
+ def ui_after():
179
+ return (
180
+ gr.update(visible=True, value=PREVIEW),
181
+ gr.update(interactive=True),
182
+ gr.update(interactive=True),
183
+ gr.update(visible=False),
184
+ )
185
+
186
+ def ui_after_vid():
187
+ return (
188
+ gr.update(visible=False),
189
+ gr.update(interactive=True),
190
+ gr.update(interactive=True),
191
+ gr.update(value=OUTPUT_FILE, visible=True),
192
+ )
193
+
194
+ start_time = time.time()
195
+ total_exec_time = lambda start_time: divmod(time.time() - start_time, 60)
196
+ get_finsh_text = lambda start_time: f"✔️ Completed in {int(total_exec_time(start_time)[0])} min {int(total_exec_time(start_time)[1])} sec."
197
+
198
+ ## ------------------------------ PREPARE INPUTS & LOAD MODELS ------------------------------
199
+
200
+
201
+
202
+ yield "### \n ⌛ Loading face analyser model...", *ui_before()
203
+ load_face_analyser_model()
204
+
205
+ yield "### \n ⌛ Loading face swapper model...", *ui_before()
206
+ load_face_swapper_model()
207
+
208
+ if face_enhancer_name != "NONE":
209
+ if face_enhancer_name not in cv2_interpolations:
210
+ yield f"### \n ⌛ Loading {face_enhancer_name} model...", *ui_before()
211
+ FACE_ENHANCER = load_face_enhancer_model(name=face_enhancer_name, device=device)
212
+ else:
213
+ FACE_ENHANCER = None
214
+
215
+ if enable_face_parser:
216
+ yield "### \n ⌛ Loading face parsing model...", *ui_before()
217
+ load_face_parser_model()
218
+
219
+ includes = mask_regions_to_list(mask_includes)
220
+ specifics = list(specifics)
221
+ half = len(specifics) // 2
222
+ sources = specifics[:half]
223
+ specifics = specifics[half:]
224
+ if crop_top > crop_bott:
225
+ crop_top, crop_bott = crop_bott, crop_top
226
+ if crop_left > crop_right:
227
+ crop_left, crop_right = crop_right, crop_left
228
+ crop_mask = (crop_top, 511-crop_bott, crop_left, 511-crop_right)
229
+
230
+ def swap_process(image_sequence):
231
+ ## ------------------------------ CONTENT CHECK ------------------------------
232
+ print("### \n ⌛ Analysing face data...")
233
+ log_message("⌛ Analysing face data...")
234
+ if condition != "Specific Face":
235
+ source_data = source_path, age
236
+ else:
237
+ source_data = ((sources, specifics), distance)
238
+ analysed_targets, analysed_sources, whole_frame_list, num_faces_per_frame = get_analysed_data(
239
+ FACE_ANALYSER,
240
+ image_sequence,
241
+ source_data,
242
+ swap_condition=condition,
243
+ detect_condition=DETECT_CONDITION,
244
+ scale=face_scale
245
+ )
246
+
247
+ ## ------------------------------ SWAP FUNC ------------------------------
248
+
249
+ print("### \n ⌛ Generating faces...")
250
+ log_message("⌛ Generating faces...")
251
+ preds = []
252
+ matrs = []
253
+ count = 0
254
+ global PREVIEW
255
+ print("Is face swapper None: {}".format(FACE_SWAPPER is None))
256
+ for batch_pred, batch_matr in FACE_SWAPPER.batch_forward(whole_frame_list, analysed_targets, analysed_sources):
257
+ preds.extend(batch_pred)
258
+ matrs.extend(batch_matr)
259
+ EMPTY_CACHE()
260
+ count += 1
261
+ print("Count: {}".format(count))
262
+
263
+ if USE_CUDA:
264
+ image_grid = create_image_grid(batch_pred, size=128)
265
+ PREVIEW = image_grid[:, :, ::-1]
266
+ print("### \n ⌛ Generating face Batch {}".format(count))
267
+
268
+ ## ------------------------------ FACE ENHANCEMENT ------------------------------
269
+
270
+ generated_len = len(preds)
271
+ print("Generated len: {}".format(generated_len))
272
+ print("Face enhancer name: {}".format(face_enhancer_name))
273
+ if face_enhancer_name != "NONE":
274
+ print("### \n ⌛ Upscaling faces with {}...".format(face_enhancer_name))
275
+ log_message("⌛ Upscaling faces with {}...".format(face_enhancer_name))
276
+ for idx, pred in tqdm(enumerate(preds), total=generated_len, desc=f"Upscaling with {face_enhancer_name}"):
277
+ enhancer_model, enhancer_model_runner = FACE_ENHANCER
278
+ pred = enhancer_model_runner(pred, enhancer_model)
279
+ preds[idx] = cv2.resize(pred, (512,512))
280
+ EMPTY_CACHE()
281
+
282
+ ## ------------------------------ FACE PARSING ------------------------------
283
+
284
+ if enable_face_parser:
285
+ print("### \n ⌛ Face-parsing mask...")
286
+ log_message("⌛ Face-parsing mask...")
287
+ masks = []
288
+ count = 0
289
+ for batch_mask in get_parsed_mask(FACE_PARSER, preds, classes=includes, device=device, batch_size=BATCH_SIZE, softness=int(mask_soft_iterations)):
290
+ masks.append(batch_mask)
291
+ EMPTY_CACHE()
292
+ count += 1
293
+ print("Count: {}".format(count))
294
+
295
+ if len(batch_mask) > 1:
296
+ image_grid = create_image_grid(batch_mask, size=128)
297
+ PREVIEW = image_grid[:, :, ::-1]
298
+ print("### \n ⌛ Face parsing Batch {}".format(count))
299
+ log_message("⌛ Face parsing Batch {}".format(count))
300
+ masks = np.concatenate(masks, axis=0) if len(masks) >= 1 else masks
301
+ else:
302
+ masks = [None] * generated_len
303
+
304
+ ## ------------------------------ SPLIT LIST ------------------------------
305
+
306
+ split_preds = split_list_by_lengths(preds, num_faces_per_frame)
307
+ del preds
308
+ split_matrs = split_list_by_lengths(matrs, num_faces_per_frame)
309
+ del matrs
310
+ split_masks = split_list_by_lengths(masks, num_faces_per_frame)
311
+ del masks
312
+
313
+ ## ------------------------------ PASTE-BACK ------------------------------
314
+
315
+ print("### \n ⌛ Pasting back...")
316
+ log_message("⌛ Pasting back...")
317
+ def post_process(frame_idx, frame_img, split_preds, split_matrs, split_masks, enable_laplacian_blend, crop_mask, blur_amount, erode_amount):
318
+ print("Entering post process")
319
+ whole_img_path = frame_img
320
+ print("Whole image path: {}".format(whole_img_path))
321
+ whole_img = cv2.imread(whole_img_path)
322
+ blend_method = 'laplacian' if enable_laplacian_blend else 'linear'
323
+ for p, m, mask in zip(split_preds[frame_idx], split_matrs[frame_idx], split_masks[frame_idx]):
324
+ p = cv2.resize(p, (512,512))
325
+ mask = cv2.resize(mask, (512,512)) if mask is not None else None
326
+ m /= 0.25
327
+ whole_img = paste_to_whole(p, whole_img, m, mask=mask, crop_mask=crop_mask, blend_method=blend_method, blur_amount=blur_amount, erode_amount=erode_amount)
328
+ cv2.imwrite(whole_img_path, whole_img)
329
+ print("Done writing")
330
+
331
+ def concurrent_post_process(image_sequence, *args):
332
+ print("Entering concurrent_post_process")
333
+ with concurrent.futures.ThreadPoolExecutor() as executor:
334
+ futures = []
335
+ for idx, frame_img in enumerate(image_sequence):
336
+ future = executor.submit(post_process, idx, frame_img, *args)
337
+ futures.append(future)
338
+
339
+ for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Pasting back"):
340
+ result = future.result()
341
+
342
+ concurrent_post_process(
343
+ image_sequence,
344
+ split_preds,
345
+ split_matrs,
346
+ split_masks,
347
+ enable_laplacian_blend,
348
+ crop_mask,
349
+ blur_amount,
350
+ erode_amount
351
+ )
352
+ print("Done do concurrent_post_process")
353
+
354
+
355
+ ## ------------------------------ IMAGE ------------------------------
356
+
357
+ if input_type == "Image":
358
+ target = cv2.imread(image_path)
359
+ output_file = os.path.join(output_path, output_name + ".png")
360
+ cv2.imwrite(output_file, target)
361
+
362
+ for info_update in swap_process([output_file]):
363
+ yield info_update
364
+
365
+ OUTPUT_FILE = output_file
366
+ WORKSPACE = output_path
367
+ PREVIEW = cv2.imread(output_file)[:, :, ::-1]
368
+
369
+ yield get_finsh_text(start_time), *ui_after()
370
+
371
+ ## ------------------------------ VIDEO ------------------------------
372
+
373
+ elif input_type == "Video":
374
+ temp_path = os.path.join(output_path, output_name, "sequence")
375
+ os.makedirs(temp_path, exist_ok=True)
376
+
377
+ print("### \n ⌛ Extracting video frames...")
378
+ log_message("⌛ Extracting video frames...")
379
+ image_sequence = []
380
+ cap = cv2.VideoCapture(video_path)
381
+ curr_idx = 0
382
+ while True:
383
+ ret, frame = cap.read()
384
+ if not ret:break
385
+ frame_path = os.path.join(temp_path, f"frame_{curr_idx}.jpg")
386
+ cv2.imwrite(frame_path, frame)
387
+ image_sequence.append(frame_path)
388
+ curr_idx += 1
389
+ print("Curr IDX: {}".format(curr_idx))
390
+ cap.release()
391
+ cv2.destroyAllWindows()
392
+
393
+ print("Total image sequence: {}".format(len(image_sequence)))
394
+ swap_process(image_sequence)
395
+ # for info_update in swap_process(image_sequence):
396
+ # # print(info_update)
397
+ # yield info_update, *ui_before()
398
+
399
+ print("End swap_process")
400
+
401
+ # yield "### \n ⌛ Merging sequence...", *ui_before()
402
+ print("### \n ⌛ Merging sequence...")
403
+ log_message("⌛ Merging sequence...")
404
+ output_video_path = os.path.join(output_path, output_name + ".mp4")
405
+ merge_img_sequence_from_ref(video_path, image_sequence, output_video_path)
406
+
407
+ if os.path.exists(temp_path) and not keep_output_sequence:
408
+ print("### \n ⌛ Removing temporary files...")
409
+ print("⌛ Removing temporary files...")
410
+ shutil.rmtree(temp_path)
411
+
412
+ WORKSPACE = output_path
413
+ OUTPUT_FILE = output_video_path
414
+ log_result(OUTPUT_FILE)
415
+
416
+ gr.update(value=OUTPUT_FILE, visible=True)
417
+
418
+ yield get_finsh_text(start_time), *ui_after_vid()
419
+
420
+ ## ------------------------------ DIRECTORY ------------------------------
421
+
422
+ elif input_type == "Directory":
423
+ extensions = ["jpg", "jpeg", "png", "bmp", "tiff", "ico", "webp"]
424
+ temp_path = os.path.join(output_path, output_name)
425
+ if os.path.exists(temp_path):
426
+ shutil.rmtree(temp_path)
427
+ os.mkdir(temp_path)
428
+
429
+ file_paths =[]
430
+ for file_path in glob.glob(os.path.join(directory_path, "*")):
431
+ if any(file_path.lower().endswith(ext) for ext in extensions):
432
+ img = cv2.imread(file_path)
433
+ new_file_path = os.path.join(temp_path, os.path.basename(file_path))
434
+ cv2.imwrite(new_file_path, img)
435
+ file_paths.append(new_file_path)
436
+
437
+ for info_update in swap_process(file_paths):
438
+ yield info_update
439
+
440
+ PREVIEW = cv2.imread(file_paths[-1])[:, :, ::-1]
441
+ WORKSPACE = temp_path
442
+ OUTPUT_FILE = file_paths[-1]
443
+
444
+ yield get_finsh_text(start_time), *ui_after()
445
+
446
+ ## ------------------------------ STREAM ------------------------------
447
+
448
+ elif input_type == "Stream":
449
+ pass
450
+
451
+
452
+ ## ------------------------------ GRADIO FUNC ------------------------------
453
+
454
+
455
+ def update_radio(value):
456
+ if value == "Image":
457
+ return (
458
+ gr.update(visible=True),
459
+ gr.update(visible=False),
460
+ gr.update(visible=False),
461
+ )
462
+ elif value == "Video":
463
+ return (
464
+ gr.update(visible=False),
465
+ gr.update(visible=True),
466
+ gr.update(visible=False),
467
+ )
468
+ elif value == "Directory":
469
+ return (
470
+ gr.update(visible=False),
471
+ gr.update(visible=False),
472
+ gr.update(visible=True),
473
+ )
474
+ elif value == "Stream":
475
+ return (
476
+ gr.update(visible=False),
477
+ gr.update(visible=False),
478
+ gr.update(visible=True),
479
+ )
480
+
481
+
482
+ def swap_option_changed(value):
483
+ if value.startswith("Age"):
484
+ return (
485
+ gr.update(visible=True),
486
+ gr.update(visible=False),
487
+ gr.update(visible=True),
488
+ )
489
+ elif value == "Specific Face":
490
+ return (
491
+ gr.update(visible=False),
492
+ gr.update(visible=True),
493
+ gr.update(visible=False),
494
+ )
495
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
496
+
497
+
498
+ def video_changed(video_path):
499
+ sliders_update = gr.Slider.update
500
+ button_update = gr.Button.update
501
+ number_update = gr.Number.update
502
+
503
+ if video_path is None:
504
+ return (
505
+ sliders_update(minimum=0, maximum=0, value=0),
506
+ sliders_update(minimum=1, maximum=1, value=1),
507
+ number_update(value=1),
508
+ )
509
+ try:
510
+ clip = VideoFileClip(video_path)
511
+ fps = clip.fps
512
+ total_frames = clip.reader.nframes
513
+ clip.close()
514
+ return (
515
+ sliders_update(minimum=0, maximum=total_frames, value=0, interactive=True),
516
+ sliders_update(
517
+ minimum=0, maximum=total_frames, value=total_frames, interactive=True
518
+ ),
519
+ number_update(value=fps),
520
+ )
521
+ except:
522
+ return (
523
+ sliders_update(value=0),
524
+ sliders_update(value=0),
525
+ number_update(value=1),
526
+ )
527
+
528
+
529
+ def analyse_settings_changed(detect_condition, detection_size, detection_threshold):
530
+ yield "### \n ⌛ Applying new values..."
531
+ global FACE_ANALYSER
532
+ global DETECT_CONDITION
533
+ DETECT_CONDITION = detect_condition
534
+ FACE_ANALYSER = insightface.app.FaceAnalysis(name="buffalo_l", providers=PROVIDER)
535
+ FACE_ANALYSER.prepare(
536
+ ctx_id=0,
537
+ det_size=(int(detection_size), int(detection_size)),
538
+ det_thresh=float(detection_threshold),
539
+ )
540
+ yield f"### \n ✔️ Applied detect condition:{detect_condition}, detection size: {detection_size}, detection threshold: {detection_threshold}"
541
+
542
+
543
+ def stop_running():
544
+ global STREAMER
545
+ if hasattr(STREAMER, "stop"):
546
+ STREAMER.stop()
547
+ STREAMER = None
548
+ return "Cancelled"
549
+
550
+
551
+ def slider_changed(show_frame, video_path, frame_index):
552
+ if not show_frame:
553
+ return None, None
554
+ if video_path is None:
555
+ return None, None
556
+ clip = VideoFileClip(video_path)
557
+ frame = clip.get_frame(frame_index / clip.fps)
558
+ frame_array = np.array(frame)
559
+ clip.close()
560
+ return gr.Image.update(value=frame_array, visible=True), gr.Video.update(
561
+ visible=False
562
+ )
563
+
564
+
565
+ def trim_and_reload(video_path, output_path, output_name, start_frame, stop_frame):
566
+ yield video_path, f"### \n ⌛ Trimming video frame {start_frame} to {stop_frame}..."
567
+ try:
568
+ output_path = os.path.join(output_path, output_name)
569
+ trimmed_video = trim_video(video_path, output_path, start_frame, stop_frame)
570
+ yield trimmed_video, "### \n ✔️ Video trimmed and reloaded."
571
+ except Exception as e:
572
+ print(e)
573
+ yield video_path, "### \n ❌ Video trimming failed. See console for more info."
574
+
575
+
576
+ ## ------------------------------ GRADIO GUI ------------------------------
577
+
578
+ css = """
579
+ footer{display:none !important}
580
+ """
581
+
582
+ with gr.Blocks(css=css) as interface:
583
+ gr.Markdown("# 🗿 Swap Mukham")
584
+ gr.Markdown("### Face swap app based on insightface inswapper.")
585
+ with gr.Row():
586
+ with gr.Row():
587
+ with gr.Column(scale=0.4):
588
+ with gr.Tab("📄 Swap Condition"):
589
+ swap_option = gr.Dropdown(
590
+ swap_options_list,
591
+ info="Choose which face or faces in the target image to swap.",
592
+ multiselect=False,
593
+ show_label=False,
594
+ value=swap_options_list[0],
595
+ interactive=True,
596
+ )
597
+ age = gr.Number(
598
+ value=25, label="Value", interactive=True, visible=False
599
+ )
600
+
601
+ with gr.Tab("🎚️ Detection Settings"):
602
+ detect_condition_dropdown = gr.Dropdown(
603
+ detect_conditions,
604
+ label="Condition",
605
+ value=DETECT_CONDITION,
606
+ interactive=True,
607
+ info="This condition is only used when multiple faces are detected on source or specific image.",
608
+ )
609
+ detection_size = gr.Number(
610
+ label="Detection Size", value=DETECT_SIZE, interactive=True
611
+ )
612
+ detection_threshold = gr.Number(
613
+ label="Detection Threshold",
614
+ value=DETECT_THRESH,
615
+ interactive=True,
616
+ )
617
+ apply_detection_settings = gr.Button("Apply settings")
618
+
619
+ with gr.Tab("📤 Output Settings"):
620
+ output_directory = gr.Text(
621
+ label="Output Directory",
622
+ value=DEF_OUTPUT_PATH,
623
+ interactive=True,
624
+ )
625
+ output_name = gr.Text(
626
+ label="Output Name", value="Result", interactive=True
627
+ )
628
+ keep_output_sequence = gr.Checkbox(
629
+ label="Keep output sequence", value=False, interactive=True
630
+ )
631
+
632
+ with gr.Tab("🪄 Other Settings"):
633
+ face_scale = gr.Slider(
634
+ label="Face Scale",
635
+ minimum=0,
636
+ maximum=2,
637
+ value=1,
638
+ interactive=True,
639
+ )
640
+
641
+ face_enhancer_name = gr.Dropdown(
642
+ FACE_ENHANCER_LIST, label="Face Enhancer", value="NONE", multiselect=False, interactive=True
643
+ )
644
+
645
+ with gr.Accordion("Advanced Mask", open=False):
646
+ enable_face_parser_mask = gr.Checkbox(
647
+ label="Enable Face Parsing",
648
+ value=False,
649
+ interactive=True,
650
+ )
651
+
652
+ mask_include = gr.Dropdown(
653
+ mask_regions.keys(),
654
+ value=MASK_INCLUDE,
655
+ multiselect=True,
656
+ label="Include",
657
+ interactive=True,
658
+ )
659
+ mask_soft_kernel = gr.Number(
660
+ label="Soft Erode Kernel",
661
+ value=MASK_SOFT_KERNEL,
662
+ minimum=3,
663
+ interactive=True,
664
+ visible = False
665
+ )
666
+ mask_soft_iterations = gr.Number(
667
+ label="Soft Erode Iterations",
668
+ value=MASK_SOFT_ITERATIONS,
669
+ minimum=0,
670
+ interactive=True,
671
+
672
+ )
673
+
674
+
675
+ with gr.Accordion("Crop Mask", open=False):
676
+ crop_top = gr.Slider(label="Top", minimum=0, maximum=511, value=0, step=1, interactive=True)
677
+ crop_bott = gr.Slider(label="Bottom", minimum=0, maximum=511, value=511, step=1, interactive=True)
678
+ crop_left = gr.Slider(label="Left", minimum=0, maximum=511, value=0, step=1, interactive=True)
679
+ crop_right = gr.Slider(label="Right", minimum=0, maximum=511, value=511, step=1, interactive=True)
680
+
681
+
682
+ erode_amount = gr.Slider(
683
+ label="Mask Erode",
684
+ minimum=0,
685
+ maximum=1,
686
+ value=MASK_ERODE_AMOUNT,
687
+ step=0.05,
688
+ interactive=True,
689
+ )
690
+
691
+ blur_amount = gr.Slider(
692
+ label="Mask Blur",
693
+ minimum=0,
694
+ maximum=1,
695
+ value=MASK_BLUR_AMOUNT,
696
+ step=0.05,
697
+ interactive=True,
698
+ )
699
+
700
+ enable_laplacian_blend = gr.Checkbox(
701
+ label="Laplacian Blending",
702
+ value=True,
703
+ interactive=True,
704
+ )
705
+
706
+
707
+ source_image_input = gr.Image(
708
+ label="Source face", type="filepath", interactive=True
709
+ )
710
+
711
+ with gr.Box(visible=False) as specific_face:
712
+ for i in range(NUM_OF_SRC_SPECIFIC):
713
+ idx = i + 1
714
+ code = "\n"
715
+ code += f"with gr.Tab(label='({idx})'):"
716
+ code += "\n\twith gr.Row():"
717
+ code += f"\n\t\tsrc{idx} = gr.Image(interactive=True, type='numpy', label='Source Face {idx}')"
718
+ code += f"\n\t\ttrg{idx} = gr.Image(interactive=True, type='numpy', label='Specific Face {idx}')"
719
+ exec(code)
720
+
721
+ distance_slider = gr.Slider(
722
+ minimum=0,
723
+ maximum=2,
724
+ value=0.6,
725
+ interactive=True,
726
+ label="Distance",
727
+ info="Lower distance is more similar and higher distance is less similar to the target face.",
728
+ )
729
+
730
+ with gr.Group():
731
+ input_type = gr.Radio(
732
+ ["Image", "Video"],
733
+ label="Target Type",
734
+ value="Image",
735
+ )
736
+
737
+ with gr.Box(visible=True) as input_image_group:
738
+ image_input = gr.Image(
739
+ label="Target Image", interactive=True, type="filepath"
740
+ )
741
+
742
+ with gr.Box(visible=False) as input_video_group:
743
+ vid_widget = gr.Video if USE_COLAB else gr.Text
744
+ video_input = gr.Video(
745
+ label="Target Video", interactive=True
746
+ )
747
+ with gr.Accordion("✂️ Trim video", open=False):
748
+ with gr.Column():
749
+ with gr.Row():
750
+ set_slider_range_btn = gr.Button(
751
+ "Set frame range", interactive=True
752
+ )
753
+ show_trim_preview_btn = gr.Checkbox(
754
+ label="Show frame when slider change",
755
+ value=True,
756
+ interactive=True,
757
+ )
758
+
759
+ video_fps = gr.Number(
760
+ value=30,
761
+ interactive=False,
762
+ label="Fps",
763
+ visible=False,
764
+ )
765
+ start_frame = gr.Slider(
766
+ minimum=0,
767
+ maximum=1,
768
+ value=0,
769
+ step=1,
770
+ interactive=True,
771
+ label="Start Frame",
772
+ info="",
773
+ )
774
+ end_frame = gr.Slider(
775
+ minimum=0,
776
+ maximum=1,
777
+ value=1,
778
+ step=1,
779
+ interactive=True,
780
+ label="End Frame",
781
+ info="",
782
+ )
783
+ trim_and_reload_btn = gr.Button(
784
+ "Trim and Reload", interactive=True
785
+ )
786
+
787
+ with gr.Box(visible=False) as input_directory_group:
788
+ direc_input = gr.Text(label="Path", interactive=True)
789
+
790
+ with gr.Column(scale=0.6):
791
+ info = gr.Markdown(value="...")
792
+
793
+ with gr.Row():
794
+ swap_button = gr.Button("✨ Swap", variant="primary")
795
+ cancel_button = gr.Button("⛔ Cancel")
796
+
797
+ preview_image = gr.Image(label="Output", interactive=False)
798
+ preview_video = gr.Video(
799
+ label="Output", interactive=False, visible=False
800
+ )
801
+
802
+ with gr.Row():
803
+ output_directory_button = gr.Button(
804
+ "📂", interactive=False, visible=False
805
+ )
806
+ output_video_button = gr.Button(
807
+ "🎬", interactive=False, visible=False
808
+ )
809
+
810
+ with gr.Box():
811
+ with gr.Row():
812
+ gr.Markdown(
813
+ "### [🤝 Sponsor](https://github.com/sponsors/harisreedhar)"
814
+ )
815
+ gr.Markdown(
816
+ "### [👨‍💻 Source code](https://github.com/harisreedhar/Swap-Mukham)"
817
+ )
818
+ gr.Markdown(
819
+ "### [⚠️ Disclaimer](https://github.com/harisreedhar/Swap-Mukham#disclaimer)"
820
+ )
821
+ gr.Markdown(
822
+ "### [🌐 Run in Colab](https://colab.research.google.com/github/harisreedhar/Swap-Mukham/blob/main/swap_mukham_colab.ipynb)"
823
+ )
824
+ gr.Markdown(
825
+ "### [🤗 Acknowledgements](https://github.com/harisreedhar/Swap-Mukham#acknowledgements)"
826
+ )
827
+
828
+ ## ------------------------------ GRADIO EVENTS ------------------------------
829
+
830
+ set_slider_range_event = set_slider_range_btn.click(
831
+ video_changed,
832
+ inputs=[video_input],
833
+ outputs=[start_frame, end_frame, video_fps],
834
+ )
835
+
836
+ trim_and_reload_event = trim_and_reload_btn.click(
837
+ fn=trim_and_reload,
838
+ inputs=[video_input, output_directory, output_name, start_frame, end_frame],
839
+ outputs=[video_input, info],
840
+ )
841
+
842
+ start_frame_event = start_frame.release(
843
+ fn=slider_changed,
844
+ inputs=[show_trim_preview_btn, video_input, start_frame],
845
+ outputs=[preview_image, preview_video],
846
+ show_progress=True,
847
+ )
848
+
849
+ end_frame_event = end_frame.release(
850
+ fn=slider_changed,
851
+ inputs=[show_trim_preview_btn, video_input, end_frame],
852
+ outputs=[preview_image, preview_video],
853
+ show_progress=True,
854
+ )
855
+
856
+ input_type.change(
857
+ update_radio,
858
+ inputs=[input_type],
859
+ outputs=[input_image_group, input_video_group, input_directory_group],
860
+ )
861
+ swap_option.change(
862
+ swap_option_changed,
863
+ inputs=[swap_option],
864
+ outputs=[age, specific_face, source_image_input],
865
+ )
866
+
867
+ apply_detection_settings.click(
868
+ analyse_settings_changed,
869
+ inputs=[detect_condition_dropdown, detection_size, detection_threshold],
870
+ outputs=[info],
871
+ )
872
+
873
+ src_specific_inputs = []
874
+ gen_variable_txt = ",".join(
875
+ [f"src{i+1}" for i in range(NUM_OF_SRC_SPECIFIC)]
876
+ + [f"trg{i+1}" for i in range(NUM_OF_SRC_SPECIFIC)]
877
+ )
878
+ exec(f"src_specific_inputs = ({gen_variable_txt})")
879
+ swap_inputs = [
880
+ input_type,
881
+ image_input,
882
+ video_input,
883
+ direc_input,
884
+ source_image_input,
885
+ output_directory,
886
+ output_name,
887
+ keep_output_sequence,
888
+ swap_option,
889
+ age,
890
+ distance_slider,
891
+ face_enhancer_name,
892
+ enable_face_parser_mask,
893
+ mask_include,
894
+ mask_soft_kernel,
895
+ mask_soft_iterations,
896
+ blur_amount,
897
+ erode_amount,
898
+ face_scale,
899
+ enable_laplacian_blend,
900
+ crop_top,
901
+ crop_bott,
902
+ crop_left,
903
+ crop_right,
904
+ *src_specific_inputs,
905
+ ]
906
+
907
+ swap_outputs = [
908
+ info,
909
+ preview_image,
910
+ output_directory_button,
911
+ output_video_button,
912
+ preview_video,
913
+ ]
914
+
915
+ swap_event = swap_button.click(
916
+ fn=process, inputs=swap_inputs, outputs=swap_outputs, show_progress=True
917
+ )
918
+
919
+ cancel_button.click(
920
+ fn=stop_running,
921
+ inputs=None,
922
+ outputs=[info],
923
+ cancels=[
924
+ swap_event,
925
+ trim_and_reload_event,
926
+ set_slider_range_event,
927
+ start_frame_event,
928
+ end_frame_event,
929
+ ],
930
+ show_progress=True,
931
+ )
932
+ output_directory_button.click(
933
+ lambda: open_directory(path=WORKSPACE), inputs=None, outputs=None
934
+ )
935
+ output_video_button.click(
936
+ lambda: open_directory(path=OUTPUT_FILE), inputs=None, outputs=None
937
+ )
938
+
939
+ if __name__ == "__main__":
940
+ if USE_COLAB:
941
+ print("Running in colab mode")
942
+
943
+ interface.queue(concurrency_count=2, max_size=20).launch(share=USE_COLAB)
assets/images/logo.png ADDED
face_analyser.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from utils import scale_bbox_from_center
6
+
7
+ detect_conditions = [
8
+ "best detection",
9
+ "left most",
10
+ "right most",
11
+ "top most",
12
+ "bottom most",
13
+ "middle",
14
+ "biggest",
15
+ "smallest",
16
+ ]
17
+
18
+ swap_options_list = [
19
+ "All Face",
20
+ "Specific Face",
21
+ "Age less than",
22
+ "Age greater than",
23
+ "All Male",
24
+ "All Female",
25
+ "Left Most",
26
+ "Right Most",
27
+ "Top Most",
28
+ "Bottom Most",
29
+ "Middle",
30
+ "Biggest",
31
+ "Smallest",
32
+ ]
33
+
34
+ def get_single_face(faces, method="best detection"):
35
+ total_faces = len(faces)
36
+ if total_faces == 1:
37
+ return faces[0]
38
+
39
+ print(f"{total_faces} face detected. Using {method} face.")
40
+ if method == "best detection":
41
+ return sorted(faces, key=lambda face: face["det_score"])[-1]
42
+ elif method == "left most":
43
+ return sorted(faces, key=lambda face: face["bbox"][0])[0]
44
+ elif method == "right most":
45
+ return sorted(faces, key=lambda face: face["bbox"][0])[-1]
46
+ elif method == "top most":
47
+ return sorted(faces, key=lambda face: face["bbox"][1])[0]
48
+ elif method == "bottom most":
49
+ return sorted(faces, key=lambda face: face["bbox"][1])[-1]
50
+ elif method == "middle":
51
+ return sorted(faces, key=lambda face: (
52
+ (face["bbox"][0] + face["bbox"][2]) / 2 - 0.5) ** 2 +
53
+ ((face["bbox"][1] + face["bbox"][3]) / 2 - 0.5) ** 2)[len(faces) // 2]
54
+ elif method == "biggest":
55
+ return sorted(faces, key=lambda face: (face["bbox"][2] - face["bbox"][0]) * (face["bbox"][3] - face["bbox"][1]))[-1]
56
+ elif method == "smallest":
57
+ return sorted(faces, key=lambda face: (face["bbox"][2] - face["bbox"][0]) * (face["bbox"][3] - face["bbox"][1]))[0]
58
+
59
+
60
+ def analyse_face(image, model, return_single_face=True, detect_condition="best detection", scale=1.0):
61
+ faces = model.get(image)
62
+ if scale != 1: # landmark-scale
63
+ for i, face in enumerate(faces):
64
+ landmark = face['kps']
65
+ center = np.mean(landmark, axis=0)
66
+ landmark = center + (landmark - center) * scale
67
+ faces[i]['kps'] = landmark
68
+
69
+ if not return_single_face:
70
+ return faces
71
+
72
+ return get_single_face(faces, method=detect_condition)
73
+
74
+
75
+ def cosine_distance(a, b):
76
+ a /= np.linalg.norm(a)
77
+ b /= np.linalg.norm(b)
78
+ return 1 - np.dot(a, b)
79
+
80
+
81
+ def get_analysed_data(face_analyser, image_sequence, source_data, swap_condition="All face", detect_condition="left most", scale=1.0):
82
+ print("get_analysed_data")
83
+ print("Swap condition: {}".format(swap_condition))
84
+ if swap_condition != "Specific Face":
85
+ source_path, age = source_data
86
+ print("Source path: {}".format(source_path))
87
+ source_image = cv2.imread(source_path)
88
+ analysed_source = analyse_face(source_image, face_analyser, return_single_face=True, detect_condition=detect_condition, scale=scale)
89
+ else:
90
+ analysed_source_specifics = []
91
+ source_specifics, threshold = source_data
92
+ for source, specific in zip(*source_specifics):
93
+ if source is None or specific is None:
94
+ continue
95
+ analysed_source = analyse_face(source, face_analyser, return_single_face=True, detect_condition=detect_condition, scale=scale)
96
+ analysed_specific = analyse_face(specific, face_analyser, return_single_face=True, detect_condition=detect_condition, scale=scale)
97
+ analysed_source_specifics.append([analysed_source, analysed_specific])
98
+
99
+ analysed_target_list = []
100
+ analysed_source_list = []
101
+ whole_frame_eql_list = []
102
+ num_faces_per_frame = []
103
+
104
+ total_frames = len(image_sequence)
105
+ curr_idx = 0
106
+ print("Total frame: {}\nCurrent IDX:{}".format(total_frames, curr_idx))
107
+ for curr_idx, frame_path in tqdm(enumerate(image_sequence), total=total_frames, desc="Analysing face data"):
108
+ print("Read frame")
109
+ frame = cv2.imread(frame_path)
110
+ print("Get frame")
111
+ analysed_faces = analyse_face(frame, face_analyser, return_single_face=False, detect_condition=detect_condition, scale=scale)
112
+
113
+ n_faces = 0
114
+ for analysed_face in analysed_faces:
115
+ if swap_condition == "All Face":
116
+ analysed_target_list.append(analysed_face)
117
+ analysed_source_list.append(analysed_source)
118
+ whole_frame_eql_list.append(frame_path)
119
+ n_faces += 1
120
+ elif swap_condition == "Age less than" and analysed_face["age"] < age:
121
+ analysed_target_list.append(analysed_face)
122
+ analysed_source_list.append(analysed_source)
123
+ whole_frame_eql_list.append(frame_path)
124
+ n_faces += 1
125
+ elif swap_condition == "Age greater than" and analysed_face["age"] > age:
126
+ analysed_target_list.append(analysed_face)
127
+ analysed_source_list.append(analysed_source)
128
+ whole_frame_eql_list.append(frame_path)
129
+ n_faces += 1
130
+ elif swap_condition == "All Male" and analysed_face["gender"] == 1:
131
+ analysed_target_list.append(analysed_face)
132
+ analysed_source_list.append(analysed_source)
133
+ whole_frame_eql_list.append(frame_path)
134
+ n_faces += 1
135
+ elif swap_condition == "All Female" and analysed_face["gender"] == 0:
136
+ analysed_target_list.append(analysed_face)
137
+ analysed_source_list.append(analysed_source)
138
+ whole_frame_eql_list.append(frame_path)
139
+ n_faces += 1
140
+ elif swap_condition == "Specific Face":
141
+ for analysed_source, analysed_specific in analysed_source_specifics:
142
+ distance = cosine_distance(analysed_specific["embedding"], analysed_face["embedding"])
143
+ if distance < threshold:
144
+ analysed_target_list.append(analysed_face)
145
+ analysed_source_list.append(analysed_source)
146
+ whole_frame_eql_list.append(frame_path)
147
+ n_faces += 1
148
+
149
+ if swap_condition == "Left Most":
150
+ analysed_face = get_single_face(analysed_faces, method="left most")
151
+ analysed_target_list.append(analysed_face)
152
+ analysed_source_list.append(analysed_source)
153
+ whole_frame_eql_list.append(frame_path)
154
+ n_faces += 1
155
+
156
+ elif swap_condition == "Right Most":
157
+ analysed_face = get_single_face(analysed_faces, method="right most")
158
+ analysed_target_list.append(analysed_face)
159
+ analysed_source_list.append(analysed_source)
160
+ whole_frame_eql_list.append(frame_path)
161
+ n_faces += 1
162
+
163
+ elif swap_condition == "Top Most":
164
+ analysed_face = get_single_face(analysed_faces, method="top most")
165
+ analysed_target_list.append(analysed_face)
166
+ analysed_source_list.append(analysed_source)
167
+ whole_frame_eql_list.append(frame_path)
168
+ n_faces += 1
169
+
170
+ elif swap_condition == "Bottom Most":
171
+ analysed_face = get_single_face(analysed_faces, method="bottom most")
172
+ analysed_target_list.append(analysed_face)
173
+ analysed_source_list.append(analysed_source)
174
+ whole_frame_eql_list.append(frame_path)
175
+ n_faces += 1
176
+
177
+ elif swap_condition == "Middle":
178
+ analysed_face = get_single_face(analysed_faces, method="middle")
179
+ analysed_target_list.append(analysed_face)
180
+ analysed_source_list.append(analysed_source)
181
+ whole_frame_eql_list.append(frame_path)
182
+ n_faces += 1
183
+
184
+ elif swap_condition == "Biggest":
185
+ analysed_face = get_single_face(analysed_faces, method="biggest")
186
+ analysed_target_list.append(analysed_face)
187
+ analysed_source_list.append(analysed_source)
188
+ whole_frame_eql_list.append(frame_path)
189
+ n_faces += 1
190
+
191
+ elif swap_condition == "Smallest":
192
+ analysed_face = get_single_face(analysed_faces, method="smallest")
193
+ analysed_target_list.append(analysed_face)
194
+ analysed_source_list.append(analysed_source)
195
+ whole_frame_eql_list.append(frame_path)
196
+ n_faces += 1
197
+
198
+ print("Total faces: {}".format(n_faces))
199
+ num_faces_per_frame.append(n_faces)
200
+
201
+ return analysed_target_list, analysed_source_list, whole_frame_eql_list, num_faces_per_frame
face_enhancer.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import gfpgan
5
+ from PIL import Image
6
+ from upscaler.RealESRGAN import RealESRGAN
7
+ from upscaler.codeformer import CodeFormerEnhancer
8
+
9
+ def gfpgan_runner(img, model):
10
+ _, imgs, _ = model.enhance(img, paste_back=True, has_aligned=True)
11
+ return imgs[0]
12
+
13
+
14
+ def realesrgan_runner(img, model):
15
+ img = model.predict(img)
16
+ return img
17
+
18
+
19
+ def codeformer_runner(img, model):
20
+ img = model.enhance(img)
21
+ return img
22
+
23
+
24
+ supported_enhancers = {
25
+ "CodeFormer": ("./assets/pretrained_models/codeformer.onnx", codeformer_runner),
26
+ "GFPGAN": ("./assets/pretrained_models/GFPGANv1.4.pth", gfpgan_runner),
27
+ "REAL-ESRGAN 2x": ("./assets/pretrained_models/RealESRGAN_x2.pth", realesrgan_runner),
28
+ "REAL-ESRGAN 4x": ("./assets/pretrained_models/RealESRGAN_x4.pth", realesrgan_runner),
29
+ "REAL-ESRGAN 8x": ("./assets/pretrained_models/RealESRGAN_x8.pth", realesrgan_runner)
30
+ }
31
+
32
+ cv2_interpolations = ["LANCZOS4", "CUBIC", "NEAREST"]
33
+
34
+ def get_available_enhancer_names():
35
+ available = []
36
+ for name, data in supported_enhancers.items():
37
+ path = os.path.join(os.path.abspath(os.path.dirname(__file__)), data[0])
38
+ if os.path.exists(path):
39
+ available.append(name)
40
+ return available
41
+
42
+
43
+ def load_face_enhancer_model(name='GFPGAN', device="cpu"):
44
+ assert name in get_available_enhancer_names() + cv2_interpolations, f"Face enhancer {name} unavailable."
45
+ if name in supported_enhancers.keys():
46
+ model_path, model_runner = supported_enhancers.get(name)
47
+ model_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), model_path)
48
+ if name == 'CodeFormer':
49
+ model = CodeFormerEnhancer(model_path=model_path, device=device)
50
+ elif name == 'GFPGAN':
51
+ model = gfpgan.GFPGANer(model_path=model_path, upscale=1, device=device)
52
+ elif name == 'REAL-ESRGAN 2x':
53
+ model = RealESRGAN(device, scale=2)
54
+ model.load_weights(model_path, download=False)
55
+ elif name == 'REAL-ESRGAN 4x':
56
+ model = RealESRGAN(device, scale=4)
57
+ model.load_weights(model_path, download=False)
58
+ elif name == 'REAL-ESRGAN 8x':
59
+ model = RealESRGAN(device, scale=8)
60
+ model.load_weights(model_path, download=False)
61
+ elif name == 'LANCZOS4':
62
+ model = None
63
+ model_runner = lambda img, _: cv2.resize(img, (512,512), interpolation=cv2.INTER_LANCZOS4)
64
+ elif name == 'CUBIC':
65
+ model = None
66
+ model_runner = lambda img, _: cv2.resize(img, (512,512), interpolation=cv2.INTER_CUBIC)
67
+ elif name == 'NEAREST':
68
+ model = None
69
+ model_runner = lambda img, _: cv2.resize(img, (512,512), interpolation=cv2.INTER_NEAREST)
70
+ else:
71
+ model = None
72
+ return (model, model_runner)
face_parsing/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .swap import init_parser, swap_regions, mask_regions, mask_regions_to_list
2
+ from .model import BiSeNet
3
+ from .parse_mask import init_parsing_model, get_parsed_mask, SoftErosion
face_parsing/model.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchvision
9
+
10
+ from .resnet import Resnet18
11
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
12
+
13
+
14
+ class ConvBNReLU(nn.Module):
15
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
16
+ super(ConvBNReLU, self).__init__()
17
+ self.conv = nn.Conv2d(in_chan,
18
+ out_chan,
19
+ kernel_size = ks,
20
+ stride = stride,
21
+ padding = padding,
22
+ bias = False)
23
+ self.bn = nn.BatchNorm2d(out_chan)
24
+ self.init_weight()
25
+
26
+ def forward(self, x):
27
+ x = self.conv(x)
28
+ x = F.relu(self.bn(x))
29
+ return x
30
+
31
+ def init_weight(self):
32
+ for ly in self.children():
33
+ if isinstance(ly, nn.Conv2d):
34
+ nn.init.kaiming_normal_(ly.weight, a=1)
35
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
36
+
37
+ class BiSeNetOutput(nn.Module):
38
+ def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
39
+ super(BiSeNetOutput, self).__init__()
40
+ self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
41
+ self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
42
+ self.init_weight()
43
+
44
+ def forward(self, x):
45
+ x = self.conv(x)
46
+ x = self.conv_out(x)
47
+ return x
48
+
49
+ def init_weight(self):
50
+ for ly in self.children():
51
+ if isinstance(ly, nn.Conv2d):
52
+ nn.init.kaiming_normal_(ly.weight, a=1)
53
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
54
+
55
+ def get_params(self):
56
+ wd_params, nowd_params = [], []
57
+ for name, module in self.named_modules():
58
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
59
+ wd_params.append(module.weight)
60
+ if not module.bias is None:
61
+ nowd_params.append(module.bias)
62
+ elif isinstance(module, nn.BatchNorm2d):
63
+ nowd_params += list(module.parameters())
64
+ return wd_params, nowd_params
65
+
66
+
67
+ class AttentionRefinementModule(nn.Module):
68
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
69
+ super(AttentionRefinementModule, self).__init__()
70
+ self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
71
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
72
+ self.bn_atten = nn.BatchNorm2d(out_chan)
73
+ self.sigmoid_atten = nn.Sigmoid()
74
+ self.init_weight()
75
+
76
+ def forward(self, x):
77
+ feat = self.conv(x)
78
+ atten = F.avg_pool2d(feat, feat.size()[2:])
79
+ atten = self.conv_atten(atten)
80
+ atten = self.bn_atten(atten)
81
+ atten = self.sigmoid_atten(atten)
82
+ out = torch.mul(feat, atten)
83
+ return out
84
+
85
+ def init_weight(self):
86
+ for ly in self.children():
87
+ if isinstance(ly, nn.Conv2d):
88
+ nn.init.kaiming_normal_(ly.weight, a=1)
89
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
90
+
91
+
92
+ class ContextPath(nn.Module):
93
+ def __init__(self, *args, **kwargs):
94
+ super(ContextPath, self).__init__()
95
+ self.resnet = Resnet18()
96
+ self.arm16 = AttentionRefinementModule(256, 128)
97
+ self.arm32 = AttentionRefinementModule(512, 128)
98
+ self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
99
+ self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
100
+ self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
101
+
102
+ self.init_weight()
103
+
104
+ def forward(self, x):
105
+ H0, W0 = x.size()[2:]
106
+ feat8, feat16, feat32 = self.resnet(x)
107
+ H8, W8 = feat8.size()[2:]
108
+ H16, W16 = feat16.size()[2:]
109
+ H32, W32 = feat32.size()[2:]
110
+
111
+ avg = F.avg_pool2d(feat32, feat32.size()[2:])
112
+ avg = self.conv_avg(avg)
113
+ avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
114
+
115
+ feat32_arm = self.arm32(feat32)
116
+ feat32_sum = feat32_arm + avg_up
117
+ feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
118
+ feat32_up = self.conv_head32(feat32_up)
119
+
120
+ feat16_arm = self.arm16(feat16)
121
+ feat16_sum = feat16_arm + feat32_up
122
+ feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
123
+ feat16_up = self.conv_head16(feat16_up)
124
+
125
+ return feat8, feat16_up, feat32_up # x8, x8, x16
126
+
127
+ def init_weight(self):
128
+ for ly in self.children():
129
+ if isinstance(ly, nn.Conv2d):
130
+ nn.init.kaiming_normal_(ly.weight, a=1)
131
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
132
+
133
+ def get_params(self):
134
+ wd_params, nowd_params = [], []
135
+ for name, module in self.named_modules():
136
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
137
+ wd_params.append(module.weight)
138
+ if not module.bias is None:
139
+ nowd_params.append(module.bias)
140
+ elif isinstance(module, nn.BatchNorm2d):
141
+ nowd_params += list(module.parameters())
142
+ return wd_params, nowd_params
143
+
144
+
145
+ ### This is not used, since I replace this with the resnet feature with the same size
146
+ class SpatialPath(nn.Module):
147
+ def __init__(self, *args, **kwargs):
148
+ super(SpatialPath, self).__init__()
149
+ self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
150
+ self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
151
+ self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
152
+ self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
153
+ self.init_weight()
154
+
155
+ def forward(self, x):
156
+ feat = self.conv1(x)
157
+ feat = self.conv2(feat)
158
+ feat = self.conv3(feat)
159
+ feat = self.conv_out(feat)
160
+ return feat
161
+
162
+ def init_weight(self):
163
+ for ly in self.children():
164
+ if isinstance(ly, nn.Conv2d):
165
+ nn.init.kaiming_normal_(ly.weight, a=1)
166
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
167
+
168
+ def get_params(self):
169
+ wd_params, nowd_params = [], []
170
+ for name, module in self.named_modules():
171
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
172
+ wd_params.append(module.weight)
173
+ if not module.bias is None:
174
+ nowd_params.append(module.bias)
175
+ elif isinstance(module, nn.BatchNorm2d):
176
+ nowd_params += list(module.parameters())
177
+ return wd_params, nowd_params
178
+
179
+
180
+ class FeatureFusionModule(nn.Module):
181
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
182
+ super(FeatureFusionModule, self).__init__()
183
+ self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
184
+ self.conv1 = nn.Conv2d(out_chan,
185
+ out_chan//4,
186
+ kernel_size = 1,
187
+ stride = 1,
188
+ padding = 0,
189
+ bias = False)
190
+ self.conv2 = nn.Conv2d(out_chan//4,
191
+ out_chan,
192
+ kernel_size = 1,
193
+ stride = 1,
194
+ padding = 0,
195
+ bias = False)
196
+ self.relu = nn.ReLU(inplace=True)
197
+ self.sigmoid = nn.Sigmoid()
198
+ self.init_weight()
199
+
200
+ def forward(self, fsp, fcp):
201
+ fcat = torch.cat([fsp, fcp], dim=1)
202
+ feat = self.convblk(fcat)
203
+ atten = F.avg_pool2d(feat, feat.size()[2:])
204
+ atten = self.conv1(atten)
205
+ atten = self.relu(atten)
206
+ atten = self.conv2(atten)
207
+ atten = self.sigmoid(atten)
208
+ feat_atten = torch.mul(feat, atten)
209
+ feat_out = feat_atten + feat
210
+ return feat_out
211
+
212
+ def init_weight(self):
213
+ for ly in self.children():
214
+ if isinstance(ly, nn.Conv2d):
215
+ nn.init.kaiming_normal_(ly.weight, a=1)
216
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
217
+
218
+ def get_params(self):
219
+ wd_params, nowd_params = [], []
220
+ for name, module in self.named_modules():
221
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
222
+ wd_params.append(module.weight)
223
+ if not module.bias is None:
224
+ nowd_params.append(module.bias)
225
+ elif isinstance(module, nn.BatchNorm2d):
226
+ nowd_params += list(module.parameters())
227
+ return wd_params, nowd_params
228
+
229
+
230
+ class BiSeNet(nn.Module):
231
+ def __init__(self, n_classes, *args, **kwargs):
232
+ super(BiSeNet, self).__init__()
233
+ self.cp = ContextPath()
234
+ ## here self.sp is deleted
235
+ self.ffm = FeatureFusionModule(256, 256)
236
+ self.conv_out = BiSeNetOutput(256, 256, n_classes)
237
+ self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
238
+ self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
239
+ self.init_weight()
240
+
241
+ def forward(self, x):
242
+ H, W = x.size()[2:]
243
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
244
+ feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
245
+ feat_fuse = self.ffm(feat_sp, feat_cp8)
246
+
247
+ feat_out = self.conv_out(feat_fuse)
248
+ feat_out16 = self.conv_out16(feat_cp8)
249
+ feat_out32 = self.conv_out32(feat_cp16)
250
+
251
+ feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
252
+ feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
253
+ feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
254
+ return feat_out, feat_out16, feat_out32
255
+
256
+ def init_weight(self):
257
+ for ly in self.children():
258
+ if isinstance(ly, nn.Conv2d):
259
+ nn.init.kaiming_normal_(ly.weight, a=1)
260
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
261
+
262
+ def get_params(self):
263
+ wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
264
+ for name, child in self.named_children():
265
+ child_wd_params, child_nowd_params = child.get_params()
266
+ if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
267
+ lr_mul_wd_params += child_wd_params
268
+ lr_mul_nowd_params += child_nowd_params
269
+ else:
270
+ wd_params += child_wd_params
271
+ nowd_params += child_nowd_params
272
+ return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
273
+
274
+
275
+ if __name__ == "__main__":
276
+ net = BiSeNet(19)
277
+ net.cuda()
278
+ net.eval()
279
+ in_ten = torch.randn(16, 3, 640, 480).cuda()
280
+ out, out16, out32 = net(in_ten)
281
+ print(out.shape)
282
+
283
+ net.get_params()
face_parsing/parse_mask.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import torchvision
4
+ import numpy as np
5
+ import torch.nn as nn
6
+ from PIL import Image
7
+ from tqdm import tqdm
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms as transforms
10
+
11
+ from . model import BiSeNet
12
+
13
+ class SoftErosion(nn.Module):
14
+ def __init__(self, kernel_size=15, threshold=0.6, iterations=1):
15
+ super(SoftErosion, self).__init__()
16
+ r = kernel_size // 2
17
+ self.padding = r
18
+ self.iterations = iterations
19
+ self.threshold = threshold
20
+
21
+ # Create kernel
22
+ y_indices, x_indices = torch.meshgrid(torch.arange(0., kernel_size), torch.arange(0., kernel_size))
23
+ dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
24
+ kernel = dist.max() - dist
25
+ kernel /= kernel.sum()
26
+ kernel = kernel.view(1, 1, *kernel.shape)
27
+ self.register_buffer('weight', kernel)
28
+
29
+ def forward(self, x):
30
+ batch_size = x.size(0) # Get the batch size
31
+ output = []
32
+
33
+ for i in tqdm(range(batch_size), desc="Soft-Erosion", leave=False):
34
+ input_tensor = x[i:i+1] # Take one input tensor from the batch
35
+ input_tensor = input_tensor.float() # Convert input to float tensor
36
+ input_tensor = input_tensor.unsqueeze(1) # Add a channel dimension
37
+
38
+ for _ in range(self.iterations - 1):
39
+ input_tensor = torch.min(input_tensor, F.conv2d(input_tensor, weight=self.weight,
40
+ groups=input_tensor.shape[1],
41
+ padding=self.padding))
42
+ input_tensor = F.conv2d(input_tensor, weight=self.weight, groups=input_tensor.shape[1],
43
+ padding=self.padding)
44
+
45
+ mask = input_tensor >= self.threshold
46
+ input_tensor[mask] = 1.0
47
+ input_tensor[~mask] /= input_tensor[~mask].max()
48
+
49
+ input_tensor = input_tensor.squeeze(1) # Remove the extra channel dimension
50
+ output.append(input_tensor.detach().cpu().numpy())
51
+
52
+ return np.array(output)
53
+
54
+ transform = transforms.Compose([
55
+ transforms.Resize((512, 512)),
56
+ transforms.ToTensor(),
57
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
58
+ ])
59
+
60
+
61
+
62
+ def init_parsing_model(model_path, device="cpu"):
63
+ net = BiSeNet(19)
64
+ net.to(device)
65
+ net.load_state_dict(torch.load(model_path,map_location=torch.device('cpu')))
66
+ net.eval()
67
+ return net
68
+
69
+ def transform_images(imgs):
70
+ tensor_images = torch.stack([transform(Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))) for img in imgs], dim=0)
71
+ return tensor_images
72
+
73
+ def get_parsed_mask(net, imgs, classes=[1, 2, 3, 4, 5, 10, 11, 12, 13], device="cpu", batch_size=8, softness=20):
74
+ if softness > 0:
75
+ smooth_mask = SoftErosion(kernel_size=17, threshold=0.9, iterations=softness).to(device)
76
+
77
+ masks = []
78
+ for i in tqdm(range(0, len(imgs), batch_size), total=len(imgs) // batch_size, desc="Face-parsing"):
79
+ batch_imgs = imgs[i:i + batch_size]
80
+
81
+ tensor_images = transform_images(batch_imgs).to(device)
82
+ with torch.no_grad():
83
+ out = net(tensor_images)[0]
84
+ # parsing = out.argmax(dim=1)
85
+ # arget_classes = torch.tensor(classes).to(device)
86
+ # batch_masks = torch.isin(parsing, target_classes).to(device)
87
+ ## torch.isin was slightly slower in my test, so using np.isin
88
+ parsing = out.argmax(dim=1).detach().cpu().numpy()
89
+ batch_masks = np.isin(parsing, classes).astype('float32')
90
+
91
+ if softness > 0:
92
+ # batch_masks = smooth_mask(batch_masks).transpose(1,0,2,3)[0]
93
+ mask_tensor = torch.from_numpy(batch_masks.copy()).float().to(device)
94
+ batch_masks = smooth_mask(mask_tensor).transpose(1,0,2,3)[0]
95
+
96
+ yield batch_masks
97
+
98
+ #masks.append(batch_masks)
99
+
100
+ #if len(masks) >= 1:
101
+ # masks = np.concatenate(masks, axis=0)
102
+ # masks = np.repeat(np.expand_dims(masks, axis=1), 3, axis=1)
103
+
104
+ # for i, mask in enumerate(masks):
105
+ # cv2.imwrite(f"mask/{i}.jpg", (mask * 255).astype("uint8"))
106
+
107
+ #return masks
face_parsing/resnet.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.model_zoo as modelzoo
8
+
9
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
10
+
11
+ resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
12
+
13
+
14
+ def conv3x3(in_planes, out_planes, stride=1):
15
+ """3x3 convolution with padding"""
16
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17
+ padding=1, bias=False)
18
+
19
+
20
+ class BasicBlock(nn.Module):
21
+ def __init__(self, in_chan, out_chan, stride=1):
22
+ super(BasicBlock, self).__init__()
23
+ self.conv1 = conv3x3(in_chan, out_chan, stride)
24
+ self.bn1 = nn.BatchNorm2d(out_chan)
25
+ self.conv2 = conv3x3(out_chan, out_chan)
26
+ self.bn2 = nn.BatchNorm2d(out_chan)
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.downsample = None
29
+ if in_chan != out_chan or stride != 1:
30
+ self.downsample = nn.Sequential(
31
+ nn.Conv2d(in_chan, out_chan,
32
+ kernel_size=1, stride=stride, bias=False),
33
+ nn.BatchNorm2d(out_chan),
34
+ )
35
+
36
+ def forward(self, x):
37
+ residual = self.conv1(x)
38
+ residual = F.relu(self.bn1(residual))
39
+ residual = self.conv2(residual)
40
+ residual = self.bn2(residual)
41
+
42
+ shortcut = x
43
+ if self.downsample is not None:
44
+ shortcut = self.downsample(x)
45
+
46
+ out = shortcut + residual
47
+ out = self.relu(out)
48
+ return out
49
+
50
+
51
+ def create_layer_basic(in_chan, out_chan, bnum, stride=1):
52
+ layers = [BasicBlock(in_chan, out_chan, stride=stride)]
53
+ for i in range(bnum-1):
54
+ layers.append(BasicBlock(out_chan, out_chan, stride=1))
55
+ return nn.Sequential(*layers)
56
+
57
+
58
+ class Resnet18(nn.Module):
59
+ def __init__(self):
60
+ super(Resnet18, self).__init__()
61
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62
+ bias=False)
63
+ self.bn1 = nn.BatchNorm2d(64)
64
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65
+ self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
66
+ self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
67
+ self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
68
+ self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
69
+ self.init_weight()
70
+
71
+ def forward(self, x):
72
+ x = self.conv1(x)
73
+ x = F.relu(self.bn1(x))
74
+ x = self.maxpool(x)
75
+
76
+ x = self.layer1(x)
77
+ feat8 = self.layer2(x) # 1/8
78
+ feat16 = self.layer3(feat8) # 1/16
79
+ feat32 = self.layer4(feat16) # 1/32
80
+ return feat8, feat16, feat32
81
+
82
+ def init_weight(self):
83
+ state_dict = modelzoo.load_url(resnet18_url)
84
+ self_state_dict = self.state_dict()
85
+ for k, v in state_dict.items():
86
+ if 'fc' in k: continue
87
+ self_state_dict.update({k: v})
88
+ self.load_state_dict(self_state_dict)
89
+
90
+ def get_params(self):
91
+ wd_params, nowd_params = [], []
92
+ for name, module in self.named_modules():
93
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
94
+ wd_params.append(module.weight)
95
+ if not module.bias is None:
96
+ nowd_params.append(module.bias)
97
+ elif isinstance(module, nn.BatchNorm2d):
98
+ nowd_params += list(module.parameters())
99
+ return wd_params, nowd_params
100
+
101
+
102
+ if __name__ == "__main__":
103
+ net = Resnet18()
104
+ x = torch.randn(16, 3, 224, 224)
105
+ out = net(x)
106
+ print(out[0].size())
107
+ print(out[1].size())
108
+ print(out[2].size())
109
+ net.get_params()
face_parsing/swap.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.transforms as transforms
5
+ import cv2
6
+ import numpy as np
7
+
8
+ from .model import BiSeNet
9
+
10
+ mask_regions = {
11
+ "Background":0,
12
+ "Skin":1,
13
+ "L-Eyebrow":2,
14
+ "R-Eyebrow":3,
15
+ "L-Eye":4,
16
+ "R-Eye":5,
17
+ "Eye-G":6,
18
+ "L-Ear":7,
19
+ "R-Ear":8,
20
+ "Ear-R":9,
21
+ "Nose":10,
22
+ "Mouth":11,
23
+ "U-Lip":12,
24
+ "L-Lip":13,
25
+ "Neck":14,
26
+ "Neck-L":15,
27
+ "Cloth":16,
28
+ "Hair":17,
29
+ "Hat":18
30
+ }
31
+
32
+ # Borrowed from simswap
33
+ # https://github.com/neuralchen/SimSwap/blob/26c84d2901bd56eda4d5e3c5ca6da16e65dc82a6/util/reverse2original.py#L30
34
+ class SoftErosion(nn.Module):
35
+ def __init__(self, kernel_size=15, threshold=0.6, iterations=1):
36
+ super(SoftErosion, self).__init__()
37
+ r = kernel_size // 2
38
+ self.padding = r
39
+ self.iterations = iterations
40
+ self.threshold = threshold
41
+
42
+ # Create kernel
43
+ y_indices, x_indices = torch.meshgrid(torch.arange(0., kernel_size), torch.arange(0., kernel_size))
44
+ dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
45
+ kernel = dist.max() - dist
46
+ kernel /= kernel.sum()
47
+ kernel = kernel.view(1, 1, *kernel.shape)
48
+ self.register_buffer('weight', kernel)
49
+
50
+ def forward(self, x):
51
+ x = x.float()
52
+ for i in range(self.iterations - 1):
53
+ x = torch.min(x, F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding))
54
+ x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)
55
+
56
+ mask = x >= self.threshold
57
+ x[mask] = 1.0
58
+ x[~mask] /= x[~mask].max()
59
+
60
+ return x, mask
61
+
62
+ device = "cpu"
63
+
64
+ def init_parser(pth_path, mode="cpu"):
65
+ global device
66
+ device = mode
67
+ n_classes = 19
68
+ net = BiSeNet(n_classes=n_classes)
69
+ if device == "cuda":
70
+ net.cuda()
71
+ net.load_state_dict(torch.load(pth_path))
72
+ else:
73
+ net.load_state_dict(torch.load(pth_path, map_location=torch.device('cpu')))
74
+ net.eval()
75
+ return net
76
+
77
+
78
+ def image_to_parsing(img, net):
79
+ img = cv2.resize(img, (512, 512))
80
+ img = img[:,:,::-1]
81
+ transform = transforms.Compose([
82
+ transforms.ToTensor(),
83
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
84
+ ])
85
+ img = transform(img.copy())
86
+ img = torch.unsqueeze(img, 0)
87
+
88
+ with torch.no_grad():
89
+ img = img.to(device)
90
+ out = net(img)[0]
91
+ parsing = out.squeeze(0).cpu().numpy().argmax(0)
92
+ return parsing
93
+
94
+
95
+ def get_mask(parsing, classes):
96
+ res = parsing == classes[0]
97
+ for val in classes[1:]:
98
+ res += parsing == val
99
+ return res
100
+
101
+
102
+ def swap_regions(source, target, net, smooth_mask, includes=[1,2,3,4,5,10,11,12,13], blur=10):
103
+ parsing = image_to_parsing(source, net)
104
+
105
+ if len(includes) == 0:
106
+ return source, np.zeros_like(source)
107
+
108
+ include_mask = get_mask(parsing, includes)
109
+ mask = np.repeat(include_mask[:, :, np.newaxis], 3, axis=2).astype("float32")
110
+
111
+ if smooth_mask is not None:
112
+ mask_tensor = torch.from_numpy(mask.copy().transpose((2, 0, 1))).float().to(device)
113
+ face_mask_tensor = mask_tensor[0] + mask_tensor[1]
114
+ soft_face_mask_tensor, _ = smooth_mask(face_mask_tensor.unsqueeze_(0).unsqueeze_(0))
115
+ soft_face_mask_tensor.squeeze_()
116
+ mask = np.repeat(soft_face_mask_tensor.cpu().numpy()[:, :, np.newaxis], 3, axis=2)
117
+
118
+ if blur > 0:
119
+ mask = cv2.GaussianBlur(mask, (0, 0), blur)
120
+
121
+ resized_source = cv2.resize((source).astype("float32"), (512, 512))
122
+ resized_target = cv2.resize((target).astype("float32"), (512, 512))
123
+ result = mask * resized_source + (1 - mask) * resized_target
124
+ result = cv2.resize(result.astype("uint8"), (source.shape[1], source.shape[0]))
125
+
126
+ return result
127
+
128
+ def mask_regions_to_list(values):
129
+ out_ids = []
130
+ for value in values:
131
+ if value in mask_regions.keys():
132
+ out_ids.append(mask_regions.get(value))
133
+ return out_ids
face_swapper.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import onnx
4
+ import cv2
5
+ import onnxruntime
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ import torch.nn as nn
9
+ from onnx import numpy_helper
10
+ from skimage import transform as trans
11
+ import torchvision.transforms.functional as F
12
+ import torch.nn.functional as F
13
+ from utils import mask_crop, laplacian_blending
14
+
15
+
16
+ arcface_dst = np.array(
17
+ [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
18
+ [41.5493, 92.3655], [70.7299, 92.2041]],
19
+ dtype=np.float32)
20
+
21
+ def estimate_norm(lmk, image_size=112, mode='arcface'):
22
+ assert lmk.shape == (5, 2)
23
+ assert image_size % 112 == 0 or image_size % 128 == 0
24
+ if image_size % 112 == 0:
25
+ ratio = float(image_size) / 112.0
26
+ diff_x = 0
27
+ else:
28
+ ratio = float(image_size) / 128.0
29
+ diff_x = 8.0 * ratio
30
+ dst = arcface_dst * ratio
31
+ dst[:, 0] += diff_x
32
+ tform = trans.SimilarityTransform()
33
+ tform.estimate(lmk, dst)
34
+ M = tform.params[0:2, :]
35
+ return M
36
+
37
+
38
+ def norm_crop2(img, landmark, image_size=112, mode='arcface'):
39
+ M = estimate_norm(landmark, image_size, mode)
40
+ warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
41
+ return warped, M
42
+
43
+
44
+ class Inswapper():
45
+ def __init__(self, model_file=None, batch_size=32, providers=['CPUExecutionProvider']):
46
+ self.model_file = model_file
47
+ self.batch_size = batch_size
48
+
49
+ model = onnx.load(self.model_file)
50
+ graph = model.graph
51
+ self.emap = numpy_helper.to_array(graph.initializer[-1])
52
+
53
+ self.session_options = onnxruntime.SessionOptions()
54
+ self.session = onnxruntime.InferenceSession(self.model_file, sess_options=self.session_options, providers=providers)
55
+
56
+ def forward(self, imgs, latents):
57
+ preds = []
58
+ for img, latent in zip(imgs, latents):
59
+ img = img / 255
60
+ pred = self.session.run(['output'], {'target': img, 'source': latent})[0]
61
+ preds.append(pred)
62
+
63
+ def get(self, imgs, target_faces, source_faces):
64
+ imgs = list(imgs)
65
+
66
+ preds = [None] * len(imgs)
67
+ matrs = [None] * len(imgs)
68
+
69
+ for idx, (img, target_face, source_face) in enumerate(zip(imgs, target_faces, source_faces)):
70
+ matrix, blob, latent = self.prepare_data(img, target_face, source_face)
71
+ pred = self.session.run(['output'], {'target': blob, 'source': latent})[0]
72
+ pred = pred.transpose((0, 2, 3, 1))[0]
73
+ pred = np.clip(255 * pred, 0, 255).astype(np.uint8)[:, :, ::-1]
74
+
75
+ preds[idx] = pred
76
+ matrs[idx] = matrix
77
+
78
+ return (preds, matrs)
79
+
80
+ def prepare_data(self, img, target_face, source_face):
81
+ if isinstance(img, str):
82
+ img = cv2.imread(img)
83
+
84
+ aligned_img, matrix = norm_crop2(img, target_face.kps, 128)
85
+
86
+ blob = cv2.dnn.blobFromImage(aligned_img, 1.0 / 255, (128, 128), (0., 0., 0.), swapRB=True)
87
+
88
+ latent = source_face.normed_embedding.reshape((1, -1))
89
+ latent = np.dot(latent, self.emap)
90
+ latent /= np.linalg.norm(latent)
91
+
92
+ return (matrix, blob, latent)
93
+
94
+ def batch_forward(self, img_list, target_f_list, source_f_list):
95
+ num_samples = len(img_list)
96
+ num_batches = (num_samples + self.batch_size - 1) // self.batch_size
97
+
98
+ for i in tqdm(range(num_batches), desc="Generating face"):
99
+ start_idx = i * self.batch_size
100
+ end_idx = min((i + 1) * self.batch_size, num_samples)
101
+
102
+ batch_img = img_list[start_idx:end_idx]
103
+ batch_target_f = target_f_list[start_idx:end_idx]
104
+ batch_source_f = source_f_list[start_idx:end_idx]
105
+
106
+ batch_pred, batch_matr = self.get(batch_img, batch_target_f, batch_source_f)
107
+
108
+ yield batch_pred, batch_matr
109
+
110
+
111
+ def paste_to_whole(foreground, background, matrix, mask=None, crop_mask=(0,0,0,0), blur_amount=0.1, erode_amount = 0.15, blend_method='linear'):
112
+ inv_matrix = cv2.invertAffineTransform(matrix)
113
+ fg_shape = foreground.shape[:2]
114
+ bg_shape = (background.shape[1], background.shape[0])
115
+ foreground = cv2.warpAffine(foreground, inv_matrix, bg_shape, borderValue=0.0)
116
+
117
+ if mask is None:
118
+ mask = np.full(fg_shape, 1., dtype=np.float32)
119
+ mask = mask_crop(mask, crop_mask)
120
+ mask = cv2.warpAffine(mask, inv_matrix, bg_shape, borderValue=0.0)
121
+ else:
122
+ assert fg_shape == mask.shape[:2], "foreground & mask shape mismatch!"
123
+ mask = mask_crop(mask, crop_mask).astype('float32')
124
+ mask = cv2.warpAffine(mask, inv_matrix, (background.shape[1], background.shape[0]), borderValue=0.0)
125
+
126
+ _mask = mask.copy()
127
+ _mask[_mask > 0.05] = 1.
128
+ non_zero_points = cv2.findNonZero(_mask)
129
+ _, _, w, h = cv2.boundingRect(non_zero_points)
130
+ mask_size = int(np.sqrt(w * h))
131
+
132
+ if erode_amount > 0:
133
+ kernel_size = max(int(mask_size * erode_amount), 1)
134
+ structuring_element = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size))
135
+ mask = cv2.erode(mask, structuring_element)
136
+
137
+ if blur_amount > 0:
138
+ kernel_size = max(int(mask_size * blur_amount), 3)
139
+ if kernel_size % 2 == 0:
140
+ kernel_size += 1
141
+ mask = cv2.GaussianBlur(mask, (kernel_size, kernel_size), 0)
142
+
143
+ mask = np.tile(np.expand_dims(mask, axis=-1), (1, 1, 3))
144
+
145
+ if blend_method == 'laplacian':
146
+ composite_image = laplacian_blending(foreground, background, mask.clip(0,1), num_levels=4)
147
+ else:
148
+ composite_image = mask * foreground + (1 - mask) * background
149
+
150
+ return composite_image.astype("uint8").clip(0, 255)
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu118
2
+
3
+ gfpgan==1.3.8
4
+ gradio==3.40.1
5
+ insightface==0.7.3
6
+ moviepy>=1.0.3
7
+ numpy==1.24.3
8
+ onnx==1.14.0
9
+ onnxruntime==1.15.1; python_version != '3.9' and sys_platform == 'darwin' and platform_machine != 'arm64'
10
+ onnxruntime-coreml==1.13.1; python_version == '3.9' and sys_platform == 'darwin' and platform_machine != 'arm64'
11
+ onnxruntime-gpu==1.15.1; sys_platform != 'darwin'
12
+ onnxruntime-silicon==1.13.1; sys_platform == 'darwin' and platform_machine == 'arm64'
13
+ opencv-python==4.8.0.74
14
+ opennsfw2==0.10.2
15
+ pillow==10.0.0
16
+ protobuf==4.23.4
17
+ psutil==5.9.5
18
+ realesrgan==0.3.0
19
+ tensorflow==2.13.0
20
+ tqdm==4.65.0
21
+ python-telegram-bot==22.1
22
+
upscaler/RealESRGAN/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import RealESRGAN
upscaler/RealESRGAN/arch_utils.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn as nn
4
+ from torch.nn import functional as F
5
+ from torch.nn import init as init
6
+ from torch.nn.modules.batchnorm import _BatchNorm
7
+
8
+ @torch.no_grad()
9
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
10
+ """Initialize network weights.
11
+
12
+ Args:
13
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
14
+ scale (float): Scale initialized weights, especially for residual
15
+ blocks. Default: 1.
16
+ bias_fill (float): The value to fill bias. Default: 0
17
+ kwargs (dict): Other arguments for initialization function.
18
+ """
19
+ if not isinstance(module_list, list):
20
+ module_list = [module_list]
21
+ for module in module_list:
22
+ for m in module.modules():
23
+ if isinstance(m, nn.Conv2d):
24
+ init.kaiming_normal_(m.weight, **kwargs)
25
+ m.weight.data *= scale
26
+ if m.bias is not None:
27
+ m.bias.data.fill_(bias_fill)
28
+ elif isinstance(m, nn.Linear):
29
+ init.kaiming_normal_(m.weight, **kwargs)
30
+ m.weight.data *= scale
31
+ if m.bias is not None:
32
+ m.bias.data.fill_(bias_fill)
33
+ elif isinstance(m, _BatchNorm):
34
+ init.constant_(m.weight, 1)
35
+ if m.bias is not None:
36
+ m.bias.data.fill_(bias_fill)
37
+
38
+
39
+ def make_layer(basic_block, num_basic_block, **kwarg):
40
+ """Make layers by stacking the same blocks.
41
+
42
+ Args:
43
+ basic_block (nn.module): nn.module class for basic block.
44
+ num_basic_block (int): number of blocks.
45
+
46
+ Returns:
47
+ nn.Sequential: Stacked blocks in nn.Sequential.
48
+ """
49
+ layers = []
50
+ for _ in range(num_basic_block):
51
+ layers.append(basic_block(**kwarg))
52
+ return nn.Sequential(*layers)
53
+
54
+
55
+ class ResidualBlockNoBN(nn.Module):
56
+ """Residual block without BN.
57
+
58
+ It has a style of:
59
+ ---Conv-ReLU-Conv-+-
60
+ |________________|
61
+
62
+ Args:
63
+ num_feat (int): Channel number of intermediate features.
64
+ Default: 64.
65
+ res_scale (float): Residual scale. Default: 1.
66
+ pytorch_init (bool): If set to True, use pytorch default init,
67
+ otherwise, use default_init_weights. Default: False.
68
+ """
69
+
70
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
71
+ super(ResidualBlockNoBN, self).__init__()
72
+ self.res_scale = res_scale
73
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
74
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
75
+ self.relu = nn.ReLU(inplace=True)
76
+
77
+ if not pytorch_init:
78
+ default_init_weights([self.conv1, self.conv2], 0.1)
79
+
80
+ def forward(self, x):
81
+ identity = x
82
+ out = self.conv2(self.relu(self.conv1(x)))
83
+ return identity + out * self.res_scale
84
+
85
+
86
+ class Upsample(nn.Sequential):
87
+ """Upsample module.
88
+
89
+ Args:
90
+ scale (int): Scale factor. Supported scales: 2^n and 3.
91
+ num_feat (int): Channel number of intermediate features.
92
+ """
93
+
94
+ def __init__(self, scale, num_feat):
95
+ m = []
96
+ if (scale & (scale - 1)) == 0: # scale = 2^n
97
+ for _ in range(int(math.log(scale, 2))):
98
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
99
+ m.append(nn.PixelShuffle(2))
100
+ elif scale == 3:
101
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
102
+ m.append(nn.PixelShuffle(3))
103
+ else:
104
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
105
+ super(Upsample, self).__init__(*m)
106
+
107
+
108
+ def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
109
+ """Warp an image or feature map with optical flow.
110
+
111
+ Args:
112
+ x (Tensor): Tensor with size (n, c, h, w).
113
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
114
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
115
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
116
+ Default: 'zeros'.
117
+ align_corners (bool): Before pytorch 1.3, the default value is
118
+ align_corners=True. After pytorch 1.3, the default value is
119
+ align_corners=False. Here, we use the True as default.
120
+
121
+ Returns:
122
+ Tensor: Warped image or feature map.
123
+ """
124
+ assert x.size()[-2:] == flow.size()[1:3]
125
+ _, _, h, w = x.size()
126
+ # create mesh grid
127
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
128
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
129
+ grid.requires_grad = False
130
+
131
+ vgrid = grid + flow
132
+ # scale grid to [-1,1]
133
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
134
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
135
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
136
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
137
+
138
+ # TODO, what if align_corners=False
139
+ return output
140
+
141
+
142
+ def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
143
+ """Resize a flow according to ratio or shape.
144
+
145
+ Args:
146
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
147
+ size_type (str): 'ratio' or 'shape'.
148
+ sizes (list[int | float]): the ratio for resizing or the final output
149
+ shape.
150
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
151
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
152
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
153
+ ratio > 1.0).
154
+ 2) The order of output_size should be [out_h, out_w].
155
+ interp_mode (str): The mode of interpolation for resizing.
156
+ Default: 'bilinear'.
157
+ align_corners (bool): Whether align corners. Default: False.
158
+
159
+ Returns:
160
+ Tensor: Resized flow.
161
+ """
162
+ _, _, flow_h, flow_w = flow.size()
163
+ if size_type == 'ratio':
164
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
165
+ elif size_type == 'shape':
166
+ output_h, output_w = sizes[0], sizes[1]
167
+ else:
168
+ raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
169
+
170
+ input_flow = flow.clone()
171
+ ratio_h = output_h / flow_h
172
+ ratio_w = output_w / flow_w
173
+ input_flow[:, 0, :, :] *= ratio_w
174
+ input_flow[:, 1, :, :] *= ratio_h
175
+ resized_flow = F.interpolate(
176
+ input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
177
+ return resized_flow
178
+
179
+
180
+ # TODO: may write a cpp file
181
+ def pixel_unshuffle(x, scale):
182
+ """ Pixel unshuffle.
183
+
184
+ Args:
185
+ x (Tensor): Input feature with shape (b, c, hh, hw).
186
+ scale (int): Downsample ratio.
187
+
188
+ Returns:
189
+ Tensor: the pixel unshuffled feature.
190
+ """
191
+ b, c, hh, hw = x.size()
192
+ out_channel = c * (scale**2)
193
+ assert hh % scale == 0 and hw % scale == 0
194
+ h = hh // scale
195
+ w = hw // scale
196
+ x_view = x.view(b, c, h, scale, w, scale)
197
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
upscaler/RealESRGAN/model.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.nn import functional as F
4
+ from PIL import Image
5
+ import numpy as np
6
+ import cv2
7
+
8
+ from .rrdbnet_arch import RRDBNet
9
+ from .utils import pad_reflect, split_image_into_overlapping_patches, stich_together, \
10
+ unpad_image
11
+
12
+
13
+ HF_MODELS = {
14
+ 2: dict(
15
+ repo_id='sberbank-ai/Real-ESRGAN',
16
+ filename='RealESRGAN_x2.pth',
17
+ ),
18
+ 4: dict(
19
+ repo_id='sberbank-ai/Real-ESRGAN',
20
+ filename='RealESRGAN_x4.pth',
21
+ ),
22
+ 8: dict(
23
+ repo_id='sberbank-ai/Real-ESRGAN',
24
+ filename='RealESRGAN_x8.pth',
25
+ ),
26
+ }
27
+
28
+
29
+ class RealESRGAN:
30
+ def __init__(self, device, scale=4):
31
+ self.device = device
32
+ self.scale = scale
33
+ self.model = RRDBNet(
34
+ num_in_ch=3, num_out_ch=3, num_feat=64,
35
+ num_block=23, num_grow_ch=32, scale=scale
36
+ )
37
+
38
+ def load_weights(self, model_path, download=True):
39
+ if not os.path.exists(model_path) and download:
40
+ from huggingface_hub import hf_hub_url, cached_download
41
+ assert self.scale in [2,4,8], 'You can download models only with scales: 2, 4, 8'
42
+ config = HF_MODELS[self.scale]
43
+ cache_dir = os.path.dirname(model_path)
44
+ local_filename = os.path.basename(model_path)
45
+ config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
46
+ cached_download(config_file_url, cache_dir=cache_dir, force_filename=local_filename)
47
+ print('Weights downloaded to:', os.path.join(cache_dir, local_filename))
48
+
49
+ loadnet = torch.load(model_path)
50
+ if 'params' in loadnet:
51
+ self.model.load_state_dict(loadnet['params'], strict=True)
52
+ elif 'params_ema' in loadnet:
53
+ self.model.load_state_dict(loadnet['params_ema'], strict=True)
54
+ else:
55
+ self.model.load_state_dict(loadnet, strict=True)
56
+ self.model.eval()
57
+ self.model.to(self.device)
58
+
59
+ @torch.cuda.amp.autocast()
60
+ def predict(self, lr_image, batch_size=4, patches_size=192,
61
+ padding=24, pad_size=15):
62
+ scale = self.scale
63
+ device = self.device
64
+ lr_image = np.array(lr_image)
65
+ lr_image = pad_reflect(lr_image, pad_size)
66
+
67
+ patches, p_shape = split_image_into_overlapping_patches(
68
+ lr_image, patch_size=patches_size, padding_size=padding
69
+ )
70
+ img = torch.FloatTensor(patches/255).permute((0,3,1,2)).to(device).detach()
71
+
72
+ with torch.no_grad():
73
+ res = self.model(img[0:batch_size])
74
+ for i in range(batch_size, img.shape[0], batch_size):
75
+ res = torch.cat((res, self.model(img[i:i+batch_size])), 0)
76
+
77
+ sr_image = res.permute((0,2,3,1)).clamp_(0, 1).cpu()
78
+ np_sr_image = sr_image.numpy()
79
+
80
+ padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
81
+ scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)
82
+ np_sr_image = stich_together(
83
+ np_sr_image, padded_image_shape=padded_size_scaled,
84
+ target_shape=scaled_image_shape, padding_size=padding * scale
85
+ )
86
+ sr_img = (np_sr_image*255).astype(np.uint8)
87
+ sr_img = unpad_image(sr_img, pad_size*scale)
88
+ #sr_img = Image.fromarray(sr_img)
89
+
90
+ return sr_img
upscaler/RealESRGAN/rrdbnet_arch.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ from .arch_utils import default_init_weights, make_layer, pixel_unshuffle
6
+
7
+
8
+ class ResidualDenseBlock(nn.Module):
9
+ """Residual Dense Block.
10
+
11
+ Used in RRDB block in ESRGAN.
12
+
13
+ Args:
14
+ num_feat (int): Channel number of intermediate features.
15
+ num_grow_ch (int): Channels for each growth.
16
+ """
17
+
18
+ def __init__(self, num_feat=64, num_grow_ch=32):
19
+ super(ResidualDenseBlock, self).__init__()
20
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
21
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
22
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
23
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
24
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
25
+
26
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
27
+
28
+ # initialization
29
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
30
+
31
+ def forward(self, x):
32
+ x1 = self.lrelu(self.conv1(x))
33
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
34
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
35
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
36
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
37
+ # Emperically, we use 0.2 to scale the residual for better performance
38
+ return x5 * 0.2 + x
39
+
40
+
41
+ class RRDB(nn.Module):
42
+ """Residual in Residual Dense Block.
43
+
44
+ Used in RRDB-Net in ESRGAN.
45
+
46
+ Args:
47
+ num_feat (int): Channel number of intermediate features.
48
+ num_grow_ch (int): Channels for each growth.
49
+ """
50
+
51
+ def __init__(self, num_feat, num_grow_ch=32):
52
+ super(RRDB, self).__init__()
53
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
54
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
55
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
56
+
57
+ def forward(self, x):
58
+ out = self.rdb1(x)
59
+ out = self.rdb2(out)
60
+ out = self.rdb3(out)
61
+ # Emperically, we use 0.2 to scale the residual for better performance
62
+ return out * 0.2 + x
63
+
64
+
65
+ class RRDBNet(nn.Module):
66
+ """Networks consisting of Residual in Residual Dense Block, which is used
67
+ in ESRGAN.
68
+
69
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
70
+
71
+ We extend ESRGAN for scale x2 and scale x1.
72
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
73
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
74
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
75
+
76
+ Args:
77
+ num_in_ch (int): Channel number of inputs.
78
+ num_out_ch (int): Channel number of outputs.
79
+ num_feat (int): Channel number of intermediate features.
80
+ Default: 64
81
+ num_block (int): Block number in the trunk network. Defaults: 23
82
+ num_grow_ch (int): Channels for each growth. Default: 32.
83
+ """
84
+
85
+ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
86
+ super(RRDBNet, self).__init__()
87
+ self.scale = scale
88
+ if scale == 2:
89
+ num_in_ch = num_in_ch * 4
90
+ elif scale == 1:
91
+ num_in_ch = num_in_ch * 16
92
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
93
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
94
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
95
+ # upsample
96
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
97
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
98
+ if scale == 8:
99
+ self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
100
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
101
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
102
+
103
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
104
+
105
+ def forward(self, x):
106
+ if self.scale == 2:
107
+ feat = pixel_unshuffle(x, scale=2)
108
+ elif self.scale == 1:
109
+ feat = pixel_unshuffle(x, scale=4)
110
+ else:
111
+ feat = x
112
+ feat = self.conv_first(feat)
113
+ body_feat = self.conv_body(self.body(feat))
114
+ feat = feat + body_feat
115
+ # upsample
116
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
117
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
118
+ if self.scale == 8:
119
+ feat = self.lrelu(self.conv_up3(F.interpolate(feat, scale_factor=2, mode='nearest')))
120
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
121
+ return out
upscaler/RealESRGAN/utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+ import os
5
+ import io
6
+
7
+ def pad_reflect(image, pad_size):
8
+ imsize = image.shape
9
+ height, width = imsize[:2]
10
+ new_img = np.zeros([height+pad_size*2, width+pad_size*2, imsize[2]]).astype(np.uint8)
11
+ new_img[pad_size:-pad_size, pad_size:-pad_size, :] = image
12
+
13
+ new_img[0:pad_size, pad_size:-pad_size, :] = np.flip(image[0:pad_size, :, :], axis=0) #top
14
+ new_img[-pad_size:, pad_size:-pad_size, :] = np.flip(image[-pad_size:, :, :], axis=0) #bottom
15
+ new_img[:, 0:pad_size, :] = np.flip(new_img[:, pad_size:pad_size*2, :], axis=1) #left
16
+ new_img[:, -pad_size:, :] = np.flip(new_img[:, -pad_size*2:-pad_size, :], axis=1) #right
17
+
18
+ return new_img
19
+
20
+ def unpad_image(image, pad_size):
21
+ return image[pad_size:-pad_size, pad_size:-pad_size, :]
22
+
23
+
24
+ def process_array(image_array, expand=True):
25
+ """ Process a 3-dimensional array into a scaled, 4 dimensional batch of size 1. """
26
+
27
+ image_batch = image_array / 255.0
28
+ if expand:
29
+ image_batch = np.expand_dims(image_batch, axis=0)
30
+ return image_batch
31
+
32
+
33
+ def process_output(output_tensor):
34
+ """ Transforms the 4-dimensional output tensor into a suitable image format. """
35
+
36
+ sr_img = output_tensor.clip(0, 1) * 255
37
+ sr_img = np.uint8(sr_img)
38
+ return sr_img
39
+
40
+
41
+ def pad_patch(image_patch, padding_size, channel_last=True):
42
+ """ Pads image_patch with with padding_size edge values. """
43
+
44
+ if channel_last:
45
+ return np.pad(
46
+ image_patch,
47
+ ((padding_size, padding_size), (padding_size, padding_size), (0, 0)),
48
+ 'edge',
49
+ )
50
+ else:
51
+ return np.pad(
52
+ image_patch,
53
+ ((0, 0), (padding_size, padding_size), (padding_size, padding_size)),
54
+ 'edge',
55
+ )
56
+
57
+
58
+ def unpad_patches(image_patches, padding_size):
59
+ return image_patches[:, padding_size:-padding_size, padding_size:-padding_size, :]
60
+
61
+
62
+ def split_image_into_overlapping_patches(image_array, patch_size, padding_size=2):
63
+ """ Splits the image into partially overlapping patches.
64
+ The patches overlap by padding_size pixels.
65
+ Pads the image twice:
66
+ - first to have a size multiple of the patch size,
67
+ - then to have equal padding at the borders.
68
+ Args:
69
+ image_array: numpy array of the input image.
70
+ patch_size: size of the patches from the original image (without padding).
71
+ padding_size: size of the overlapping area.
72
+ """
73
+
74
+ xmax, ymax, _ = image_array.shape
75
+ x_remainder = xmax % patch_size
76
+ y_remainder = ymax % patch_size
77
+
78
+ # modulo here is to avoid extending of patch_size instead of 0
79
+ x_extend = (patch_size - x_remainder) % patch_size
80
+ y_extend = (patch_size - y_remainder) % patch_size
81
+
82
+ # make sure the image is divisible into regular patches
83
+ extended_image = np.pad(image_array, ((0, x_extend), (0, y_extend), (0, 0)), 'edge')
84
+
85
+ # add padding around the image to simplify computations
86
+ padded_image = pad_patch(extended_image, padding_size, channel_last=True)
87
+
88
+ xmax, ymax, _ = padded_image.shape
89
+ patches = []
90
+
91
+ x_lefts = range(padding_size, xmax - padding_size, patch_size)
92
+ y_tops = range(padding_size, ymax - padding_size, patch_size)
93
+
94
+ for x in x_lefts:
95
+ for y in y_tops:
96
+ x_left = x - padding_size
97
+ y_top = y - padding_size
98
+ x_right = x + patch_size + padding_size
99
+ y_bottom = y + patch_size + padding_size
100
+ patch = padded_image[x_left:x_right, y_top:y_bottom, :]
101
+ patches.append(patch)
102
+
103
+ return np.array(patches), padded_image.shape
104
+
105
+
106
+ def stich_together(patches, padded_image_shape, target_shape, padding_size=4):
107
+ """ Reconstruct the image from overlapping patches.
108
+ After scaling, shapes and padding should be scaled too.
109
+ Args:
110
+ patches: patches obtained with split_image_into_overlapping_patches
111
+ padded_image_shape: shape of the padded image contructed in split_image_into_overlapping_patches
112
+ target_shape: shape of the final image
113
+ padding_size: size of the overlapping area.
114
+ """
115
+
116
+ xmax, ymax, _ = padded_image_shape
117
+ patches = unpad_patches(patches, padding_size)
118
+ patch_size = patches.shape[1]
119
+ n_patches_per_row = ymax // patch_size
120
+
121
+ complete_image = np.zeros((xmax, ymax, 3))
122
+
123
+ row = -1
124
+ col = 0
125
+ for i in range(len(patches)):
126
+ if i % n_patches_per_row == 0:
127
+ row += 1
128
+ col = 0
129
+ complete_image[
130
+ row * patch_size: (row + 1) * patch_size, col * patch_size: (col + 1) * patch_size,:
131
+ ] = patches[i]
132
+ col += 1
133
+ return complete_image[0: target_shape[0], 0: target_shape[1], :]
upscaler/__init__.py ADDED
File without changes
upscaler/codeformer.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import onnx
4
+ import onnxruntime
5
+ import numpy as np
6
+
7
+ import time
8
+
9
+ # codeformer converted to onnx
10
+ # using https://github.com/redthing1/CodeFormer
11
+
12
+
13
+ class CodeFormerEnhancer:
14
+ def __init__(self, model_path="codeformer.onnx", device='cpu'):
15
+ model = onnx.load(model_path)
16
+ session_options = onnxruntime.SessionOptions()
17
+ session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
18
+ providers = ["CPUExecutionProvider"]
19
+ if device == 'cuda':
20
+ providers = [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}),"CPUExecutionProvider"]
21
+ self.session = onnxruntime.InferenceSession(model_path, sess_options=session_options, providers=providers)
22
+
23
+ def enhance(self, img, w=0.9):
24
+ img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
25
+ img = img.astype(np.float32)[:,:,::-1] / 255.0
26
+ img = img.transpose((2, 0, 1))
27
+ nrm_mean = np.array([0.5, 0.5, 0.5]).reshape((-1, 1, 1))
28
+ nrm_std = np.array([0.5, 0.5, 0.5]).reshape((-1, 1, 1))
29
+ img = (img - nrm_mean) / nrm_std
30
+
31
+ img = np.expand_dims(img, axis=0)
32
+
33
+ out = self.session.run(None, {'x':img.astype(np.float32), 'w':np.array([w], dtype=np.double)})[0]
34
+ out = (out[0].transpose(1,2,0).clip(-1,1) + 1) * 0.5
35
+ out = (out * 255)[:,:,::-1]
36
+
37
+ return out.astype('uint8')
utils.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import glob
5
+ import shutil
6
+ import platform
7
+ import datetime
8
+ import subprocess
9
+ import numpy as np
10
+ from threading import Thread
11
+ from moviepy.editor import VideoFileClip, ImageSequenceClip
12
+ from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
13
+
14
+
15
+ logo_image = cv2.imread("./assets/images/logo.png", cv2.IMREAD_UNCHANGED)
16
+
17
+
18
+ quality_types = ["poor", "low", "medium", "high", "best"]
19
+
20
+
21
+ bitrate_quality_by_resolution = {
22
+ 240: {"poor": "300k", "low": "500k", "medium": "800k", "high": "1000k", "best": "1200k"},
23
+ 360: {"poor": "500k","low": "800k","medium": "1200k","high": "1500k","best": "2000k"},
24
+ 480: {"poor": "800k","low": "1200k","medium": "2000k","high": "2500k","best": "3000k"},
25
+ 720: {"poor": "1500k","low": "2500k","medium": "4000k","high": "5000k","best": "6000k"},
26
+ 1080: {"poor": "2500k","low": "4000k","medium": "6000k","high": "7000k","best": "8000k"},
27
+ 1440: {"poor": "4000k","low": "6000k","medium": "8000k","high": "10000k","best": "12000k"},
28
+ 2160: {"poor": "8000k","low": "10000k","medium": "12000k","high": "15000k","best": "20000k"}
29
+ }
30
+
31
+
32
+ crf_quality_by_resolution = {
33
+ 240: {"poor": 45, "low": 35, "medium": 28, "high": 23, "best": 20},
34
+ 360: {"poor": 35, "low": 28, "medium": 23, "high": 20, "best": 18},
35
+ 480: {"poor": 28, "low": 23, "medium": 20, "high": 18, "best": 16},
36
+ 720: {"poor": 23, "low": 20, "medium": 18, "high": 16, "best": 14},
37
+ 1080: {"poor": 20, "low": 18, "medium": 16, "high": 14, "best": 12},
38
+ 1440: {"poor": 18, "low": 16, "medium": 14, "high": 12, "best": 10},
39
+ 2160: {"poor": 16, "low": 14, "medium": 12, "high": 10, "best": 8}
40
+ }
41
+
42
+
43
+ def get_bitrate_for_resolution(resolution, quality):
44
+ available_resolutions = list(bitrate_quality_by_resolution.keys())
45
+ closest_resolution = min(available_resolutions, key=lambda x: abs(x - resolution))
46
+ return bitrate_quality_by_resolution[closest_resolution][quality]
47
+
48
+
49
+ def get_crf_for_resolution(resolution, quality):
50
+ available_resolutions = list(crf_quality_by_resolution.keys())
51
+ closest_resolution = min(available_resolutions, key=lambda x: abs(x - resolution))
52
+ return crf_quality_by_resolution[closest_resolution][quality]
53
+
54
+
55
+ def get_video_bitrate(video_file):
56
+ ffprobe_cmd = ['ffprobe', '-v', 'error', '-select_streams', 'v:0', '-show_entries',
57
+ 'stream=bit_rate', '-of', 'default=noprint_wrappers=1:nokey=1', video_file]
58
+ result = subprocess.run(ffprobe_cmd, stdout=subprocess.PIPE)
59
+ kbps = max(int(result.stdout) // 1000, 10)
60
+ return str(kbps) + 'k'
61
+
62
+
63
+ def trim_video(video_path, output_path, start_frame, stop_frame):
64
+ video_name, _ = os.path.splitext(os.path.basename(video_path))
65
+ trimmed_video_filename = video_name + "_trimmed" + ".mp4"
66
+ temp_path = os.path.join(output_path, "trim")
67
+ os.makedirs(temp_path, exist_ok=True)
68
+ trimmed_video_file_path = os.path.join(temp_path, trimmed_video_filename)
69
+
70
+ video = VideoFileClip(video_path, fps_source="fps")
71
+ fps = video.fps
72
+ start_time = start_frame / fps
73
+ duration = (stop_frame - start_frame) / fps
74
+
75
+ bitrate = get_bitrate_for_resolution(min(*video.size), "high")
76
+
77
+ trimmed_video = video.subclip(start_time, start_time + duration)
78
+ trimmed_video.write_videofile(
79
+ trimmed_video_file_path, codec="libx264", audio_codec="aac", bitrate=bitrate,
80
+ )
81
+ trimmed_video.close()
82
+ video.close()
83
+
84
+ return trimmed_video_file_path
85
+
86
+
87
+ def open_directory(path=None):
88
+ if path is None:
89
+ return
90
+ try:
91
+ os.startfile(path)
92
+ except:
93
+ subprocess.Popen(["xdg-open", path])
94
+
95
+
96
+ class StreamerThread(object):
97
+ def __init__(self, src=0):
98
+ self.capture = cv2.VideoCapture(src)
99
+ self.capture.set(cv2.CAP_PROP_BUFFERSIZE, 2)
100
+ self.FPS = 1 / 30
101
+ self.FPS_MS = int(self.FPS * 1000)
102
+ self.thread = None
103
+ self.stopped = False
104
+ self.frame = None
105
+
106
+ def start(self):
107
+ self.thread = Thread(target=self.update, args=())
108
+ self.thread.daemon = True
109
+ self.thread.start()
110
+
111
+ def stop(self):
112
+ self.stopped = True
113
+ self.thread.join()
114
+ print("stopped")
115
+
116
+ def update(self):
117
+ while not self.stopped:
118
+ if self.capture.isOpened():
119
+ (self.status, self.frame) = self.capture.read()
120
+ time.sleep(self.FPS)
121
+
122
+
123
+ class ProcessBar:
124
+ def __init__(self, bar_length, total, before="⬛", after="🟨"):
125
+ self.bar_length = bar_length
126
+ self.total = total
127
+ self.before = before
128
+ self.after = after
129
+ self.bar = [self.before] * bar_length
130
+ self.start_time = time.time()
131
+
132
+ def get(self, index):
133
+ total = self.total
134
+ elapsed_time = time.time() - self.start_time
135
+ average_time_per_iteration = elapsed_time / (index + 1)
136
+ remaining_iterations = total - (index + 1)
137
+ estimated_remaining_time = remaining_iterations * average_time_per_iteration
138
+
139
+ self.bar[int(index / total * self.bar_length)] = self.after
140
+ info_text = f"({index+1}/{total}) {''.join(self.bar)} "
141
+ info_text += f"(ETR: {int(estimated_remaining_time // 60)} min {int(estimated_remaining_time % 60)} sec)"
142
+ return info_text
143
+
144
+
145
+ def add_logo_to_image(img, logo=logo_image):
146
+ logo_size = int(img.shape[1] * 0.1)
147
+ logo = cv2.resize(logo, (logo_size, logo_size))
148
+ if logo.shape[2] == 4:
149
+ alpha = logo[:, :, 3]
150
+ else:
151
+ alpha = np.ones_like(logo[:, :, 0]) * 255
152
+ padding = int(logo_size * 0.1)
153
+ roi = img.shape[0] - logo_size - padding, img.shape[1] - logo_size - padding
154
+ for c in range(0, 3):
155
+ img[roi[0] : roi[0] + logo_size, roi[1] : roi[1] + logo_size, c] = (
156
+ alpha / 255.0
157
+ ) * logo[:, :, c] + (1 - alpha / 255.0) * img[
158
+ roi[0] : roi[0] + logo_size, roi[1] : roi[1] + logo_size, c
159
+ ]
160
+ return img
161
+
162
+
163
+ def split_list_by_lengths(data, length_list):
164
+ split_data = []
165
+ start_idx = 0
166
+ for length in length_list:
167
+ end_idx = start_idx + length
168
+ sublist = data[start_idx:end_idx]
169
+ split_data.append(sublist)
170
+ start_idx = end_idx
171
+ return split_data
172
+
173
+
174
+ def merge_img_sequence_from_ref(ref_video_path, image_sequence, output_file_name):
175
+ video_clip = VideoFileClip(ref_video_path, fps_source="fps")
176
+ fps = video_clip.fps
177
+ duration = video_clip.duration
178
+ total_frames = video_clip.reader.nframes
179
+ audio_clip = video_clip.audio if video_clip.audio is not None else None
180
+ edited_video_clip = ImageSequenceClip(image_sequence, fps=fps)
181
+
182
+ if audio_clip is not None:
183
+ edited_video_clip = edited_video_clip.set_audio(audio_clip)
184
+
185
+ bitrate = get_bitrate_for_resolution(min(*edited_video_clip.size), "high")
186
+
187
+ edited_video_clip.set_duration(duration).write_videofile(
188
+ output_file_name, codec="libx264", bitrate=bitrate,
189
+ )
190
+ edited_video_clip.close()
191
+ video_clip.close()
192
+
193
+
194
+ def scale_bbox_from_center(bbox, scale_width, scale_height, image_width, image_height):
195
+ # Extract the coordinates of the bbox
196
+ x1, y1, x2, y2 = bbox
197
+
198
+ # Calculate the center point of the bbox
199
+ center_x = (x1 + x2) / 2
200
+ center_y = (y1 + y2) / 2
201
+
202
+ # Calculate the new width and height of the bbox based on the scaling factors
203
+ width = x2 - x1
204
+ height = y2 - y1
205
+ new_width = width * scale_width
206
+ new_height = height * scale_height
207
+
208
+ # Calculate the new coordinates of the bbox, considering the image boundaries
209
+ new_x1 = center_x - new_width / 2
210
+ new_y1 = center_y - new_height / 2
211
+ new_x2 = center_x + new_width / 2
212
+ new_y2 = center_y + new_height / 2
213
+
214
+ # Adjust the coordinates to ensure the bbox remains within the image boundaries
215
+ new_x1 = max(0, new_x1)
216
+ new_y1 = max(0, new_y1)
217
+ new_x2 = min(image_width - 1, new_x2)
218
+ new_y2 = min(image_height - 1, new_y2)
219
+
220
+ # Return the scaled bbox coordinates
221
+ scaled_bbox = [new_x1, new_y1, new_x2, new_y2]
222
+ return scaled_bbox
223
+
224
+
225
+ def laplacian_blending(A, B, m, num_levels=7):
226
+ assert A.shape == B.shape
227
+ assert B.shape == m.shape
228
+ height = m.shape[0]
229
+ width = m.shape[1]
230
+ size_list = np.array([4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192])
231
+ size = size_list[np.where(size_list > max(height, width))][0]
232
+ GA = np.zeros((size, size, 3), dtype=np.float32)
233
+ GA[:height, :width, :] = A
234
+ GB = np.zeros((size, size, 3), dtype=np.float32)
235
+ GB[:height, :width, :] = B
236
+ GM = np.zeros((size, size, 3), dtype=np.float32)
237
+ GM[:height, :width, :] = m
238
+ gpA = [GA]
239
+ gpB = [GB]
240
+ gpM = [GM]
241
+ for i in range(num_levels):
242
+ GA = cv2.pyrDown(GA)
243
+ GB = cv2.pyrDown(GB)
244
+ GM = cv2.pyrDown(GM)
245
+ gpA.append(np.float32(GA))
246
+ gpB.append(np.float32(GB))
247
+ gpM.append(np.float32(GM))
248
+ lpA = [gpA[num_levels-1]]
249
+ lpB = [gpB[num_levels-1]]
250
+ gpMr = [gpM[num_levels-1]]
251
+ for i in range(num_levels-1,0,-1):
252
+ LA = np.subtract(gpA[i-1], cv2.pyrUp(gpA[i]))
253
+ LB = np.subtract(gpB[i-1], cv2.pyrUp(gpB[i]))
254
+ lpA.append(LA)
255
+ lpB.append(LB)
256
+ gpMr.append(gpM[i-1])
257
+ LS = []
258
+ for la,lb,gm in zip(lpA,lpB,gpMr):
259
+ ls = la * gm + lb * (1.0 - gm)
260
+ LS.append(ls)
261
+ ls_ = LS[0]
262
+ for i in range(1,num_levels):
263
+ ls_ = cv2.pyrUp(ls_)
264
+ ls_ = cv2.add(ls_, LS[i])
265
+ ls_ = ls_[:height, :width, :]
266
+ #ls_ = (ls_ - np.min(ls_)) * (255.0 / (np.max(ls_) - np.min(ls_)))
267
+ return ls_.clip(0, 255)
268
+
269
+
270
+ def mask_crop(mask, crop):
271
+ top, bottom, left, right = crop
272
+ shape = mask.shape
273
+ top = int(top)
274
+ bottom = int(bottom)
275
+ if top + bottom < shape[1]:
276
+ if top > 0: mask[:top, :] = 0
277
+ if bottom > 0: mask[-bottom:, :] = 0
278
+
279
+ left = int(left)
280
+ right = int(right)
281
+ if left + right < shape[0]:
282
+ if left > 0: mask[:, :left] = 0
283
+ if right > 0: mask[:, -right:] = 0
284
+
285
+ return mask
286
+
287
+ def create_image_grid(images, size=128):
288
+ num_images = len(images)
289
+ num_cols = int(np.ceil(np.sqrt(num_images)))
290
+ num_rows = int(np.ceil(num_images / num_cols))
291
+ grid = np.zeros((num_rows * size, num_cols * size, 3), dtype=np.uint8)
292
+
293
+ for i, image in enumerate(images):
294
+ row_idx = (i // num_cols) * size
295
+ col_idx = (i % num_cols) * size
296
+ image = cv2.resize(image.copy(), (size,size))
297
+ if image.dtype != np.uint8:
298
+ image = (image.astype('float32') * 255).astype('uint8')
299
+ if image.ndim == 2:
300
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
301
+ grid[row_idx:row_idx + size, col_idx:col_idx + size] = image
302
+
303
+ return grid