Spaces:
Runtime error
Runtime error
| # -*- coding:UTF-8 -*- | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.responses import Response | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from contextlib import asynccontextmanager | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import os | |
| import logging | |
| import requests | |
| from pathlib import Path | |
| import uvicorn | |
| # Initialize FastAPI | |
| app = FastAPI( | |
| title="Face Swap API", | |
| description="API for swapping faces in images.", | |
| docs_url="/docs", | |
| redoc_url="/redoc", | |
| ) | |
| # Logging setup | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # CORS setup | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Update with your domain in production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Health check route | |
| async def root(): | |
| return {"message": "Face Swap API is running. Use /docs to test the API."} | |
| async def health_check(): | |
| return {"status": "healthy"} | |
| # Prevent multiple downloads | |
| MODEL_PATH = Path("models/inswapper_128.onnx") | |
| MODEL_URL = "https://huggingface.co/ezioruan/inswapper_128.onnx/resolve/main/inswapper_128.onnx" | |
| def download_model(): | |
| if MODEL_PATH.exists(): | |
| logger.info("Model already exists, skipping download.") | |
| return | |
| logger.info("Downloading model...") | |
| MODEL_PATH.parent.mkdir(exist_ok=True) | |
| try: | |
| response = requests.get(MODEL_URL, stream=True, timeout=30) | |
| response.raise_for_status() | |
| with open(MODEL_PATH, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| logger.info("Model downloaded successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to download model: {e}") | |
| raise RuntimeError("Could not download inswapper_128.onnx.") | |
| # FastAPI startup event | |
| async def lifespan(app: FastAPI): | |
| logger.info("Starting application...") | |
| try: | |
| download_model() | |
| logger.info("Startup completed successfully.") | |
| except Exception as e: | |
| logger.error(f"Startup failed: {e}") | |
| raise | |
| yield | |
| logger.info("Shutting down application...") | |
| app.lifespan = lifespan | |
| # Face detection and swap functions | |
| def get_faces(image): | |
| try: | |
| from insightface.app import FaceAnalysis | |
| app = FaceAnalysis(name="buffalo_l") | |
| app.prepare(ctx_id=0, det_size=(640, 640)) | |
| return app.get(image) or [] | |
| except Exception as e: | |
| logger.error(f"Face detection failed: {e}") | |
| raise | |
| def swap_faces(source_img, target_img): | |
| try: | |
| from insightface.utils import face_align | |
| from insightface.model_zoo import face_swapper | |
| face_analyzer = FaceAnalysis(name="buffalo_l") | |
| face_analyzer.prepare(ctx_id=0, det_size=(640, 640)) | |
| source_faces = face_analyzer.get(source_img) | |
| target_faces = face_analyzer.get(target_img) | |
| if not source_faces or not target_faces: | |
| raise ValueError("No faces detected.") | |
| if len(source_faces) > 1 or len(target_faces) > 1: | |
| raise ValueError("Multiple faces detected. Only one face per image is supported.") | |
| swapper = face_swapper.FaceSwapper(MODEL_PATH) | |
| result = swapper.get(target_img, target_faces[0], source_faces[0], paste_back=True) | |
| return cv2.cvtColor(np.array(Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))), cv2.COLOR_RGB2BGR) | |
| except Exception as e: | |
| logger.error(f"Face swap failed: {e}") | |
| raise | |
| async def swap_face(source_file: UploadFile = File(...), target_file: UploadFile = File(...)): | |
| try: | |
| source_path = "temp_source.jpg" | |
| target_path = "temp_target.jpg" | |
| output_path = "output.jpg" | |
| with open(source_path, "wb") as f: | |
| f.write(await source_file.read()) | |
| with open(target_path, "wb") as f: | |
| f.write(await target_file.read()) | |
| source_img = cv2.imread(source_path) | |
| target_img = cv2.imread(target_path) | |
| if source_img is None or target_img is None: | |
| raise ValueError("Invalid images provided.") | |
| result_img = swap_faces(source_img, target_img) | |
| cv2.imwrite(output_path, result_img) | |
| with open(output_path, "rb") as f: | |
| image_data = f.read() | |
| for path in [source_path, target_path, output_path]: | |
| if os.path.exists(path): | |
| os.remove(path) | |
| return Response(content=image_data, media_type="image/jpeg") | |
| except Exception as e: | |
| logger.error("Error in swap_face: %s", str(e)) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |