yasser5711 commited on
Commit
a10afeb
·
verified ·
1 Parent(s): ddc1b06

Secure /predict with Bearer API key auth

Browse files
Files changed (1) hide show
  1. server.py +23 -2
server.py CHANGED
@@ -1,11 +1,16 @@
1
  from __future__ import annotations
2
 
3
- from fastapi import FastAPI
 
 
4
  from fastapi.middleware.cors import CORSMiddleware
 
5
  from pydantic import BaseModel
6
 
7
  from inference.predict import predict
8
 
 
 
9
  app = FastAPI(title="M2Predict API")
10
 
11
  app.add_middleware(
@@ -15,6 +20,18 @@ app.add_middleware(
15
  allow_headers=["*"],
16
  )
17
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  class PredictRequest(BaseModel):
20
  code_postal: str
@@ -24,6 +41,10 @@ class PredictRequest(BaseModel):
24
 
25
 
26
  @app.post("/predict")
27
- def predict_endpoint(req: PredictRequest, model_version: str = "v1_rf_te"):
 
 
 
 
28
  result = predict(req.model_dump(), model_version=model_version)
29
  return result
 
1
  from __future__ import annotations
2
 
3
+ import os
4
+
5
+ from fastapi import Depends, FastAPI, HTTPException, Security
6
  from fastapi.middleware.cors import CORSMiddleware
7
+ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
8
  from pydantic import BaseModel
9
 
10
  from inference.predict import predict
11
 
12
+ API_KEY = os.environ.get("API_KEY", "")
13
+
14
  app = FastAPI(title="M2Predict API")
15
 
16
  app.add_middleware(
 
20
  allow_headers=["*"],
21
  )
22
 
23
+ security = HTTPBearer()
24
+
25
+
26
+ def verify_api_key(
27
+ credentials: HTTPAuthorizationCredentials = Security(security),
28
+ ) -> str:
29
+ if not API_KEY:
30
+ raise HTTPException(status_code=500, detail="API_KEY not configured")
31
+ if credentials.credentials != API_KEY:
32
+ raise HTTPException(status_code=403, detail="Invalid API key")
33
+ return credentials.credentials
34
+
35
 
36
  class PredictRequest(BaseModel):
37
  code_postal: str
 
41
 
42
 
43
  @app.post("/predict")
44
+ def predict_endpoint(
45
+ req: PredictRequest,
46
+ model_version: str = "v1_rf_te",
47
+ _key: str = Depends(verify_api_key),
48
+ ):
49
  result = predict(req.model_dump(), model_version=model_version)
50
  return result