from functools import lru_cache from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from src.api.services.auth_service import AuthService from src.api.models.user import User from src.api.dependencies.clients import get_db from src.api.selectors.user.get_user import get_user_by_email from typing import Optional, Annotated import redis.asyncio as redis from src.api.dependencies.clients import get_redis @lru_cache() def get_auth_service() -> AuthService: return AuthService() oauth2_scheme = OAuth2PasswordBearer( tokenUrl="/api/v1/auth/token", auto_error=False ) async def get_current_user( token: Annotated[str, Depends(oauth2_scheme)], db: Annotated[AsyncSession, Depends(get_db)], auth_service: Annotated[AuthService, Depends(get_auth_service)], redis_client: Annotated[redis.Redis, Depends(get_redis)] ) -> Optional[User]: if not token: return None is_blacklisted = await redis_client.get(f"blacklist:{token}") if is_blacklisted: return None payload = auth_service.decode_access_token(token) if not payload: return None user_email = payload.get("sub") if user_email is None: return None user = await get_user_by_email(user_email, db) return user async def get_current_user_required( user: Annotated[Optional[User], Depends(get_current_user)], ) -> User: if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized", headers={"WWW-Authenticate": "Bearer"}, ) return user async def get_current_user_optional( user: Annotated[Optional[User], Depends(get_current_user)], ) -> Optional[User]: return user