SanadLLM / app /middleware /__init__.py
Hydra-Bolt
fixed
92daf66
from fastapi import HTTPException, status, Depends, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jose import JWTError, jwt
from datetime import datetime, timedelta
from typing import Optional
from supabase import create_client, Client
from app.config.settings import settings
from app.db.models import User, UserSession
class AuthMiddleware:
"""JWT Authentication middleware with Supabase integration."""
def __init__(self):
self.security = HTTPBearer()
self.supabase = create_client(
settings.SUPABASE_URL,
settings.SUPABASE_SERVICE_KEY
) if settings.SUPABASE_URL and settings.SUPABASE_SERVICE_KEY else None
def create_access_token(self, data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create JWT access token."""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire, "type": "access"})
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
return encoded_jwt
def create_refresh_token(self, data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""Create JWT refresh token."""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
to_encode.update({"exp": expire, "type": "refresh"})
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
return encoded_jwt
def verify_token(self, token: str) -> dict:
"""Verify JWT token."""
try:
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
return payload
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
async def get_user_from_supabase(self, user_id: str) -> Optional[User]:
"""Get user from Supabase."""
if not self.supabase:
return None
try:
response = self.supabase.table("users").select("*").eq("id", user_id).execute()
if response.data:
user_data = response.data[0]
return User(**user_data)
return None
except Exception:
return None
async def get_current_user(self, credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer())) -> User:
"""Get current authenticated user."""
token = credentials.credentials
payload = self.verify_token(token)
user_id = payload.get("sub")
if user_id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
# Get user from Supabase
user = await self.get_user_from_supabase(user_id)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found",
headers={"WWW-Authenticate": "Bearer"},
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Inactive user"
)
return user
async def get_current_active_user(self, current_user: User = Depends(lambda: auth_middleware.get_current_user)) -> User:
"""Get current active user."""
if not current_user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
async def get_admin_user(self, current_user: User = Depends(lambda: auth_middleware.get_current_user)) -> User:
"""Get current user if they have admin role."""
if current_user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions"
)
return current_user
# Create global instance
auth_middleware = AuthMiddleware()
# Dependency functions for FastAPI
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer())) -> User:
"""FastAPI dependency to get current user."""
return await auth_middleware.get_current_user(credentials)
async def get_current_active_user(current_user: User = Depends(get_current_user)) -> User:
"""FastAPI dependency to get current active user."""
return await auth_middleware.get_current_active_user(current_user)
async def get_admin_user(current_user: User = Depends(get_current_user)) -> User:
"""FastAPI dependency to get admin user."""
return await auth_middleware.get_admin_user(current_user)
def get_user_ip(request: Request) -> str:
"""Extract user IP address from request."""
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else "unknown"