Spaces:
Configuration error
Configuration error
| import os | |
| import cv2 | |
| import torch | |
| from gfpgan import GFPGANer | |
| from tqdm import tqdm | |
| from visualizr import logger | |
| from visualizr.face_sr.videoio import load_video_to_cv2 | |
| class GeneratorWithLen(object): | |
| """From https://stackoverflow.com/a/7460929""" | |
| def __init__(self, gen, length): | |
| self.gen = gen | |
| self.length = length | |
| def __len__(self): | |
| return self.length | |
| def __iter__(self): | |
| return self.gen | |
| def enhancer_list(images, method="gfpgan", bg_upsampler="realesrgan"): | |
| gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler) | |
| return list(gen) | |
| def enhancer_generator_with_len(images, method="gfpgan", bg_upsampler="realesrgan"): | |
| """Provide a generator with a __len__ method so that it can passed to functions that | |
| call len()""" | |
| if os.path.isfile(images): # handle video to images | |
| images = load_video_to_cv2(images) | |
| gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler) | |
| gen_with_len = GeneratorWithLen(gen, len(images)) | |
| return gen_with_len | |
| def enhancer_generator_no_len(images, method="gfpgan", bg_upsampler="realesrgan"): | |
| """Provide a generator function so that all of the enhanced images don't need | |
| to be stored in memory at the same time. This can save tons of RAM compared to | |
| the enhancer function.""" | |
| if method not in ["gfpgan", "RestoreFormer", "codeformer"]: | |
| raise ValueError(f"Wrong model version {method}.") | |
| logger.info("face enhancer....") | |
| if not isinstance(images, list) and os.path.isfile( | |
| images | |
| ): # handle video to images | |
| images = load_video_to_cv2(images) | |
| # ------------------------ set up GFPGAN restorer ------------------------ | |
| match method: | |
| case "gfpgan": | |
| arch = "clean" | |
| channel_multiplier = 2 | |
| model_name = "GFPGANv1.4" | |
| url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" | |
| case "RestoreFormer": | |
| arch = "RestoreFormer" | |
| channel_multiplier = 2 | |
| model_name = "RestoreFormer" | |
| url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth" | |
| case "codeformer": | |
| arch = "CodeFormer" | |
| channel_multiplier = 2 | |
| model_name = "CodeFormer" | |
| url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth" | |
| # ------------------------ set up background upsampler ------------------------ | |
| if bg_upsampler == "realesrgan": | |
| if not torch.cuda.is_available(): # CPU | |
| import warnings | |
| warnings.warn( | |
| "The unoptimized RealESRGAN is slow on CPU. We do not use it. " | |
| "If you really want to use it, please modify the corresponding codes." | |
| ) | |
| bg_upsampler = None | |
| else: | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| from realesrgan import RealESRGANer | |
| model = RRDBNet( | |
| num_in_ch=3, | |
| num_out_ch=3, | |
| num_feat=64, | |
| num_block=23, | |
| num_grow_ch=32, | |
| scale=2, | |
| ) | |
| bg_upsampler = RealESRGANer( | |
| scale=2, | |
| model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", | |
| model=model, | |
| tile=400, | |
| tile_pad=10, | |
| pre_pad=0, | |
| half=True, | |
| ) # need to set False in CPU mode | |
| else: | |
| bg_upsampler = None | |
| # determine model paths | |
| model_path = os.path.join("gfpgan/weights", model_name + ".pth") | |
| if not os.path.isfile(model_path): | |
| model_path = os.path.join("checkpoints", model_name + ".pth") | |
| if not os.path.isfile(model_path): | |
| # download pre-trained models from url | |
| model_path = url | |
| restorer = GFPGANer( | |
| model_path=model_path, | |
| upscale=2, | |
| arch=arch, | |
| channel_multiplier=channel_multiplier, | |
| bg_upsampler=bg_upsampler, | |
| ) | |
| # ------------------------ restore ------------------------ | |
| for idx in tqdm(range(len(images)), "Face Enhancer:"): | |
| img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR) | |
| # restore faces and background if necessary | |
| cropped_faces, restored_faces, r_img = restorer.enhance( | |
| img, has_aligned=False, only_center_face=False, paste_back=True | |
| ) | |
| r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB) | |
| yield r_img | |