|
|
from fastapi import HTTPException, status, Depends, Request |
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
|
|
from jose import JWTError, jwt |
|
|
from datetime import datetime, timedelta |
|
|
from typing import Optional |
|
|
from supabase import create_client, Client |
|
|
|
|
|
from app.config.settings import settings |
|
|
from app.db.models import User, UserSession |
|
|
|
|
|
|
|
|
class AuthMiddleware: |
|
|
"""JWT Authentication middleware with Supabase integration.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.security = HTTPBearer() |
|
|
self.supabase = create_client( |
|
|
settings.SUPABASE_URL, |
|
|
settings.SUPABASE_SERVICE_KEY |
|
|
) if settings.SUPABASE_URL and settings.SUPABASE_SERVICE_KEY else None |
|
|
|
|
|
def create_access_token(self, data: dict, expires_delta: Optional[timedelta] = None) -> str: |
|
|
"""Create JWT access token.""" |
|
|
to_encode = data.copy() |
|
|
if expires_delta: |
|
|
expire = datetime.utcnow() + expires_delta |
|
|
else: |
|
|
expire = datetime.utcnow() + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES) |
|
|
|
|
|
to_encode.update({"exp": expire, "type": "access"}) |
|
|
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM) |
|
|
return encoded_jwt |
|
|
|
|
|
def create_refresh_token(self, data: dict, expires_delta: Optional[timedelta] = None) -> str: |
|
|
"""Create JWT refresh token.""" |
|
|
to_encode = data.copy() |
|
|
if expires_delta: |
|
|
expire = datetime.utcnow() + expires_delta |
|
|
else: |
|
|
expire = datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS) |
|
|
|
|
|
to_encode.update({"exp": expire, "type": "refresh"}) |
|
|
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM) |
|
|
return encoded_jwt |
|
|
|
|
|
def verify_token(self, token: str) -> dict: |
|
|
"""Verify JWT token.""" |
|
|
try: |
|
|
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]) |
|
|
return payload |
|
|
except JWTError: |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
|
detail="Could not validate credentials", |
|
|
headers={"WWW-Authenticate": "Bearer"}, |
|
|
) |
|
|
|
|
|
async def get_user_from_supabase(self, user_id: str) -> Optional[User]: |
|
|
"""Get user from Supabase.""" |
|
|
if not self.supabase: |
|
|
return None |
|
|
|
|
|
try: |
|
|
response = self.supabase.table("users").select("*").eq("id", user_id).execute() |
|
|
if response.data: |
|
|
user_data = response.data[0] |
|
|
return User(**user_data) |
|
|
return None |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
async def get_current_user(self, credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer())) -> User: |
|
|
"""Get current authenticated user.""" |
|
|
token = credentials.credentials |
|
|
payload = self.verify_token(token) |
|
|
|
|
|
user_id = payload.get("sub") |
|
|
if user_id is None: |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
|
detail="Could not validate credentials", |
|
|
headers={"WWW-Authenticate": "Bearer"}, |
|
|
) |
|
|
|
|
|
|
|
|
user = await self.get_user_from_supabase(user_id) |
|
|
if user is None: |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
|
detail="User not found", |
|
|
headers={"WWW-Authenticate": "Bearer"}, |
|
|
) |
|
|
|
|
|
if not user.is_active: |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_400_BAD_REQUEST, |
|
|
detail="Inactive user" |
|
|
) |
|
|
|
|
|
return user |
|
|
|
|
|
async def get_current_active_user(self, current_user: User = Depends(lambda: auth_middleware.get_current_user)) -> User: |
|
|
"""Get current active user.""" |
|
|
if not current_user.is_active: |
|
|
raise HTTPException(status_code=400, detail="Inactive user") |
|
|
return current_user |
|
|
|
|
|
async def get_admin_user(self, current_user: User = Depends(lambda: auth_middleware.get_current_user)) -> User: |
|
|
"""Get current user if they have admin role.""" |
|
|
if current_user.role != "admin": |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_403_FORBIDDEN, |
|
|
detail="Not enough permissions" |
|
|
) |
|
|
return current_user |
|
|
|
|
|
|
|
|
|
|
|
auth_middleware = AuthMiddleware() |
|
|
|
|
|
|
|
|
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer())) -> User: |
|
|
"""FastAPI dependency to get current user.""" |
|
|
return await auth_middleware.get_current_user(credentials) |
|
|
|
|
|
async def get_current_active_user(current_user: User = Depends(get_current_user)) -> User: |
|
|
"""FastAPI dependency to get current active user.""" |
|
|
return await auth_middleware.get_current_active_user(current_user) |
|
|
|
|
|
async def get_admin_user(current_user: User = Depends(get_current_user)) -> User: |
|
|
"""FastAPI dependency to get admin user.""" |
|
|
return await auth_middleware.get_admin_user(current_user) |
|
|
|
|
|
def get_user_ip(request: Request) -> str: |
|
|
"""Extract user IP address from request.""" |
|
|
forwarded = request.headers.get("X-Forwarded-For") |
|
|
if forwarded: |
|
|
return forwarded.split(",")[0].strip() |
|
|
return request.client.host if request.client else "unknown" |
|
|
|