| """CSRF Protection Middleware.""" |
| import secrets |
| from fastapi import Request |
| from starlette.middleware.base import BaseHTTPMiddleware |
| from starlette.responses import JSONResponse |
|
|
| from app.core.config import settings |
|
|
|
|
| class CSRFMiddleware(BaseHTTPMiddleware): |
| def __init__(self, app, exempt_paths: list[str] | None = None, cookie_name: str = "csrf_token"): |
| super().__init__(app) |
| self.exempt_paths = exempt_paths or [] |
| self.cookie_name = cookie_name |
| self.header_name = "X-CSRF-Token" |
|
|
| async def dispatch(self, request: Request, call_next): |
| |
| if request.method in ("GET", "HEAD", "OPTIONS"): |
| response = await call_next(request) |
| if self.cookie_name not in request.cookies: |
| token = secrets.token_urlsafe(32) |
| response.set_cookie( |
| self.cookie_name, |
| token, |
| httponly=False, |
| secure=settings.ENV == "production", |
| samesite="lax", |
| max_age=86400, |
| path="/", |
| ) |
| return response |
|
|
| |
| for path in self.exempt_paths: |
| if request.url.path.startswith(path): |
| return await call_next(request) |
|
|
| |
| user_agent = request.headers.get("user-agent", "") |
| if user_agent == "node": |
| return await call_next(request) |
|
|
| |
| cookie_token = request.cookies.get(self.cookie_name) |
| header_token = request.headers.get(self.header_name) |
|
|
| if not cookie_token or not header_token: |
| return JSONResponse( |
| status_code=403, |
| content={"error": "CSRF_MISSING", "detail": "CSRF token missing"}, |
| ) |
|
|
| if not secrets.compare_digest(cookie_token, header_token): |
| return JSONResponse( |
| status_code=403, |
| content={"error": "CSRF_INVALID", "detail": "CSRF token invalid"}, |
| ) |
|
|
| return await call_next(request) |