Spaces:
Running
Running
| import os | |
| import sys | |
| import cv2 | |
| import time | |
| import torch | |
| import uuid | |
| import threading | |
| from queue import Queue | |
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.responses import HTMLResponse, FileResponse | |
| # ============================== | |
| # 📁 DIRECTORIES | |
| # ============================== | |
| BASE = os.getcwd() | |
| CACHE_DIR = os.path.join(BASE, "cache_weights") | |
| UPLOAD_DIR = os.path.join(BASE, "uploads") | |
| OUTPUT_DIR = os.path.join(BASE, "outputs") | |
| for d in [CACHE_DIR, UPLOAD_DIR, OUTPUT_DIR]: | |
| os.makedirs(d, exist_ok=True) | |
| # ============================== | |
| # 🔧 PATH | |
| # ============================== | |
| sys.path.append(os.path.abspath("CodeFormer")) | |
| # ============================== | |
| # 📦 IMPORTS | |
| # ============================== | |
| from basicsr.utils.download_util import load_file_from_url | |
| from torchvision.transforms.functional import normalize | |
| from basicsr.utils import imwrite, img2tensor, tensor2img | |
| from facelib.utils.face_restoration_helper import FaceRestoreHelper | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| from basicsr.utils.realesrgan_utils import RealESRGANer | |
| from basicsr.utils.registry import ARCH_REGISTRY | |
| # ============================== | |
| # ⚙️ SETTINGS | |
| # ============================== | |
| UPSCALE = 2 | |
| FIDELITY = 0.5 | |
| DEVICE = torch.device("cpu") | |
| # ============================== | |
| # 📥 DOWNLOAD WEIGHTS (FIXED) | |
| # ============================== | |
| def download_weight(filename, url): | |
| path = os.path.join(CACHE_DIR, filename) | |
| if not os.path.exists(path): | |
| load_file_from_url(url, model_dir=CACHE_DIR, file_name=filename) | |
| return path | |
| weights = { | |
| "codeformer": download_weight( | |
| "codeformer.pth", | |
| "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth" | |
| ), | |
| "realesrgan": download_weight( | |
| "RealESRGAN_x2plus.pth", | |
| "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth" | |
| ) | |
| } | |
| # ============================== | |
| # 🧠 LAZY LOAD MODELS | |
| # ============================== | |
| models_loaded = False | |
| upsampler = None | |
| codeformer = None | |
| lock = threading.Lock() | |
| def load_models(): | |
| global models_loaded, upsampler, codeformer | |
| if models_loaded: | |
| return | |
| with lock: | |
| if models_loaded: | |
| return | |
| print("Loading models...") | |
| model = RRDBNet( | |
| num_in_ch=3, | |
| num_out_ch=3, | |
| num_feat=64, | |
| num_block=23, | |
| num_grow_ch=32, | |
| scale=2 | |
| ) | |
| upsampler = RealESRGANer( | |
| scale=2, | |
| model_path=weights["realesrgan"], | |
| model=model, | |
| tile=200, | |
| tile_pad=10, | |
| pre_pad=0, | |
| half=False | |
| ) | |
| codeformer = ARCH_REGISTRY.get("CodeFormer")( | |
| dim_embd=512, | |
| codebook_size=1024, | |
| n_head=8, | |
| n_layers=9, | |
| connect_list=["32","64","128","256"], | |
| ).to(DEVICE) | |
| ckpt = torch.load(weights["codeformer"], map_location=DEVICE)["params_ema"] | |
| codeformer.load_state_dict(ckpt) | |
| codeformer.eval() | |
| models_loaded = True | |
| print("Models loaded.") | |
| # ============================== | |
| # 🧹 AUTO CLEAN (5 MIN) | |
| # ============================== | |
| def cleanup(): | |
| while True: | |
| now = time.time() | |
| for folder in [UPLOAD_DIR, OUTPUT_DIR]: | |
| for f in os.listdir(folder): | |
| path = os.path.join(folder, f) | |
| if os.path.isfile(path) and now - os.path.getmtime(path) > 300: | |
| try: | |
| os.remove(path) | |
| except: | |
| pass | |
| time.sleep(60) | |
| threading.Thread(target=cleanup, daemon=True).start() | |
| # ============================== | |
| # 🧠 INFERENCE | |
| # ============================== | |
| def process_image(inp, out): | |
| load_models() | |
| img = cv2.imread(inp) | |
| face_helper = FaceRestoreHelper( | |
| UPSCALE, | |
| face_size=512, | |
| crop_ratio=(1,1), | |
| det_model="retinaface_resnet50", | |
| device=DEVICE, | |
| ) | |
| face_helper.read_image(img) | |
| num_faces = face_helper.get_face_landmarks_5() | |
| if num_faces == 0: | |
| result = upsampler.enhance(img, outscale=UPSCALE)[0] | |
| imwrite(result, out) | |
| return | |
| face_helper.align_warp_face() | |
| for face in face_helper.cropped_faces: | |
| t = img2tensor(face / 255., bgr2rgb=True, float32=True) | |
| normalize(t, (0.5,)*3, (0.5,)*3, inplace=True) | |
| t = t.unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| out_face = codeformer(t, w=FIDELITY, adain=True)[0] | |
| restored = tensor2img(out_face, rgb2bgr=True, min_max=(-1,1)) | |
| face_helper.add_restored_face(restored.astype("uint8"), face) | |
| bg = upsampler.enhance(img, outscale=UPSCALE)[0] | |
| face_helper.get_inverse_affine(None) | |
| final = face_helper.paste_faces_to_input_image(upsample_img=bg) | |
| imwrite(final, out) | |
| # ============================== | |
| # 🔁 QUEUE | |
| # ============================== | |
| queue = Queue() | |
| def worker(): | |
| while True: | |
| inp, out = queue.get() | |
| try: | |
| process_image(inp, out) | |
| except Exception as e: | |
| print("Worker error:", e) | |
| queue.task_done() | |
| # multiple workers | |
| for _ in range(3): | |
| threading.Thread(target=worker, daemon=True).start() | |
| # ============================== | |
| # 🌐 FASTAPI | |
| # ============================== | |
| app = FastAPI() | |
| async def restore_image(file: UploadFile = File(...)): | |
| uid = str(uuid.uuid4()) | |
| input_path = os.path.join(UPLOAD_DIR, uid + ".png") | |
| output_path = os.path.join(OUTPUT_DIR, uid + ".png") | |
| data = await file.read() | |
| with open(input_path, "wb") as f: | |
| f.write(data) | |
| queue.put((input_path, output_path)) | |
| # wait until ready | |
| while not os.path.exists(output_path): | |
| time.sleep(0.3) | |
| return FileResponse(output_path, media_type="image/png") | |
| # ============================== | |
| # 💻 UI | |
| # ============================== | |
| def home(): | |
| return """ | |
| <html> | |
| <head> | |
| <meta name="viewport" content="width=device-width, initial-scale=1"> | |
| <style> | |
| body {background:#111;color:white;text-align:center;font-family:sans-serif;} | |
| .box {max-width:600px;margin:auto;padding:20px;} | |
| input,button {width:100%;padding:10px;margin-top:10px;} | |
| .row {display:flex;gap:10px;margin-top:20px;} | |
| img {width:50%;} | |
| </style> | |
| </head> | |
| <body> | |
| <div class="box"> | |
| <h2>AI Face Restore</h2> | |
| <input type="file" id="file"> | |
| <button onclick="upload()">Restore</button> | |
| <p id="status"></p> | |
| <div class="row"> | |
| <img id="preview"> | |
| <img id="result"> | |
| </div> | |
| </div> | |
| <script> | |
| async function upload(){ | |
| let file=document.getElementById('file').files[0]; | |
| if(!file) return; | |
| document.getElementById('preview').src = URL.createObjectURL(file); | |
| document.getElementById('status').innerText = "Processing..."; | |
| let fd=new FormData(); | |
| fd.append('file',file); | |
| let res=await fetch('/restore-image',{method:'POST',body:fd}); | |
| let blob=await res.blob(); | |
| document.getElementById('result').src = URL.createObjectURL(blob); | |
| document.getElementById('status').innerText = "Done"; | |
| } | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| # ============================== | |
| # ▶️ RUN | |
| # ============================== | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860) |