Spaces:
Running
Running
| import os | |
| from typing import Optional | |
| from fastapi import HTTPException, Request | |
| from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| from starlette.responses import JSONResponse | |
| class PasswordAuthMiddleware(BaseHTTPMiddleware): | |
| """ | |
| Middleware to check password authentication for all API requests. | |
| Only active when OPEN_NOTEBOOK_PASSWORD environment variable is set. | |
| """ | |
| def __init__(self, app, excluded_paths: Optional[list] = None): | |
| super().__init__(app) | |
| self.password = os.environ.get("OPEN_NOTEBOOK_PASSWORD") | |
| self.excluded_paths = excluded_paths or ["/", "/health", "/docs", "/openapi.json", "/redoc"] | |
| async def dispatch(self, request: Request, call_next): | |
| # Skip authentication if no password is set | |
| if not self.password: | |
| return await call_next(request) | |
| # Skip authentication for excluded paths | |
| if request.url.path in self.excluded_paths: | |
| return await call_next(request) | |
| # Skip authentication for CORS preflight requests (OPTIONS) | |
| if request.method == "OPTIONS": | |
| return await call_next(request) | |
| # Check authorization header | |
| auth_header = request.headers.get("Authorization") | |
| if not auth_header: | |
| return JSONResponse( | |
| status_code=401, | |
| content={"detail": "Missing authorization header"}, | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| # Expected format: "Bearer {password}" | |
| try: | |
| scheme, credentials = auth_header.split(" ", 1) | |
| if scheme.lower() != "bearer": | |
| raise ValueError("Invalid authentication scheme") | |
| except ValueError: | |
| return JSONResponse( | |
| status_code=401, | |
| content={"detail": "Invalid authorization header format"}, | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| # Check password | |
| if credentials != self.password: | |
| return JSONResponse( | |
| status_code=401, | |
| content={"detail": "Invalid password"}, | |
| headers={"WWW-Authenticate": "Bearer"} | |
| ) | |
| # Password is correct, proceed with the request | |
| response = await call_next(request) | |
| return response | |
| # Optional: HTTPBearer security scheme for OpenAPI documentation | |
| security = HTTPBearer(auto_error=False) | |
| def check_api_password(credentials: Optional[HTTPAuthorizationCredentials] = None) -> bool: | |
| """ | |
| Utility function to check API password. | |
| Can be used as a dependency in individual routes if needed. | |
| """ | |
| password = os.environ.get("OPEN_NOTEBOOK_PASSWORD") | |
| # No password set, allow access | |
| if not password: | |
| return True | |
| # No credentials provided | |
| if not credentials: | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Missing authorization", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| # Check password | |
| if credentials.credentials != password: | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Invalid password", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| return True |