LogicGoInfotechSpaces commited on
Commit
b60a4ed
·
verified ·
1 Parent(s): 0b138e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -112
app.py CHANGED
@@ -14,9 +14,8 @@ import insightface
14
  from insightface.app import FaceAnalysis
15
  from huggingface_hub import hf_hub_download
16
 
17
- from fastapi import FastAPI, UploadFile, File, HTTPException, Response, Depends, Security, Query
18
  from fastapi.responses import RedirectResponse
19
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
20
  from pydantic import BaseModel
21
  from motor.motor_asyncio import AsyncIOMotorClient
22
 
@@ -28,7 +27,9 @@ from gradio import mount_gradio_app
28
  import boto3
29
  from botocore.client import Config
30
  from io import BytesIO
31
- from typing import Optional
 
 
32
  # --------------------- Logging ---------------------
33
  logging.basicConfig(level=logging.INFO)
34
  logger = logging.getLogger(__name__)
@@ -42,13 +43,40 @@ os.makedirs(MODELS_DIR, exist_ok=True)
42
 
43
  # --------------------- Secrets ---------------------
44
  HF_TOKEN = os.getenv("HF_TOKEN") # Hugging Face private repo token
45
- API_SECRET_TOKEN = os.getenv("API_SECRET_TOKEN") # Bearer token for API
 
 
46
  # --------------------- DigitalOcean Spaces Credentials ---------------------
47
- DO_SPACES_REGION = os.getenv("DO_SPACES_REGION", "blr1") # Default region = Bangalore
48
  DO_SPACES_ENDPOINT = os.getenv("DO_SPACES_ENDPOINT", f"https://{DO_SPACES_REGION}.digitaloceanspaces.com")
49
- DO_SPACES_KEY = os.getenv("DO_SPACES_KEY") # Your Access Key
50
- DO_SPACES_SECRET = os.getenv("DO_SPACES_SECRET") # Your Secret Key
51
- DO_SPACES_BUCKET = os.getenv("DO_SPACES_BUCKET") # Bucket Name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  # --------------------- Download Models ---------------------
54
  def download_models():
@@ -100,7 +128,7 @@ def ensure_codeformer():
100
 
101
  ensure_codeformer()
102
 
103
- # --------------------- MongoDB (for API logs only) ---------------------
104
  MONGODB_URL = os.getenv("MONGODB_URL")
105
 
106
  client = None
@@ -124,22 +152,14 @@ async def shutdown_db():
124
  client.close()
125
  logger.info("MongoDB connection closed")
126
 
127
- # --------------------- Auth ---------------------
128
- security = HTTPBearer()
129
-
130
- def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)):
131
- if credentials.credentials != API_SECRET_TOKEN:
132
- raise HTTPException(status_code=401, detail="Invalid or missing token")
133
- return credentials.credentials
134
-
135
  # --------------------- Logging API Hits ---------------------
136
- async def log_faceswap_hit(token: str, status: str = "success"):
137
  global database
138
  if database is None:
139
  return
140
  await database.api_logs.insert_one({
141
- "token": token,
142
- "endpoint": "/faceswap",
143
  "status": status,
144
  "timestamp": datetime.utcnow()
145
  })
@@ -150,7 +170,6 @@ swap_lock = threading.Lock()
150
  def face_swap_and_enhance(src_img, tgt_img, temp_dir="/tmp/faceswap_work"):
151
  try:
152
  with swap_lock:
153
- # Prepare temporary directory
154
  if os.path.exists(temp_dir):
155
  shutil.rmtree(temp_dir)
156
  os.makedirs(temp_dir, exist_ok=True)
@@ -164,71 +183,16 @@ def face_swap_and_enhance(src_img, tgt_img, temp_dir="/tmp/faceswap_work"):
164
  if not src_faces or not tgt_faces:
165
  return None, None, "❌ Face not detected in source or target image"
166
 
167
- def expand_bbox(bbox, img_shape, scale=1.6):
168
- ih, iw = img_shape[:2]
169
- x1, y1, x2, y2 = map(int, bbox)
170
- w, h = x2 - x1, y2 - y1
171
- cx, cy = x1 + w // 2, y1 + h // 2
172
- new_w, new_h = int(w * scale), int(h * scale)
173
- nx1 = max(0, cx - new_w // 2)
174
- ny1 = max(0, cy - new_h // 2)
175
- nx2 = min(iw, cx + new_w // 2)
176
- ny2 = min(ih, cy + new_h // 2)
177
- return nx1, ny1, nx2, ny2
178
-
179
  src_face0 = src_faces[0]
180
  tgt_face0 = tgt_faces[0]
181
 
182
- # More accurate source face crop with slight expansion
183
- s_x1, s_y1, s_x2, s_y2 = expand_bbox(src_face0.bbox, src_bgr.shape, scale=1.4)
184
- src_crop = src_bgr[s_y1:s_y2, s_x1:s_x2]
185
- src_crop_faces = face_analysis_app.get(src_crop)
186
- if src_crop_faces:
187
- src_for_swap = src_crop
188
- src_face_for_swap = src_crop_faces[0]
189
- else:
190
- src_for_swap = src_bgr
191
- src_face_for_swap = src_face0
192
-
193
- # More aggressive target crop for precise landmark detection
194
- t_x1, t_y1, t_x2, t_y2 = expand_bbox(tgt_face0.bbox, tgt_bgr_full.shape, scale=1.6)
195
- tgt_crop = tgt_bgr_full[t_y1:t_y2, t_x1:t_x2]
196
- tgt_crop_faces = face_analysis_app.get(tgt_crop)
197
-
198
- if tgt_crop_faces:
199
- tgt_for_swap = tgt_crop
200
- tgt_face_for_swap = tgt_crop_faces[0]
201
-
202
- swapped_crop = swapper.get(tgt_for_swap, tgt_face_for_swap, src_face_for_swap)
203
- if swapped_crop is None:
204
- return None, None, "❌ Face swap failed on crop"
205
-
206
- # Create mask with threshold for seamlessClone
207
- mask = cv2.cvtColor(swapped_crop, cv2.COLOR_BGR2GRAY)
208
- _, mask = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)
209
-
210
- center = ((t_x1 + t_x2) // 2, (t_y1 + t_y2) // 2)
211
-
212
- try:
213
- blended = cv2.seamlessClone(swapped_crop, tgt_bgr_full, mask, center, cv2.NORMAL_CLONE)
214
- except Exception:
215
- # Fallback to direct paste if seamlessClone fails
216
- blended = tgt_bgr_full.copy()
217
- h, w = swapped_crop.shape[:2]
218
- blended[t_y1:t_y1+h, t_x1:t_x1+w] = swapped_crop
219
-
220
- swapped_path = os.path.join(temp_dir, f"swapped_{uuid.uuid4().hex[:8]}.jpg")
221
- cv2.imwrite(swapped_path, blended)
222
-
223
- else:
224
- # Fallback: swap on full image if crop detection fails
225
- swapped_bgr_full = swapper.get(tgt_bgr_full, tgt_face0, src_face0)
226
- if swapped_bgr_full is None:
227
- return None, None, "❌ Face swap failed on full image"
228
- swapped_path = os.path.join(temp_dir, f"swapped_{uuid.uuid4().hex[:8]}.jpg")
229
- cv2.imwrite(swapped_path, swapped_bgr_full)
230
-
231
- # Run CodeFormer enhancement on the swapped image
232
  cmd = f"python {CODEFORMER_PATH} -w 0.7 --input_path {swapped_path} --output_path {temp_dir} --bg_upsampler realesrgan --face_upsample"
233
  result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
234
  if result.returncode != 0:
@@ -247,8 +211,6 @@ def face_swap_and_enhance(src_img, tgt_img, temp_dir="/tmp/faceswap_work"):
247
  except Exception as e:
248
  return None, None, f"❌ Error: {str(e)}"
249
 
250
-
251
-
252
  # --------------------- Gradio ---------------------
253
  with gr.Blocks() as demo:
254
  gr.Markdown("Face Swap")
@@ -300,72 +262,55 @@ def root():
300
  async def health():
301
  return {"status": "healthy"}
302
 
303
- from fastapi import Form
304
- import requests
305
-
306
- @fastapi_app.post("/face-swap", dependencies=[Depends(verify_token)])
307
  async def face_swap_api(
308
  source: UploadFile = File(...),
309
  target_category_id: str = Form(...),
310
- credentials: HTTPAuthorizationCredentials = Security(security)
311
  ):
312
  try:
313
- # Read source image
314
  src_bytes = await source.read()
315
-
316
- # Save source image to Spaces under bikini-theme/source/
317
  src_key = f"bikini-theme/source/{uuid.uuid4().hex}_{source.filename}"
318
  upload_to_spaces(src_bytes, src_key, content_type=source.content_type)
319
 
320
- # Build target image URL directly (no category scanning)
321
  target_filename = f"{target_category_id}.jpg"
322
  target_url = f"https://{DO_SPACES_BUCKET}.{DO_SPACES_REGION}.digitaloceanspaces.com/bikini-theme/target/{target_filename}"
323
 
324
- # Try to fetch the target image
325
  resp = requests.get(target_url)
326
  if resp.status_code != 200:
327
- await log_faceswap_hit(credentials.credentials, status="error")
328
- raise HTTPException(
329
- status_code=404,
330
- detail=f"Target image not found at {target_url}"
331
- )
332
 
333
  tgt_bytes = resp.content
334
 
335
- # Decode source and target images
336
  src_array = np.frombuffer(src_bytes, np.uint8)
337
  tgt_array = np.frombuffer(tgt_bytes, np.uint8)
338
  src_bgr = cv2.imdecode(src_array, cv2.IMREAD_COLOR)
339
  tgt_bgr = cv2.imdecode(tgt_array, cv2.IMREAD_COLOR)
340
 
341
  if src_bgr is None or tgt_bgr is None:
342
- await log_faceswap_hit(credentials.credentials, status="error")
343
  raise HTTPException(status_code=400, detail="Invalid image data")
344
 
345
  src_rgb = cv2.cvtColor(src_bgr, cv2.COLOR_BGR2RGB)
346
  tgt_rgb = cv2.cvtColor(tgt_bgr, cv2.COLOR_BGR2RGB)
347
 
348
- # Run face swap and enhancement
349
  final_img, final_path, err = face_swap_and_enhance(src_rgb, tgt_rgb)
350
  if err:
351
- await log_faceswap_hit(credentials.credentials, status="error")
352
  raise HTTPException(status_code=500, detail=err)
353
 
354
- # Upload the result back to bikini-theme/result/
355
  with open(final_path, "rb") as f:
356
  result_bytes = f.read()
357
  result_key = f"bikini-theme/result/{uuid.uuid4().hex}_enhanced.png"
358
  result_url = upload_to_spaces(result_bytes, result_key, content_type="image/png")
359
 
360
- # Log success in MongoDB
361
- await log_faceswap_hit(credentials.credentials, status="success")
362
 
363
- return {
364
- "result_url": result_url
365
- }
366
 
367
  except Exception as e:
368
- await log_faceswap_hit(credentials.credentials, status="error")
369
  raise HTTPException(status_code=500, detail=f"Face swap failed: {str(e)}")
370
 
371
  @fastapi_app.get("/preview/{result_key:path}")
@@ -379,8 +324,9 @@ async def preview_result(result_key: str):
379
  media_type="image/png",
380
  headers={"Content-Disposition": "inline; filename=result.png"}
381
  )
 
382
  # --------------------- Mount Gradio ---------------------
383
  fastapi_app = mount_gradio_app(fastapi_app, demo, path="/gradio")
384
 
385
  if __name__ == "__main__":
386
- uvicorn.run(fastapi_app, host="0.0.0.0", port=7860)
 
14
  from insightface.app import FaceAnalysis
15
  from huggingface_hub import hf_hub_download
16
 
17
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Response, Depends, Security, Form
18
  from fastapi.responses import RedirectResponse
 
19
  from pydantic import BaseModel
20
  from motor.motor_asyncio import AsyncIOMotorClient
21
 
 
27
  import boto3
28
  from botocore.client import Config
29
  from io import BytesIO
30
+ from typing import Optional
31
+ import requests
32
+
33
  # --------------------- Logging ---------------------
34
  logging.basicConfig(level=logging.INFO)
35
  logger = logging.getLogger(__name__)
 
43
 
44
  # --------------------- Secrets ---------------------
45
  HF_TOKEN = os.getenv("HF_TOKEN") # Hugging Face private repo token
46
+ # Firebase credentials JSON
47
+ FIREBASE_CREDENTIALS_PATH = os.getenv("FIREBASE_CREDENTIALS_PATH")
48
+
49
  # --------------------- DigitalOcean Spaces Credentials ---------------------
50
+ DO_SPACES_REGION = os.getenv("DO_SPACES_REGION", "blr1")
51
  DO_SPACES_ENDPOINT = os.getenv("DO_SPACES_ENDPOINT", f"https://{DO_SPACES_REGION}.digitaloceanspaces.com")
52
+ DO_SPACES_KEY = os.getenv("DO_SPACES_KEY")
53
+ DO_SPACES_SECRET = os.getenv("DO_SPACES_SECRET")
54
+ DO_SPACES_BUCKET = os.getenv("DO_SPACES_BUCKET")
55
+
56
+ # --------------------- Firebase Auth ---------------------
57
+ import firebase_admin
58
+ from firebase_admin import credentials, auth
59
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
60
+
61
+ if not firebase_admin._apps:
62
+ cred = credentials.Certificate(FIREBASE_CREDENTIALS_PATH)
63
+ firebase_admin.initialize_app(cred)
64
+ logger.info("✅ Firebase initialized successfully")
65
+
66
+ security = HTTPBearer()
67
+
68
+ def verify_firebase_token(credentials: HTTPAuthorizationCredentials = Security(security)):
69
+ """Verify Firebase ID token from Authorization header."""
70
+ try:
71
+ id_token = credentials.credentials
72
+ decoded_token = auth.verify_id_token(id_token)
73
+ user_email = decoded_token.get("email")
74
+ if not user_email:
75
+ raise HTTPException(status_code=401, detail="Firebase token invalid or missing email")
76
+ return user_email
77
+ except Exception as e:
78
+ logger.error(f"Firebase auth failed: {e}")
79
+ raise HTTPException(status_code=401, detail="Unauthorized: Invalid Firebase token")
80
 
81
  # --------------------- Download Models ---------------------
82
  def download_models():
 
128
 
129
  ensure_codeformer()
130
 
131
+ # --------------------- MongoDB ---------------------
132
  MONGODB_URL = os.getenv("MONGODB_URL")
133
 
134
  client = None
 
152
  client.close()
153
  logger.info("MongoDB connection closed")
154
 
 
 
 
 
 
 
 
 
155
  # --------------------- Logging API Hits ---------------------
156
+ async def log_faceswap_hit(user_email: str, status: str = "success"):
157
  global database
158
  if database is None:
159
  return
160
  await database.api_logs.insert_one({
161
+ "user": user_email,
162
+ "endpoint": "/face-swap",
163
  "status": status,
164
  "timestamp": datetime.utcnow()
165
  })
 
170
  def face_swap_and_enhance(src_img, tgt_img, temp_dir="/tmp/faceswap_work"):
171
  try:
172
  with swap_lock:
 
173
  if os.path.exists(temp_dir):
174
  shutil.rmtree(temp_dir)
175
  os.makedirs(temp_dir, exist_ok=True)
 
183
  if not src_faces or not tgt_faces:
184
  return None, None, "❌ Face not detected in source or target image"
185
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  src_face0 = src_faces[0]
187
  tgt_face0 = tgt_faces[0]
188
 
189
+ swapped_bgr_full = swapper.get(tgt_bgr_full, tgt_face0, src_face0)
190
+ if swapped_bgr_full is None:
191
+ return None, None, "❌ Face swap failed"
192
+
193
+ swapped_path = os.path.join(temp_dir, f"swapped_{uuid.uuid4().hex[:8]}.jpg")
194
+ cv2.imwrite(swapped_path, swapped_bgr_full)
195
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  cmd = f"python {CODEFORMER_PATH} -w 0.7 --input_path {swapped_path} --output_path {temp_dir} --bg_upsampler realesrgan --face_upsample"
197
  result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
198
  if result.returncode != 0:
 
211
  except Exception as e:
212
  return None, None, f"❌ Error: {str(e)}"
213
 
 
 
214
  # --------------------- Gradio ---------------------
215
  with gr.Blocks() as demo:
216
  gr.Markdown("Face Swap")
 
262
  async def health():
263
  return {"status": "healthy"}
264
 
265
+ @fastapi_app.post("/face-swap")
 
 
 
266
  async def face_swap_api(
267
  source: UploadFile = File(...),
268
  target_category_id: str = Form(...),
269
+ user_email: str = Depends(verify_firebase_token)
270
  ):
271
  try:
 
272
  src_bytes = await source.read()
 
 
273
  src_key = f"bikini-theme/source/{uuid.uuid4().hex}_{source.filename}"
274
  upload_to_spaces(src_bytes, src_key, content_type=source.content_type)
275
 
 
276
  target_filename = f"{target_category_id}.jpg"
277
  target_url = f"https://{DO_SPACES_BUCKET}.{DO_SPACES_REGION}.digitaloceanspaces.com/bikini-theme/target/{target_filename}"
278
 
 
279
  resp = requests.get(target_url)
280
  if resp.status_code != 200:
281
+ await log_faceswap_hit(user_email, status="error")
282
+ raise HTTPException(status_code=404, detail=f"Target image not found at {target_url}")
 
 
 
283
 
284
  tgt_bytes = resp.content
285
 
 
286
  src_array = np.frombuffer(src_bytes, np.uint8)
287
  tgt_array = np.frombuffer(tgt_bytes, np.uint8)
288
  src_bgr = cv2.imdecode(src_array, cv2.IMREAD_COLOR)
289
  tgt_bgr = cv2.imdecode(tgt_array, cv2.IMREAD_COLOR)
290
 
291
  if src_bgr is None or tgt_bgr is None:
292
+ await log_faceswap_hit(user_email, status="error")
293
  raise HTTPException(status_code=400, detail="Invalid image data")
294
 
295
  src_rgb = cv2.cvtColor(src_bgr, cv2.COLOR_BGR2RGB)
296
  tgt_rgb = cv2.cvtColor(tgt_bgr, cv2.COLOR_BGR2RGB)
297
 
 
298
  final_img, final_path, err = face_swap_and_enhance(src_rgb, tgt_rgb)
299
  if err:
300
+ await log_faceswap_hit(user_email, status="error")
301
  raise HTTPException(status_code=500, detail=err)
302
 
 
303
  with open(final_path, "rb") as f:
304
  result_bytes = f.read()
305
  result_key = f"bikini-theme/result/{uuid.uuid4().hex}_enhanced.png"
306
  result_url = upload_to_spaces(result_bytes, result_key, content_type="image/png")
307
 
308
+ await log_faceswap_hit(user_email, status="success")
 
309
 
310
+ return {"result_url": result_url}
 
 
311
 
312
  except Exception as e:
313
+ await log_faceswap_hit(user_email, status="error")
314
  raise HTTPException(status_code=500, detail=f"Face swap failed: {str(e)}")
315
 
316
  @fastapi_app.get("/preview/{result_key:path}")
 
324
  media_type="image/png",
325
  headers={"Content-Disposition": "inline; filename=result.png"}
326
  )
327
+
328
  # --------------------- Mount Gradio ---------------------
329
  fastapi_app = mount_gradio_app(fastapi_app, demo, path="/gradio")
330
 
331
  if __name__ == "__main__":
332
+ uvicorn.run(fastapi_app, host="0.0.0.0", port=7860)