enhancer / app.py
lucky0146's picture
Update app.py
7b6c178 verified
import sys
import cv2
import torch
import numpy as np
from PIL import Image
from gradio import Interface
# At the top of app.py
import torch
device = torch.device('cpu') # Force CPU usage
# Modify model loading
checkpoint = torch.load(
'/CodeFormer/weights/CodeFormer/codeformer.pth',
map_location='cpu' # Load weights to CPU
)
net.load_state_dict(checkpoint['params_ema'])
sys.path.insert(0, '/CodeFormer')
from basicsr.utils import img2tensor, tensor2img
from basicsr.archs.codeformer_arch import CodeFormer
from facexlib.utils.face_restoration_helper import FaceRestorationHelper
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize models
net = CodeFormer(
dim_embd=512,
codebook_size=1024,
n_head=8,
n_layers=9,
connect_list=['32', '64', '128', '256']
).to(device)
net.load_state_dict(torch.load('/CodeFormer/weights/CodeFormer/codeformer.pth')['params_ema'])
net.eval()
face_helper = FaceRestorationHelper(
upscale_factor=1,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
use_parse=True,
device=device
)
def process_image(img: np.ndarray, w: float = 0.7) -> np.ndarray:
face_helper.clean_all()
face_helper.read_image(img)
face_helper.get_face_landmarks_5()
face_helper.align_warp_face()
for cropped_face in face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
with torch.no_grad():
output = net(cropped_face_t, w=w, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
face_helper.add_restored_face(restored_face)
face_helper.get_inverse_affine(None)
return face_helper.paste_faces_to_input_image()
def predict(input_img: Image.Image, w: float = 0.7) -> Image.Image:
img = np.array(input_img)
result = process_image(img, w)
return Image.fromarray(result)
iface = Interface(
fn=predict,
inputs=[
gr.Image(label="Input Image", type="pil"),
gr.Slider(0.0, 1.0, value=0.7, label="Fidelity Weight")
],
outputs=gr.Image(label="Enhanced Image", type="pil"),
title="CodeFormer Face Restoration"
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)