CodeFormer / app.py
lucky0146's picture
Update app.py
bbdeecf verified
raw
history blame
4.38 kB
import os
import sys
import gradio as gr
import torch
import cv2
import numpy as np
from PIL import Image
import urllib.request
from basicsr.utils import img2tensor, tensor2img
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from codeformer_arch import CodeFormer
# Function to download a file from a URL
def download_file(url, dest):
if not os.path.exists(dest):
os.makedirs(os.path.dirname(dest), exist_ok=True)
urllib.request.urlretrieve(url, dest)
print(f"Downloaded {dest}")
# Download pretrained models
def setup_environment():
# Download CodeFormer pretrained model
model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
model_path = "weights/CodeFormer/codeformer.pth"
download_file(model_url, model_path)
# Download facelib model (for face detection)
facelib_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/facelib.pth"
facelib_path = "weights/facelib.pth"
download_file(facelib_url, facelib_path)
# Download Real-ESRGAN model for background upsampling (optional)
realesrgan_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/RealESRGAN_x4plus.pth"
realesrgan_path = "weights/RealESRGAN_x4plus.pth"
download_file(realesrgan_url, realesrgan_path)
# Load CodeFormer model
def load_codeformer():
setup_environment()
model = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layer=9, connect_list=['32', '64', '128', '256'])
model.load_state_dict(torch.load("weights/CodeFormer/codeformer.pth", map_location='cpu'))
model.eval()
model = model.to('cpu') # Force CPU
return model
# Inference function
def enhance_image(input_image, fidelity_weight=0.5, background_enhance=True, face_upsample=False):
# Convert PIL image to OpenCV format
img = np.array(input_image)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
# Initialize face helper
face_helper = FaceRestoreHelper(
upscale_factor=1 if not face_upsample else 2,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
device='cpu'
)
face_helper.clean_all()
face_helper.read_image(img)
face_helper.get_face_landmarks_5()
face_helper.align_warp_face()
# Load CodeFormer model
net = load_codeformer()
# Enhance face
for cropped_face in face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face, bgr2rgb=True, float32=True)
with torch.no_grad():
output = net(cropped_face_t.unsqueeze(0), w=fidelity_weight, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8')
face_helper.add_restored_face(restored_face)
# Get restored image
face_helper.get_inverse_affine(None)
restored_img = face_helper.paste_faces_to_input_image()
# Background enhancement with Real-ESRGAN (optional)
if background_enhance:
from realesrgan import RealESRGANer
upsampler = RealESRGANer(
scale=4,
model_path="weights/RealESRGAN_x4plus.pth",
device='cpu'
)
restored_img, _ = upsampler.enhance(restored_img, outscale=4)
# Convert back to PIL for Gradio
restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
return Image.fromarray(restored_img)
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# CodeFormer Face Restoration (CPU)")
gr.Markdown("Upload an image to enhance faces using CodeFormer. Runs on CPU in Hugging Face Spaces.")
with gr.Row():
input_image = gr.Image(type="pil", label="Input Image")
output_image = gr.Image(type="pil", label="Enhanced Image")
fidelity_slider = gr.Slider(0, 1, value=0.5, step=0.01, label="Fidelity Weight (0 = more restoration, 1 = more original)")
background_enhance = gr.Checkbox(label="Enhance Background (Real-ESRGAN)", value=True)
face_upsample = gr.Checkbox(label="Upsample Restored Faces", value=False)
submit_btn = gr.Button("Enhance")
submit_btn.click(
fn=enhance_image,
inputs=[input_image, fidelity_slider, background_enhance, face_upsample],
outputs=output_image
)
if __name__ == "__main__":
setup_environment()
demo.launch()