dxcanh commited on
Commit
6e647b1
Β·
verified Β·
1 Parent(s): 9dadf50

Upload 5 files

Browse files
scripts/crop_align_face.py CHANGED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)
3
+ author: lzhbrian (https://lzhbrian.me)
4
+ link: https://gist.github.com/lzhbrian/bde87ab23b499dd02ba4f588258f57d5
5
+ date: 2020.1.5
6
+ note: code is heavily borrowed from
7
+ https://github.com/NVlabs/ffhq-dataset
8
+ http://dlib.net/face_landmark_detection.py.html
9
+ requirements:
10
+ conda install Pillow numpy scipy
11
+ conda install -c conda-forge dlib
12
+ # download face landmark model from:
13
+ # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
14
+ """
15
+
16
+ import os
17
+ import glob
18
+ import numpy as np
19
+ import PIL
20
+ import PIL.Image
21
+ import scipy
22
+ import scipy.ndimage
23
+ import argparse
24
+ from basicsr.utils.download_util import load_file_from_url
25
+
26
+ try:
27
+ import dlib
28
+ except ImportError:
29
+ print('Please install dlib by running:' 'conda install -c conda-forge dlib')
30
+
31
+ # download model from: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
32
+ shape_predictor_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/shape_predictor_68_face_landmarks-fbdc2cb8.dat'
33
+ ckpt_path = load_file_from_url(url=shape_predictor_url,
34
+ model_dir='weights/dlib', progress=True, file_name=None)
35
+ predictor = dlib.shape_predictor('weights/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat')
36
+
37
+
38
+ def get_landmark(filepath, only_keep_largest=True):
39
+ """get landmark with dlib
40
+ :return: np.array shape=(68, 2)
41
+ """
42
+ detector = dlib.get_frontal_face_detector()
43
+
44
+ img = dlib.load_rgb_image(filepath)
45
+ dets = detector(img, 1)
46
+
47
+ # Shangchen modified
48
+ print("\tNumber of faces detected: {}".format(len(dets)))
49
+ if only_keep_largest:
50
+ print('\tOnly keep the largest.')
51
+ face_areas = []
52
+ for k, d in enumerate(dets):
53
+ face_area = (d.right() - d.left()) * (d.bottom() - d.top())
54
+ face_areas.append(face_area)
55
+
56
+ largest_idx = face_areas.index(max(face_areas))
57
+ d = dets[largest_idx]
58
+ shape = predictor(img, d)
59
+ # print("Part 0: {}, Part 1: {} ...".format(
60
+ # shape.part(0), shape.part(1)))
61
+ else:
62
+ for k, d in enumerate(dets):
63
+ # print("Detection {}: Left: {} Top: {} Right: {} Bottom: {}".format(
64
+ # k, d.left(), d.top(), d.right(), d.bottom()))
65
+ # Get the landmarks/parts for the face in box d.
66
+ shape = predictor(img, d)
67
+ # print("Part 0: {}, Part 1: {} ...".format(
68
+ # shape.part(0), shape.part(1)))
69
+
70
+ t = list(shape.parts())
71
+ a = []
72
+ for tt in t:
73
+ a.append([tt.x, tt.y])
74
+ lm = np.array(a)
75
+ # lm is a shape=(68,2) np.array
76
+ return lm
77
+
78
+ def align_face(filepath, out_path):
79
+ """
80
+ :param filepath: str
81
+ :return: PIL Image
82
+ """
83
+ try:
84
+ lm = get_landmark(filepath)
85
+ except:
86
+ print('No landmark ...')
87
+ return
88
+
89
+ lm_chin = lm[0:17] # left-right
90
+ lm_eyebrow_left = lm[17:22] # left-right
91
+ lm_eyebrow_right = lm[22:27] # left-right
92
+ lm_nose = lm[27:31] # top-down
93
+ lm_nostrils = lm[31:36] # top-down
94
+ lm_eye_left = lm[36:42] # left-clockwise
95
+ lm_eye_right = lm[42:48] # left-clockwise
96
+ lm_mouth_outer = lm[48:60] # left-clockwise
97
+ lm_mouth_inner = lm[60:68] # left-clockwise
98
+
99
+ # Calculate auxiliary vectors.
100
+ eye_left = np.mean(lm_eye_left, axis=0)
101
+ eye_right = np.mean(lm_eye_right, axis=0)
102
+ eye_avg = (eye_left + eye_right) * 0.5
103
+ eye_to_eye = eye_right - eye_left
104
+ mouth_left = lm_mouth_outer[0]
105
+ mouth_right = lm_mouth_outer[6]
106
+ mouth_avg = (mouth_left + mouth_right) * 0.5
107
+ eye_to_mouth = mouth_avg - eye_avg
108
+
109
+ # Choose oriented crop rectangle.
110
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
111
+ x /= np.hypot(*x)
112
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
113
+ y = np.flipud(x) * [-1, 1]
114
+ c = eye_avg + eye_to_mouth * 0.1
115
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
116
+ qsize = np.hypot(*x) * 2
117
+
118
+ # read image
119
+ img = PIL.Image.open(filepath)
120
+
121
+ output_size = 512
122
+ transform_size = 4096
123
+ enable_padding = False
124
+
125
+ # Shrink.
126
+ shrink = int(np.floor(qsize / output_size * 0.5))
127
+ if shrink > 1:
128
+ rsize = (int(np.rint(float(img.size[0]) / shrink)),
129
+ int(np.rint(float(img.size[1]) / shrink)))
130
+ img = img.resize(rsize, PIL.Image.ANTIALIAS)
131
+ quad /= shrink
132
+ qsize /= shrink
133
+
134
+ # Crop.
135
+ border = max(int(np.rint(qsize * 0.1)), 3)
136
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
137
+ int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
138
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0),
139
+ min(crop[2] + border,
140
+ img.size[0]), min(crop[3] + border, img.size[1]))
141
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
142
+ img = img.crop(crop)
143
+ quad -= crop[0:2]
144
+
145
+ # Pad.
146
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
147
+ int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
148
+ pad = (max(-pad[0] + border,
149
+ 0), max(-pad[1] + border,
150
+ 0), max(pad[2] - img.size[0] + border,
151
+ 0), max(pad[3] - img.size[1] + border, 0))
152
+ if enable_padding and max(pad) > border - 4:
153
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
154
+ img = np.pad(
155
+ np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)),
156
+ 'reflect')
157
+ h, w, _ = img.shape
158
+ y, x, _ = np.ogrid[:h, :w, :1]
159
+ mask = np.maximum(
160
+ 1.0 -
161
+ np.minimum(np.float32(x) / pad[0],
162
+ np.float32(w - 1 - x) / pad[2]), 1.0 -
163
+ np.minimum(np.float32(y) / pad[1],
164
+ np.float32(h - 1 - y) / pad[3]))
165
+ blur = qsize * 0.02
166
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) -
167
+ img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
168
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
169
+ img = PIL.Image.fromarray(
170
+ np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
171
+ quad += pad[:2]
172
+
173
+ img = img.transform((transform_size, transform_size), PIL.Image.QUAD,
174
+ (quad + 0.5).flatten(), PIL.Image.BILINEAR)
175
+
176
+ if output_size < transform_size:
177
+ img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
178
+
179
+ # Save aligned image.
180
+ # print('saveing: ', out_path)
181
+ img.save(out_path)
182
+
183
+ return img, np.max(quad[:, 0]) - np.min(quad[:, 0])
184
+
185
+
186
+ if __name__ == '__main__':
187
+ parser = argparse.ArgumentParser()
188
+ parser.add_argument('-i', '--in_dir', type=str, default='./inputs/whole_imgs')
189
+ parser.add_argument('-o', '--out_dir', type=str, default='./inputs/cropped_faces')
190
+ args = parser.parse_args()
191
+
192
+ if args.out_dir.endswith('/'): # solve when path ends with /
193
+ args.out_dir = args.out_dir[:-1]
194
+ dir_name = os.path.abspath(args.out_dir)
195
+ os.makedirs(dir_name, exist_ok=True)
196
+
197
+ img_list = sorted(glob.glob(os.path.join(args.in_dir, '*.[jpJP][pnPN]*[gG]')))
198
+ test_img_num = len(img_list)
199
+
200
+ for i, in_path in enumerate(img_list):
201
+ img_name = os.path.basename(in_path)
202
+ print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
203
+ out_path = os.path.join(args.out_dir, in_path.split("/")[-1])
204
+ out_path = out_path.replace('.jpg', '.png')
205
+ size_ = align_face(in_path, out_path)
scripts/download_pretrained_models.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from os import path as osp
4
+
5
+ from basicsr.utils.download_util import load_file_from_url
6
+
7
+
8
+ def download_pretrained_models(method, file_urls):
9
+ save_path_root = f'./weights/{method}'
10
+ os.makedirs(save_path_root, exist_ok=True)
11
+
12
+ for file_name, file_url in file_urls.items():
13
+ save_path = load_file_from_url(url=file_url, model_dir=save_path_root, progress=True, file_name=file_name)
14
+
15
+
16
+ if __name__ == '__main__':
17
+ parser = argparse.ArgumentParser()
18
+
19
+ parser.add_argument(
20
+ 'method',
21
+ type=str,
22
+ help=("Options: 'CodeFormer' 'facelib' 'dlib'. Set to 'all' to download all the models."))
23
+ args = parser.parse_args()
24
+
25
+ file_urls = {
26
+ 'CodeFormer': {
27
+ 'codeformer.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
28
+ },
29
+ 'facelib': {
30
+ # 'yolov5l-face.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth',
31
+ 'detection_Resnet50_Final.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth',
32
+ 'parsing_parsenet.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth'
33
+ },
34
+ 'dlib': {
35
+ 'mmod_human_face_detector-4cb19393.dat': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/mmod_human_face_detector-4cb19393.dat',
36
+ 'shape_predictor_5_face_landmarks-c4b1e980.dat': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/shape_predictor_5_face_landmarks-c4b1e980.dat'
37
+ }
38
+ }
39
+
40
+ if args.method == 'all':
41
+ for method in file_urls.keys():
42
+ download_pretrained_models(method, file_urls[method])
43
+ else:
44
+ download_pretrained_models(args.method, file_urls[args.method])
scripts/download_pretrained_models_from_gdrive.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from os import path as osp
4
+
5
+ # from basicsr.utils.download_util import download_file_from_google_drive
6
+ import gdown
7
+
8
+
9
+ def download_pretrained_models(method, file_ids):
10
+ save_path_root = f'./weights/{method}'
11
+ os.makedirs(save_path_root, exist_ok=True)
12
+
13
+ for file_name, file_id in file_ids.items():
14
+ file_url = 'https://drive.google.com/uc?id='+file_id
15
+ save_path = osp.abspath(osp.join(save_path_root, file_name))
16
+ if osp.exists(save_path):
17
+ user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
18
+ if user_response.lower() == 'y':
19
+ print(f'Covering {file_name} to {save_path}')
20
+ gdown.download(file_url, save_path, quiet=False)
21
+ # download_file_from_google_drive(file_id, save_path)
22
+ elif user_response.lower() == 'n':
23
+ print(f'Skipping {file_name}')
24
+ else:
25
+ raise ValueError('Wrong input. Only accepts Y/N.')
26
+ else:
27
+ print(f'Downloading {file_name} to {save_path}')
28
+ gdown.download(file_url, save_path, quiet=False)
29
+ # download_file_from_google_drive(file_id, save_path)
30
+
31
+ if __name__ == '__main__':
32
+ parser = argparse.ArgumentParser()
33
+
34
+ parser.add_argument(
35
+ 'method',
36
+ type=str,
37
+ help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models."))
38
+ args = parser.parse_args()
39
+
40
+ # file name: file id
41
+ # 'dlib': {
42
+ # 'mmod_human_face_detector-4cb19393.dat': '1qD-OqY8M6j4PWUP_FtqfwUPFPRMu6ubX',
43
+ # 'shape_predictor_5_face_landmarks-c4b1e980.dat': '1vF3WBUApw4662v9Pw6wke3uk1qxnmLdg',
44
+ # 'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1tJyIVdCHaU6IDMDx86BZCxLGZfsWB8yq'
45
+ # }
46
+ file_ids = {
47
+ 'CodeFormer': {
48
+ 'codeformer.pth': '1v_E_vZvP-dQPF55Kc5SRCjaKTQXDz-JB'
49
+ },
50
+ 'facelib': {
51
+ 'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV',
52
+ 'parsing_parsenet.pth': '16pkohyZZ8ViHGBk3QtVqxLZKzdo466bK'
53
+ }
54
+ }
55
+
56
+ if args.method == 'all':
57
+ for method in file_ids.keys():
58
+ download_pretrained_models(method, file_ids[method])
59
+ else:
60
+ download_pretrained_models(args.method, file_ids[args.method])
scripts/generate_latent_gt.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import numpy as np
4
+ import os
5
+ import cv2
6
+ import torch
7
+ from torchvision.transforms.functional import normalize
8
+ from basicsr.utils import imwrite, img2tensor, tensor2img
9
+
10
+ from basicsr.utils.registry import ARCH_REGISTRY
11
+
12
+ if __name__ == '__main__':
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument('-i', '--test_path', type=str, default='datasets/ffhq/ffhq_512')
15
+ parser.add_argument('-o', '--save_root', type=str, default='./experiments/pretrained_models/vqgan')
16
+ parser.add_argument('--codebook_size', type=int, default=1024)
17
+ parser.add_argument('--ckpt_path', type=str, default='./experiments/pretrained_models/vqgan/net_g.pth')
18
+ args = parser.parse_args()
19
+
20
+ if args.save_root.endswith('/'): # solve when path ends with /
21
+ args.save_root = args.save_root[:-1]
22
+ dir_name = os.path.abspath(args.save_root)
23
+ os.makedirs(dir_name, exist_ok=True)
24
+
25
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
+ test_path = args.test_path
27
+ save_root = args.save_root
28
+ ckpt_path = args.ckpt_path
29
+ codebook_size = args.codebook_size
30
+
31
+ vqgan = ARCH_REGISTRY.get('VQAutoEncoder')(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',
32
+ codebook_size=codebook_size).to(device)
33
+ checkpoint = torch.load(ckpt_path)['params_ema']
34
+
35
+ vqgan.load_state_dict(checkpoint)
36
+ vqgan.eval()
37
+
38
+ sum_latent = np.zeros((codebook_size)).astype('float64')
39
+ size_latent = 16
40
+ latent = {}
41
+ latent['orig'] = {}
42
+ latent['hflip'] = {}
43
+ for i in ['orig', 'hflip']:
44
+ # for i in ['hflip']:
45
+ for img_path in sorted(glob.glob(os.path.join(test_path, '*.[jp][pn]g'))):
46
+ img_name = os.path.basename(img_path)
47
+ img = cv2.imread(img_path)
48
+ if i == 'hflip':
49
+ cv2.flip(img, 1, img)
50
+ img = img2tensor(img / 255., bgr2rgb=True, float32=True)
51
+ normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
52
+ img = img.unsqueeze(0).to(device)
53
+ with torch.no_grad():
54
+ # output = net(img)[0]
55
+ x, feat_dict = vqgan.encoder(img, True)
56
+ x, _, log = vqgan.quantize(x)
57
+ # del output
58
+ torch.cuda.empty_cache()
59
+
60
+ min_encoding_indices = log['min_encoding_indices']
61
+ min_encoding_indices = min_encoding_indices.view(size_latent,size_latent)
62
+ latent[i][img_name[:-4]] = min_encoding_indices.cpu().numpy()
63
+ print(img_name, latent[i][img_name[:-4]].shape)
64
+
65
+ latent_save_path = os.path.join(save_root, f'latent_gt_code{codebook_size}.pth')
66
+ torch.save(latent, latent_save_path)
67
+ print(f'\nLatent GT code are saved in {save_root}')
scripts/inference_vqgan.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import numpy as np
4
+ import os
5
+ import cv2
6
+ import torch
7
+ from torchvision.transforms.functional import normalize
8
+ from basicsr.utils import imwrite, img2tensor, tensor2img
9
+
10
+ from basicsr.utils.registry import ARCH_REGISTRY
11
+
12
+ if __name__ == '__main__':
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument('-i', '--test_path', type=str, default='datasets/ffhq/ffhq_512')
15
+ parser.add_argument('-o', '--save_root', type=str, default='./results/vqgan_rec')
16
+ parser.add_argument('--codebook_size', type=int, default=1024)
17
+ parser.add_argument('--ckpt_path', type=str, default='./experiments/pretrained_models/vqgan/net_g.pth')
18
+ args = parser.parse_args()
19
+
20
+ if args.save_root.endswith('/'): # solve when path ends with /
21
+ args.save_root = args.save_root[:-1]
22
+ dir_name = os.path.abspath(args.save_root)
23
+ os.makedirs(dir_name, exist_ok=True)
24
+
25
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
+ test_path = args.test_path
27
+ save_root = args.save_root
28
+ ckpt_path = args.ckpt_path
29
+ codebook_size = args.codebook_size
30
+
31
+ vqgan = ARCH_REGISTRY.get('VQAutoEncoder')(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',
32
+ codebook_size=codebook_size).to(device)
33
+ checkpoint = torch.load(ckpt_path)['params_ema']
34
+
35
+ vqgan.load_state_dict(checkpoint)
36
+ vqgan.eval()
37
+
38
+ for img_path in sorted(glob.glob(os.path.join(test_path, '*.[jp][pn]g'))):
39
+ img_name = os.path.basename(img_path)
40
+ print(img_name)
41
+ img = cv2.imread(img_path)
42
+ img = img2tensor(img / 255., bgr2rgb=True, float32=True)
43
+ normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
44
+ img = img.unsqueeze(0).to(device)
45
+ with torch.no_grad():
46
+ output = vqgan(img)[0]
47
+ output = tensor2img(output, min_max=[-1,1])
48
+ img = tensor2img(img, min_max=[-1,1])
49
+ restored_img = np.concatenate([img, output], axis=1)
50
+ restored_img = output
51
+ del output
52
+ torch.cuda.empty_cache()
53
+
54
+ path = os.path.splitext(os.path.join(save_root, img_name))[0]
55
+ save_path = f'{path}.png'
56
+ imwrite(restored_img, save_path)
57
+
58
+ print(f'\nAll results are saved in {save_root}')
59
+