✨ Face Enhancer
AI-powered face restoration using CodeFormer
import os from io import BytesIO import uuid import threading import cv2 import numpy as np import torch from fastapi import FastAPI, UploadFile, File from fastapi.responses import HTMLResponse, StreamingResponse, FileResponse from huggingface_hub import hf_hub_download from torchvision.transforms.functional import normalize from facelib.utils.face_restoration_helper import FaceRestoreHelper from models import CodeFormer from utils import img2tensor, tensor2img app = FastAPI( title="CodeFormer Face Enhancement", description="Face Restoration API", version="1.0.0" ) # ==================================================== # Job Storage # ==================================================== jobs = {} results = {} def update_progress(job_id, value, message): jobs[job_id] = { "progress": value, "message": message } # ==================================================== # Load Model Once # ==================================================== REPO_ID = "leonelhs/gfpgan" pretrain_model_path = hf_hub_download( repo_id=REPO_ID, filename="CodeFormer.pth" ) device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) net = CodeFormer( dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=["32", "64", "128", "256"] ).to(device) checkpoint = torch.load( pretrain_model_path, map_location=device )["params_ema"] net.load_state_dict(checkpoint) net.eval() face_helper = FaceRestoreHelper( upscale_factor=2, face_size=512, crop_ratio=(1, 1), det_model="retinaface_resnet50", save_ext="png", use_parse=True, device=device ) # ==================================================== # Face Enhancement Function # ==================================================== def enhance_face(image_rgb, job_id=None): face_helper.clean_all() if job_id: update_progress(job_id, 5, "Preparing image") image_bgr = cv2.cvtColor( image_rgb, cv2.COLOR_RGB2BGR ) face_helper.read_image(image_bgr) if job_id: update_progress(job_id, 20, "Detecting faces") face_helper.get_face_landmarks_5( only_center_face=False, resize=640, eye_dist_threshold=5 ) if job_id: update_progress(job_id, 35, "Aligning faces") face_helper.align_warp_face() total_faces = len(face_helper.cropped_faces) if total_faces == 0: if job_id: update_progress(job_id, 100, "No faces found") return image_rgb for idx, cropped_face in enumerate(face_helper.cropped_faces): if job_id: progress = 40 + int((idx / total_faces) * 50) update_progress( job_id, progress, f"Enhancing face {idx+1}/{total_faces}" ) cropped_face_t = img2tensor( cropped_face / 255.0, bgr2rgb=True, float32=True ) normalize( cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True ) cropped_face_t = cropped_face_t.unsqueeze(0).to(device) try: with torch.no_grad(): output = net( cropped_face_t, w=0.5, adain=True )[0] restored_face = tensor2img( output, rgb2bgr=True, min_max=(-1, 1) ) restored_face = restored_face.astype("uint8") except Exception as e: print(e) restored_face = tensor2img( cropped_face_t, rgb2bgr=True, min_max=(-1, 1) ) face_helper.add_restored_face( restored_face, cropped_face ) if job_id: update_progress(job_id, 95, "Finalizing image") face_helper.get_inverse_affine(None) restored_img = face_helper.paste_faces_to_input_image() restored_img = cv2.cvtColor( restored_img, cv2.COLOR_BGR2RGB ) if job_id: update_progress(job_id, 100, "Completed") return restored_img # ==================================================== # Background Worker # ==================================================== def process_job(job_id, image): try: result = enhance_face(image, job_id) result = cv2.cvtColor( result, cv2.COLOR_RGB2BGR ) _, buffer = cv2.imencode( ".png", result ) results[job_id] = buffer.tobytes() update_progress( job_id, 100, "Completed" ) except Exception as e: update_progress( job_id, -1, str(e) ) # ==================================================== # UI # ==================================================== HTML_PAGE = """
AI-powered face restoration using CodeFormer