Spaces:
Running
Running
| import os | |
| import cv2 | |
| import uuid | |
| import time | |
| import numpy as np | |
| import insightface | |
| import concurrent.futures | |
| import traceback | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Form | |
| from fastapi.responses import HTMLResponse, StreamingResponse | |
| # HEIC SUPPORT | |
| try: | |
| import pillow_heif | |
| from PIL import Image | |
| pillow_heif.register_heif_opener() | |
| HEIC_SUPPORTED = True | |
| except: | |
| HEIC_SUPPORTED = False | |
| # ============================================================ | |
| # CONFIG | |
| # ============================================================ | |
| MAX_FILE_MB = 10 | |
| MAX_DIM = 640 | |
| MAX_WORKERS = 3 | |
| CLEANUP_TIME = 300 | |
| TASKS = {} | |
| executor = concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) | |
| # ============================================================ | |
| # LOAD MODELS | |
| # ============================================================ | |
| face_app = insightface.app.FaceAnalysis(name="buffalo_l") | |
| face_app.prepare(ctx_id=-1, det_size=(640, 640)) | |
| swapper = insightface.model_zoo.get_model("inswapper_128.onnx", root=".") | |
| # ============================================================ | |
| # IMAGE HELPERS | |
| # ============================================================ | |
| def decode_image(file_bytes): | |
| arr = np.frombuffer(file_bytes, np.uint8) | |
| img = cv2.imdecode(arr, cv2.IMREAD_COLOR) | |
| if img is None and HEIC_SUPPORTED: | |
| try: | |
| from PIL import Image | |
| import io | |
| pil_img = Image.open(io.BytesIO(file_bytes)).convert("RGB") | |
| img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) | |
| except: | |
| pass | |
| if img is None: | |
| raise ValueError("Unsupported image format") | |
| return img | |
| def compress_and_resize(file_bytes): | |
| img = decode_image(file_bytes) | |
| size_mb = len(file_bytes) / (1024 * 1024) | |
| if size_mb > MAX_FILE_MB: | |
| img = cv2.resize(img, None, fx=0.6, fy=0.6, interpolation=cv2.INTER_AREA) | |
| h, w = img.shape[:2] | |
| if max(h, w) > MAX_DIM: | |
| scale = MAX_DIM / max(h, w) | |
| img = cv2.resize(img, (int(w * scale), int(h * scale))) | |
| return img | |
| def enhance(img): | |
| blur = cv2.GaussianBlur(img, (0, 0), 1.2) | |
| return cv2.addWeighted(img, 1.2, blur, -0.2, 0) | |
| def cleanup(): | |
| now = time.time() | |
| remove = [] | |
| for k, v in TASKS.items(): | |
| if "time" in v and now - v["time"] > CLEANUP_TIME: | |
| try: | |
| if "result" in v: | |
| os.remove(v["result"]) | |
| except: | |
| pass | |
| remove.append(k) | |
| for k in remove: | |
| TASKS.pop(k, None) | |
| # ============================================================ | |
| # WORKER | |
| # ============================================================ | |
| def run_task(tid, src_bytes, tgt_bytes, filename, face_index): | |
| TASKS[tid]["status"] = "processing" | |
| try: | |
| src = compress_and_resize(src_bytes) | |
| tgt = compress_and_resize(tgt_bytes) | |
| s_faces = face_app.get(src) | |
| t_faces = face_app.get(tgt) | |
| if not s_faces or not t_faces: | |
| raise ValueError("Face not detected") | |
| face_index = min(face_index, len(t_faces) - 1) | |
| result = swapper.get(tgt, t_faces[face_index], s_faces[0], paste_back=True) | |
| result = enhance(result) | |
| name = os.path.splitext(filename)[0] | |
| out_path = f"/tmp/{name}_{tid}.png" | |
| cv2.imwrite(out_path, result, [cv2.IMWRITE_PNG_COMPRESSION, 3]) | |
| TASKS[tid] = { | |
| "status": "done", | |
| "result": out_path, | |
| "filename": f"{name}.png", | |
| "time": time.time() | |
| } | |
| except Exception as e: | |
| TASKS[tid] = {"status": "failed", "error": str(e)} | |
| print(traceback.format_exc()) | |
| # ============================================================ | |
| # FASTAPI | |
| # ============================================================ | |
| app = FastAPI() | |
| # ============================================================ | |
| # UI PAGE | |
| # ============================================================ | |
| def home(): | |
| return """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <meta name="viewport" content="width=device-width, initial-scale=1"> | |
| <title>Face Swap API Test</title> | |
| <style> | |
| body{font-family:sans-serif;background:#0f172a;color:white;text-align:center;padding:20px} | |
| .box{border:2px dashed #444;padding:20px;margin:10px;border-radius:10px;cursor:pointer} | |
| img{max-width:200px;margin-top:10px;border-radius:8px} | |
| button{padding:12px 20px;background:#6366f1;color:white;border:none;border-radius:8px} | |
| .progress{height:10px;background:#333;margin-top:10px;border-radius:10px;overflow:hidden;display:none} | |
| .bar{height:100%;width:0;background:#22c55e} | |
| </style> | |
| </head> | |
| <body> | |
| <h2>⚡ Face Swap Test UI</h2> | |
| <div class="box" onclick="src.click()">Upload Source<input type="file" id="src" hidden></div> | |
| <img id="p1"> | |
| <div class="box" onclick="tgt.click()">Upload Target<input type="file" id="tgt" hidden></div> | |
| <img id="p2"> | |
| <br> | |
| <label>Select Face Index:</label> | |
| <input type="number" id="faceIndex" value="0" min="0"> | |
| <br><br> | |
| <button onclick="start()">Start Swap</button> | |
| <div class="progress" id="progress"><div class="bar" id="bar"></div></div> | |
| <br> | |
| <img id="out"> | |
| <br> | |
| <a id="dl" download="faceswap.png" style="display:none;color:lightgreen">Download PNG</a> | |
| <script> | |
| const src=document.getElementById("src"); | |
| const tgt=document.getElementById("tgt"); | |
| src.onchange=()=>p1.src=URL.createObjectURL(src.files[0]); | |
| tgt.onchange=()=>p2.src=URL.createObjectURL(tgt.files[0]); | |
| function upload(url,fd){ | |
| return new Promise((res,rej)=>{ | |
| let xhr=new XMLHttpRequest(); | |
| xhr.open("POST",url); | |
| xhr.upload.onprogress=(e)=>{ | |
| if(e.lengthComputable){ | |
| progress.style.display="block"; | |
| bar.style.width=(e.loaded/e.total*100)+"%"; | |
| } | |
| }; | |
| xhr.onload=()=>res(JSON.parse(xhr.responseText)); | |
| xhr.onerror=rej; | |
| xhr.send(fd); | |
| }); | |
| } | |
| async function start(){ | |
| if(!src.files[0]||!tgt.files[0]) return alert("Upload both"); | |
| let fd=new FormData(); | |
| fd.append("source",src.files[0]); | |
| fd.append("target",tgt.files[0]); | |
| fd.append("face_index",document.getElementById("faceIndex").value); | |
| let data=await upload("/swap",fd); | |
| poll(data.task_id); | |
| } | |
| async function poll(id){ | |
| let r=await fetch("/status/"+id); | |
| let j=await r.json(); | |
| if(j.status==="done"){ | |
| let img=await fetch("/result/"+id); | |
| let blob=await img.blob(); | |
| let url=URL.createObjectURL(blob); | |
| out.src=url; | |
| dl.href=url; | |
| dl.style.display="block"; | |
| bar.style.width="100%"; | |
| } | |
| else if(j.status==="failed"){ | |
| alert(j.error); | |
| } | |
| else{ | |
| setTimeout(()=>poll(id),800); | |
| } | |
| } | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| # ============================================================ | |
| # API | |
| # ============================================================ | |
| async def swap( | |
| source: UploadFile = File(...), | |
| target: UploadFile = File(...), | |
| face_index: int = Form(0) | |
| ): | |
| tid = str(uuid.uuid4()) | |
| TASKS[tid] = {"status": "queued", "time": time.time()} | |
| executor.submit( | |
| run_task, | |
| tid, | |
| await source.read(), | |
| await target.read(), | |
| source.filename, | |
| face_index | |
| ) | |
| return {"task_id": tid} | |
| def status(tid: str): | |
| cleanup() | |
| if tid not in TASKS: | |
| raise HTTPException(404) | |
| return TASKS[tid] | |
| def result(tid: str): | |
| task = TASKS.get(tid) | |
| if not task or task["status"] != "done": | |
| raise HTTPException(404) | |
| return StreamingResponse( | |
| open(task["result"], "rb"), | |
| media_type="image/png", | |
| headers={ | |
| "Content-Disposition": f'attachment; filename="{task["filename"]}"' | |
| } | |
| ) |