# -*- coding:UTF-8 -*- from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.responses import Response from fastapi.middleware.cors import CORSMiddleware import cv2 import numpy as np from PIL import Image import os import shutil import logging import requests from pathlib import Path app = FastAPI() # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Update with Framer domain in production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def download_model(): model_dir = Path("models") model_path = model_dir / "inswapper_128.onnx" model_url = "https://huggingface.co/ezioruan/inswapper_128.onnx/resolve/main/inswapper_128.onnx" if not model_path.exists(): logger.info("Model not found. Downloading inswapper_128.onnx...") model_dir.mkdir(exist_ok=True) try: response = requests.get(model_url, stream=True) response.raise_for_status() with open(model_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) logger.info("Model downloaded successfully.") except Exception as e: logger.error(f"Failed to download model: {e}") raise RuntimeError("Could not download inswapper_128.onnx. Please check logs.") # Download model on startup download_model() def get_many_faces(image): """Simplified face detection using insightface (placeholder).""" from insightface.app import FaceAnalysis app = FaceAnalysis(name="buffalo_l") app.prepare(ctx_id=0, det_size=(640, 640)) faces = app.get(image) return faces if faces else [] def swap_faces(source_img, target_img): """Perform face swapping using insightface and inswapper model.""" from insightface.utils import face_align from insightface.model_zoo import face_swapper # Initialize face analysis face_analyzer = FaceAnalysis(name="buffalo_l") face_analyzer.prepare(ctx_id=0, det_size=(640, 640)) # Detect faces source_faces = face_analyzer.get(source_img) target_faces = face_analyzer.get(target_img) if not source_faces or not target_faces: raise ValueError("No faces detected in one or both images.") if len(source_faces) > 1 or len(target_faces) > 1: raise ValueError("Multiple faces detected; only one face per image is supported.") source_face = source_faces[0] target_face = target_faces[0] # Load the face swapper model model_path = Path("models/inswapper_128.onnx") swapper = face_swapper.FaceSwapper(model_path) # Perform face swap result = swapper.get(target_img, target_face, source_face, paste_back=True) # Resize to match target image size target_pil = Image.fromarray(cv2.cvtColor(target_img, cv2.COLOR_BGR2RGB)) result_pil = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) result_pil = result_pil.resize(target_pil.size, Image.Resampling.LANCZOS) return cv2.cvtColor(np.array(result_pil), cv2.COLOR_RGB2BGR) @app.post("/swap-face/") async def swap_face(source_file: UploadFile = File(...), target_file: UploadFile = File(...), doFaceEnhancer: bool = True): try: # Save uploaded files temporarily source_path = "temp_source.jpg" target_path = "temp_target.jpg" output_path = "output.jpg" # Read and save source image source_content = await source_file.read() with open(source_path, "wb") as f: f.write(source_content) source_img = cv2.imread(source_path) if source_img is None: raise ValueError("Failed to load source image.") # Read and save target image target_content = await target_file.read() with open(target_path, "wb") as f: f.write(target_content) target_img = cv2.imread(target_path) if target_img is None: raise ValueError("Failed to load target image.") # Perform face swap result_img = swap_faces(source_img, target_img) # Save the result cv2.imwrite(output_path, result_img) # Read the output image with open(output_path, "rb") as f: image_data = f.read() # Clean up temporary files for path in [source_path, target_path, output_path]: if os.path.exists(path): os.remove(path) # Return the swapped image return Response(content=image_data, media_type="image/jpeg") except Exception as e: logger.error("Error in swap_face: %s", str(e)) raise HTTPException(status_code=500, detail=str(e))