File size: 1,855 Bytes
edda8af
 
 
 
 
 
 
 
 
 
83caf02
c22ac93
edda8af
 
 
 
 
 
83caf02
edda8af
 
 
 
 
 
83caf02
c22ac93
edda8af
 
 
83caf02
 
 
edda8af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from functools import lru_cache
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from src.api.services.auth_service import AuthService
from src.api.models.user import User
from src.api.dependencies.clients import get_db
from src.api.selectors.user.get_user import get_user_by_email
from typing import Optional, Annotated
import redis.asyncio as redis
from src.api.dependencies.clients import get_redis

@lru_cache()
def get_auth_service() -> AuthService:
    return AuthService()

oauth2_scheme = OAuth2PasswordBearer(
    tokenUrl="/api/v1/auth/token", 
    auto_error=False
)

async def get_current_user(
    token: Annotated[str, Depends(oauth2_scheme)], 
    db: Annotated[AsyncSession, Depends(get_db)],
    auth_service: Annotated[AuthService, Depends(get_auth_service)],
    redis_client: Annotated[redis.Redis, Depends(get_redis)]
) -> Optional[User]:
    if not token:
        return None
    is_blacklisted = await redis_client.get(f"blacklist:{token}")
    if is_blacklisted:
        return None
    payload = auth_service.decode_access_token(token)
    if not payload:
        return None
    user_email = payload.get("sub")
    if user_email is None:
        return None
    user = await get_user_by_email(user_email, db)
    return user
    
async def get_current_user_required(
    user: Annotated[Optional[User], Depends(get_current_user)],
) -> User:
    if not user:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Unauthorized",
            headers={"WWW-Authenticate": "Bearer"},
        )
    return user

async def get_current_user_optional(
    user: Annotated[Optional[User], Depends(get_current_user)],
) -> Optional[User]:
    return user