File size: 3,140 Bytes
302d8c7
 
 
 
aac8286
302d8c7
 
 
 
 
 
 
 
 
 
78c825a
302d8c7
 
605ed23
14bfd4c
 
605ed23
302d8c7
 
 
 
 
 
 
1fac072
 
 
302d8c7
 
14bfd4c
 
 
302d8c7
 
aac8286
 
a3aa6c1
aac8286
a3aa6c1
aac8286
 
 
 
 
a3aa6c1
aac8286
862aa9a
aac8286
a3aa6c1
aac8286
a3aa6c1
aac8286
 
 
 
 
a3aa6c1
aac8286
a3aa6c1
302d8c7
 
 
14bfd4c
 
 
302d8c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14bfd4c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import re
from fastapi.responses import JSONResponse
from fastapi import HTTPException, Request
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi.middleware.cors import CORSMiddleware

from src.utils import JWTUtil
from src.services import SessionService


EXEMPT_ROUTES = [
    "/",
    "/docs",
    "/openapi.json",
    "/api/v1/auth/login",
    "/api/v1/auth/logout",
    "/api/v1/auth/register",
]

EXEMPT_ROUTE_PATTERNS = []


class AuthenticationMiddleware(BaseHTTPMiddleware):
    def __init__(self, app):
        super().__init__(app)
        self._jwt = JWTUtil()
        self._session_service = SessionService()

    async def dispatch(self, request: Request, call_next):
        if request.method == "OPTIONS":
            return await call_next(request)

        if self._require_auth(request.url.path):
            try:
                user_payload = await self._validate_api_key(
                    request.headers.get("Authorization")
                )
                request.state.user = user_payload
            except HTTPException as e:
                # Create a response with CORS headers
                response = JSONResponse(
                    status_code=e.status_code, content={"detail": e.detail}
                )

                origin = request.headers.get("origin", "*")
                response.headers["Access-Control-Allow-Origin"] = origin
                response.headers["Access-Control-Allow-Credentials"] = "true"
                response.headers["Access-Control-Allow-Methods"] = "*"
                response.headers["Access-Control-Allow-Headers"] = "*"

                return response
            except Exception as e:
                response = JSONResponse(
                    status_code=500, content={"detail": "Internal server error"}
                )

                origin = request.headers.get("origin", "*")
                response.headers["Access-Control-Allow-Origin"] = origin
                response.headers["Access-Control-Allow-Credentials"] = "true"
                response.headers["Access-Control-Allow-Methods"] = "*"
                response.headers["Access-Control-Allow-Headers"] = "*"

                return response

        return await call_next(request)

    def _require_auth(self, path: str):
        if path in EXEMPT_ROUTES or any(
            re.match(pattern, path) for pattern in EXEMPT_ROUTE_PATTERNS
        ):
            return False
        return True

    async def _validate_api_key(self, api_key: str):
        if not api_key:
            raise HTTPException(status_code=401, detail="No API key provided")
        if not re.match(r"Bearer .+", api_key):
            raise HTTPException(status_code=401, detail="Invalid API key")

        token = api_key.split(" ")[1]

        is_expired = await self._session_service.is_session_expired(token)
        if is_expired:
            raise HTTPException(status_code=401, detail="Token expired")

        jwt_payload = self._jwt.validate_jwt(token)
        if not jwt_payload:
            raise HTTPException(status_code=401, detail="Invalid token")
        return jwt_payload