HariLogicgo commited on
Commit
beaf86b
·
1 Parent(s): 033a45d

modified app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -15
app.py CHANGED
@@ -12,9 +12,8 @@ import insightface
12
  from insightface.app import FaceAnalysis
13
  from huggingface_hub import hf_hub_download
14
 
15
- from fastapi import FastAPI, UploadFile, File, HTTPException, Response, Depends, Security
16
  from fastapi.responses import RedirectResponse
17
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
18
  from pydantic import BaseModel
19
  from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorGridFSBucket
20
  from bson import ObjectId
@@ -129,15 +128,29 @@ async def shutdown_db():
129
  logger.info("MongoDB connection closed")
130
 
131
  # -------------------------------------------------
132
- # Auth Setup
133
  # -------------------------------------------------
134
- security = HTTPBearer()
 
135
  API_SECRET_TOKEN = os.getenv("API_SECRET_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
- def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)):
138
- if credentials.credentials != API_SECRET_TOKEN:
139
- raise HTTPException(status_code=401, detail="Invalid or missing token")
140
- return credentials.credentials
141
 
142
  # -------------------------------------------------
143
  # Lock for face swap
@@ -221,13 +234,13 @@ async def health():
221
  return {"status": "healthy"}
222
 
223
  @fastapi_app.post("/source")
224
- async def upload_source(image: UploadFile = File(...), token: str = Depends(verify_token)):
225
  contents = await image.read()
226
  file_id = await fs_bucket.upload_from_stream(image.filename, contents, metadata={"type": "source"})
227
  return {"source_id": str(file_id)}
228
 
229
  @fastapi_app.get("/targets")
230
- async def list_targets(token: str = Depends(verify_token)):
231
  files = []
232
  async for file in database.fs.files.find({"metadata.type": "target", "metadata.predefined": True}):
233
  files.append({
@@ -237,7 +250,7 @@ async def list_targets(token: str = Depends(verify_token)):
237
  return {"targets": files}
238
 
239
  @fastapi_app.post("/target")
240
- async def upload_target(image: UploadFile = File(...), token: str = Depends(verify_token)):
241
  contents = await image.read()
242
  file_id = await fs_bucket.upload_from_stream(image.filename, contents, metadata={"type": "target"})
243
  return {"target_id": str(file_id)}
@@ -247,7 +260,7 @@ class FaceSwapRequest(BaseModel):
247
  target_id: str
248
 
249
  @fastapi_app.post("/faceswap")
250
- async def perform_faceswap(request: FaceSwapRequest, token: str = Depends(verify_token)):
251
  try:
252
  src_stream = await fs_bucket.open_download_stream(ObjectId(request.source_id))
253
  src_bytes = await src_stream.read()
@@ -278,7 +291,7 @@ async def perform_faceswap(request: FaceSwapRequest, token: str = Depends(verify
278
  return {"result_id": str(result_id)}
279
 
280
  @fastapi_app.get("/download/{result_id}")
281
- async def download_result(result_id: str, token: str = Depends(verify_token)):
282
  try:
283
  stream = await fs_bucket.open_download_stream(ObjectId(result_id))
284
  data = await stream.read()
@@ -292,9 +305,9 @@ async def download_result(result_id: str, token: str = Depends(verify_token)):
292
  )
293
 
294
  # -------------------------------------------------
295
- # Mount Gradio
296
  # -------------------------------------------------
297
- fastapi_app = mount_gradio_app(fastapi_app, demo, path="/gradio")
298
 
299
  if __name__ == "__main__":
300
  uvicorn.run(fastapi_app, host="0.0.0.0", port=7860)
 
12
  from insightface.app import FaceAnalysis
13
  from huggingface_hub import hf_hub_download
14
 
15
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Request, Response
16
  from fastapi.responses import RedirectResponse
 
17
  from pydantic import BaseModel
18
  from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorGridFSBucket
19
  from bson import ObjectId
 
128
  logger.info("MongoDB connection closed")
129
 
130
  # -------------------------------------------------
131
+ # 🔐 Global Auth Middleware
132
  # -------------------------------------------------
133
+ PUBLIC_PATHS = ["/", "/health", "/gradio"]
134
+
135
  API_SECRET_TOKEN = os.getenv("API_SECRET_TOKEN")
136
+ if not API_SECRET_TOKEN:
137
+ raise RuntimeError("❌ API_SECRET_TOKEN not set in environment!")
138
+
139
+ @fastapi_app.middleware("http")
140
+ async def auth_middleware(request: Request, call_next):
141
+ path = request.url.path
142
+ if any(path.startswith(p) for p in PUBLIC_PATHS):
143
+ return await call_next(request)
144
+
145
+ auth_header = request.headers.get("Authorization")
146
+ if not auth_header or not auth_header.startswith("Bearer "):
147
+ return Response(content="Missing Bearer token", status_code=401)
148
+
149
+ token = auth_header.split("Bearer ")[1]
150
+ if token != API_SECRET_TOKEN:
151
+ return Response(content="Invalid token", status_code=403)
152
 
153
+ return await call_next(request)
 
 
 
154
 
155
  # -------------------------------------------------
156
  # Lock for face swap
 
234
  return {"status": "healthy"}
235
 
236
  @fastapi_app.post("/source")
237
+ async def upload_source(image: UploadFile = File(...)):
238
  contents = await image.read()
239
  file_id = await fs_bucket.upload_from_stream(image.filename, contents, metadata={"type": "source"})
240
  return {"source_id": str(file_id)}
241
 
242
  @fastapi_app.get("/targets")
243
+ async def list_targets():
244
  files = []
245
  async for file in database.fs.files.find({"metadata.type": "target", "metadata.predefined": True}):
246
  files.append({
 
250
  return {"targets": files}
251
 
252
  @fastapi_app.post("/target")
253
+ async def upload_target(image: UploadFile = File(...)):
254
  contents = await image.read()
255
  file_id = await fs_bucket.upload_from_stream(image.filename, contents, metadata={"type": "target"})
256
  return {"target_id": str(file_id)}
 
260
  target_id: str
261
 
262
  @fastapi_app.post("/faceswap")
263
+ async def perform_faceswap(request: FaceSwapRequest):
264
  try:
265
  src_stream = await fs_bucket.open_download_stream(ObjectId(request.source_id))
266
  src_bytes = await src_stream.read()
 
291
  return {"result_id": str(result_id)}
292
 
293
  @fastapi_app.get("/download/{result_id}")
294
+ async def download_result(result_id: str):
295
  try:
296
  stream = await fs_bucket.open_download_stream(ObjectId(result_id))
297
  data = await stream.read()
 
305
  )
306
 
307
  # -------------------------------------------------
308
+ # Mount Gradio (last, without overwriting app)
309
  # -------------------------------------------------
310
+ mount_gradio_app(fastapi_app, demo, path="/gradio")
311
 
312
  if __name__ == "__main__":
313
  uvicorn.run(fastapi_app, host="0.0.0.0", port=7860)