import os import sys import cv2 import time import torch import uuid import threading from queue import Queue from fastapi import FastAPI, UploadFile, File from fastapi.responses import HTMLResponse, FileResponse # ============================== # ๐ DIRECTORIES # ============================== BASE = os.getcwd() CACHE_DIR = os.path.join(BASE, "cache_weights") UPLOAD_DIR = os.path.join(BASE, "uploads") OUTPUT_DIR = os.path.join(BASE, "outputs") for d in [CACHE_DIR, UPLOAD_DIR, OUTPUT_DIR]: os.makedirs(d, exist_ok=True) # ============================== # ๐ง PATH # ============================== sys.path.append(os.path.abspath("CodeFormer")) # ============================== # ๐ฆ IMPORTS # ============================== from basicsr.utils.download_util import load_file_from_url from torchvision.transforms.functional import normalize from basicsr.utils import imwrite, img2tensor, tensor2img from facelib.utils.face_restoration_helper import FaceRestoreHelper from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.utils.realesrgan_utils import RealESRGANer from basicsr.utils.registry import ARCH_REGISTRY # ============================== # โ๏ธ SETTINGS # ============================== UPSCALE = 2 FIDELITY = 0.5 DEVICE = torch.device("cpu") # ============================== # ๐ฅ DOWNLOAD WEIGHTS (FIXED) # ============================== def download_weight(filename, url): path = os.path.join(CACHE_DIR, filename) if not os.path.exists(path): load_file_from_url(url, model_dir=CACHE_DIR, file_name=filename) return path weights = { "codeformer": download_weight( "codeformer.pth", "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth" ), "realesrgan": download_weight( "RealESRGAN_x2plus.pth", "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth" ) } # ============================== # ๐ง LAZY LOAD MODELS # ============================== models_loaded = False upsampler = None codeformer = None lock = threading.Lock() def load_models(): global models_loaded, upsampler, codeformer if models_loaded: return with lock: if models_loaded: return print("Loading models...") model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2 ) upsampler = RealESRGANer( scale=2, model_path=weights["realesrgan"], model=model, tile=200, tile_pad=10, pre_pad=0, half=False ) codeformer = ARCH_REGISTRY.get("CodeFormer")( dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=["32","64","128","256"], ).to(DEVICE) ckpt = torch.load(weights["codeformer"], map_location=DEVICE)["params_ema"] codeformer.load_state_dict(ckpt) codeformer.eval() models_loaded = True print("Models loaded.") # ============================== # ๐งน AUTO CLEAN (5 MIN) # ============================== def cleanup(): while True: now = time.time() for folder in [UPLOAD_DIR, OUTPUT_DIR]: for f in os.listdir(folder): path = os.path.join(folder, f) if os.path.isfile(path) and now - os.path.getmtime(path) > 300: try: os.remove(path) except: pass time.sleep(60) threading.Thread(target=cleanup, daemon=True).start() # ============================== # ๐ง INFERENCE # ============================== def process_image(inp, out): load_models() img = cv2.imread(inp) face_helper = FaceRestoreHelper( UPSCALE, face_size=512, crop_ratio=(1,1), det_model="retinaface_resnet50", device=DEVICE, ) face_helper.read_image(img) num_faces = face_helper.get_face_landmarks_5() if num_faces == 0: result = upsampler.enhance(img, outscale=UPSCALE)[0] imwrite(result, out) return face_helper.align_warp_face() for face in face_helper.cropped_faces: t = img2tensor(face / 255., bgr2rgb=True, float32=True) normalize(t, (0.5,)*3, (0.5,)*3, inplace=True) t = t.unsqueeze(0).to(DEVICE) with torch.no_grad(): out_face = codeformer(t, w=FIDELITY, adain=True)[0] restored = tensor2img(out_face, rgb2bgr=True, min_max=(-1,1)) face_helper.add_restored_face(restored.astype("uint8"), face) bg = upsampler.enhance(img, outscale=UPSCALE)[0] face_helper.get_inverse_affine(None) final = face_helper.paste_faces_to_input_image(upsample_img=bg) imwrite(final, out) # ============================== # ๐ QUEUE # ============================== queue = Queue() def worker(): while True: inp, out = queue.get() try: process_image(inp, out) except Exception as e: print("Worker error:", e) queue.task_done() # multiple workers for _ in range(3): threading.Thread(target=worker, daemon=True).start() # ============================== # ๐ FASTAPI # ============================== app = FastAPI() @app.post("/restore-image") async def restore_image(file: UploadFile = File(...)): uid = str(uuid.uuid4()) input_path = os.path.join(UPLOAD_DIR, uid + ".png") output_path = os.path.join(OUTPUT_DIR, uid + ".png") data = await file.read() with open(input_path, "wb") as f: f.write(data) queue.put((input_path, output_path)) # wait until ready while not os.path.exists(output_path): time.sleep(0.3) return FileResponse(output_path, media_type="image/png") # ============================== # ๐ป UI # ============================== @app.get("/", response_class=HTMLResponse) def home(): return """