Upload 5 files
Browse files
scripts/crop_align_face.py
CHANGED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)
|
| 3 |
+
author: lzhbrian (https://lzhbrian.me)
|
| 4 |
+
link: https://gist.github.com/lzhbrian/bde87ab23b499dd02ba4f588258f57d5
|
| 5 |
+
date: 2020.1.5
|
| 6 |
+
note: code is heavily borrowed from
|
| 7 |
+
https://github.com/NVlabs/ffhq-dataset
|
| 8 |
+
http://dlib.net/face_landmark_detection.py.html
|
| 9 |
+
requirements:
|
| 10 |
+
conda install Pillow numpy scipy
|
| 11 |
+
conda install -c conda-forge dlib
|
| 12 |
+
# download face landmark model from:
|
| 13 |
+
# http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import glob
|
| 18 |
+
import numpy as np
|
| 19 |
+
import PIL
|
| 20 |
+
import PIL.Image
|
| 21 |
+
import scipy
|
| 22 |
+
import scipy.ndimage
|
| 23 |
+
import argparse
|
| 24 |
+
from basicsr.utils.download_util import load_file_from_url
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
import dlib
|
| 28 |
+
except ImportError:
|
| 29 |
+
print('Please install dlib by running:' 'conda install -c conda-forge dlib')
|
| 30 |
+
|
| 31 |
+
# download model from: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
|
| 32 |
+
shape_predictor_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/shape_predictor_68_face_landmarks-fbdc2cb8.dat'
|
| 33 |
+
ckpt_path = load_file_from_url(url=shape_predictor_url,
|
| 34 |
+
model_dir='weights/dlib', progress=True, file_name=None)
|
| 35 |
+
predictor = dlib.shape_predictor('weights/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat')
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_landmark(filepath, only_keep_largest=True):
|
| 39 |
+
"""get landmark with dlib
|
| 40 |
+
:return: np.array shape=(68, 2)
|
| 41 |
+
"""
|
| 42 |
+
detector = dlib.get_frontal_face_detector()
|
| 43 |
+
|
| 44 |
+
img = dlib.load_rgb_image(filepath)
|
| 45 |
+
dets = detector(img, 1)
|
| 46 |
+
|
| 47 |
+
# Shangchen modified
|
| 48 |
+
print("\tNumber of faces detected: {}".format(len(dets)))
|
| 49 |
+
if only_keep_largest:
|
| 50 |
+
print('\tOnly keep the largest.')
|
| 51 |
+
face_areas = []
|
| 52 |
+
for k, d in enumerate(dets):
|
| 53 |
+
face_area = (d.right() - d.left()) * (d.bottom() - d.top())
|
| 54 |
+
face_areas.append(face_area)
|
| 55 |
+
|
| 56 |
+
largest_idx = face_areas.index(max(face_areas))
|
| 57 |
+
d = dets[largest_idx]
|
| 58 |
+
shape = predictor(img, d)
|
| 59 |
+
# print("Part 0: {}, Part 1: {} ...".format(
|
| 60 |
+
# shape.part(0), shape.part(1)))
|
| 61 |
+
else:
|
| 62 |
+
for k, d in enumerate(dets):
|
| 63 |
+
# print("Detection {}: Left: {} Top: {} Right: {} Bottom: {}".format(
|
| 64 |
+
# k, d.left(), d.top(), d.right(), d.bottom()))
|
| 65 |
+
# Get the landmarks/parts for the face in box d.
|
| 66 |
+
shape = predictor(img, d)
|
| 67 |
+
# print("Part 0: {}, Part 1: {} ...".format(
|
| 68 |
+
# shape.part(0), shape.part(1)))
|
| 69 |
+
|
| 70 |
+
t = list(shape.parts())
|
| 71 |
+
a = []
|
| 72 |
+
for tt in t:
|
| 73 |
+
a.append([tt.x, tt.y])
|
| 74 |
+
lm = np.array(a)
|
| 75 |
+
# lm is a shape=(68,2) np.array
|
| 76 |
+
return lm
|
| 77 |
+
|
| 78 |
+
def align_face(filepath, out_path):
|
| 79 |
+
"""
|
| 80 |
+
:param filepath: str
|
| 81 |
+
:return: PIL Image
|
| 82 |
+
"""
|
| 83 |
+
try:
|
| 84 |
+
lm = get_landmark(filepath)
|
| 85 |
+
except:
|
| 86 |
+
print('No landmark ...')
|
| 87 |
+
return
|
| 88 |
+
|
| 89 |
+
lm_chin = lm[0:17] # left-right
|
| 90 |
+
lm_eyebrow_left = lm[17:22] # left-right
|
| 91 |
+
lm_eyebrow_right = lm[22:27] # left-right
|
| 92 |
+
lm_nose = lm[27:31] # top-down
|
| 93 |
+
lm_nostrils = lm[31:36] # top-down
|
| 94 |
+
lm_eye_left = lm[36:42] # left-clockwise
|
| 95 |
+
lm_eye_right = lm[42:48] # left-clockwise
|
| 96 |
+
lm_mouth_outer = lm[48:60] # left-clockwise
|
| 97 |
+
lm_mouth_inner = lm[60:68] # left-clockwise
|
| 98 |
+
|
| 99 |
+
# Calculate auxiliary vectors.
|
| 100 |
+
eye_left = np.mean(lm_eye_left, axis=0)
|
| 101 |
+
eye_right = np.mean(lm_eye_right, axis=0)
|
| 102 |
+
eye_avg = (eye_left + eye_right) * 0.5
|
| 103 |
+
eye_to_eye = eye_right - eye_left
|
| 104 |
+
mouth_left = lm_mouth_outer[0]
|
| 105 |
+
mouth_right = lm_mouth_outer[6]
|
| 106 |
+
mouth_avg = (mouth_left + mouth_right) * 0.5
|
| 107 |
+
eye_to_mouth = mouth_avg - eye_avg
|
| 108 |
+
|
| 109 |
+
# Choose oriented crop rectangle.
|
| 110 |
+
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
|
| 111 |
+
x /= np.hypot(*x)
|
| 112 |
+
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
|
| 113 |
+
y = np.flipud(x) * [-1, 1]
|
| 114 |
+
c = eye_avg + eye_to_mouth * 0.1
|
| 115 |
+
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
| 116 |
+
qsize = np.hypot(*x) * 2
|
| 117 |
+
|
| 118 |
+
# read image
|
| 119 |
+
img = PIL.Image.open(filepath)
|
| 120 |
+
|
| 121 |
+
output_size = 512
|
| 122 |
+
transform_size = 4096
|
| 123 |
+
enable_padding = False
|
| 124 |
+
|
| 125 |
+
# Shrink.
|
| 126 |
+
shrink = int(np.floor(qsize / output_size * 0.5))
|
| 127 |
+
if shrink > 1:
|
| 128 |
+
rsize = (int(np.rint(float(img.size[0]) / shrink)),
|
| 129 |
+
int(np.rint(float(img.size[1]) / shrink)))
|
| 130 |
+
img = img.resize(rsize, PIL.Image.ANTIALIAS)
|
| 131 |
+
quad /= shrink
|
| 132 |
+
qsize /= shrink
|
| 133 |
+
|
| 134 |
+
# Crop.
|
| 135 |
+
border = max(int(np.rint(qsize * 0.1)), 3)
|
| 136 |
+
crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
|
| 137 |
+
int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
|
| 138 |
+
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0),
|
| 139 |
+
min(crop[2] + border,
|
| 140 |
+
img.size[0]), min(crop[3] + border, img.size[1]))
|
| 141 |
+
if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
|
| 142 |
+
img = img.crop(crop)
|
| 143 |
+
quad -= crop[0:2]
|
| 144 |
+
|
| 145 |
+
# Pad.
|
| 146 |
+
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
|
| 147 |
+
int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
|
| 148 |
+
pad = (max(-pad[0] + border,
|
| 149 |
+
0), max(-pad[1] + border,
|
| 150 |
+
0), max(pad[2] - img.size[0] + border,
|
| 151 |
+
0), max(pad[3] - img.size[1] + border, 0))
|
| 152 |
+
if enable_padding and max(pad) > border - 4:
|
| 153 |
+
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
|
| 154 |
+
img = np.pad(
|
| 155 |
+
np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)),
|
| 156 |
+
'reflect')
|
| 157 |
+
h, w, _ = img.shape
|
| 158 |
+
y, x, _ = np.ogrid[:h, :w, :1]
|
| 159 |
+
mask = np.maximum(
|
| 160 |
+
1.0 -
|
| 161 |
+
np.minimum(np.float32(x) / pad[0],
|
| 162 |
+
np.float32(w - 1 - x) / pad[2]), 1.0 -
|
| 163 |
+
np.minimum(np.float32(y) / pad[1],
|
| 164 |
+
np.float32(h - 1 - y) / pad[3]))
|
| 165 |
+
blur = qsize * 0.02
|
| 166 |
+
img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) -
|
| 167 |
+
img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
|
| 168 |
+
img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
|
| 169 |
+
img = PIL.Image.fromarray(
|
| 170 |
+
np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
|
| 171 |
+
quad += pad[:2]
|
| 172 |
+
|
| 173 |
+
img = img.transform((transform_size, transform_size), PIL.Image.QUAD,
|
| 174 |
+
(quad + 0.5).flatten(), PIL.Image.BILINEAR)
|
| 175 |
+
|
| 176 |
+
if output_size < transform_size:
|
| 177 |
+
img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
|
| 178 |
+
|
| 179 |
+
# Save aligned image.
|
| 180 |
+
# print('saveing: ', out_path)
|
| 181 |
+
img.save(out_path)
|
| 182 |
+
|
| 183 |
+
return img, np.max(quad[:, 0]) - np.min(quad[:, 0])
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
if __name__ == '__main__':
|
| 187 |
+
parser = argparse.ArgumentParser()
|
| 188 |
+
parser.add_argument('-i', '--in_dir', type=str, default='./inputs/whole_imgs')
|
| 189 |
+
parser.add_argument('-o', '--out_dir', type=str, default='./inputs/cropped_faces')
|
| 190 |
+
args = parser.parse_args()
|
| 191 |
+
|
| 192 |
+
if args.out_dir.endswith('/'): # solve when path ends with /
|
| 193 |
+
args.out_dir = args.out_dir[:-1]
|
| 194 |
+
dir_name = os.path.abspath(args.out_dir)
|
| 195 |
+
os.makedirs(dir_name, exist_ok=True)
|
| 196 |
+
|
| 197 |
+
img_list = sorted(glob.glob(os.path.join(args.in_dir, '*.[jpJP][pnPN]*[gG]')))
|
| 198 |
+
test_img_num = len(img_list)
|
| 199 |
+
|
| 200 |
+
for i, in_path in enumerate(img_list):
|
| 201 |
+
img_name = os.path.basename(in_path)
|
| 202 |
+
print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
|
| 203 |
+
out_path = os.path.join(args.out_dir, in_path.split("/")[-1])
|
| 204 |
+
out_path = out_path.replace('.jpg', '.png')
|
| 205 |
+
size_ = align_face(in_path, out_path)
|
scripts/download_pretrained_models.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
from os import path as osp
|
| 4 |
+
|
| 5 |
+
from basicsr.utils.download_util import load_file_from_url
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def download_pretrained_models(method, file_urls):
|
| 9 |
+
save_path_root = f'./weights/{method}'
|
| 10 |
+
os.makedirs(save_path_root, exist_ok=True)
|
| 11 |
+
|
| 12 |
+
for file_name, file_url in file_urls.items():
|
| 13 |
+
save_path = load_file_from_url(url=file_url, model_dir=save_path_root, progress=True, file_name=file_name)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
if __name__ == '__main__':
|
| 17 |
+
parser = argparse.ArgumentParser()
|
| 18 |
+
|
| 19 |
+
parser.add_argument(
|
| 20 |
+
'method',
|
| 21 |
+
type=str,
|
| 22 |
+
help=("Options: 'CodeFormer' 'facelib' 'dlib'. Set to 'all' to download all the models."))
|
| 23 |
+
args = parser.parse_args()
|
| 24 |
+
|
| 25 |
+
file_urls = {
|
| 26 |
+
'CodeFormer': {
|
| 27 |
+
'codeformer.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
| 28 |
+
},
|
| 29 |
+
'facelib': {
|
| 30 |
+
# 'yolov5l-face.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth',
|
| 31 |
+
'detection_Resnet50_Final.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth',
|
| 32 |
+
'parsing_parsenet.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth'
|
| 33 |
+
},
|
| 34 |
+
'dlib': {
|
| 35 |
+
'mmod_human_face_detector-4cb19393.dat': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/mmod_human_face_detector-4cb19393.dat',
|
| 36 |
+
'shape_predictor_5_face_landmarks-c4b1e980.dat': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/shape_predictor_5_face_landmarks-c4b1e980.dat'
|
| 37 |
+
}
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
if args.method == 'all':
|
| 41 |
+
for method in file_urls.keys():
|
| 42 |
+
download_pretrained_models(method, file_urls[method])
|
| 43 |
+
else:
|
| 44 |
+
download_pretrained_models(args.method, file_urls[args.method])
|
scripts/download_pretrained_models_from_gdrive.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
from os import path as osp
|
| 4 |
+
|
| 5 |
+
# from basicsr.utils.download_util import download_file_from_google_drive
|
| 6 |
+
import gdown
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def download_pretrained_models(method, file_ids):
|
| 10 |
+
save_path_root = f'./weights/{method}'
|
| 11 |
+
os.makedirs(save_path_root, exist_ok=True)
|
| 12 |
+
|
| 13 |
+
for file_name, file_id in file_ids.items():
|
| 14 |
+
file_url = 'https://drive.google.com/uc?id='+file_id
|
| 15 |
+
save_path = osp.abspath(osp.join(save_path_root, file_name))
|
| 16 |
+
if osp.exists(save_path):
|
| 17 |
+
user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
|
| 18 |
+
if user_response.lower() == 'y':
|
| 19 |
+
print(f'Covering {file_name} to {save_path}')
|
| 20 |
+
gdown.download(file_url, save_path, quiet=False)
|
| 21 |
+
# download_file_from_google_drive(file_id, save_path)
|
| 22 |
+
elif user_response.lower() == 'n':
|
| 23 |
+
print(f'Skipping {file_name}')
|
| 24 |
+
else:
|
| 25 |
+
raise ValueError('Wrong input. Only accepts Y/N.')
|
| 26 |
+
else:
|
| 27 |
+
print(f'Downloading {file_name} to {save_path}')
|
| 28 |
+
gdown.download(file_url, save_path, quiet=False)
|
| 29 |
+
# download_file_from_google_drive(file_id, save_path)
|
| 30 |
+
|
| 31 |
+
if __name__ == '__main__':
|
| 32 |
+
parser = argparse.ArgumentParser()
|
| 33 |
+
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
'method',
|
| 36 |
+
type=str,
|
| 37 |
+
help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models."))
|
| 38 |
+
args = parser.parse_args()
|
| 39 |
+
|
| 40 |
+
# file name: file id
|
| 41 |
+
# 'dlib': {
|
| 42 |
+
# 'mmod_human_face_detector-4cb19393.dat': '1qD-OqY8M6j4PWUP_FtqfwUPFPRMu6ubX',
|
| 43 |
+
# 'shape_predictor_5_face_landmarks-c4b1e980.dat': '1vF3WBUApw4662v9Pw6wke3uk1qxnmLdg',
|
| 44 |
+
# 'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1tJyIVdCHaU6IDMDx86BZCxLGZfsWB8yq'
|
| 45 |
+
# }
|
| 46 |
+
file_ids = {
|
| 47 |
+
'CodeFormer': {
|
| 48 |
+
'codeformer.pth': '1v_E_vZvP-dQPF55Kc5SRCjaKTQXDz-JB'
|
| 49 |
+
},
|
| 50 |
+
'facelib': {
|
| 51 |
+
'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV',
|
| 52 |
+
'parsing_parsenet.pth': '16pkohyZZ8ViHGBk3QtVqxLZKzdo466bK'
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
if args.method == 'all':
|
| 57 |
+
for method in file_ids.keys():
|
| 58 |
+
download_pretrained_models(method, file_ids[method])
|
| 59 |
+
else:
|
| 60 |
+
download_pretrained_models(args.method, file_ids[args.method])
|
scripts/generate_latent_gt.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import glob
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import cv2
|
| 6 |
+
import torch
|
| 7 |
+
from torchvision.transforms.functional import normalize
|
| 8 |
+
from basicsr.utils import imwrite, img2tensor, tensor2img
|
| 9 |
+
|
| 10 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
| 11 |
+
|
| 12 |
+
if __name__ == '__main__':
|
| 13 |
+
parser = argparse.ArgumentParser()
|
| 14 |
+
parser.add_argument('-i', '--test_path', type=str, default='datasets/ffhq/ffhq_512')
|
| 15 |
+
parser.add_argument('-o', '--save_root', type=str, default='./experiments/pretrained_models/vqgan')
|
| 16 |
+
parser.add_argument('--codebook_size', type=int, default=1024)
|
| 17 |
+
parser.add_argument('--ckpt_path', type=str, default='./experiments/pretrained_models/vqgan/net_g.pth')
|
| 18 |
+
args = parser.parse_args()
|
| 19 |
+
|
| 20 |
+
if args.save_root.endswith('/'): # solve when path ends with /
|
| 21 |
+
args.save_root = args.save_root[:-1]
|
| 22 |
+
dir_name = os.path.abspath(args.save_root)
|
| 23 |
+
os.makedirs(dir_name, exist_ok=True)
|
| 24 |
+
|
| 25 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 26 |
+
test_path = args.test_path
|
| 27 |
+
save_root = args.save_root
|
| 28 |
+
ckpt_path = args.ckpt_path
|
| 29 |
+
codebook_size = args.codebook_size
|
| 30 |
+
|
| 31 |
+
vqgan = ARCH_REGISTRY.get('VQAutoEncoder')(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',
|
| 32 |
+
codebook_size=codebook_size).to(device)
|
| 33 |
+
checkpoint = torch.load(ckpt_path)['params_ema']
|
| 34 |
+
|
| 35 |
+
vqgan.load_state_dict(checkpoint)
|
| 36 |
+
vqgan.eval()
|
| 37 |
+
|
| 38 |
+
sum_latent = np.zeros((codebook_size)).astype('float64')
|
| 39 |
+
size_latent = 16
|
| 40 |
+
latent = {}
|
| 41 |
+
latent['orig'] = {}
|
| 42 |
+
latent['hflip'] = {}
|
| 43 |
+
for i in ['orig', 'hflip']:
|
| 44 |
+
# for i in ['hflip']:
|
| 45 |
+
for img_path in sorted(glob.glob(os.path.join(test_path, '*.[jp][pn]g'))):
|
| 46 |
+
img_name = os.path.basename(img_path)
|
| 47 |
+
img = cv2.imread(img_path)
|
| 48 |
+
if i == 'hflip':
|
| 49 |
+
cv2.flip(img, 1, img)
|
| 50 |
+
img = img2tensor(img / 255., bgr2rgb=True, float32=True)
|
| 51 |
+
normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
| 52 |
+
img = img.unsqueeze(0).to(device)
|
| 53 |
+
with torch.no_grad():
|
| 54 |
+
# output = net(img)[0]
|
| 55 |
+
x, feat_dict = vqgan.encoder(img, True)
|
| 56 |
+
x, _, log = vqgan.quantize(x)
|
| 57 |
+
# del output
|
| 58 |
+
torch.cuda.empty_cache()
|
| 59 |
+
|
| 60 |
+
min_encoding_indices = log['min_encoding_indices']
|
| 61 |
+
min_encoding_indices = min_encoding_indices.view(size_latent,size_latent)
|
| 62 |
+
latent[i][img_name[:-4]] = min_encoding_indices.cpu().numpy()
|
| 63 |
+
print(img_name, latent[i][img_name[:-4]].shape)
|
| 64 |
+
|
| 65 |
+
latent_save_path = os.path.join(save_root, f'latent_gt_code{codebook_size}.pth')
|
| 66 |
+
torch.save(latent, latent_save_path)
|
| 67 |
+
print(f'\nLatent GT code are saved in {save_root}')
|
scripts/inference_vqgan.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import glob
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import cv2
|
| 6 |
+
import torch
|
| 7 |
+
from torchvision.transforms.functional import normalize
|
| 8 |
+
from basicsr.utils import imwrite, img2tensor, tensor2img
|
| 9 |
+
|
| 10 |
+
from basicsr.utils.registry import ARCH_REGISTRY
|
| 11 |
+
|
| 12 |
+
if __name__ == '__main__':
|
| 13 |
+
parser = argparse.ArgumentParser()
|
| 14 |
+
parser.add_argument('-i', '--test_path', type=str, default='datasets/ffhq/ffhq_512')
|
| 15 |
+
parser.add_argument('-o', '--save_root', type=str, default='./results/vqgan_rec')
|
| 16 |
+
parser.add_argument('--codebook_size', type=int, default=1024)
|
| 17 |
+
parser.add_argument('--ckpt_path', type=str, default='./experiments/pretrained_models/vqgan/net_g.pth')
|
| 18 |
+
args = parser.parse_args()
|
| 19 |
+
|
| 20 |
+
if args.save_root.endswith('/'): # solve when path ends with /
|
| 21 |
+
args.save_root = args.save_root[:-1]
|
| 22 |
+
dir_name = os.path.abspath(args.save_root)
|
| 23 |
+
os.makedirs(dir_name, exist_ok=True)
|
| 24 |
+
|
| 25 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 26 |
+
test_path = args.test_path
|
| 27 |
+
save_root = args.save_root
|
| 28 |
+
ckpt_path = args.ckpt_path
|
| 29 |
+
codebook_size = args.codebook_size
|
| 30 |
+
|
| 31 |
+
vqgan = ARCH_REGISTRY.get('VQAutoEncoder')(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',
|
| 32 |
+
codebook_size=codebook_size).to(device)
|
| 33 |
+
checkpoint = torch.load(ckpt_path)['params_ema']
|
| 34 |
+
|
| 35 |
+
vqgan.load_state_dict(checkpoint)
|
| 36 |
+
vqgan.eval()
|
| 37 |
+
|
| 38 |
+
for img_path in sorted(glob.glob(os.path.join(test_path, '*.[jp][pn]g'))):
|
| 39 |
+
img_name = os.path.basename(img_path)
|
| 40 |
+
print(img_name)
|
| 41 |
+
img = cv2.imread(img_path)
|
| 42 |
+
img = img2tensor(img / 255., bgr2rgb=True, float32=True)
|
| 43 |
+
normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
| 44 |
+
img = img.unsqueeze(0).to(device)
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
output = vqgan(img)[0]
|
| 47 |
+
output = tensor2img(output, min_max=[-1,1])
|
| 48 |
+
img = tensor2img(img, min_max=[-1,1])
|
| 49 |
+
restored_img = np.concatenate([img, output], axis=1)
|
| 50 |
+
restored_img = output
|
| 51 |
+
del output
|
| 52 |
+
torch.cuda.empty_cache()
|
| 53 |
+
|
| 54 |
+
path = os.path.splitext(os.path.join(save_root, img_name))[0]
|
| 55 |
+
save_path = f'{path}.png'
|
| 56 |
+
imwrite(restored_img, save_path)
|
| 57 |
+
|
| 58 |
+
print(f'\nAll results are saved in {save_root}')
|
| 59 |
+
|