|
|
""" |
|
|
Authentication and authorization dependencies for Silver Table Assistant. |
|
|
Provides JWT verification and user role management using Supabase Auth. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import jwt |
|
|
from typing import Optional, Dict, Any, Callable |
|
|
from fastapi import Depends, HTTPException, status, Request |
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
|
|
from supabase import create_client, Client |
|
|
from sqlalchemy.ext.asyncio import AsyncSession |
|
|
|
|
|
from database import get_session |
|
|
from models import Profile |
|
|
|
|
|
|
|
|
|
|
|
SUPABASE_URL = os.getenv("SUPABASE_URL") |
|
|
SUPABASE_SERVICE_ROLE_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY") |
|
|
|
|
|
if not SUPABASE_URL or not SUPABASE_SERVICE_ROLE_KEY: |
|
|
raise ValueError("SUPABASE_URL and SUPABASE_SERVICE_ROLE_KEY environment variables are required") |
|
|
|
|
|
supabase: Client = create_client(SUPABASE_URL, SUPABASE_SERVICE_ROLE_KEY) |
|
|
|
|
|
|
|
|
security = HTTPBearer() |
|
|
|
|
|
|
|
|
class AuthenticationError(HTTPException): |
|
|
"""Custom exception for authentication failures.""" |
|
|
|
|
|
def __init__(self, detail: str = "Could not validate credentials"): |
|
|
super().__init__( |
|
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
|
detail=detail, |
|
|
headers={"WWW-Authenticate": "Bearer"}, |
|
|
) |
|
|
|
|
|
|
|
|
class AuthorizationError(HTTPException): |
|
|
"""Custom exception for authorization failures.""" |
|
|
|
|
|
def __init__(self, detail: str = "Not enough permissions"): |
|
|
super().__init__( |
|
|
status_code=status.HTTP_403_FORBIDDEN, |
|
|
detail=detail, |
|
|
) |
|
|
|
|
|
|
|
|
class User: |
|
|
"""User class representing an authenticated user.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
user_id: str, |
|
|
email: Optional[str] = None, |
|
|
role: Optional[str] = None, |
|
|
metadata: Optional[Dict[str, Any]] = None, |
|
|
raw_user: Optional[Dict[str, Any]] = None |
|
|
): |
|
|
self.user_id = user_id |
|
|
self.email = email |
|
|
self.role = role or "family" |
|
|
self.metadata = metadata or {} |
|
|
self.raw_user = raw_user or {} |
|
|
|
|
|
|
|
|
async def verify_jwt_token(token: str) -> Dict[str, Any]: |
|
|
""" |
|
|
Verify JWT token using Supabase Auth. |
|
|
|
|
|
Args: |
|
|
token: JWT token to verify |
|
|
|
|
|
Returns: |
|
|
Decoded user information |
|
|
|
|
|
Raises: |
|
|
AuthenticationError: If token is invalid |
|
|
""" |
|
|
try: |
|
|
|
|
|
response = supabase.auth.get_user(token) |
|
|
|
|
|
if response.user is None: |
|
|
raise AuthenticationError("Invalid token") |
|
|
|
|
|
return { |
|
|
"user_id": response.user.id, |
|
|
"email": response.user.email, |
|
|
"role": response.user.user_metadata.get("role", "user"), |
|
|
"metadata": response.user.user_metadata, |
|
|
"raw_user": response.user.__dict__ |
|
|
} |
|
|
except Exception as e: |
|
|
raise AuthenticationError(f"Token verification failed: {str(e)}") |
|
|
|
|
|
|
|
|
async def get_current_user( |
|
|
credentials: HTTPAuthorizationCredentials = Depends(security) |
|
|
) -> User: |
|
|
""" |
|
|
Get current authenticated user from JWT token. |
|
|
|
|
|
Args: |
|
|
credentials: HTTP authorization credentials |
|
|
|
|
|
Returns: |
|
|
User object with authentication information |
|
|
|
|
|
Raises: |
|
|
AuthenticationError: If authentication fails |
|
|
""" |
|
|
if not credentials: |
|
|
raise AuthenticationError("No credentials provided") |
|
|
|
|
|
try: |
|
|
|
|
|
token = credentials.credentials |
|
|
|
|
|
|
|
|
user_data = await verify_jwt_token(token) |
|
|
|
|
|
return User( |
|
|
user_id=user_data["user_id"], |
|
|
email=user_data["email"], |
|
|
role=user_data["role"], |
|
|
metadata=user_data["metadata"], |
|
|
raw_user=user_data["raw_user"] |
|
|
) |
|
|
except AuthenticationError: |
|
|
raise |
|
|
except Exception as e: |
|
|
raise AuthenticationError(f"Authentication failed: {str(e)}") |
|
|
|
|
|
|
|
|
async def get_optional_user( |
|
|
request: Request |
|
|
) -> Optional[User]: |
|
|
""" |
|
|
Get current user if authenticated, otherwise return None. |
|
|
|
|
|
Args: |
|
|
request: FastAPI request object |
|
|
|
|
|
Returns: |
|
|
User object or None |
|
|
""" |
|
|
try: |
|
|
|
|
|
auth_header = request.headers.get("Authorization") |
|
|
if not auth_header or not auth_header.startswith("Bearer "): |
|
|
return None |
|
|
|
|
|
token = auth_header.split(" ", 1)[1] |
|
|
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token) |
|
|
|
|
|
return await get_current_user(credentials) |
|
|
except AuthenticationError: |
|
|
return None |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
|
|
|
def require_role(required_role: str) -> Callable: |
|
|
""" |
|
|
Dependency factory to require specific user role. |
|
|
|
|
|
Args: |
|
|
required_role: Required user role |
|
|
|
|
|
Returns: |
|
|
Dependency function |
|
|
""" |
|
|
async def role_checker(user: User = Depends(get_current_user)) -> User: |
|
|
if user.role != required_role and user.role != "admin": |
|
|
raise AuthorizationError( |
|
|
f"Required role: {required_role}, your role: {user.role}" |
|
|
) |
|
|
return user |
|
|
return role_checker |
|
|
|
|
|
|
|
|
def require_roles(allowed_roles: list[str]) -> Callable: |
|
|
""" |
|
|
Dependency factory to require one of multiple user roles. |
|
|
|
|
|
Args: |
|
|
allowed_roles: List of allowed user roles |
|
|
|
|
|
Returns: |
|
|
Dependency function |
|
|
""" |
|
|
async def roles_checker(user: User = Depends(get_current_user)) -> User: |
|
|
if user.role not in allowed_roles and user.role != "admin": |
|
|
raise AuthorizationError( |
|
|
f"Required roles: {allowed_roles}, your role: {user.role}" |
|
|
) |
|
|
return user |
|
|
return roles_checker |
|
|
|
|
|
|
|
|
async def get_user_profile( |
|
|
user: User = Depends(get_current_user), |
|
|
db: AsyncSession = Depends(get_session) |
|
|
) -> Optional[Profile]: |
|
|
""" |
|
|
Get user's first profile from database (legacy/helper). |
|
|
""" |
|
|
try: |
|
|
from sqlmodel import select |
|
|
from uuid import UUID |
|
|
user_id_uuid = UUID(user.user_id) |
|
|
result = await db.execute( |
|
|
select(Profile).where(Profile.user_id == user_id_uuid) |
|
|
) |
|
|
profile = result.scalars().first() |
|
|
return profile |
|
|
except Exception as e: |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
|
detail=f"Failed to fetch user profile: {str(e)}" |
|
|
) |
|
|
|
|
|
|
|
|
async def get_or_create_user_profile( |
|
|
user: User = Depends(get_current_user), |
|
|
db: AsyncSession = Depends(get_session) |
|
|
) -> Profile: |
|
|
""" |
|
|
Get existing user profile or create a generic one. |
|
|
""" |
|
|
profile = await get_user_profile(user, db) |
|
|
|
|
|
if profile is None: |
|
|
from models import Profile |
|
|
from uuid import UUID, uuid4 |
|
|
|
|
|
user_id_uuid = UUID(user.user_id) |
|
|
|
|
|
|
|
|
profile = Profile( |
|
|
id=uuid4(), |
|
|
user_id=user_id_uuid, |
|
|
name=user.email.split('@')[0] if user.email else "User", |
|
|
age=70, |
|
|
gender="male", |
|
|
height=165.0, |
|
|
weight=60.0, |
|
|
chronic_diseases=[], |
|
|
dietary_restrictions=[], |
|
|
chewing_ability="normal" |
|
|
) |
|
|
|
|
|
db.add(profile) |
|
|
await db.commit() |
|
|
await db.refresh(profile) |
|
|
|
|
|
return profile |
|
|
|
|
|
|
|
|
def get_user_metadata(user: User = Depends(get_current_user)) -> Dict[str, Any]: |
|
|
""" |
|
|
Get user metadata for use in AI prompts and recommendations. |
|
|
|
|
|
Args: |
|
|
user: Authenticated user |
|
|
|
|
|
Returns: |
|
|
User metadata dictionary |
|
|
""" |
|
|
return { |
|
|
"user_id": user.user_id, |
|
|
"email": user.email, |
|
|
"role": user.role, |
|
|
**user.metadata |
|
|
} |
|
|
|
|
|
|
|
|
def get_admin_user( |
|
|
admin_only: User = Depends(require_role("admin")) |
|
|
) -> User: |
|
|
""" |
|
|
Dependency for admin-only endpoints. |
|
|
|
|
|
Args: |
|
|
admin_only: User with admin role |
|
|
|
|
|
Returns: |
|
|
Admin user |
|
|
""" |
|
|
return admin_only |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_admin(user: User) -> bool: |
|
|
"""Check if user is an admin.""" |
|
|
return user.role == "admin" |
|
|
|
|
|
|
|
|
def is_staff(user: User) -> bool: |
|
|
"""Check if user is staff or admin.""" |
|
|
return user.role in ["staff", "admin"] |
|
|
|
|
|
|
|
|
def can_access_health_data(user: User) -> bool: |
|
|
"""Check if user can access health-related data.""" |
|
|
return user.role in ["admin", "staff", "user"] |
|
|
|
|
|
|
|
|
def can_manage_orders(user: User) -> bool: |
|
|
"""Check if user can manage orders.""" |
|
|
return user.role in ["admin", "staff"] |
|
|
|
|
|
|
|
|
def can_view_analytics(user: User) -> bool: |
|
|
"""Check if user can view analytics.""" |
|
|
return user.role == "admin" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_allowed_origins() -> list[str]: |
|
|
"""Get allowed CORS origins from environment.""" |
|
|
from config import settings |
|
|
return settings.cors_origins |
|
|
|
|
|
|
|
|
def get_api_version() -> str: |
|
|
"""Get API version from environment or default.""" |
|
|
from config import settings |
|
|
return settings.api_version |