Fs / main.py
Luisgust's picture
Update main.py
0ee03ec verified
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import insightface
from insightface.app import FaceAnalysis
from insightface.model_zoo import get_model
import numpy as np
import cv2
import io
app = FastAPI()
# Initialize FaceAnalysis app and swapper
face_analysis = FaceAnalysis(name='buffalo_l')
face_analysis.prepare(ctx_id=0, det_size=(640, 640))
swapper = get_model('inswapper_128.onnx', download=True, download_zip=True)
@app.post("/swap_faces/")
async def swap_faces(source_file: UploadFile = File(...),
source_face_index: int = Form(...),
destination_file: UploadFile = File(...),
destination_face_index: int = Form(...)):
"""Swaps faces between the source and destination images based on the specified face indices."""
source_bytes = await source_file.read()
destination_bytes = await destination_file.read()
# Decode images
source_image = cv2.imdecode(np.frombuffer(source_bytes, np.uint8), cv2.IMREAD_COLOR)
destination_image = cv2.imdecode(np.frombuffer(destination_bytes, np.uint8), cv2.IMREAD_COLOR)
# Face detection and sorting
faces_source = sort_faces(face_analysis.get(source_image))
if not faces_source:
raise HTTPException(status_code=400, detail="No faces detected in the source image.")
source_face = get_face(faces_source, source_face_index)
faces_destination = sort_faces(face_analysis.get(destination_image))
if not faces_destination:
raise HTTPException(status_code=400, detail="No faces detected in the destination image.")
destination_face = get_face(faces_destination, destination_face_index)
# Swap faces
result_image = swapper.get(destination_image, destination_face, source_face, paste_back=True)
# Convert result_image back to bytes
_, result_bytes = cv2.imencode('.jpg', result_image)
# Return the image bytes as a streaming response
return StreamingResponse(io.BytesIO(result_bytes), media_type="image/jpeg")
def sort_faces(faces):
return sorted(faces, key=lambda x: x.bbox[0])
def get_face(faces, face_id):
if len(faces) < face_id or face_id < 1:
raise ValueError(f"The image includes only {len(faces)} faces, however, you asked for face {face_id}")
return faces[face_id - 1]
@app.exception_handler(ValueError)
async def value_error_handler(request, exc):
"""Custom exception handler to return JSON error responses for ValueError."""
return JSONResponse(status_code=400, content={"detail": str(exc)})