testing1 / app.py
AkashKumarave's picture
Update app.py
e9cafa4 verified
raw
history blame
4.8 kB
# -*- coding:UTF-8 -*-
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import Response
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
import cv2
import numpy as np
from PIL import Image
import os
import logging
import requests
from pathlib import Path
import uvicorn
# Initialize FastAPI
app = FastAPI(
title="Face Swap API",
description="API for swapping faces in images.",
docs_url="/docs",
redoc_url="/redoc",
)
# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# CORS setup
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Update with your domain in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Health check route
@app.get("/")
async def root():
return {"message": "Face Swap API is running. Use /docs to test the API."}
@app.get("/health")
async def health_check():
return {"status": "healthy"}
# Prevent multiple downloads
MODEL_PATH = Path("models/inswapper_128.onnx")
MODEL_URL = "https://huggingface.co/ezioruan/inswapper_128.onnx/resolve/main/inswapper_128.onnx"
def download_model():
if MODEL_PATH.exists():
logger.info("Model already exists, skipping download.")
return
logger.info("Downloading model...")
MODEL_PATH.parent.mkdir(exist_ok=True)
try:
response = requests.get(MODEL_URL, stream=True, timeout=30)
response.raise_for_status()
with open(MODEL_PATH, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.info("Model downloaded successfully.")
except Exception as e:
logger.error(f"Failed to download model: {e}")
raise RuntimeError("Could not download inswapper_128.onnx.")
# FastAPI startup event
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Starting application...")
try:
download_model()
logger.info("Startup completed successfully.")
except Exception as e:
logger.error(f"Startup failed: {e}")
raise
yield
logger.info("Shutting down application...")
app.lifespan = lifespan
# Face detection and swap functions
def get_faces(image):
try:
from insightface.app import FaceAnalysis
app = FaceAnalysis(name="buffalo_l")
app.prepare(ctx_id=0, det_size=(640, 640))
return app.get(image) or []
except Exception as e:
logger.error(f"Face detection failed: {e}")
raise
def swap_faces(source_img, target_img):
try:
from insightface.utils import face_align
from insightface.model_zoo import face_swapper
face_analyzer = FaceAnalysis(name="buffalo_l")
face_analyzer.prepare(ctx_id=0, det_size=(640, 640))
source_faces = face_analyzer.get(source_img)
target_faces = face_analyzer.get(target_img)
if not source_faces or not target_faces:
raise ValueError("No faces detected.")
if len(source_faces) > 1 or len(target_faces) > 1:
raise ValueError("Multiple faces detected. Only one face per image is supported.")
swapper = face_swapper.FaceSwapper(MODEL_PATH)
result = swapper.get(target_img, target_faces[0], source_faces[0], paste_back=True)
return cv2.cvtColor(np.array(Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))), cv2.COLOR_RGB2BGR)
except Exception as e:
logger.error(f"Face swap failed: {e}")
raise
@app.post("/swap-face/")
async def swap_face(source_file: UploadFile = File(...), target_file: UploadFile = File(...)):
try:
source_path = "temp_source.jpg"
target_path = "temp_target.jpg"
output_path = "output.jpg"
with open(source_path, "wb") as f:
f.write(await source_file.read())
with open(target_path, "wb") as f:
f.write(await target_file.read())
source_img = cv2.imread(source_path)
target_img = cv2.imread(target_path)
if source_img is None or target_img is None:
raise ValueError("Invalid images provided.")
result_img = swap_faces(source_img, target_img)
cv2.imwrite(output_path, result_img)
with open(output_path, "rb") as f:
image_data = f.read()
for path in [source_path, target_path, output_path]:
if os.path.exists(path):
os.remove(path)
return Response(content=image_data, media_type="image/jpeg")
except Exception as e:
logger.error("Error in swap_face: %s", str(e))
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)