kingloft's picture
Update app.py
20421d6 verified
import os
import sys
import cv2
import time
import torch
import uuid
import threading
from queue import Queue
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import HTMLResponse, FileResponse
# ==============================
# 📁 DIRECTORIES
# ==============================
BASE = os.getcwd()
CACHE_DIR = os.path.join(BASE, "cache_weights")
UPLOAD_DIR = os.path.join(BASE, "uploads")
OUTPUT_DIR = os.path.join(BASE, "outputs")
for d in [CACHE_DIR, UPLOAD_DIR, OUTPUT_DIR]:
os.makedirs(d, exist_ok=True)
# ==============================
# 🔧 PATH
# ==============================
sys.path.append(os.path.abspath("CodeFormer"))
# ==============================
# 📦 IMPORTS
# ==============================
from basicsr.utils.download_util import load_file_from_url
from torchvision.transforms.functional import normalize
from basicsr.utils import imwrite, img2tensor, tensor2img
from facelib.utils.face_restoration_helper import FaceRestoreHelper
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.realesrgan_utils import RealESRGANer
from basicsr.utils.registry import ARCH_REGISTRY
# ==============================
# ⚙️ SETTINGS
# ==============================
UPSCALE = 2
FIDELITY = 0.5
DEVICE = torch.device("cpu")
# ==============================
# 📥 DOWNLOAD WEIGHTS (FIXED)
# ==============================
def download_weight(filename, url):
path = os.path.join(CACHE_DIR, filename)
if not os.path.exists(path):
load_file_from_url(url, model_dir=CACHE_DIR, file_name=filename)
return path
weights = {
"codeformer": download_weight(
"codeformer.pth",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
),
"realesrgan": download_weight(
"RealESRGAN_x2plus.pth",
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth"
)
}
# ==============================
# 🧠 LAZY LOAD MODELS
# ==============================
models_loaded = False
upsampler = None
codeformer = None
lock = threading.Lock()
def load_models():
global models_loaded, upsampler, codeformer
if models_loaded:
return
with lock:
if models_loaded:
return
print("Loading models...")
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2
)
upsampler = RealESRGANer(
scale=2,
model_path=weights["realesrgan"],
model=model,
tile=200,
tile_pad=10,
pre_pad=0,
half=False
)
codeformer = ARCH_REGISTRY.get("CodeFormer")(
dim_embd=512,
codebook_size=1024,
n_head=8,
n_layers=9,
connect_list=["32","64","128","256"],
).to(DEVICE)
ckpt = torch.load(weights["codeformer"], map_location=DEVICE)["params_ema"]
codeformer.load_state_dict(ckpt)
codeformer.eval()
models_loaded = True
print("Models loaded.")
# ==============================
# 🧹 AUTO CLEAN (5 MIN)
# ==============================
def cleanup():
while True:
now = time.time()
for folder in [UPLOAD_DIR, OUTPUT_DIR]:
for f in os.listdir(folder):
path = os.path.join(folder, f)
if os.path.isfile(path) and now - os.path.getmtime(path) > 300:
try:
os.remove(path)
except:
pass
time.sleep(60)
threading.Thread(target=cleanup, daemon=True).start()
# ==============================
# 🧠 INFERENCE
# ==============================
def process_image(inp, out):
load_models()
img = cv2.imread(inp)
face_helper = FaceRestoreHelper(
UPSCALE,
face_size=512,
crop_ratio=(1,1),
det_model="retinaface_resnet50",
device=DEVICE,
)
face_helper.read_image(img)
num_faces = face_helper.get_face_landmarks_5()
if num_faces == 0:
result = upsampler.enhance(img, outscale=UPSCALE)[0]
imwrite(result, out)
return
face_helper.align_warp_face()
for face in face_helper.cropped_faces:
t = img2tensor(face / 255., bgr2rgb=True, float32=True)
normalize(t, (0.5,)*3, (0.5,)*3, inplace=True)
t = t.unsqueeze(0).to(DEVICE)
with torch.no_grad():
out_face = codeformer(t, w=FIDELITY, adain=True)[0]
restored = tensor2img(out_face, rgb2bgr=True, min_max=(-1,1))
face_helper.add_restored_face(restored.astype("uint8"), face)
bg = upsampler.enhance(img, outscale=UPSCALE)[0]
face_helper.get_inverse_affine(None)
final = face_helper.paste_faces_to_input_image(upsample_img=bg)
imwrite(final, out)
# ==============================
# 🔁 QUEUE
# ==============================
queue = Queue()
def worker():
while True:
inp, out = queue.get()
try:
process_image(inp, out)
except Exception as e:
print("Worker error:", e)
queue.task_done()
# multiple workers
for _ in range(3):
threading.Thread(target=worker, daemon=True).start()
# ==============================
# 🌐 FASTAPI
# ==============================
app = FastAPI()
@app.post("/restore-image")
async def restore_image(file: UploadFile = File(...)):
uid = str(uuid.uuid4())
input_path = os.path.join(UPLOAD_DIR, uid + ".png")
output_path = os.path.join(OUTPUT_DIR, uid + ".png")
data = await file.read()
with open(input_path, "wb") as f:
f.write(data)
queue.put((input_path, output_path))
# wait until ready
while not os.path.exists(output_path):
time.sleep(0.3)
return FileResponse(output_path, media_type="image/png")
# ==============================
# 💻 UI
# ==============================
@app.get("/", response_class=HTMLResponse)
def home():
return """
<html>
<head>
<meta name="viewport" content="width=device-width, initial-scale=1">
<style>
body {background:#111;color:white;text-align:center;font-family:sans-serif;}
.box {max-width:600px;margin:auto;padding:20px;}
input,button {width:100%;padding:10px;margin-top:10px;}
.row {display:flex;gap:10px;margin-top:20px;}
img {width:50%;}
</style>
</head>
<body>
<div class="box">
<h2>AI Face Restore</h2>
<input type="file" id="file">
<button onclick="upload()">Restore</button>
<p id="status"></p>
<div class="row">
<img id="preview">
<img id="result">
</div>
</div>
<script>
async function upload(){
let file=document.getElementById('file').files[0];
if(!file) return;
document.getElementById('preview').src = URL.createObjectURL(file);
document.getElementById('status').innerText = "Processing...";
let fd=new FormData();
fd.append('file',file);
let res=await fetch('/restore-image',{method:'POST',body:fd});
let blob=await res.blob();
document.getElementById('result').src = URL.createObjectURL(blob);
document.getElementById('status').innerText = "Done";
}
</script>
</body>
</html>
"""
# ==============================
# ▶️ RUN
# ==============================
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)