import os from io import BytesIO import uuid import threading import cv2 import numpy as np import torch from fastapi import FastAPI, UploadFile, File from fastapi.responses import HTMLResponse, StreamingResponse, FileResponse from huggingface_hub import hf_hub_download from torchvision.transforms.functional import normalize from facelib.utils.face_restoration_helper import FaceRestoreHelper from models import CodeFormer from utils import img2tensor, tensor2img app = FastAPI( title="CodeFormer Face Enhancement", description="Face Restoration API", version="1.0.0" ) # ==================================================== # Job Storage # ==================================================== jobs = {} results = {} def update_progress(job_id, value, message): jobs[job_id] = { "progress": value, "message": message } # ==================================================== # Load Model Once # ==================================================== REPO_ID = "leonelhs/gfpgan" pretrain_model_path = hf_hub_download( repo_id=REPO_ID, filename="CodeFormer.pth" ) device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) net = CodeFormer( dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=["32", "64", "128", "256"] ).to(device) checkpoint = torch.load( pretrain_model_path, map_location=device )["params_ema"] net.load_state_dict(checkpoint) net.eval() face_helper = FaceRestoreHelper( upscale_factor=2, face_size=512, crop_ratio=(1, 1), det_model="retinaface_resnet50", save_ext="png", use_parse=True, device=device ) # ==================================================== # Face Enhancement Function # ==================================================== def enhance_face(image_rgb, job_id=None): face_helper.clean_all() if job_id: update_progress(job_id, 5, "Preparing image") image_bgr = cv2.cvtColor( image_rgb, cv2.COLOR_RGB2BGR ) face_helper.read_image(image_bgr) if job_id: update_progress(job_id, 20, "Detecting faces") face_helper.get_face_landmarks_5( only_center_face=False, resize=640, eye_dist_threshold=5 ) if job_id: update_progress(job_id, 35, "Aligning faces") face_helper.align_warp_face() total_faces = len(face_helper.cropped_faces) if total_faces == 0: if job_id: update_progress(job_id, 100, "No faces found") return image_rgb for idx, cropped_face in enumerate(face_helper.cropped_faces): if job_id: progress = 40 + int((idx / total_faces) * 50) update_progress( job_id, progress, f"Enhancing face {idx+1}/{total_faces}" ) cropped_face_t = img2tensor( cropped_face / 255.0, bgr2rgb=True, float32=True ) normalize( cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True ) cropped_face_t = cropped_face_t.unsqueeze(0).to(device) try: with torch.no_grad(): output = net( cropped_face_t, w=0.5, adain=True )[0] restored_face = tensor2img( output, rgb2bgr=True, min_max=(-1, 1) ) restored_face = restored_face.astype("uint8") except Exception as e: print(e) restored_face = tensor2img( cropped_face_t, rgb2bgr=True, min_max=(-1, 1) ) face_helper.add_restored_face( restored_face, cropped_face ) if job_id: update_progress(job_id, 95, "Finalizing image") face_helper.get_inverse_affine(None) restored_img = face_helper.paste_faces_to_input_image() restored_img = cv2.cvtColor( restored_img, cv2.COLOR_BGR2RGB ) if job_id: update_progress(job_id, 100, "Completed") return restored_img # ==================================================== # Background Worker # ==================================================== def process_job(job_id, image): try: result = enhance_face(image, job_id) result = cv2.cvtColor( result, cv2.COLOR_RGB2BGR ) _, buffer = cv2.imencode( ".png", result ) results[job_id] = buffer.tobytes() update_progress( job_id, 100, "Completed" ) except Exception as e: update_progress( job_id, -1, str(e) ) # ==================================================== # UI # ==================================================== HTML_PAGE = """ CodeFormer Face Enhancement

✨ Face Enhancer

AI-powered face restoration using CodeFormer

📸
Drop your image here or click to browse
Supports JPG, PNG, WEBP (Max 10MB)
Processing... 0%
Enhancement complete! Download your image
📷 Original Input
🖼️ No image selected
✨ Enhanced Output
🎯 Waiting for enhancement
""" @app.get("/", response_class=HTMLResponse) async def home(): return HTML_PAGE # ==================================================== # API Endpoints # ==================================================== @app.post("/convert") async def convert(file: UploadFile = File(...)): contents = await file.read() np_img = np.frombuffer( contents, np.uint8 ) image = cv2.imdecode( np_img, cv2.IMREAD_COLOR ) image = cv2.cvtColor( image, cv2.COLOR_BGR2RGB ) job_id = str(uuid.uuid4()) jobs[job_id] = { "progress": 0, "message": "Queued" } threading.Thread( target=process_job, args=(job_id, image), daemon=True ).start() return { "job_id": job_id } @app.get("/progress/{job_id}") async def progress(job_id: str): return jobs.get( job_id, { "progress": 0, "message": "Unknown job" } ) @app.get("/result/{job_id}") async def result(job_id: str): if job_id not in results: return { "status": "processing" } return StreamingResponse( BytesIO(results[job_id]), media_type="image/png", headers={ "Content-Disposition": f"attachment; filename=enhanced_{job_id[:8]}.png" } ) @app.get("/download/{job_id}") async def download_result(job_id: str): """Direct download endpoint that forces browser to download as PNG""" if job_id not in results: return { "status": "processing", "message": "Result not ready yet" } return StreamingResponse( BytesIO(results[job_id]), media_type="image/png", headers={ "Content-Disposition": f"attachment; filename=enhanced_{job_id[:8]}.png", "Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0" } ) # ==================================================== # Run # ==================================================== if __name__ == "__main__": import uvicorn port = int(os.environ.get("PORT", 7860)) uvicorn.run(app, host="0.0.0.0", port=port)