Spaces:
Runtime error
Runtime error
Commit ·
302d8c7
1
Parent(s): 855455a
Add authentication middleware and update auth controller for API key validation
Browse files- src/app.py +2 -0
- src/controllers/__init__.py +5 -2
- src/controllers/_auth_controller.py +4 -5
- src/middlewares/__init__.py +7 -0
- src/middlewares/_authentication.py +52 -0
src/app.py
CHANGED
|
@@ -6,6 +6,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
| 6 |
|
| 7 |
from src.controllers import api_router, websocket_router
|
| 8 |
from src.config import logger, DatabaseConfig
|
|
|
|
| 9 |
|
| 10 |
@asynccontextmanager
|
| 11 |
async def lifespan(app: FastAPI):
|
|
@@ -37,6 +38,7 @@ app.add_middleware(
|
|
| 37 |
allow_headers=["*"],
|
| 38 |
)
|
| 39 |
|
|
|
|
| 40 |
|
| 41 |
@app.get("/")
|
| 42 |
async def check_health():
|
|
|
|
| 6 |
|
| 7 |
from src.controllers import api_router, websocket_router
|
| 8 |
from src.config import logger, DatabaseConfig
|
| 9 |
+
from src.middlewares import AuthenticationMiddleware
|
| 10 |
|
| 11 |
@asynccontextmanager
|
| 12 |
async def lifespan(app: FastAPI):
|
|
|
|
| 38 |
allow_headers=["*"],
|
| 39 |
)
|
| 40 |
|
| 41 |
+
app.add_middleware(AuthenticationMiddleware)
|
| 42 |
|
| 43 |
@app.get("/")
|
| 44 |
async def check_health():
|
src/controllers/__init__.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
-
from fastapi import APIRouter
|
|
|
|
| 2 |
from ._auth_controller import AuthController
|
| 3 |
from ._ai_voice_controller import AIVoiceController
|
| 4 |
from ._ai_text_controller import AITextController
|
|
@@ -12,10 +13,12 @@ _ai_voice_controller = AIVoiceController()
|
|
| 12 |
_ai_text_controller = AITextController()
|
| 13 |
_pinecone_controller = PineconeController()
|
| 14 |
|
|
|
|
|
|
|
| 15 |
websocket_router.include_router(_ai_voice_controller.websocket_router, prefix="/v1")
|
| 16 |
websocket_router.include_router(_ai_text_controller.websocket_router, prefix="/v1")
|
| 17 |
|
| 18 |
-
api_router.include_router(_auth_controller.api_router, prefix="/v1")
|
| 19 |
api_router.include_router(_ai_voice_controller.api_router, prefix="/v1")
|
| 20 |
api_router.include_router(_ai_text_controller.api_router, prefix="/v1")
|
| 21 |
api_router.include_router(_pinecone_controller.api_router, prefix="/v1")
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends
|
| 2 |
+
from fastapi.security import APIKeyHeader
|
| 3 |
from ._auth_controller import AuthController
|
| 4 |
from ._ai_voice_controller import AIVoiceController
|
| 5 |
from ._ai_text_controller import AITextController
|
|
|
|
| 13 |
_ai_text_controller = AITextController()
|
| 14 |
_pinecone_controller = PineconeController()
|
| 15 |
|
| 16 |
+
API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=False)
|
| 17 |
+
|
| 18 |
websocket_router.include_router(_ai_voice_controller.websocket_router, prefix="/v1")
|
| 19 |
websocket_router.include_router(_ai_text_controller.websocket_router, prefix="/v1")
|
| 20 |
|
| 21 |
+
api_router.include_router(_auth_controller.api_router, prefix="/v1", dependencies=[Depends(API_KEY_HEADER)])
|
| 22 |
api_router.include_router(_ai_voice_controller.api_router, prefix="/v1")
|
| 23 |
api_router.include_router(_ai_text_controller.api_router, prefix="/v1")
|
| 24 |
api_router.include_router(_pinecone_controller.api_router, prefix="/v1")
|
src/controllers/_auth_controller.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
-
from fastapi import APIRouter, HTTPException,
|
| 2 |
-
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 3 |
from src.services import AuthService
|
| 4 |
from src.schemas import (
|
| 5 |
UserSignInSchema,
|
|
@@ -51,9 +50,9 @@ class AuthController:
|
|
| 51 |
logger.error(f"Error during signin: {e}")
|
| 52 |
raise HTTPException(status_code=500, detail="Error during signin")
|
| 53 |
|
| 54 |
-
async def _signout(self,
|
| 55 |
-
try:
|
| 56 |
-
token =
|
| 57 |
await self._auth_service.sign_out(token)
|
| 58 |
return status.HTTP_200_OK
|
| 59 |
except HTTPException as e:
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException, status, Request
|
|
|
|
| 2 |
from src.services import AuthService
|
| 3 |
from src.schemas import (
|
| 4 |
UserSignInSchema,
|
|
|
|
| 50 |
logger.error(f"Error during signin: {e}")
|
| 51 |
raise HTTPException(status_code=500, detail="Error during signin")
|
| 52 |
|
| 53 |
+
async def _signout(self, request: Request, ):
|
| 54 |
+
try:
|
| 55 |
+
token = request.headers.get("Authorization").split(" ")[1]
|
| 56 |
await self._auth_service.sign_out(token)
|
| 57 |
return status.HTTP_200_OK
|
| 58 |
except HTTPException as e:
|
src/middlewares/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._authentication import AuthenticationMiddleware
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
"AuthenticationMiddleware",
|
| 5 |
+
]
|
| 6 |
+
__version__ = "0.1.0"
|
| 7 |
+
__author__ = "Narinder Singh"
|
src/middlewares/_authentication.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from fastapi.responses import JSONResponse
|
| 3 |
+
from fastapi import HTTPException, Request
|
| 4 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 5 |
+
|
| 6 |
+
from src.utils import JWTUtil
|
| 7 |
+
from src.services import SessionService
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
EXEMPT_ROUTES = [
|
| 11 |
+
"/",
|
| 12 |
+
"/docs",
|
| 13 |
+
"/openapi.json",
|
| 14 |
+
"/api/v1/auth/login",
|
| 15 |
+
"/api/v1/auth/register",
|
| 16 |
+
]
|
| 17 |
+
class AuthenticationMiddleware(BaseHTTPMiddleware):
|
| 18 |
+
def __init__(self, app):
|
| 19 |
+
super().__init__(app)
|
| 20 |
+
self._jwt = JWTUtil()
|
| 21 |
+
self._session_service = SessionService()
|
| 22 |
+
|
| 23 |
+
async def dispatch(self, request: Request, call_next):
|
| 24 |
+
if self._require_auth(request.url.path):
|
| 25 |
+
try:
|
| 26 |
+
user_payload =await self._validate_api_key(request.headers.get("Authorization"))
|
| 27 |
+
request.state.user = user_payload
|
| 28 |
+
except HTTPException as e:
|
| 29 |
+
return JSONResponse(status_code=e.status_code, content=e.detail)
|
| 30 |
+
return await call_next(request)
|
| 31 |
+
|
| 32 |
+
def _require_auth(self, path: str):
|
| 33 |
+
if str(path) in EXEMPT_ROUTES:
|
| 34 |
+
return False
|
| 35 |
+
return True
|
| 36 |
+
|
| 37 |
+
async def _validate_api_key(self, api_key: str):
|
| 38 |
+
if not api_key:
|
| 39 |
+
raise HTTPException(status_code=401, detail="No API key provided")
|
| 40 |
+
if not re.match(r"Bearer .+", api_key):
|
| 41 |
+
raise HTTPException(status_code=401, detail="Invalid API key")
|
| 42 |
+
|
| 43 |
+
token = api_key.split(" ")[1]
|
| 44 |
+
|
| 45 |
+
is_expired = await self._session_service.is_session_expired(token)
|
| 46 |
+
if is_expired:
|
| 47 |
+
raise HTTPException(status_code=401, detail="Token expired")
|
| 48 |
+
|
| 49 |
+
jwt_payload = self._jwt.validate_jwt(token)
|
| 50 |
+
if not jwt_payload:
|
| 51 |
+
raise HTTPException(status_code=401, detail="Invalid token")
|
| 52 |
+
return jwt_payload
|