import io import traceback from typing import Optional from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, HTMLResponse, JSONResponse from huggingface_hub import hf_hub_download import insightface import numpy as np from PIL import Image import cv2 app = FastAPI(title="FaceSwap API", version="1.0.0") # CORS (so that GitHub Pages frontend can call this API) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) MODEL_REPO = "deepinsight/inswapper" MODEL_FILE = "inswapper_128.onnx" @app.on_event("startup") def load_models(): # Download ONNX model from Hub model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, repo_type="model") # Load swap model and face analyzer (CPU providers only on free Spaces) app.state.swapper = insightface.model_zoo.get_model(model_path, providers=["CPUExecutionProvider"]) app.state.analyser = insightface.app.FaceAnalysis(name="buffalo_l", providers=["CPUExecutionProvider"]) app.state.analyser.prepare(ctx_id=0, det_size=(640, 640)) @app.get("/health") def health(): return {"status": "ok"} @app.get("/", response_class=HTMLResponse) def index(): return """

FaceSwap API

Source:

Target:

POST /swap with form-data: source, target

""" def _read_image(file: UploadFile) -> np.ndarray: data = file.file.read() pil = Image.open(io.BytesIO(data)).convert("RGB") return np.array(pil) @app.post("/swap") def swap_faces(source: UploadFile = File(...), target: UploadFile = File(...)): try: src = _read_image(source) tgt = _read_image(target) faces_src = app.state.analyser.get(src) faces_tgt = app.state.analyser.get(tgt) if not faces_src: raise HTTPException(status_code=400, detail="No face detected in source image.") if not faces_tgt: raise HTTPException(status_code=400, detail="No face detected in target image.") face_src = faces_src[0] face_tgt = faces_tgt[0] # InsightFace returns BGR result = app.state.swapper.get(tgt, face_tgt, face_src, paste_back=True) result_rgb = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) buf = io.BytesIO() Image.fromarray(result_rgb).save(buf, format="PNG") buf.seek(0) return StreamingResponse(buf, media_type="image/png") except HTTPException: raise except Exception as e: traceback.print_exc() return JSONResponse(status_code=500, content={"error": str(e)})