from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.responses import StreamingResponse, JSONResponse, FileResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates import insightface from insightface.app import FaceAnalysis from insightface.model_zoo import get_model import numpy as np import cv2 import io app = FastAPI() # Initialize FaceAnalysis app and swapper face_analysis = FaceAnalysis(name='buffalo_l') face_analysis.prepare(ctx_id=0, det_size=(640, 640)) swapper = get_model('inswapper_128.onnx', download=True, download_zip=True) @app.post("/swap_faces/") async def swap_faces(source_file: UploadFile = File(...), source_face_index: int = Form(...), destination_file: UploadFile = File(...), destination_face_index: int = Form(...)): """Swaps faces between the source and destination images based on the specified face indices.""" source_bytes = await source_file.read() destination_bytes = await destination_file.read() # Decode images source_image = cv2.imdecode(np.frombuffer(source_bytes, np.uint8), cv2.IMREAD_COLOR) destination_image = cv2.imdecode(np.frombuffer(destination_bytes, np.uint8), cv2.IMREAD_COLOR) # Face detection and sorting faces_source = sort_faces(face_analysis.get(source_image)) if not faces_source: raise HTTPException(status_code=400, detail="No faces detected in the source image.") source_face = get_face(faces_source, source_face_index) faces_destination = sort_faces(face_analysis.get(destination_image)) if not faces_destination: raise HTTPException(status_code=400, detail="No faces detected in the destination image.") destination_face = get_face(faces_destination, destination_face_index) # Swap faces result_image = swapper.get(destination_image, destination_face, source_face, paste_back=True) # Convert result_image back to bytes _, result_bytes = cv2.imencode('.jpg', result_image) # Return the image bytes as a streaming response return StreamingResponse(io.BytesIO(result_bytes), media_type="image/jpeg") def sort_faces(faces): return sorted(faces, key=lambda x: x.bbox[0]) def get_face(faces, face_id): if len(faces) < face_id or face_id < 1: raise ValueError(f"The image includes only {len(faces)} faces, however, you asked for face {face_id}") return faces[face_id - 1] @app.exception_handler(ValueError) async def value_error_handler(request, exc): """Custom exception handler to return JSON error responses for ValueError.""" return JSONResponse(status_code=400, content={"detail": str(exc)})