Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import gradio as gr | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import urllib.request | |
| from basicsr.utils import img2tensor, tensor2img | |
| from facexlib.utils.face_restoration_helper import FaceRestoreHelper | |
| from codeformer_arch import CodeFormer | |
| # Function to download a file from a URL | |
| def download_file(url, dest): | |
| if not os.path.exists(dest): | |
| os.makedirs(os.path.dirname(dest), exist_ok=True) | |
| urllib.request.urlretrieve(url, dest) | |
| print(f"Downloaded {dest}") | |
| # Download pretrained models | |
| def setup_environment(): | |
| # Download CodeFormer pretrained model | |
| model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth" | |
| model_path = "weights/CodeFormer/codeformer.pth" | |
| download_file(model_url, model_path) | |
| # Download facelib model (for face detection) | |
| facelib_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/facelib.pth" | |
| facelib_path = "weights/facelib.pth" | |
| download_file(facelib_url, facelib_path) | |
| # Download Real-ESRGAN model for background upsampling (optional) | |
| realesrgan_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/RealESRGAN_x4plus.pth" | |
| realesrgan_path = "weights/RealESRGAN_x4plus.pth" | |
| download_file(realesrgan_url, realesrgan_path) | |
| # Load CodeFormer model | |
| def load_codeformer(): | |
| setup_environment() | |
| model = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layer=9, connect_list=['32', '64', '128', '256']) | |
| model.load_state_dict(torch.load("weights/CodeFormer/codeformer.pth", map_location='cpu')) | |
| model.eval() | |
| model = model.to('cpu') # Force CPU | |
| return model | |
| # Inference function | |
| def enhance_image(input_image, fidelity_weight=0.5, background_enhance=True, face_upsample=False): | |
| # Convert PIL image to OpenCV format | |
| img = np.array(input_image) | |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
| # Initialize face helper | |
| face_helper = FaceRestoreHelper( | |
| upscale_factor=1 if not face_upsample else 2, | |
| face_size=512, | |
| crop_ratio=(1, 1), | |
| det_model='retinaface_resnet50', | |
| save_ext='png', | |
| device='cpu' | |
| ) | |
| face_helper.clean_all() | |
| face_helper.read_image(img) | |
| face_helper.get_face_landmarks_5() | |
| face_helper.align_warp_face() | |
| # Load CodeFormer model | |
| net = load_codeformer() | |
| # Enhance face | |
| for cropped_face in face_helper.cropped_faces: | |
| cropped_face_t = img2tensor(cropped_face, bgr2rgb=True, float32=True) | |
| with torch.no_grad(): | |
| output = net(cropped_face_t.unsqueeze(0), w=fidelity_weight, adain=True)[0] | |
| restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) | |
| restored_face = restored_face.astype('uint8') | |
| face_helper.add_restored_face(restored_face) | |
| # Get restored image | |
| face_helper.get_inverse_affine(None) | |
| restored_img = face_helper.paste_faces_to_input_image() | |
| # Background enhancement with Real-ESRGAN (optional) | |
| if background_enhance: | |
| from realesrgan import RealESRGANer | |
| upsampler = RealESRGANer( | |
| scale=4, | |
| model_path="weights/RealESRGAN_x4plus.pth", | |
| device='cpu' | |
| ) | |
| restored_img, _ = upsampler.enhance(restored_img, outscale=4) | |
| # Convert back to PIL for Gradio | |
| restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB) | |
| return Image.fromarray(restored_img) | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# CodeFormer Face Restoration (CPU)") | |
| gr.Markdown("Upload an image to enhance faces using CodeFormer. Runs on CPU in Hugging Face Spaces.") | |
| with gr.Row(): | |
| input_image = gr.Image(type="pil", label="Input Image") | |
| output_image = gr.Image(type="pil", label="Enhanced Image") | |
| fidelity_slider = gr.Slider(0, 1, value=0.5, step=0.01, label="Fidelity Weight (0 = more restoration, 1 = more original)") | |
| background_enhance = gr.Checkbox(label="Enhance Background (Real-ESRGAN)", value=True) | |
| face_upsample = gr.Checkbox(label="Upsample Restored Faces", value=False) | |
| submit_btn = gr.Button("Enhance") | |
| submit_btn.click( | |
| fn=enhance_image, | |
| inputs=[input_image, fidelity_slider, background_enhance, face_upsample], | |
| outputs=output_image | |
| ) | |
| if __name__ == "__main__": | |
| setup_environment() | |
| demo.launch() |