File size: 2,979 Bytes
84c328d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""JWT middleware for FastAPI.

[Task]: T012
[From]: specs/001-user-auth/quickstart.md
"""
from typing import Callable
from fastapi import Request, HTTPException, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware

from core.security import JWTManager


class JWTMiddleware(BaseHTTPMiddleware):
    """JWT authentication middleware.

    Validates JWT tokens for all requests except public paths.
    Adds user_id to request.state for downstream dependency injection.
    """

    def __init__(self, app, excluded_paths: list[str] = None):
        """Initialize JWT middleware.

        Args:
            app: FastAPI application instance
            excluded_paths: List of paths to exclude from JWT validation
        """
        super().__init__(app)
        self.excluded_paths = excluded_paths or []
        self.public_paths = [
            "/",
            "/docs",
            "/redoc",
            "/openapi.json",
            "/health",
        ] + self.excluded_paths

    async def dispatch(self, request: Request, call_next: Callable):
        """Process each request with JWT validation.

        Args:
            request: Incoming HTTP request
            call_next: Next middleware or route handler

        Returns:
            HTTP response with JWT validation applied

        Raises:
            HTTPException: If JWT validation fails
        """
        # Skip JWT validation for public paths
        if request.url.path in self.public_paths:
            return await call_next(request)

        # Extract token from Authorization header OR httpOnly cookie
        token = None

        # Try Authorization header first
        authorization = request.headers.get("Authorization")
        if authorization:
            try:
                token = JWTManager.get_token_from_header(authorization)
            except:
                pass  # Fall through to cookie

        # If no token in header, try httpOnly cookie
        if not token:
            auth_token = request.cookies.get("auth_token")
            if auth_token:
                token = auth_token

        # If still no token, return 401
        if not token:
            return JSONResponse(
                status_code=status.HTTP_401_UNAUTHORIZED,
                content={"detail": "Not authenticated"},
                headers={"WWW-Authenticate": "Bearer"},
            )

        try:
            # Verify token and extract user_id
            user_id = JWTManager.get_user_id_from_token(token)

            # Add user_id to request state for route handlers
            request.state.user_id = user_id

            return await call_next(request)

        except HTTPException as e:
            raise e
        except Exception as e:
            return JSONResponse(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                content={"detail": "Internal server error during authentication"},
            )