AkashKumarave commited on
Commit
e9cafa4
·
verified ·
1 Parent(s): c7d6b3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -73
app.py CHANGED
@@ -16,66 +16,60 @@ import uvicorn
16
  app = FastAPI(
17
  title="Face Swap API",
18
  description="API for swapping faces in images.",
19
- docs_url="/docs",
20
- redoc_url="/redoc"
21
  )
22
 
23
- # Configure logging
24
  logging.basicConfig(level=logging.INFO)
25
  logger = logging.getLogger(__name__)
26
 
27
- # Add CORS middleware
28
  app.add_middleware(
29
  CORSMiddleware,
30
- allow_origins=["*"],
31
  allow_credentials=True,
32
  allow_methods=["*"],
33
  allow_headers=["*"],
34
  )
35
 
36
- # Root endpoint
37
  @app.get("/")
38
  async def root():
39
- return {"message": "Welcome to the Face Swap API! Use /swap-face/ to swap faces or /docs to test the API."}
40
 
41
- # Health check
42
  @app.get("/health")
43
  async def health_check():
44
  return {"status": "healthy"}
45
 
46
- # Global flag for model download
47
- MODEL_DOWNLOADED = False
 
48
 
49
  def download_model():
50
- global MODEL_DOWNLOADED
51
- model_dir = Path("models")
52
- model_path = model_dir / "inswapper_128.onnx"
53
- model_url = "https://huggingface.co/ezioruan/inswapper_128.onnx/resolve/main/inswapper_128.onnx"
54
-
55
- if not model_path.exists():
56
- logger.info("Downloading inswapper_128.onnx...")
57
- model_dir.mkdir(exist_ok=True)
58
- try:
59
- response = requests.get(model_url, stream=True, timeout=30)
60
- response.raise_for_status()
61
- with open(model_path, 'wb') as f:
62
- for chunk in response.iter_content(chunk_size=8192):
63
- f.write(chunk)
64
- logger.info("Model downloaded successfully.")
65
- MODEL_DOWNLOADED = True
66
- except Exception as e:
67
- logger.error(f"Failed to download model: {e}")
68
- raise RuntimeError("Could not download inswapper_128.onnx.")
69
- else:
70
- logger.info("Model already exists.")
71
- MODEL_DOWNLOADED = True
72
 
 
73
  @asynccontextmanager
74
  async def lifespan(app: FastAPI):
75
  logger.info("Starting application...")
76
  try:
77
  download_model()
78
- logger.info("Application started successfully.")
79
  except Exception as e:
80
  logger.error(f"Startup failed: {e}")
81
  raise
@@ -84,43 +78,37 @@ async def lifespan(app: FastAPI):
84
 
85
  app.lifespan = lifespan
86
 
87
- def swap_faces(source_img, target_img):
 
88
  try:
89
  from insightface.app import FaceAnalysis
 
 
 
 
 
 
 
 
 
90
  from insightface.utils import face_align
91
  from insightface.model_zoo import face_swapper
92
 
93
- # Initialize face analysis
94
  face_analyzer = FaceAnalysis(name="buffalo_l")
95
- face_analyzer.prepare(ctx_id=-1, det_size=(640, 640))
96
 
97
- # Detect faces
98
  source_faces = face_analyzer.get(source_img)
99
  target_faces = face_analyzer.get(target_img)
100
 
101
  if not source_faces or not target_faces:
102
  raise ValueError("No faces detected.")
103
  if len(source_faces) > 1 or len(target_faces) > 1:
104
- raise ValueError("Multiple faces detected; only one per image supported.")
105
-
106
- source_face = source_faces[0]
107
- target_face = target_faces[0]
108
 
109
- # Load the face swapper model
110
- model_path = Path("models/inswapper_128.onnx")
111
- if not model_path.exists():
112
- raise FileNotFoundError("Model file inswapper_128.onnx not found.")
113
- swapper = face_swapper.FaceSwapper(str(model_path))
114
 
115
- # Perform face swap
116
- result = swapper.get(target_img, target_face, source_face, paste_back=True)
117
-
118
- # Resize result to match target image size
119
- result_pil = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
120
- target_pil = Image.fromarray(cv2.cvtColor(target_img, cv2.COLOR_BGR2RGB))
121
- result_pil = result_pil.resize(target_pil.size, Image.Resampling.LANCZOS)
122
-
123
- return cv2.cvtColor(np.array(result_pil), cv2.COLOR_RGB2BGR)
124
  except Exception as e:
125
  logger.error(f"Face swap failed: {e}")
126
  raise
@@ -128,26 +116,32 @@ def swap_faces(source_img, target_img):
128
  @app.post("/swap-face/")
129
  async def swap_face(source_file: UploadFile = File(...), target_file: UploadFile = File(...)):
130
  try:
131
- # Read source image
132
- source_content = await source_file.read()
133
- source_np = np.frombuffer(source_content, np.uint8)
134
- source_img = cv2.imdecode(source_np, cv2.IMREAD_COLOR)
135
- if source_img is None:
136
- raise ValueError("Invalid source image.")
137
-
138
- # Read target image
139
- target_content = await target_file.read()
140
- target_np = np.frombuffer(target_content, np.uint8)
141
- target_img = cv2.imdecode(target_np, cv2.IMREAD_COLOR)
142
- if target_img is None:
143
- raise ValueError("Invalid target image.")
144
-
145
- # Perform face swap
146
  result_img = swap_faces(source_img, target_img)
147
 
148
- # Convert result to bytes
149
- _, img_encoded = cv2.imencode(".jpg", result_img)
150
- return Response(content=img_encoded.tobytes(), media_type="image/jpeg")
 
 
 
 
 
 
151
 
152
  except Exception as e:
153
  logger.error("Error in swap_face: %s", str(e))
 
16
  app = FastAPI(
17
  title="Face Swap API",
18
  description="API for swapping faces in images.",
19
+ docs_url="/docs",
20
+ redoc_url="/redoc",
21
  )
22
 
23
+ # Logging setup
24
  logging.basicConfig(level=logging.INFO)
25
  logger = logging.getLogger(__name__)
26
 
27
+ # CORS setup
28
  app.add_middleware(
29
  CORSMiddleware,
30
+ allow_origins=["*"], # Update with your domain in production
31
  allow_credentials=True,
32
  allow_methods=["*"],
33
  allow_headers=["*"],
34
  )
35
 
36
+ # Health check route
37
  @app.get("/")
38
  async def root():
39
+ return {"message": "Face Swap API is running. Use /docs to test the API."}
40
 
 
41
  @app.get("/health")
42
  async def health_check():
43
  return {"status": "healthy"}
44
 
45
+ # Prevent multiple downloads
46
+ MODEL_PATH = Path("models/inswapper_128.onnx")
47
+ MODEL_URL = "https://huggingface.co/ezioruan/inswapper_128.onnx/resolve/main/inswapper_128.onnx"
48
 
49
  def download_model():
50
+ if MODEL_PATH.exists():
51
+ logger.info("Model already exists, skipping download.")
52
+ return
53
+ logger.info("Downloading model...")
54
+ MODEL_PATH.parent.mkdir(exist_ok=True)
55
+ try:
56
+ response = requests.get(MODEL_URL, stream=True, timeout=30)
57
+ response.raise_for_status()
58
+ with open(MODEL_PATH, 'wb') as f:
59
+ for chunk in response.iter_content(chunk_size=8192):
60
+ f.write(chunk)
61
+ logger.info("Model downloaded successfully.")
62
+ except Exception as e:
63
+ logger.error(f"Failed to download model: {e}")
64
+ raise RuntimeError("Could not download inswapper_128.onnx.")
 
 
 
 
 
 
 
65
 
66
+ # FastAPI startup event
67
  @asynccontextmanager
68
  async def lifespan(app: FastAPI):
69
  logger.info("Starting application...")
70
  try:
71
  download_model()
72
+ logger.info("Startup completed successfully.")
73
  except Exception as e:
74
  logger.error(f"Startup failed: {e}")
75
  raise
 
78
 
79
  app.lifespan = lifespan
80
 
81
+ # Face detection and swap functions
82
+ def get_faces(image):
83
  try:
84
  from insightface.app import FaceAnalysis
85
+ app = FaceAnalysis(name="buffalo_l")
86
+ app.prepare(ctx_id=0, det_size=(640, 640))
87
+ return app.get(image) or []
88
+ except Exception as e:
89
+ logger.error(f"Face detection failed: {e}")
90
+ raise
91
+
92
+ def swap_faces(source_img, target_img):
93
+ try:
94
  from insightface.utils import face_align
95
  from insightface.model_zoo import face_swapper
96
 
 
97
  face_analyzer = FaceAnalysis(name="buffalo_l")
98
+ face_analyzer.prepare(ctx_id=0, det_size=(640, 640))
99
 
 
100
  source_faces = face_analyzer.get(source_img)
101
  target_faces = face_analyzer.get(target_img)
102
 
103
  if not source_faces or not target_faces:
104
  raise ValueError("No faces detected.")
105
  if len(source_faces) > 1 or len(target_faces) > 1:
106
+ raise ValueError("Multiple faces detected. Only one face per image is supported.")
 
 
 
107
 
108
+ swapper = face_swapper.FaceSwapper(MODEL_PATH)
109
+ result = swapper.get(target_img, target_faces[0], source_faces[0], paste_back=True)
 
 
 
110
 
111
+ return cv2.cvtColor(np.array(Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))), cv2.COLOR_RGB2BGR)
 
 
 
 
 
 
 
 
112
  except Exception as e:
113
  logger.error(f"Face swap failed: {e}")
114
  raise
 
116
  @app.post("/swap-face/")
117
  async def swap_face(source_file: UploadFile = File(...), target_file: UploadFile = File(...)):
118
  try:
119
+ source_path = "temp_source.jpg"
120
+ target_path = "temp_target.jpg"
121
+ output_path = "output.jpg"
122
+
123
+ with open(source_path, "wb") as f:
124
+ f.write(await source_file.read())
125
+ with open(target_path, "wb") as f:
126
+ f.write(await target_file.read())
127
+
128
+ source_img = cv2.imread(source_path)
129
+ target_img = cv2.imread(target_path)
130
+
131
+ if source_img is None or target_img is None:
132
+ raise ValueError("Invalid images provided.")
133
+
134
  result_img = swap_faces(source_img, target_img)
135
 
136
+ cv2.imwrite(output_path, result_img)
137
+ with open(output_path, "rb") as f:
138
+ image_data = f.read()
139
+
140
+ for path in [source_path, target_path, output_path]:
141
+ if os.path.exists(path):
142
+ os.remove(path)
143
+
144
+ return Response(content=image_data, media_type="image/jpeg")
145
 
146
  except Exception as e:
147
  logger.error("Error in swap_face: %s", str(e))