# -*- coding:UTF-8 -*- from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.responses import Response from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager import cv2 import numpy as np from PIL import Image import os import logging import requests from pathlib import Path import uvicorn # Initialize FastAPI app = FastAPI( title="Face Swap API", description="API for swapping faces in images.", docs_url="/docs", redoc_url="/redoc", ) # Logging setup logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # CORS setup app.add_middleware( CORSMiddleware, allow_origins=["*"], # Update with your domain in production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Health check route @app.get("/") async def root(): return {"message": "Face Swap API is running. Use /docs to test the API."} @app.get("/health") async def health_check(): return {"status": "healthy"} # Prevent multiple downloads MODEL_PATH = Path("models/inswapper_128.onnx") MODEL_URL = "https://huggingface.co/ezioruan/inswapper_128.onnx/resolve/main/inswapper_128.onnx" def download_model(): if MODEL_PATH.exists(): logger.info("Model already exists, skipping download.") return logger.info("Downloading model...") MODEL_PATH.parent.mkdir(exist_ok=True) try: response = requests.get(MODEL_URL, stream=True, timeout=30) response.raise_for_status() with open(MODEL_PATH, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) logger.info("Model downloaded successfully.") except Exception as e: logger.error(f"Failed to download model: {e}") raise RuntimeError("Could not download inswapper_128.onnx.") # FastAPI startup event @asynccontextmanager async def lifespan(app: FastAPI): logger.info("Starting application...") try: download_model() logger.info("Startup completed successfully.") except Exception as e: logger.error(f"Startup failed: {e}") raise yield logger.info("Shutting down application...") app.lifespan = lifespan # Face detection and swap functions def get_faces(image): try: from insightface.app import FaceAnalysis app = FaceAnalysis(name="buffalo_l") app.prepare(ctx_id=0, det_size=(640, 640)) return app.get(image) or [] except Exception as e: logger.error(f"Face detection failed: {e}") raise def swap_faces(source_img, target_img): try: from insightface.utils import face_align from insightface.model_zoo import face_swapper face_analyzer = FaceAnalysis(name="buffalo_l") face_analyzer.prepare(ctx_id=0, det_size=(640, 640)) source_faces = face_analyzer.get(source_img) target_faces = face_analyzer.get(target_img) if not source_faces or not target_faces: raise ValueError("No faces detected.") if len(source_faces) > 1 or len(target_faces) > 1: raise ValueError("Multiple faces detected. Only one face per image is supported.") swapper = face_swapper.FaceSwapper(MODEL_PATH) result = swapper.get(target_img, target_faces[0], source_faces[0], paste_back=True) return cv2.cvtColor(np.array(Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))), cv2.COLOR_RGB2BGR) except Exception as e: logger.error(f"Face swap failed: {e}") raise @app.post("/swap-face/") async def swap_face(source_file: UploadFile = File(...), target_file: UploadFile = File(...)): try: source_path = "temp_source.jpg" target_path = "temp_target.jpg" output_path = "output.jpg" with open(source_path, "wb") as f: f.write(await source_file.read()) with open(target_path, "wb") as f: f.write(await target_file.read()) source_img = cv2.imread(source_path) target_img = cv2.imread(target_path) if source_img is None or target_img is None: raise ValueError("Invalid images provided.") result_img = swap_faces(source_img, target_img) cv2.imwrite(output_path, result_img) with open(output_path, "rb") as f: image_data = f.read() for path in [source_path, target_path, output_path]: if os.path.exists(path): os.remove(path) return Response(content=image_data, media_type="image/jpeg") except Exception as e: logger.error("Error in swap_face: %s", str(e)) raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)