Nav27 commited on
Commit
cae3c87
·
verified ·
1 Parent(s): 621ca38

Upload 3 files

Browse files
Files changed (3) hide show
  1. __init__.py +7 -0
  2. train.py +11 -0
  3. utils.py +148 -0
__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ from .archs import *
3
+ from .data import *
4
+ from .models import *
5
+ from .utils import *
6
+
7
+ # from .version import *
train.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ import os.path as osp
3
+ from basicsr.train import train_pipeline
4
+
5
+ import gfpgan.archs
6
+ import gfpgan.data
7
+ import gfpgan.models
8
+
9
+ if __name__ == '__main__':
10
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
11
+ train_pipeline(root_path)
utils.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import torch
4
+ from basicsr.utils import img2tensor, tensor2img
5
+ from basicsr.utils.download_util import load_file_from_url
6
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
7
+ from torchvision.transforms.functional import normalize
8
+
9
+ from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear
10
+ from gfpgan.archs.gfpganv1_arch import GFPGANv1
11
+ from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
12
+
13
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
14
+
15
+
16
+ class GFPGANer():
17
+ """Helper for restoration with GFPGAN.
18
+
19
+ It will detect and crop faces, and then resize the faces to 512x512.
20
+ GFPGAN is used to restored the resized faces.
21
+ The background is upsampled with the bg_upsampler.
22
+ Finally, the faces will be pasted back to the upsample background image.
23
+
24
+ Args:
25
+ model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
26
+ upscale (float): The upscale of the final output. Default: 2.
27
+ arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
28
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
29
+ bg_upsampler (nn.Module): The upsampler for the background. Default: None.
30
+ """
31
+
32
+ def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None):
33
+ self.upscale = upscale
34
+ self.bg_upsampler = bg_upsampler
35
+
36
+ # initialize model
37
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
38
+ # initialize the GFP-GAN
39
+ if arch == 'clean':
40
+ self.gfpgan = GFPGANv1Clean(
41
+ out_size=512,
42
+ num_style_feat=512,
43
+ channel_multiplier=channel_multiplier,
44
+ decoder_load_path=None,
45
+ fix_decoder=False,
46
+ num_mlp=8,
47
+ input_is_latent=True,
48
+ different_w=True,
49
+ narrow=1,
50
+ sft_half=True)
51
+ elif arch == 'bilinear':
52
+ self.gfpgan = GFPGANBilinear(
53
+ out_size=512,
54
+ num_style_feat=512,
55
+ channel_multiplier=channel_multiplier,
56
+ decoder_load_path=None,
57
+ fix_decoder=False,
58
+ num_mlp=8,
59
+ input_is_latent=True,
60
+ different_w=True,
61
+ narrow=1,
62
+ sft_half=True)
63
+ elif arch == 'original':
64
+ self.gfpgan = GFPGANv1(
65
+ out_size=512,
66
+ num_style_feat=512,
67
+ channel_multiplier=channel_multiplier,
68
+ decoder_load_path=None,
69
+ fix_decoder=True,
70
+ num_mlp=8,
71
+ input_is_latent=True,
72
+ different_w=True,
73
+ narrow=1,
74
+ sft_half=True)
75
+ elif arch == 'RestoreFormer':
76
+ from gfpgan.archs.restoreformer_arch import RestoreFormer
77
+ self.gfpgan = RestoreFormer()
78
+ # initialize face helper
79
+ self.face_helper = FaceRestoreHelper(
80
+ upscale,
81
+ face_size=512,
82
+ crop_ratio=(1, 1),
83
+ det_model='retinaface_resnet50',
84
+ save_ext='png',
85
+ use_parse=True,
86
+ device=self.device,
87
+ model_rootpath='gfpgan/weights')
88
+
89
+ if model_path.startswith('https://'):
90
+ model_path = load_file_from_url(
91
+ url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None)
92
+ loadnet = torch.load(model_path)
93
+ if 'params_ema' in loadnet:
94
+ keyname = 'params_ema'
95
+ else:
96
+ keyname = 'params'
97
+ self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
98
+ self.gfpgan.eval()
99
+ self.gfpgan = self.gfpgan.to(self.device)
100
+
101
+ @torch.no_grad()
102
+ def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True, weight=0.5):
103
+ self.face_helper.clean_all()
104
+
105
+ if has_aligned: # the inputs are already aligned
106
+ img = cv2.resize(img, (512, 512))
107
+ self.face_helper.cropped_faces = [img]
108
+ else:
109
+ self.face_helper.read_image(img)
110
+ # get face landmarks for each face
111
+ self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
112
+ # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
113
+ # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
114
+ # align and warp each face
115
+ self.face_helper.align_warp_face()
116
+
117
+ # face restoration
118
+ for cropped_face in self.face_helper.cropped_faces:
119
+ # prepare data
120
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
121
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
122
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
123
+
124
+ try:
125
+ output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0]
126
+ # convert to image
127
+ restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
128
+ except RuntimeError as error:
129
+ print(f'\tFailed inference for GFPGAN: {error}.')
130
+ restored_face = cropped_face
131
+
132
+ restored_face = restored_face.astype('uint8')
133
+ self.face_helper.add_restored_face(restored_face)
134
+
135
+ if not has_aligned and paste_back:
136
+ # upsample the background
137
+ if self.bg_upsampler is not None:
138
+ # Now only support RealESRGAN for upsampling background
139
+ bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
140
+ else:
141
+ bg_img = None
142
+
143
+ self.face_helper.get_inverse_affine(None)
144
+ # paste each restored face to the input image
145
+ restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
146
+ return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
147
+ else:
148
+ return self.face_helper.cropped_faces, self.face_helper.restored_faces, None