face-swap-app / app.py
mobisoft's picture
Update app.py
9941a23 verified
import os
import cv2
import uuid
import time
import numpy as np
import insightface
import concurrent.futures
import traceback
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
from fastapi.responses import HTMLResponse, StreamingResponse
# HEIC SUPPORT
try:
import pillow_heif
from PIL import Image
pillow_heif.register_heif_opener()
HEIC_SUPPORTED = True
except:
HEIC_SUPPORTED = False
# ============================================================
# CONFIG
# ============================================================
MAX_FILE_MB = 10
MAX_DIM = 640
MAX_WORKERS = 3
CLEANUP_TIME = 300
TASKS = {}
executor = concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)
# ============================================================
# LOAD MODELS
# ============================================================
face_app = insightface.app.FaceAnalysis(name="buffalo_l")
face_app.prepare(ctx_id=-1, det_size=(640, 640))
swapper = insightface.model_zoo.get_model("inswapper_128.onnx", root=".")
# ============================================================
# IMAGE HELPERS
# ============================================================
def decode_image(file_bytes):
arr = np.frombuffer(file_bytes, np.uint8)
img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if img is None and HEIC_SUPPORTED:
try:
from PIL import Image
import io
pil_img = Image.open(io.BytesIO(file_bytes)).convert("RGB")
img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
except:
pass
if img is None:
raise ValueError("Unsupported image format")
return img
def compress_and_resize(file_bytes):
img = decode_image(file_bytes)
size_mb = len(file_bytes) / (1024 * 1024)
if size_mb > MAX_FILE_MB:
img = cv2.resize(img, None, fx=0.6, fy=0.6, interpolation=cv2.INTER_AREA)
h, w = img.shape[:2]
if max(h, w) > MAX_DIM:
scale = MAX_DIM / max(h, w)
img = cv2.resize(img, (int(w * scale), int(h * scale)))
return img
def enhance(img):
blur = cv2.GaussianBlur(img, (0, 0), 1.2)
return cv2.addWeighted(img, 1.2, blur, -0.2, 0)
def cleanup():
now = time.time()
remove = []
for k, v in TASKS.items():
if "time" in v and now - v["time"] > CLEANUP_TIME:
try:
if "result" in v:
os.remove(v["result"])
except:
pass
remove.append(k)
for k in remove:
TASKS.pop(k, None)
# ============================================================
# WORKER
# ============================================================
def run_task(tid, src_bytes, tgt_bytes, filename, face_index):
TASKS[tid]["status"] = "processing"
try:
src = compress_and_resize(src_bytes)
tgt = compress_and_resize(tgt_bytes)
s_faces = face_app.get(src)
t_faces = face_app.get(tgt)
if not s_faces or not t_faces:
raise ValueError("Face not detected")
face_index = min(face_index, len(t_faces) - 1)
result = swapper.get(tgt, t_faces[face_index], s_faces[0], paste_back=True)
result = enhance(result)
name = os.path.splitext(filename)[0]
out_path = f"/tmp/{name}_{tid}.png"
cv2.imwrite(out_path, result, [cv2.IMWRITE_PNG_COMPRESSION, 3])
TASKS[tid] = {
"status": "done",
"result": out_path,
"filename": f"{name}.png",
"time": time.time()
}
except Exception as e:
TASKS[tid] = {"status": "failed", "error": str(e)}
print(traceback.format_exc())
# ============================================================
# FASTAPI
# ============================================================
app = FastAPI()
# ============================================================
# UI PAGE
# ============================================================
@app.get("/", response_class=HTMLResponse)
def home():
return """
<!DOCTYPE html>
<html>
<head>
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Face Swap API Test</title>
<style>
body{font-family:sans-serif;background:#0f172a;color:white;text-align:center;padding:20px}
.box{border:2px dashed #444;padding:20px;margin:10px;border-radius:10px;cursor:pointer}
img{max-width:200px;margin-top:10px;border-radius:8px}
button{padding:12px 20px;background:#6366f1;color:white;border:none;border-radius:8px}
.progress{height:10px;background:#333;margin-top:10px;border-radius:10px;overflow:hidden;display:none}
.bar{height:100%;width:0;background:#22c55e}
</style>
</head>
<body>
<h2>⚡ Face Swap Test UI</h2>
<div class="box" onclick="src.click()">Upload Source<input type="file" id="src" hidden></div>
<img id="p1">
<div class="box" onclick="tgt.click()">Upload Target<input type="file" id="tgt" hidden></div>
<img id="p2">
<br>
<label>Select Face Index:</label>
<input type="number" id="faceIndex" value="0" min="0">
<br><br>
<button onclick="start()">Start Swap</button>
<div class="progress" id="progress"><div class="bar" id="bar"></div></div>
<br>
<img id="out">
<br>
<a id="dl" download="faceswap.png" style="display:none;color:lightgreen">Download PNG</a>
<script>
const src=document.getElementById("src");
const tgt=document.getElementById("tgt");
src.onchange=()=>p1.src=URL.createObjectURL(src.files[0]);
tgt.onchange=()=>p2.src=URL.createObjectURL(tgt.files[0]);
function upload(url,fd){
return new Promise((res,rej)=>{
let xhr=new XMLHttpRequest();
xhr.open("POST",url);
xhr.upload.onprogress=(e)=>{
if(e.lengthComputable){
progress.style.display="block";
bar.style.width=(e.loaded/e.total*100)+"%";
}
};
xhr.onload=()=>res(JSON.parse(xhr.responseText));
xhr.onerror=rej;
xhr.send(fd);
});
}
async function start(){
if(!src.files[0]||!tgt.files[0]) return alert("Upload both");
let fd=new FormData();
fd.append("source",src.files[0]);
fd.append("target",tgt.files[0]);
fd.append("face_index",document.getElementById("faceIndex").value);
let data=await upload("/swap",fd);
poll(data.task_id);
}
async function poll(id){
let r=await fetch("/status/"+id);
let j=await r.json();
if(j.status==="done"){
let img=await fetch("/result/"+id);
let blob=await img.blob();
let url=URL.createObjectURL(blob);
out.src=url;
dl.href=url;
dl.style.display="block";
bar.style.width="100%";
}
else if(j.status==="failed"){
alert(j.error);
}
else{
setTimeout(()=>poll(id),800);
}
}
</script>
</body>
</html>
"""
# ============================================================
# API
# ============================================================
@app.post("/swap")
async def swap(
source: UploadFile = File(...),
target: UploadFile = File(...),
face_index: int = Form(0)
):
tid = str(uuid.uuid4())
TASKS[tid] = {"status": "queued", "time": time.time()}
executor.submit(
run_task,
tid,
await source.read(),
await target.read(),
source.filename,
face_index
)
return {"task_id": tid}
@app.get("/status/{tid}")
def status(tid: str):
cleanup()
if tid not in TASKS:
raise HTTPException(404)
return TASKS[tid]
@app.get("/result/{tid}")
def result(tid: str):
task = TASKS.get(tid)
if not task or task["status"] != "done":
raise HTTPException(404)
return StreamingResponse(
open(task["result"], "rb"),
media_type="image/png",
headers={
"Content-Disposition": f'attachment; filename="{task["filename"]}"'
}
)