| | from __future__ import annotations
|
| |
|
| | import logging
|
| | import os
|
| |
|
| | import torch
|
| |
|
| | from modules import (
|
| | devices,
|
| | errors,
|
| | face_restoration,
|
| | face_restoration_utils,
|
| | modelloader,
|
| | shared,
|
| | )
|
| |
|
| | logger = logging.getLogger(__name__)
|
| | model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
| | model_download_name = "GFPGANv1.4.pth"
|
| | gfpgan_face_restorer: face_restoration.FaceRestoration | None = None
|
| |
|
| |
|
| | class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
|
| | def name(self):
|
| | return "GFPGAN"
|
| |
|
| | def get_device(self):
|
| | return devices.device_gfpgan
|
| |
|
| | def load_net(self) -> torch.Module:
|
| | for model_path in modelloader.load_models(
|
| | model_path=self.model_path,
|
| | model_url=model_url,
|
| | command_path=self.model_path,
|
| | download_name=model_download_name,
|
| | ext_filter=['.pth'],
|
| | ):
|
| | if 'GFPGAN' in os.path.basename(model_path):
|
| | model = modelloader.load_spandrel_model(
|
| | model_path,
|
| | device=self.get_device(),
|
| | expected_architecture='GFPGAN',
|
| | ).model
|
| | model.different_w = True
|
| | return model
|
| | raise ValueError("No GFPGAN model found")
|
| |
|
| | def restore(self, np_image):
|
| | def restore_face(cropped_face_t):
|
| | assert self.net is not None
|
| | return self.net(cropped_face_t, return_rgb=False)[0]
|
| |
|
| | return self.restore_with_helper(np_image, restore_face)
|
| |
|
| |
|
| | def gfpgan_fix_faces(np_image):
|
| | if gfpgan_face_restorer:
|
| | return gfpgan_face_restorer.restore(np_image)
|
| | logger.warning("GFPGAN face restorer not set up")
|
| | return np_image
|
| |
|
| |
|
| | def setup_model(dirname: str) -> None:
|
| | global gfpgan_face_restorer
|
| |
|
| | try:
|
| | face_restoration_utils.patch_facexlib(dirname)
|
| | gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname)
|
| | shared.face_restorers.append(gfpgan_face_restorer)
|
| | except Exception:
|
| | errors.report("Error setting up GFPGAN", exc_info=True)
|
| |
|