HariLogicgo commited on
Commit
e6468b0
·
1 Parent(s): 0d9b89a

import fixed

Browse files
Files changed (1) hide show
  1. app.py +77 -69
app.py CHANGED
@@ -1,26 +1,25 @@
1
  import os
2
- os.environ["OMP_NUM_THREADS"] = "1"
3
- import gradio as gr
4
- import cv2
5
  import shutil
6
  import uuid
 
 
 
 
 
 
7
  import insightface
8
  from insightface.app import FaceAnalysis
9
  from huggingface_hub import hf_hub_download
10
- import subprocess
11
- import numpy as np
12
- import threading
13
  from fastapi import FastAPI, UploadFile, File, HTTPException, Response
14
  from fastapi.responses import RedirectResponse
15
  from pydantic import BaseModel
16
- from motor.motor_asyncio import AsyncIOMotorClient
17
- from bson.objectid import ObjectId
18
- from motor.motor_asyncio import AsyncIOMotorGridFSBucket
19
 
20
- from gradio import mount_gradio_app
21
  import uvicorn
22
- import logging
23
- import io
24
 
25
  # -------------------------------------------------
26
  # Logging
@@ -75,11 +74,9 @@ inswapper_path = download_models()
75
  # Face Analysis + Swapper
76
  # -------------------------------------------------
77
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
78
- logger.info(f"Initializing FaceAnalysis with providers: {providers}")
79
  face_analysis_app = FaceAnalysis(name="buffalo_l", root=MODELS_DIR, providers=providers)
80
  face_analysis_app.prepare(ctx_id=0, det_size=(640, 640))
81
  swapper = insightface.model_zoo.get_model(inswapper_path, providers=providers)
82
- logger.info("FaceAnalysis and swapper initialized")
83
 
84
  # -------------------------------------------------
85
  # CodeFormer setup
@@ -88,28 +85,46 @@ CODEFORMER_PATH = "CodeFormer/inference_codeformer.py"
88
 
89
  def ensure_codeformer():
90
  if not os.path.exists("CodeFormer"):
91
- logger.info("Cloning CodeFormer repository...")
92
  subprocess.run("git clone https://github.com/sczhou/CodeFormer.git", shell=True, check=True)
93
  subprocess.run("pip install -r CodeFormer/requirements.txt", shell=True, check=True)
94
  subprocess.run("python CodeFormer/basicsr/setup.py develop", shell=True, check=True)
95
  subprocess.run("python CodeFormer/scripts/download_pretrained_models.py facelib", shell=True, check=True)
96
  subprocess.run("python CodeFormer/scripts/download_pretrained_models.py CodeFormer", shell=True, check=True)
97
- logger.info("CodeFormer setup complete")
98
 
99
  ensure_codeformer()
100
 
101
  # -------------------------------------------------
102
- # MongoDB + GridFS
103
  # -------------------------------------------------
104
  MONGODB_URL = os.getenv(
105
  "MONGODB_URL",
106
  "mongodb+srv://harilogicgo_db_user:logicgoinfotech@cluster0.dcs1tnb.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0"
107
  )
108
- client = AsyncIOMotorClient(MONGODB_URL)
109
- database = client.FaceSwap
110
- fs_bucket = AsyncIOMotorGridFSBucket(database)
111
 
112
- logger.info("MongoDB + GridFS initialized")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  # -------------------------------------------------
115
  # Lock for face swap
@@ -117,10 +132,9 @@ logger.info("MongoDB + GridFS initialized")
117
  swap_lock = threading.Lock()
118
 
119
  # -------------------------------------------------
120
- # Face Swap Pipeline
121
  # -------------------------------------------------
122
  def face_swap_and_enhance(src_img, tgt_img):
123
- logger.info("Starting face swap and enhancement")
124
  try:
125
  with swap_lock:
126
  shutil.rmtree(UPLOAD_DIR, ignore_errors=True)
@@ -128,16 +142,13 @@ def face_swap_and_enhance(src_img, tgt_img):
128
  os.makedirs(UPLOAD_DIR, exist_ok=True)
129
  os.makedirs(RESULT_DIR, exist_ok=True)
130
 
131
- if not isinstance(src_img, np.ndarray) or not isinstance(tgt_img, np.ndarray):
132
- return None, None, "❌ Invalid input images"
133
-
134
  src_bgr = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR)
135
  tgt_bgr = cv2.cvtColor(tgt_img, cv2.COLOR_RGB2BGR)
136
 
137
  src_faces = face_analysis_app.get(src_bgr)
138
  tgt_faces = face_analysis_app.get(tgt_bgr)
139
  if not src_faces or not tgt_faces:
140
- return None, None, "❌ Face not detected"
141
 
142
  swapped_path = os.path.join(UPLOAD_DIR, f"swapped_{uuid.uuid4().hex[:8]}.jpg")
143
  swapped_bgr = swapper.get(tgt_bgr, tgt_faces[0], src_faces[0])
@@ -165,7 +176,7 @@ def face_swap_and_enhance(src_img, tgt_img):
165
  return None, None, f"❌ Error: {str(e)}"
166
 
167
  # -------------------------------------------------
168
- # Gradio Interface
169
  # -------------------------------------------------
170
  with gr.Blocks() as demo:
171
  gr.Markdown("Face Swap")
@@ -186,10 +197,8 @@ with gr.Blocks() as demo:
186
  btn.click(process, [src_input, tgt_input], [output_img, download, error_box])
187
 
188
  # -------------------------------------------------
189
- # FastAPI App
190
  # -------------------------------------------------
191
- fastapi_app = FastAPI()
192
-
193
  @fastapi_app.get("/")
194
  def root():
195
  return RedirectResponse("/gradio")
@@ -198,20 +207,18 @@ def root():
198
  async def health():
199
  return {"status": "healthy"}
200
 
201
- # -------- Upload Endpoints with GridFS --------
202
  @fastapi_app.post("/source")
203
  async def upload_source(image: UploadFile = File(...)):
204
  contents = await image.read()
205
- file_id = await fs_bucket.upload_from_stream(image.filename, contents)
206
  return {"source_id": str(file_id)}
207
 
208
  @fastapi_app.post("/target")
209
  async def upload_target(image: UploadFile = File(...)):
210
  contents = await image.read()
211
- file_id = await fs_bucket.upload_from_stream(image.filename, contents)
212
  return {"target_id": str(file_id)}
213
 
214
- # -------- Faceswap Endpoint --------
215
  class FaceSwapRequest(BaseModel):
216
  source_id: str
217
  target_id: str
@@ -219,50 +226,51 @@ class FaceSwapRequest(BaseModel):
219
  @fastapi_app.post("/faceswap")
220
  async def perform_faceswap(request: FaceSwapRequest):
221
  try:
222
- # Read source
223
- source_stream = await fs_bucket.open_download_stream(ObjectId(request.source_id))
224
- source_bytes = await source_stream.read()
225
- source_array = np.frombuffer(source_bytes, np.uint8)
226
- source_bgr = cv2.imdecode(source_array, cv2.IMREAD_COLOR)
227
- source_rgb = cv2.cvtColor(source_bgr, cv2.COLOR_BGR2RGB)
228
-
229
- # Read target
230
- target_stream = await fs_bucket.open_download_stream(ObjectId(request.target_id))
231
- target_bytes = await target_stream.read()
232
- target_array = np.frombuffer(target_bytes, np.uint8)
233
- target_bgr = cv2.imdecode(target_array, cv2.IMREAD_COLOR)
234
- target_rgb = cv2.cvtColor(target_bgr, cv2.COLOR_BGR2RGB)
235
-
236
- # Run pipeline
237
- final_img, final_path, err = face_swap_and_enhance(source_rgb, target_rgb)
238
- if err:
239
- raise HTTPException(status_code=500, detail=err)
240
-
241
- # Store result in GridFS
242
- with open(final_path, "rb") as f:
243
- final_bytes = f.read()
244
- result_id = await fs_bucket.upload_from_stream("enhanced.png", final_bytes)
245
-
246
- return {"result_id": str(result_id)}
247
 
248
- except Exception as e:
249
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
- # -------- Download Endpoint --------
252
  @fastapi_app.get("/download/{result_id}")
253
  async def download_result(result_id: str):
254
  try:
255
  stream = await fs_bucket.open_download_stream(ObjectId(result_id))
256
- file_data = await stream.read()
257
- return Response(
258
- content=file_data,
259
- media_type="image/png",
260
- headers={"Content-Disposition": f"attachment; filename=enhanced.png"}
261
- )
262
  except Exception:
263
  raise HTTPException(status_code=404, detail="Result not found")
264
 
 
 
 
 
 
 
 
265
  # Mount Gradio
 
266
  fastapi_app = mount_gradio_app(fastapi_app, demo, path="/gradio")
267
 
268
  if __name__ == "__main__":
 
1
  import os
 
 
 
2
  import shutil
3
  import uuid
4
+ import cv2
5
+ import numpy as np
6
+ import threading
7
+ import subprocess
8
+ import logging
9
+
10
  import insightface
11
  from insightface.app import FaceAnalysis
12
  from huggingface_hub import hf_hub_download
13
+
 
 
14
  from fastapi import FastAPI, UploadFile, File, HTTPException, Response
15
  from fastapi.responses import RedirectResponse
16
  from pydantic import BaseModel
17
+ from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorGridFSBucket
18
+ from bson import ObjectId
 
19
 
 
20
  import uvicorn
21
+ import gradio as gr
22
+ from gradio import mount_gradio_app
23
 
24
  # -------------------------------------------------
25
  # Logging
 
74
  # Face Analysis + Swapper
75
  # -------------------------------------------------
76
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
 
77
  face_analysis_app = FaceAnalysis(name="buffalo_l", root=MODELS_DIR, providers=providers)
78
  face_analysis_app.prepare(ctx_id=0, det_size=(640, 640))
79
  swapper = insightface.model_zoo.get_model(inswapper_path, providers=providers)
 
80
 
81
  # -------------------------------------------------
82
  # CodeFormer setup
 
85
 
86
  def ensure_codeformer():
87
  if not os.path.exists("CodeFormer"):
 
88
  subprocess.run("git clone https://github.com/sczhou/CodeFormer.git", shell=True, check=True)
89
  subprocess.run("pip install -r CodeFormer/requirements.txt", shell=True, check=True)
90
  subprocess.run("python CodeFormer/basicsr/setup.py develop", shell=True, check=True)
91
  subprocess.run("python CodeFormer/scripts/download_pretrained_models.py facelib", shell=True, check=True)
92
  subprocess.run("python CodeFormer/scripts/download_pretrained_models.py CodeFormer", shell=True, check=True)
 
93
 
94
  ensure_codeformer()
95
 
96
  # -------------------------------------------------
97
+ # MongoDB + GridFS setup
98
  # -------------------------------------------------
99
  MONGODB_URL = os.getenv(
100
  "MONGODB_URL",
101
  "mongodb+srv://harilogicgo_db_user:logicgoinfotech@cluster0.dcs1tnb.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0"
102
  )
 
 
 
103
 
104
+ client: AsyncIOMotorClient = None
105
+ database = None
106
+ fs_bucket: AsyncIOMotorGridFSBucket = None
107
+
108
+ # -------------------------------------------------
109
+ # FastAPI App
110
+ # -------------------------------------------------
111
+ fastapi_app = FastAPI()
112
+
113
+ @fastapi_app.on_event("startup")
114
+ async def startup_db():
115
+ global client, database, fs_bucket
116
+ logger.info("Initializing MongoDB + GridFS...")
117
+ client = AsyncIOMotorClient(MONGODB_URL)
118
+ database = client.FaceSwap
119
+ fs_bucket = AsyncIOMotorGridFSBucket(database)
120
+ logger.info("MongoDB + GridFS initialized")
121
+
122
+ @fastapi_app.on_event("shutdown")
123
+ async def shutdown_db():
124
+ global client
125
+ if client:
126
+ client.close()
127
+ logger.info("MongoDB connection closed")
128
 
129
  # -------------------------------------------------
130
  # Lock for face swap
 
132
  swap_lock = threading.Lock()
133
 
134
  # -------------------------------------------------
135
+ # Pipeline Function
136
  # -------------------------------------------------
137
  def face_swap_and_enhance(src_img, tgt_img):
 
138
  try:
139
  with swap_lock:
140
  shutil.rmtree(UPLOAD_DIR, ignore_errors=True)
 
142
  os.makedirs(UPLOAD_DIR, exist_ok=True)
143
  os.makedirs(RESULT_DIR, exist_ok=True)
144
 
 
 
 
145
  src_bgr = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR)
146
  tgt_bgr = cv2.cvtColor(tgt_img, cv2.COLOR_RGB2BGR)
147
 
148
  src_faces = face_analysis_app.get(src_bgr)
149
  tgt_faces = face_analysis_app.get(tgt_bgr)
150
  if not src_faces or not tgt_faces:
151
+ return None, None, "❌ Face not detected in one of the images"
152
 
153
  swapped_path = os.path.join(UPLOAD_DIR, f"swapped_{uuid.uuid4().hex[:8]}.jpg")
154
  swapped_bgr = swapper.get(tgt_bgr, tgt_faces[0], src_faces[0])
 
176
  return None, None, f"❌ Error: {str(e)}"
177
 
178
  # -------------------------------------------------
179
+ # Gradio UI
180
  # -------------------------------------------------
181
  with gr.Blocks() as demo:
182
  gr.Markdown("Face Swap")
 
197
  btn.click(process, [src_input, tgt_input], [output_img, download, error_box])
198
 
199
  # -------------------------------------------------
200
+ # API Endpoints
201
  # -------------------------------------------------
 
 
202
  @fastapi_app.get("/")
203
  def root():
204
  return RedirectResponse("/gradio")
 
207
  async def health():
208
  return {"status": "healthy"}
209
 
 
210
  @fastapi_app.post("/source")
211
  async def upload_source(image: UploadFile = File(...)):
212
  contents = await image.read()
213
+ file_id = await fs_bucket.upload_from_stream(image.filename, contents, metadata={"type": "source"})
214
  return {"source_id": str(file_id)}
215
 
216
  @fastapi_app.post("/target")
217
  async def upload_target(image: UploadFile = File(...)):
218
  contents = await image.read()
219
+ file_id = await fs_bucket.upload_from_stream(image.filename, contents, metadata={"type": "target"})
220
  return {"target_id": str(file_id)}
221
 
 
222
  class FaceSwapRequest(BaseModel):
223
  source_id: str
224
  target_id: str
 
226
  @fastapi_app.post("/faceswap")
227
  async def perform_faceswap(request: FaceSwapRequest):
228
  try:
229
+ src_stream = await fs_bucket.open_download_stream(ObjectId(request.source_id))
230
+ src_bytes = await src_stream.read()
231
+ tgt_stream = await fs_bucket.open_download_stream(ObjectId(request.target_id))
232
+ tgt_bytes = await tgt_stream.read()
233
+ except Exception:
234
+ raise HTTPException(status_code=404, detail="Source or Target not found")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
+ src_array = np.frombuffer(src_bytes, np.uint8)
237
+ src_bgr = cv2.imdecode(src_array, cv2.IMREAD_COLOR)
238
+ tgt_array = np.frombuffer(tgt_bytes, np.uint8)
239
+ tgt_bgr = cv2.imdecode(tgt_array, cv2.IMREAD_COLOR)
240
+
241
+ if src_bgr is None or tgt_bgr is None:
242
+ raise HTTPException(status_code=400, detail="Invalid image data")
243
+
244
+ src_rgb = cv2.cvtColor(src_bgr, cv2.COLOR_BGR2RGB)
245
+ tgt_rgb = cv2.cvtColor(tgt_bgr, cv2.COLOR_BGR2RGB)
246
+
247
+ final_img, final_path, err = face_swap_and_enhance(src_rgb, tgt_rgb)
248
+ if err:
249
+ raise HTTPException(status_code=500, detail=err)
250
+
251
+ with open(final_path, "rb") as f:
252
+ final_bytes = f.read()
253
+
254
+ result_id = await fs_bucket.upload_from_stream("enhanced.png", final_bytes, metadata={"type": "result"})
255
+ return {"result_id": str(result_id)}
256
 
 
257
  @fastapi_app.get("/download/{result_id}")
258
  async def download_result(result_id: str):
259
  try:
260
  stream = await fs_bucket.open_download_stream(ObjectId(result_id))
261
+ data = await stream.read()
 
 
 
 
 
262
  except Exception:
263
  raise HTTPException(status_code=404, detail="Result not found")
264
 
265
+ return Response(
266
+ content=data,
267
+ media_type="image/png",
268
+ headers={"Content-Disposition": "attachment; filename=result.png"}
269
+ )
270
+
271
+ # -------------------------------------------------
272
  # Mount Gradio
273
+ # -------------------------------------------------
274
  fastapi_app = mount_gradio_app(fastapi_app, demo, path="/gradio")
275
 
276
  if __name__ == "__main__":