narinder1231 commited on
Commit
302d8c7
·
1 Parent(s): 855455a

Add authentication middleware and update auth controller for API key validation

Browse files
src/app.py CHANGED
@@ -6,6 +6,7 @@ from fastapi.middleware.cors import CORSMiddleware
6
 
7
  from src.controllers import api_router, websocket_router
8
  from src.config import logger, DatabaseConfig
 
9
 
10
  @asynccontextmanager
11
  async def lifespan(app: FastAPI):
@@ -37,6 +38,7 @@ app.add_middleware(
37
  allow_headers=["*"],
38
  )
39
 
 
40
 
41
  @app.get("/")
42
  async def check_health():
 
6
 
7
  from src.controllers import api_router, websocket_router
8
  from src.config import logger, DatabaseConfig
9
+ from src.middlewares import AuthenticationMiddleware
10
 
11
  @asynccontextmanager
12
  async def lifespan(app: FastAPI):
 
38
  allow_headers=["*"],
39
  )
40
 
41
+ app.add_middleware(AuthenticationMiddleware)
42
 
43
  @app.get("/")
44
  async def check_health():
src/controllers/__init__.py CHANGED
@@ -1,4 +1,5 @@
1
- from fastapi import APIRouter
 
2
  from ._auth_controller import AuthController
3
  from ._ai_voice_controller import AIVoiceController
4
  from ._ai_text_controller import AITextController
@@ -12,10 +13,12 @@ _ai_voice_controller = AIVoiceController()
12
  _ai_text_controller = AITextController()
13
  _pinecone_controller = PineconeController()
14
 
 
 
15
  websocket_router.include_router(_ai_voice_controller.websocket_router, prefix="/v1")
16
  websocket_router.include_router(_ai_text_controller.websocket_router, prefix="/v1")
17
 
18
- api_router.include_router(_auth_controller.api_router, prefix="/v1")
19
  api_router.include_router(_ai_voice_controller.api_router, prefix="/v1")
20
  api_router.include_router(_ai_text_controller.api_router, prefix="/v1")
21
  api_router.include_router(_pinecone_controller.api_router, prefix="/v1")
 
1
+ from fastapi import APIRouter, Depends
2
+ from fastapi.security import APIKeyHeader
3
  from ._auth_controller import AuthController
4
  from ._ai_voice_controller import AIVoiceController
5
  from ._ai_text_controller import AITextController
 
13
  _ai_text_controller = AITextController()
14
  _pinecone_controller = PineconeController()
15
 
16
+ API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=False)
17
+
18
  websocket_router.include_router(_ai_voice_controller.websocket_router, prefix="/v1")
19
  websocket_router.include_router(_ai_text_controller.websocket_router, prefix="/v1")
20
 
21
+ api_router.include_router(_auth_controller.api_router, prefix="/v1", dependencies=[Depends(API_KEY_HEADER)])
22
  api_router.include_router(_ai_voice_controller.api_router, prefix="/v1")
23
  api_router.include_router(_ai_text_controller.api_router, prefix="/v1")
24
  api_router.include_router(_pinecone_controller.api_router, prefix="/v1")
src/controllers/_auth_controller.py CHANGED
@@ -1,5 +1,4 @@
1
- from fastapi import APIRouter, HTTPException, Depends, status
2
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
3
  from src.services import AuthService
4
  from src.schemas import (
5
  UserSignInSchema,
@@ -51,9 +50,9 @@ class AuthController:
51
  logger.error(f"Error during signin: {e}")
52
  raise HTTPException(status_code=500, detail="Error during signin")
53
 
54
- async def _signout(self, token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
55
- try:
56
- token = token.credentials.split("Bearer ")[1]
57
  await self._auth_service.sign_out(token)
58
  return status.HTTP_200_OK
59
  except HTTPException as e:
 
1
+ from fastapi import APIRouter, HTTPException, status, Request
 
2
  from src.services import AuthService
3
  from src.schemas import (
4
  UserSignInSchema,
 
50
  logger.error(f"Error during signin: {e}")
51
  raise HTTPException(status_code=500, detail="Error during signin")
52
 
53
+ async def _signout(self, request: Request, ):
54
+ try:
55
+ token = request.headers.get("Authorization").split(" ")[1]
56
  await self._auth_service.sign_out(token)
57
  return status.HTTP_200_OK
58
  except HTTPException as e:
src/middlewares/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from ._authentication import AuthenticationMiddleware
2
+
3
+ __all__ = [
4
+ "AuthenticationMiddleware",
5
+ ]
6
+ __version__ = "0.1.0"
7
+ __author__ = "Narinder Singh"
src/middlewares/_authentication.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from fastapi.responses import JSONResponse
3
+ from fastapi import HTTPException, Request
4
+ from starlette.middleware.base import BaseHTTPMiddleware
5
+
6
+ from src.utils import JWTUtil
7
+ from src.services import SessionService
8
+
9
+
10
+ EXEMPT_ROUTES = [
11
+ "/",
12
+ "/docs",
13
+ "/openapi.json",
14
+ "/api/v1/auth/login",
15
+ "/api/v1/auth/register",
16
+ ]
17
+ class AuthenticationMiddleware(BaseHTTPMiddleware):
18
+ def __init__(self, app):
19
+ super().__init__(app)
20
+ self._jwt = JWTUtil()
21
+ self._session_service = SessionService()
22
+
23
+ async def dispatch(self, request: Request, call_next):
24
+ if self._require_auth(request.url.path):
25
+ try:
26
+ user_payload =await self._validate_api_key(request.headers.get("Authorization"))
27
+ request.state.user = user_payload
28
+ except HTTPException as e:
29
+ return JSONResponse(status_code=e.status_code, content=e.detail)
30
+ return await call_next(request)
31
+
32
+ def _require_auth(self, path: str):
33
+ if str(path) in EXEMPT_ROUTES:
34
+ return False
35
+ return True
36
+
37
+ async def _validate_api_key(self, api_key: str):
38
+ if not api_key:
39
+ raise HTTPException(status_code=401, detail="No API key provided")
40
+ if not re.match(r"Bearer .+", api_key):
41
+ raise HTTPException(status_code=401, detail="Invalid API key")
42
+
43
+ token = api_key.split(" ")[1]
44
+
45
+ is_expired = await self._session_service.is_session_expired(token)
46
+ if is_expired:
47
+ raise HTTPException(status_code=401, detail="Token expired")
48
+
49
+ jwt_payload = self._jwt.validate_jwt(token)
50
+ if not jwt_payload:
51
+ raise HTTPException(status_code=401, detail="Invalid token")
52
+ return jwt_payload