HariLogicgo commited on
Commit
1b22455
·
1 Parent(s): 58c492c

fastapi added

Browse files
Files changed (2) hide show
  1. app.py +191 -1
  2. requirements.txt +5 -1
app.py CHANGED
@@ -15,6 +15,42 @@ from collections import defaultdict
15
  from facexlib.utils.misc import download_from_url
16
  from basicsr.utils.realesrganer import RealESRGANer
17
  from utils.dataops import auto_split_upscale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  input_images_limit = 5
20
  # Define URLs and their corresponding local storage paths
@@ -1298,4 +1334,158 @@ if __name__ == "__main__":
1298
  parser.add_argument("--input_images_limit", type=int, default=5)
1299
  args = parser.parse_args()
1300
  input_images_limit = args.input_images_limit
1301
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  from facexlib.utils.misc import download_from_url
16
  from basicsr.utils.realesrganer import RealESRGANer
17
  from utils.dataops import auto_split_upscale
18
+ from typing import List, Optional
19
+
20
+ # FastAPI imports (API server)
21
+ try:
22
+ from fastapi import FastAPI, UploadFile, File, Form, Depends, HTTPException, status
23
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
24
+ from fastapi.responses import FileResponse, StreamingResponse, JSONResponse
25
+ except Exception:
26
+ # Allow Gradio-only usage without FastAPI installed
27
+ FastAPI = None
28
+ UploadFile = None
29
+ File = None
30
+ Form = None
31
+ Depends = None
32
+ HTTPException = Exception
33
+ status = type("status", (), {"HTTP_401_UNAUTHORIZED": 401, "HTTP_500_INTERNAL_SERVER_ERROR": 500})
34
+ HTTPBearer = None
35
+ HTTPAuthorizationCredentials = None
36
+ FileResponse = None
37
+ StreamingResponse = None
38
+ JSONResponse = None
39
+
40
+ # Mongo imports (optional)
41
+ _mongo_client = None
42
+ _mongo_collection = None
43
+ try:
44
+ from pymongo import MongoClient
45
+ _mongo_uri = os.getenv("MONGODB_URI")
46
+ _mongo_db = os.getenv("MONGODB_DB", "face_upscale")
47
+ _mongo_col = os.getenv("MONGODB_COLLECTION", "submits")
48
+ if _mongo_uri:
49
+ _mongo_client = MongoClient(_mongo_uri, connect=False)
50
+ _mongo_collection = _mongo_client[_mongo_db][_mongo_col]
51
+ except Exception:
52
+ _mongo_client = None
53
+ _mongo_collection = None
54
 
55
  input_images_limit = 5
56
  # Define URLs and their corresponding local storage paths
 
1334
  parser.add_argument("--input_images_limit", type=int, default=5)
1335
  args = parser.parse_args()
1336
  input_images_limit = args.input_images_limit
1337
+ main()
1338
+
1339
+ # ---------------------------- FastAPI Application ----------------------------
1340
+
1341
+ # Expose FastAPI app named `fastapi_app` so it can be served via: uvicorn app:fastapi_app --host 0.0.0.0 --port 8000
1342
+ if FastAPI:
1343
+ fastapi_app = FastAPI(title="Face Upscale API")
1344
+
1345
+ _bearer_scheme = HTTPBearer(auto_error=False) if HTTPBearer else None
1346
+ _api_bearer_token = os.getenv("API_BEARER_TOKEN", "changeme")
1347
+
1348
+ def _verify_bearer_token(credentials: Optional[HTTPAuthorizationCredentials] = Depends(_bearer_scheme) if Depends and _bearer_scheme else None):
1349
+ if not _bearer_scheme:
1350
+ # If FastAPI security is unavailable, deny access
1351
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized")
1352
+ if not credentials or credentials.scheme.lower() != "bearer":
1353
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid auth scheme")
1354
+ token = credentials.credentials
1355
+ if not _api_bearer_token or token != _api_bearer_token:
1356
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
1357
+ return True
1358
+
1359
+ # Defaults aligned with Gradio UI
1360
+ DEFAULT_FACE_MODEL = 'GFPGANv1.4.pth'
1361
+ DEFAULT_UPSCALE_MODEL = 'SRVGG, realesr-general-x4v3.pth'
1362
+ DEFAULT_SCALE = 4.0
1363
+ DEFAULT_FACE_DET = 'retinaface_resnet50'
1364
+ DEFAULT_FACE_DET_THRESHOLD = 10.0
1365
+ DEFAULT_ONLY_CENTER = False
1366
+ DEFAULT_WITH_MODEL_NAME = True
1367
+
1368
+ # Utility to ensure output dir exists
1369
+ os.makedirs('output', exist_ok=True)
1370
+ os.makedirs('input', exist_ok=True)
1371
+
1372
+ # Singleton Upscale instance for API
1373
+ _api_upscale = Upscale()
1374
+
1375
+ @fastapi_app.get("/health")
1376
+ def health(_: bool = Depends(_verify_bearer_token)):
1377
+ return {"status": "ok"}
1378
+
1379
+ @fastapi_app.get("/source")
1380
+ def source(_: bool = Depends(_verify_bearer_token)):
1381
+ return {
1382
+ "face_models": list(face_models.keys()),
1383
+ "upscale_models": list(typed_upscale_models.keys()),
1384
+ "defaults": {
1385
+ "face_model": DEFAULT_FACE_MODEL,
1386
+ "upscale_model": DEFAULT_UPSCALE_MODEL,
1387
+ "scale": DEFAULT_SCALE,
1388
+ "face_detection": DEFAULT_FACE_DET,
1389
+ "face_detection_threshold": DEFAULT_FACE_DET_THRESHOLD,
1390
+ "face_detection_only_center": DEFAULT_ONLY_CENTER,
1391
+ "with_model_name": DEFAULT_WITH_MODEL_NAME,
1392
+ }
1393
+ }
1394
+
1395
+ @fastapi_app.post("/submit")
1396
+ async def submit(
1397
+ files: List[UploadFile] = File(..., description="One or more image files"),
1398
+ face_model: Optional[str] = Form(DEFAULT_FACE_MODEL),
1399
+ upscale_model: Optional[str] = Form(DEFAULT_UPSCALE_MODEL),
1400
+ scale: float = Form(DEFAULT_SCALE),
1401
+ face_detection: str = Form(DEFAULT_FACE_DET),
1402
+ face_detection_threshold: float = Form(DEFAULT_FACE_DET_THRESHOLD),
1403
+ face_detection_only_center: bool = Form(DEFAULT_ONLY_CENTER),
1404
+ with_model_name: bool = Form(DEFAULT_WITH_MODEL_NAME),
1405
+ _: bool = Depends(_verify_bearer_token)
1406
+ ):
1407
+ try:
1408
+ saved_paths = []
1409
+ for f in files:
1410
+ # Save uploaded file to input directory
1411
+ raw_bytes = await f.read()
1412
+ safe_name = os.path.basename(f.filename)
1413
+ save_path = os.path.join('input', f"{int(time.time()*1000)}_{safe_name}")
1414
+ with open(save_path, 'wb') as out:
1415
+ out.write(raw_bytes)
1416
+ saved_paths.append(save_path)
1417
+
1418
+ # Build gallery structure expected by Upscale.inference
1419
+ gallery = [[p, ""] for p in saved_paths]
1420
+
1421
+ # Progress stub for API
1422
+ class _NoProgress:
1423
+ def __call__(self, *args, **kwargs):
1424
+ return None
1425
+ progress_stub = _NoProgress()
1426
+
1427
+ output_files, zip_files = _api_upscale.inference(
1428
+ gallery=gallery,
1429
+ face_restoration=face_model if face_model != "None" else None,
1430
+ upscale_model=upscale_model if upscale_model != "None" else None,
1431
+ scale=float(scale),
1432
+ face_detection=face_detection,
1433
+ face_detection_threshold=face_detection_threshold,
1434
+ face_detection_only_center=face_detection_only_center,
1435
+ outputWithModelName=with_model_name,
1436
+ save_as_png=True,
1437
+ progress=progress_stub,
1438
+ )
1439
+
1440
+ # MongoDB logging
1441
+ if _mongo_collection is not None:
1442
+ try:
1443
+ _mongo_collection.insert_one({
1444
+ "ts": int(time.time()),
1445
+ "files": [os.path.basename(p) for p in saved_paths],
1446
+ "params": {
1447
+ "face_model": face_model,
1448
+ "upscale_model": upscale_model,
1449
+ "scale": scale,
1450
+ "face_detection": face_detection,
1451
+ "face_detection_threshold": face_detection_threshold,
1452
+ "face_detection_only_center": face_detection_only_center,
1453
+ "with_model_name": with_model_name,
1454
+ },
1455
+ "outputs": {
1456
+ "images": [os.path.basename(p) for p in (output_files or [])],
1457
+ "zips": [os.path.basename(p) for p in (zip_files or [])],
1458
+ },
1459
+ })
1460
+ except Exception:
1461
+ pass
1462
+
1463
+ # Build download URLs
1464
+ def _dl_url(name: str) -> str:
1465
+ return f"/download/{name}"
1466
+
1467
+ return {
1468
+ "output_images": [os.path.basename(p) for p in (output_files or [])],
1469
+ "output_zips": [os.path.basename(p) for p in (zip_files or [])],
1470
+ "download_urls": {
1471
+ "images": [_dl_url(os.path.basename(p)) for p in (output_files or [])],
1472
+ "zips": [_dl_url(os.path.basename(p)) for p in (zip_files or [])],
1473
+ },
1474
+ }
1475
+ except HTTPException:
1476
+ raise
1477
+ except Exception as e:
1478
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
1479
+
1480
+ @fastapi_app.get("/download/{filename}")
1481
+ def download(filename: str, _: bool = Depends(_verify_bearer_token)):
1482
+ safe_name = os.path.basename(filename)
1483
+ # Prefer output dir; fallback to input if requested file is an original
1484
+ candidate_paths = [
1485
+ os.path.join('output', safe_name),
1486
+ os.path.join('input', safe_name),
1487
+ ]
1488
+ for p in candidate_paths:
1489
+ if os.path.isfile(p):
1490
+ return FileResponse(p, filename=safe_name)
1491
+ raise HTTPException(status_code=404, detail="File not found")
requirements.txt CHANGED
@@ -19,4 +19,8 @@ pyyaml
19
  yapf
20
 
21
  image_gen_aux @ git+https://github.com/huggingface/image_gen_aux
22
- gdown # supports downloading the large file from Google Drive
 
 
 
 
 
19
  yapf
20
 
21
  image_gen_aux @ git+https://github.com/huggingface/image_gen_aux
22
+ gdown # supports downloading the large file from Google Drive
23
+ fastapi
24
+ uvicorn[standard]
25
+ pymongo
26
+ python-multipart