CodeFormer / app.py
lucky0146's picture
Update app.py
8fd5dd4 verified
raw
history blame
5.25 kB
import os
import sys
import gradio as gr
import torch
import cv2
import numpy as np
from PIL import Image
import urllib.request
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from codeformer_arch import CodeFormer
# Function to download a file from a URL
def download_file(url, dest):
if not os.path.exists(dest):
os.makedirs(os.path.dirname(dest), exist_ok=True)
urllib.request.urlretrieve(url, dest)
print(f"Downloaded {dest}")
# Download pretrained models
def setup_environment():
# Download CodeFormer pretrained model
model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
model_path = "weights/CodeFormer/codeformer.pth"
download_file(model_url, model_path)
# Download facelib model (for face detection)
facelib_url = "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth"
facelib_path = "weights/facelib/detection_Resnet50_Final.pth"
download_file(facelib_url, facelib_path)
# Download Real-ESRGAN model for background upsampling (optional)
realesrgan_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
realesrgan_path = "weights/realesrgan/RealESRGAN_x4plus.pth"
download_file(realesrgan_url, realesrgan_path)
# Load CodeFormer model
def load_codeformer():
setup_environment()
model = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256'])
# Load the state dict, extracting the 'params_ema' key
checkpoint = torch.load("weights/CodeFormer/codeformer.pth", map_location='cpu')
state_dict = checkpoint['params_ema'] if 'params_ema' in checkpoint else checkpoint
model.load_state_dict(state_dict, strict=False) # Use strict=False to ignore missing keys
model.eval()
model = model.to('cpu') # Force CPU
return model
# Image processing utilities (mimicking basicsr.utils)
def img2tensor(img, bgr2rgb=True, float32=True):
if bgr2rgb:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1)).float()
if float32:
img = img / 255.0
return img
def tensor2img(tensor, rgb2bgr=True, min_max=(-1, 1)):
tensor = tensor.squeeze().float().cpu().clamp_(*min_max)
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) * 255.0
img = tensor.numpy().transpose(1, 2, 0).astype(np.uint8)
if rgb2bgr:
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
return img
# Inference function
def enhance_image(input_image, fidelity_weight=0.5, background_enhance=True, face_upsample=False):
# Convert PIL image to OpenCV format
img = np.array(input_image)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
# Initialize face helper
face_helper = FaceRestoreHelper(
upscale_factor=1 if not face_upsample else 2,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
device='cpu'
)
face_helper.clean_all()
face_helper.read_image(img)
face_helper.get_face_landmarks_5()
face_helper.align_warp_face()
# Load CodeFormer model
net = load_codeformer()
# Enhance face
for cropped_face in face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face, bgr2rgb=True, float32=True)
with torch.no_grad():
output = net(cropped_face_t.unsqueeze(0), w=fidelity_weight, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8')
face_helper.add_restored_face(restored_face)
# Get restored image
face_helper.get_inverse_affine(None)
restored_img = face_helper.paste_faces_to_input_image()
# Background enhancement with Real-ESRGAN (optional)
if background_enhance:
from realesrgan import RealESRGANer
upsampler = RealESRGANer(
scale=4,
model_path="weights/realesrgan/RealESRGAN_x4plus.pth",
device='cpu'
)
restored_img, _ = upsampler.enhance(restored_img, outscale=4)
# Convert back to PIL for Gradio
restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
return Image.fromarray(restored_img)
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# CodeFormer Face Restoration (CPU)")
gr.Markdown("Upload an image to enhance faces using CodeFormer. Runs on CPU in Hugging Face Spaces.")
with gr.Row():
input_image = gr.Image(type="pil", label="Input Image")
output_image = gr.Image(type="pil", label="Enhanced Image")
fidelity_slider = gr.Slider(0, 1, value=0.5, step=0.01, label="Fidelity Weight (0 = more restoration, 1 = more original)")
background_enhance = gr.Checkbox(label="Enhance Background (Real-ESRGAN)", value=True)
face_upsample = gr.Checkbox(label="Upsample Restored Faces", value=False)
submit_btn = gr.Button("Enhance")
submit_btn.click(
fn=enhance_image,
inputs=[input_image, fidelity_slider, background_enhance, face_upsample],
outputs=output_image
)
if __name__ == "__main__":
setup_environment()
demo.launch()