| import argparse
|
| import cv2
|
| import os
|
|
|
| from imutils import paths
|
| from tqdm import tqdm
|
| from config import *
|
| from utils import get_face_enhancer, get_upsampler
|
|
|
|
|
| def process(image_path, upsampler_name, face_enhancer_name=None, scale=2, device="cpu"):
|
| if scale > 4:
|
| scale = 4
|
| try:
|
| img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
|
|
|
| h, w = img.shape[0:2]
|
| if h > 3500 or w > 3500:
|
| output = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| return output
|
|
|
| if (h < 300 and w < 300) and upsampler_name != "srcnn":
|
| img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
|
| return img
|
|
|
| upsampler = get_upsampler(upsampler_name, device=device)
|
|
|
| if face_enhancer_name:
|
| face_enhancer = get_face_enhancer(
|
| face_enhancer_name, scale, upsampler, device=device
|
| )
|
| else:
|
| face_enhancer = None
|
|
|
| try:
|
| if face_enhancer is not None:
|
| _, _, output = face_enhancer.enhance(
|
| img, has_aligned=False, only_center_face=False, paste_back=True
|
| )
|
| else:
|
| output, _ = upsampler.enhance(img, outscale=scale)
|
| except RuntimeError as error:
|
| print(f"Runtime error: {error}")
|
|
|
| return output
|
| except Exception as error:
|
| print(f"global exception: {error}")
|
|
|
|
|
| def main(args: argparse.Namespace) -> None:
|
| device = args.device
|
| scale = args.scale
|
|
|
| upsampler_name = args.upsampler
|
| face_enhancer_name = args.face_enhancer
|
|
|
| if face_enhancer_name and ("srcnn" in upsampler_name or "anime" in upsampler_name):
|
| print(
|
| "Warnings: SRCNN and Anime model aren't compatible with face enhance. We will turn it off for you"
|
| )
|
| face_enhancer_name = None
|
|
|
| os.makedirs(args.output, exist_ok=True)
|
| if not os.path.exists(args.input):
|
| raise ValueError("The input directory doesn't exist!")
|
| elif not os.path.isdir(args.input):
|
| image_paths = [args.input]
|
| else:
|
| image_paths = paths.list_images(args.input)
|
|
|
| with tqdm(image_paths) as pbar:
|
| for image_path in pbar:
|
| filename = os.path.basename(image_path)
|
| pbar.set_postfix_str(f"Processing {image_path}")
|
| upsampled_image = process(
|
| image_path=image_path,
|
| upsampler_name=upsampler_name,
|
| face_enhancer_name=face_enhancer_name,
|
| scale=scale,
|
| device=device,
|
| )
|
| if upsampled_image is not None:
|
| save_path = os.path.join(args.output, filename)
|
| cv2.imwrite(save_path, upsampled_image)
|
|
|
|
|
| if __name__ == "__main__":
|
| parser = argparse.ArgumentParser(
|
| description=(
|
| "Runs automatic detection and mask generation on an input image or directory of images"
|
| )
|
| )
|
|
|
| parser.add_argument(
|
| "--input",
|
| "-i",
|
| type=str,
|
| required=True,
|
| help="Path to either a single input image or folder of images.",
|
| )
|
|
|
| parser.add_argument(
|
| "--output",
|
| "-o",
|
| type=str,
|
| required=True,
|
| help="Path to the output directory.",
|
| )
|
|
|
| parser.add_argument(
|
| "--upsampler",
|
| type=str,
|
| default="realesr-general-x4v3",
|
| choices=[
|
| "srcnn",
|
| "RealESRGAN_x2plus",
|
| "RealESRGAN_x4plus",
|
| "RealESRNet_x4plus",
|
| "realesr-general-x4v3",
|
| "RealESRGAN_x4plus_anime_6B",
|
| "realesr-animevideov3",
|
| ],
|
| help="The type of upsampler model to load",
|
| )
|
|
|
| parser.add_argument(
|
| "--face-enhancer",
|
| type=str,
|
| choices=["GFPGANv1.3", "GFPGANv1.4", "RestoreFormer"],
|
| help="The type of face enhancer model to load",
|
| )
|
|
|
| parser.add_argument(
|
| "--scale",
|
| type=float,
|
| default=2,
|
| choices=[1.5, 2, 2.5, 3, 3.5, 4],
|
| help="scaling factor",
|
| )
|
| parser.add_argument(
|
| "--device", type=str, default="cuda", help="The device to run upsampling on."
|
| )
|
| args = parser.parse_args()
|
| main(args)
|
|
|