|
|
import sys |
|
|
import cv2 |
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from gradio import Interface |
|
|
|
|
|
import torch |
|
|
device = torch.device('cpu') |
|
|
|
|
|
|
|
|
checkpoint = torch.load( |
|
|
'/CodeFormer/weights/CodeFormer/codeformer.pth', |
|
|
map_location='cpu' |
|
|
) |
|
|
net.load_state_dict(checkpoint['params_ema']) |
|
|
sys.path.insert(0, '/CodeFormer') |
|
|
|
|
|
from basicsr.utils import img2tensor, tensor2img |
|
|
from basicsr.archs.codeformer_arch import CodeFormer |
|
|
from facexlib.utils.face_restoration_helper import FaceRestorationHelper |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
net = CodeFormer( |
|
|
dim_embd=512, |
|
|
codebook_size=1024, |
|
|
n_head=8, |
|
|
n_layers=9, |
|
|
connect_list=['32', '64', '128', '256'] |
|
|
).to(device) |
|
|
net.load_state_dict(torch.load('/CodeFormer/weights/CodeFormer/codeformer.pth')['params_ema']) |
|
|
net.eval() |
|
|
|
|
|
face_helper = FaceRestorationHelper( |
|
|
upscale_factor=1, |
|
|
face_size=512, |
|
|
crop_ratio=(1, 1), |
|
|
det_model='retinaface_resnet50', |
|
|
save_ext='png', |
|
|
use_parse=True, |
|
|
device=device |
|
|
) |
|
|
|
|
|
def process_image(img: np.ndarray, w: float = 0.7) -> np.ndarray: |
|
|
face_helper.clean_all() |
|
|
face_helper.read_image(img) |
|
|
face_helper.get_face_landmarks_5() |
|
|
face_helper.align_warp_face() |
|
|
|
|
|
for cropped_face in face_helper.cropped_faces: |
|
|
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) |
|
|
cropped_face_t = cropped_face_t.unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = net(cropped_face_t, w=w, adain=True)[0] |
|
|
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) |
|
|
|
|
|
face_helper.add_restored_face(restored_face) |
|
|
|
|
|
face_helper.get_inverse_affine(None) |
|
|
return face_helper.paste_faces_to_input_image() |
|
|
|
|
|
def predict(input_img: Image.Image, w: float = 0.7) -> Image.Image: |
|
|
img = np.array(input_img) |
|
|
result = process_image(img, w) |
|
|
return Image.fromarray(result) |
|
|
|
|
|
iface = Interface( |
|
|
fn=predict, |
|
|
inputs=[ |
|
|
gr.Image(label="Input Image", type="pil"), |
|
|
gr.Slider(0.0, 1.0, value=0.7, label="Fidelity Weight") |
|
|
], |
|
|
outputs=gr.Image(label="Enhanced Image", type="pil"), |
|
|
title="CodeFormer Face Restoration" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface.launch(server_name="0.0.0.0", server_port=7860) |