File size: 6,700 Bytes
456bcc4
0ed24f7
 
 
411c7f2
456bcc4
0ed24f7
456bcc4
0ed24f7
279af2f
 
0ed24f7
 
a73b3bc
279af2f
b8c2f1b
 
 
 
 
 
 
279af2f
456bcc4
 
 
 
 
279af2f
 
249a867
279af2f
 
 
 
 
249a867
 
 
 
 
b8c2f1b
 
 
 
 
a73b3bc
 
 
456bcc4
a73b3bc
 
 
 
 
456bcc4
 
 
 
 
 
 
 
a73b3bc
456bcc4
 
 
 
 
a73b3bc
456bcc4
 
0ed24f7
a73b3bc
 
 
 
b8c2f1b
411c7f2
 
 
a73b3bc
 
 
 
 
 
 
411c7f2
 
 
 
 
456bcc4
0ed24f7
a73b3bc
 
 
 
 
 
 
 
 
 
0ed24f7
 
 
a73b3bc
 
 
0ed24f7
a73b3bc
 
 
0ed24f7
a73b3bc
 
 
0ed24f7
a73b3bc
 
 
 
0ed24f7
a73b3bc
 
0ed24f7
a73b3bc
 
 
 
 
0ed24f7
a73b3bc
 
0ed24f7
a73b3bc
 
 
 
0ed24f7
a73b3bc
 
 
 
279af2f
 
0ed24f7
279af2f
 
0ed24f7
 
279af2f
 
456bcc4
 
279af2f
456bcc4
0ed24f7
 
 
456bcc4
0ed24f7
456bcc4
279af2f
456bcc4
279af2f
0ed24f7
 
456bcc4
 
0ed24f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279af2f
456bcc4
a73b3bc
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# -*- coding:UTF-8 -*-
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import Response
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
import cv2
import numpy as np
from PIL import Image
import os
import shutil
import logging
import requests
from pathlib import Path
import uvicorn

# Initialize FastAPI with explicit docs settings
app = FastAPI(
    title="Face Swap API",
    description="API for swapping faces in images.",
    docs_url="/docs",  # Explicitly set docs URL
    redoc_url="/redoc",  # Explicitly set redoc URL
)

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Update with your Framer domain in production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Add a root endpoint to confirm the app is running
@app.get("/")
async def root():
    return {"message": "Welcome to the Face Swap API! Use /swap-face/ to swap faces or /docs to test the API."}

# Add a health check endpoint
@app.get("/health")
async def health_check():
    return {"status": "healthy"}

# Global flag to prevent multiple downloads
MODEL_DOWNLOADED = False

def download_model():
    global MODEL_DOWNLOADED
    if MODEL_DOWNLOADED:
        logger.info("Model already downloaded, skipping.")
        return

    model_dir = Path("models")
    model_path = model_dir / "inswapper_128.onnx"
    model_url = "https://huggingface.co/ezioruan/inswapper_128.onnx/resolve/main/inswapper_128.onnx"

    if not model_path.exists():
        logger.info("Model not found. Downloading inswapper_128.onnx...")
        model_dir.mkdir(exist_ok=True)
        try:
            response = requests.get(model_url, stream=True, timeout=30)
            response.raise_for_status()
            with open(model_path, 'wb') as f:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)
            logger.info("Model downloaded successfully.")
            MODEL_DOWNLOADED = True
        except Exception as e:
            logger.error(f"Failed to download model: {e}")
            raise RuntimeError("Could not download inswapper_128.onnx. Please check logs.")
    else:
        logger.info("Model already exists at: %s", model_path)
        MODEL_DOWNLOADED = True

# Use lifespan event handler
@asynccontextmanager
async def lifespan(app: FastAPI):
    # Startup code
    logger.info("Starting up application...")
    try:
        download_model()
        logger.info("Startup completed successfully.")
    except Exception as e:
        logger.error(f"Startup failed: {e}")
        raise
    yield
    # Shutdown code (if any)
    logger.info("Shutting down application...")

app.lifespan = lifespan

def get_many_faces(image):
    """Simplified face detection using insightface."""
    try:
        from insightface.app import FaceAnalysis
        app = FaceAnalysis(name="buffalo_l")
        app.prepare(ctx_id=0, det_size=(640, 640))
        faces = app.get(image)
        return faces if faces else []
    except Exception as e:
        logger.error(f"Face detection failed: {e}")
        raise

def swap_faces(source_img, target_img):
    """Perform face swapping using insightface and inswapper model."""
    try:
        from insightface.utils import face_align
        from insightface.model_zoo import face_swapper

        # Initialize face analysis
        face_analyzer = FaceAnalysis(name="buffalo_l")
        face_analyzer.prepare(ctx_id=0, det_size=(640, 640))

        # Detect faces
        source_faces = face_analyzer.get(source_img)
        target_faces = face_analyzer.get(target_img)

        if not source_faces or not target_faces:
            raise ValueError("No faces detected in one or both images.")
        if len(source_faces) > 1 or len(target_faces) > 1:
            raise ValueError("Multiple faces detected; only one face per image is supported.")

        source_face = source_faces[0]
        target_face = target_faces[0]

        # Load the face swapper model
        model_path = Path("models/inswapper_128.onnx")
        if not model_path.exists():
            raise FileNotFoundError("Model file inswapper_128.onnx not found.")
        swapper = face_swapper.FaceSwapper(model_path)

        # Perform face swap
        result = swapper.get(target_img, target_face, source_face, paste_back=True)

        # Resize to match target image size
        target_pil = Image.fromarray(cv2.cvtColor(target_img, cv2.COLOR_BGR2RGB))
        result_pil = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
        result_pil = result_pil.resize(target_pil.size, Image.Resampling.LANCZOS)

        return cv2.cvtColor(np.array(result_pil), cv2.COLOR_RGB2BGR)
    except Exception as e:
        logger.error(f"Face swap failed: {e}")
        raise

@app.post("/swap-face/")
async def swap_face(source_file: UploadFile = File(...), target_file: UploadFile = File(...), doFaceEnhancer: bool = True):
    try:
        # Save uploaded files temporarily
        source_path = "temp_source.jpg"
        target_path = "temp_target.jpg"
        output_path = "output.jpg"

        # Read and save source image
        source_content = await source_file.read()
        with open(source_path, "wb") as f:
            f.write(source_content)
        source_img = cv2.imread(source_path)
        if source_img is None:
            raise ValueError("Failed to load source image.")

        # Read and save target image
        target_content = await target_file.read()
        with open(target_path, "wb") as f:
            f.write(target_content)
        target_img = cv2.imread(target_path)
        if target_img is None:
            raise ValueError("Failed to load target image.")

        # Perform face swap
        result_img = swap_faces(source_img, target_img)

        # Save the result
        cv2.imwrite(output_path, result_img)

        # Read the output image
        with open(output_path, "rb") as f:
            image_data = f.read()

        # Clean up temporary files
        for path in [source_path, target_path, output_path]:
            if os.path.exists(path):
                os.remove(path)

        # Return the swapped image
        return Response(content=image_data, media_type="image/jpeg")

    except Exception as e:
        logger.error("Error in swap_face: %s", str(e))
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    # Hugging Face Spaces expects the app to run on port 7860
    uvicorn.run(app, host="0.0.0.0", port=7860)