|
|
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException |
| from fastapi.responses import StreamingResponse, JSONResponse, FileResponse |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.templating import Jinja2Templates |
| import insightface |
| from insightface.app import FaceAnalysis |
| from insightface.model_zoo import get_model |
| import numpy as np |
| import cv2 |
| import io |
|
|
| app = FastAPI() |
|
|
| |
| face_analysis = FaceAnalysis(name='buffalo_l') |
| face_analysis.prepare(ctx_id=0, det_size=(640, 640)) |
| swapper = get_model('inswapper_128.onnx', download=True, download_zip=True) |
|
|
| @app.post("/swap_faces/") |
| async def swap_faces(source_file: UploadFile = File(...), |
| source_face_index: int = Form(...), |
| destination_file: UploadFile = File(...), |
| destination_face_index: int = Form(...)): |
| """Swaps faces between the source and destination images based on the specified face indices.""" |
| source_bytes = await source_file.read() |
| destination_bytes = await destination_file.read() |
|
|
| |
| source_image = cv2.imdecode(np.frombuffer(source_bytes, np.uint8), cv2.IMREAD_COLOR) |
| destination_image = cv2.imdecode(np.frombuffer(destination_bytes, np.uint8), cv2.IMREAD_COLOR) |
|
|
| |
| faces_source = sort_faces(face_analysis.get(source_image)) |
| if not faces_source: |
| raise HTTPException(status_code=400, detail="No faces detected in the source image.") |
| source_face = get_face(faces_source, source_face_index) |
|
|
| faces_destination = sort_faces(face_analysis.get(destination_image)) |
| if not faces_destination: |
| raise HTTPException(status_code=400, detail="No faces detected in the destination image.") |
| destination_face = get_face(faces_destination, destination_face_index) |
|
|
| |
| result_image = swapper.get(destination_image, destination_face, source_face, paste_back=True) |
| |
| |
| _, result_bytes = cv2.imencode('.jpg', result_image) |
| |
| |
| return StreamingResponse(io.BytesIO(result_bytes), media_type="image/jpeg") |
|
|
| def sort_faces(faces): |
| return sorted(faces, key=lambda x: x.bbox[0]) |
|
|
| def get_face(faces, face_id): |
| if len(faces) < face_id or face_id < 1: |
| raise ValueError(f"The image includes only {len(faces)} faces, however, you asked for face {face_id}") |
| return faces[face_id - 1] |
|
|
| @app.exception_handler(ValueError) |
| async def value_error_handler(request, exc): |
| """Custom exception handler to return JSON error responses for ValueError.""" |
| return JSONResponse(status_code=400, content={"detail": str(exc)}) |
|
|
|
|
|
|
|
|